websocket refactoring
[src/app-framework-binder.git] / src / websock.c
1 /*
2  * Copyright 2016 iot.bzh
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * This work is a far adaptation of apache-websocket:
19  *   origin:  https://github.com/disconnect/apache-websocket
20  *   commit:  cfaef071223f11ba016bff7e1e4b7c9e5df45b50
21  *   Copyright 2010-2012 self.disconnect (APACHE-2)
22  */
23
24 #include <stdlib.h>
25 #include <stdint.h>
26 #include <errno.h>
27 #include <string.h>
28 #include <sys/uio.h>
29
30 #include "websock.h"
31
32 #define BLOCK_DATA_SIZE              4096
33
34 #define FRAME_GET_FIN(BYTE)         (((BYTE) >> 7) & 0x01)
35 #define FRAME_GET_RSV1(BYTE)        (((BYTE) >> 6) & 0x01)
36 #define FRAME_GET_RSV2(BYTE)        (((BYTE) >> 5) & 0x01)
37 #define FRAME_GET_RSV3(BYTE)        (((BYTE) >> 4) & 0x01)
38 #define FRAME_GET_OPCODE(BYTE)      ( (BYTE)       & 0x0F)
39 #define FRAME_GET_MASK(BYTE)        (((BYTE) >> 7) & 0x01)
40 #define FRAME_GET_PAYLOAD_LEN(BYTE) ( (BYTE)       & 0x7F)
41
42 #define FRAME_SET_FIN(BYTE)         (((BYTE) & 0x01) << 7)
43 #define FRAME_SET_RSV1(BYTE)        (((BYTE) & 0x01) << 6)
44 #define FRAME_SET_RSV2(BYTE)        (((BYTE) & 0x01) << 5)
45 #define FRAME_SET_RSV3(BYTE)        (((BYTE) & 0x01) << 4)
46 #define FRAME_SET_OPCODE(BYTE)      ((BYTE) & 0x0F)
47 #define FRAME_SET_MASK(BYTE)        (((BYTE) & 0x01) << 7)
48 #define FRAME_SET_LENGTH(X64, IDX)  (unsigned char)(((X64) >> ((IDX)*8)) & 0xFF)
49
50 #define OPCODE_CONTINUATION 0x0
51 #define OPCODE_TEXT         0x1
52 #define OPCODE_BINARY       0x2
53 #define OPCODE_CLOSE        0x8
54 #define OPCODE_PING         0x9
55 #define OPCODE_PONG         0xA
56
57 #define STATE_INIT    0
58 #define STATE_START   1
59 #define STATE_LENGTH  2
60 #define STATE_DATA    3
61
62 struct websock {
63         int state;
64         uint64_t maxlength;
65         int lenhead, szhead;
66         uint64_t length;
67         uint32_t mask;
68         unsigned char header[14];       /* 2 + 8 + 4 */
69         const struct websock_itf *itf;
70         void *closure;
71 };
72
73 static ssize_t ws_writev(struct websock *ws, const struct iovec *iov, int iovcnt)
74 {
75         return ws->itf->writev(ws->closure, iov, iovcnt);
76 }
77
78 static ssize_t ws_readv(struct websock *ws, const struct iovec *iov, int iovcnt)
79 {
80         return ws->itf->readv(ws->closure, iov, iovcnt);
81 }
82
83 #if 0
84 static ssize_t ws_write(struct websock *ws, const void *buffer, size_t buffer_size)
85 {
86         struct iovec iov;
87         iov.iov_base = (void *)buffer;  /* const cast */
88         iov.iov_len = buffer_size;
89         return ws_writev(ws, &iov, 1);
90 }
91 #endif
92
93 static ssize_t ws_read(struct websock *ws, void *buffer, size_t buffer_size)
94 {
95         struct iovec iov;
96         iov.iov_base = buffer;
97         iov.iov_len = buffer_size;
98         return ws_readv(ws, &iov, 1);
99 }
100
101 static int websock_send_internal(struct websock *ws, unsigned char first, const void *buffer, size_t size)
102 {
103         struct iovec iov[2];
104         size_t pos;
105         ssize_t rc;
106         unsigned char header[32];
107
108         pos = 0;
109         header[pos++] = first;
110         size = (uint64_t) size;
111         if (size < 126) {
112                 header[pos++] = FRAME_SET_MASK(0) | FRAME_SET_LENGTH(size, 0);
113         } else {
114                 if (size < 65536) {
115                         header[pos++] = FRAME_SET_MASK(0) | 126;
116                 } else {
117                         header[pos++] = FRAME_SET_MASK(0) | 127;
118                         header[pos++] = FRAME_SET_LENGTH(size, 7);
119                         header[pos++] = FRAME_SET_LENGTH(size, 6);
120                         header[pos++] = FRAME_SET_LENGTH(size, 5);
121                         header[pos++] = FRAME_SET_LENGTH(size, 4);
122                         header[pos++] = FRAME_SET_LENGTH(size, 3);
123                         header[pos++] = FRAME_SET_LENGTH(size, 2);
124                 }
125                 header[pos++] = FRAME_SET_LENGTH(size, 1);
126                 header[pos++] = FRAME_SET_LENGTH(size, 0);
127         }
128
129         iov[0].iov_base = header;
130         iov[0].iov_len = pos;
131         iov[1].iov_base = (void *)buffer;       /* const cast */
132         iov[1].iov_len = size;
133
134         rc = ws_writev(ws, iov, 1 + !!size);
135
136         return rc < 0 ? -1 : 0;
137 }
138
139 static inline int websock_send(struct websock *ws, int last, int rsv1, int rsv2, int rsv3, int opcode, const void *buffer, size_t size)
140 {
141         unsigned char first = (unsigned char)(FRAME_SET_FIN(last)
142                                 | FRAME_SET_RSV1(rsv1)
143                                 | FRAME_SET_RSV1(rsv2)
144                                 | FRAME_SET_RSV1(rsv3)
145                                 | FRAME_SET_OPCODE(opcode));
146         return websock_send_internal(ws, first, buffer, size);
147 }
148
149 int websock_close_empty(struct websock *ws)
150 {
151         return websock_close(ws, WEBSOCKET_CODE_NOT_SET, NULL, 0);
152 }
153
154 int websock_close(struct websock *ws, uint16_t code, const void *data, size_t length)
155 {
156         unsigned char buffer[125];
157
158         if (code == WEBSOCKET_CODE_NOT_SET && length == 0)
159                 return websock_send(ws, 1, 0, 0, 0, OPCODE_CLOSE, NULL, 0);
160
161         /* checks the length */
162         if (length > 123) {
163                 errno = EINVAL;
164                 return -1;
165         }
166
167         /* prepare the buffer */
168         buffer[0] = (unsigned char)((code >> 8) & 0xFF);
169         buffer[1] = (unsigned char)(code & 0xFF);
170         if (length != 0)
171                 memcpy(&buffer[2], data, length);
172
173         /* Send server-side closing handshake */
174         return websock_send(ws, 1, 0, 0, 0, OPCODE_CLOSE, buffer, 2 + length);
175 }
176
177 int websock_ping(struct websock *ws, const void *data, size_t length)
178 {
179         /* checks the length */
180         if (length > 125) {
181                 errno = EINVAL;
182                 return -1;
183         }
184
185         return websock_send(ws, 1, 0, 0, 0, OPCODE_PING, data, length);
186 }
187
188 int websock_pong(struct websock *ws, const void *data, size_t length)
189 {
190         /* checks the length */
191         if (length > 125) {
192                 errno = EINVAL;
193                 return -1;
194         }
195
196         return websock_send(ws, 1, 0, 0, 0, OPCODE_PONG, data, length);
197 }
198
199 int websock_text(struct websock *ws, int last, const char *text, size_t length)
200 {
201         return websock_send(ws, last, 0, 0, 0, OPCODE_TEXT, text, length);
202 }
203
204 int websock_binary(struct websock *ws, int last, const void *data, size_t length)
205 {
206         return websock_send(ws, last, 0, 0, 0, OPCODE_BINARY, data, length);
207 }
208
209 int websock_error(struct websock *ws, uint16_t code, const void *data, size_t size)
210 {
211         int rc = websock_close(ws, code, data, size);
212         if (ws->itf->on_error != NULL)
213                 ws->itf->on_error(ws->closure, code, data, size);
214         return rc;
215 }
216
217 static int read_header(struct websock *ws)
218 {
219         if (ws->lenhead < ws->szhead) {
220                 ssize_t rbc =
221                     ws_read(ws, &ws->header[ws->lenhead], (size_t)(ws->szhead - ws->lenhead));
222                 if (rbc < 0)
223                         return -1;
224                 ws->lenhead += (int)rbc;
225         }
226         return 0;
227 }
228
229 static int check_control_header(struct websock *ws)
230 {
231         /* sanity checks */
232         if (FRAME_GET_RSV1(ws->header[0]) != 0)
233                 return 0;
234         if (FRAME_GET_RSV2(ws->header[0]) != 0)
235                 return 0;
236         if (FRAME_GET_RSV3(ws->header[0]) != 0)
237                 return 0;
238         if (FRAME_GET_PAYLOAD_LEN(ws->header[1]) > 125)
239                 return 0;
240         if (FRAME_GET_OPCODE(ws->header[0]) == OPCODE_CLOSE)
241                 return FRAME_GET_PAYLOAD_LEN(ws->header[1]) != 1;
242         return 1;
243 }
244
245 int websock_dispatch(struct websock *ws)
246 {
247         uint16_t code;
248 loop:
249         switch (ws->state) {
250         case STATE_INIT:
251                 ws->lenhead = 0;
252                 ws->szhead = 2;
253                 ws->state = STATE_START;
254
255         case STATE_START:
256                 /* read the header */
257                 if (read_header(ws))
258                         return -1;
259                 else if (ws->lenhead < ws->szhead)
260                         return 0;
261                 /* fast track */
262                 switch (FRAME_GET_OPCODE(ws->header[0])) {
263                 case OPCODE_CONTINUATION:
264                 case OPCODE_TEXT:
265                 case OPCODE_BINARY:
266                         break;
267                 case OPCODE_CLOSE:
268                         if (!check_control_header(ws))
269                                 goto protocol_error;
270                         if (FRAME_GET_PAYLOAD_LEN(ws->header[1]))
271                                 ws->szhead += 2;
272                         break;
273                 case OPCODE_PING:
274                 case OPCODE_PONG:
275                         if (!check_control_header(ws))
276                                 goto protocol_error;
277                 default:
278                         break;
279                 }
280                 /* update heading size */
281                 switch (FRAME_GET_PAYLOAD_LEN(ws->header[1])) {
282                 case 127:
283                         ws->szhead += 6;
284                 case 126:
285                         ws->szhead += 2;
286                 default:
287                         ws->szhead += 4 * FRAME_GET_MASK(ws->header[1]);
288                 }
289                 ws->state = STATE_LENGTH;
290
291         case STATE_LENGTH:
292                 /* continue to read the header */
293                 if (read_header(ws))
294                         return -1;
295                 else if (ws->lenhead < ws->szhead)
296                         return 0;
297
298                 /* compute length */
299                 switch (FRAME_GET_PAYLOAD_LEN(ws->header[1])) {
300                 case 127:
301                         ws->length = (((uint64_t) ws->header[2]) << 56)
302                             | (((uint64_t) ws->header[3]) << 48)
303                             | (((uint64_t) ws->header[4]) << 40)
304                             | (((uint64_t) ws->header[5]) << 32)
305                             | (((uint64_t) ws->header[6]) << 24)
306                             | (((uint64_t) ws->header[7]) << 16)
307                             | (((uint64_t) ws->header[8]) << 8)
308                             | (uint64_t) ws->header[9];
309                         break;
310                 case 126:
311                         ws->length = (((uint64_t) ws->header[2]) << 8)
312                             | (uint64_t) ws->header[3];
313                         break;
314                 default:
315                         ws->length = FRAME_GET_PAYLOAD_LEN(ws->header[1]);
316                         break;
317                 }
318                 if (FRAME_GET_OPCODE(ws->header[0]) == OPCODE_CLOSE && ws->length != 0)
319                         ws->length -= 2;
320                 if (ws->length > ws->maxlength)
321                         goto too_long_error;
322
323                 /* compute mask */
324                 if (FRAME_GET_MASK(ws->header[1])) {
325                         ((unsigned char *)&ws->mask)[0] = ws->header[ws->szhead - 4];
326                         ((unsigned char *)&ws->mask)[1] = ws->header[ws->szhead - 3];
327                         ((unsigned char *)&ws->mask)[2] = ws->header[ws->szhead - 2];
328                         ((unsigned char *)&ws->mask)[3] = ws->header[ws->szhead - 1];
329                 } else
330                         ws->mask = 0;
331
332                 /* all heading fields are known, process */
333                 ws->state = STATE_DATA;
334                 if (ws->itf->on_extension != NULL) {
335                         if (ws->itf->on_extension(ws->closure,
336                                         FRAME_GET_FIN(ws->header[0]),
337                                         FRAME_GET_RSV1(ws->header[0]),
338                                         FRAME_GET_RSV2(ws->header[0]),
339                                         FRAME_GET_RSV3(ws->header[0]),
340                                         FRAME_GET_OPCODE(ws->header[0]),
341                                         (size_t) ws->length)) {
342                                 return 0;
343                         }
344                 }
345
346                 /* not an extension case */
347                 if (FRAME_GET_RSV1(ws->header[0]) != 0)
348                         goto protocol_error;
349                 if (FRAME_GET_RSV2(ws->header[0]) != 0)
350                         goto protocol_error;
351                 if (FRAME_GET_RSV3(ws->header[0]) != 0)
352                         goto protocol_error;
353
354                 /* handle */
355                 switch (FRAME_GET_OPCODE(ws->header[0])) {
356                 case OPCODE_CONTINUATION:
357                         ws->itf->on_continue(ws->closure,
358                                              FRAME_GET_FIN(ws->header[0]),
359                                              (size_t) ws->length);
360                         break;
361                 case OPCODE_TEXT:
362                         ws->itf->on_text(ws->closure,
363                                          FRAME_GET_FIN(ws->header[0]),
364                                          (size_t) ws->length);
365                         break;
366                 case OPCODE_BINARY:
367                         ws->itf->on_binary(ws->closure,
368                                            FRAME_GET_FIN(ws->header[0]),
369                                            (size_t) ws->length);
370                         break;
371                 case OPCODE_CLOSE:
372                         if (ws->length == 0)
373                                 code = WEBSOCKET_CODE_NOT_SET;
374                         else {
375                                 code = (uint16_t)(ws->header[ws->szhead - 2] & 0xff);
376                                 code = (uint16_t)(code << 8);
377                                 code = (uint16_t)(code | (uint16_t)(ws->header[ws->szhead - 1] & 0xff));
378                         }
379                         ws->itf->on_close(ws->closure, code, (size_t) ws->length);
380                         return 0;
381                 case OPCODE_PING:
382                         if (ws->itf->on_ping)
383                                 ws->itf->on_ping(ws->closure, ws->length);
384                         else {
385                                 websock_drop(ws);
386                                 websock_pong(ws, NULL, 0);
387                         }
388                         ws->state = STATE_INIT;
389                         break;
390                 case OPCODE_PONG:
391                         if (ws->itf->on_pong)
392                                 ws->itf->on_pong(ws->closure, ws->length);
393                         else
394                                 websock_drop(ws);
395                         ws->state = STATE_INIT;
396                         break;
397                 default:
398                         goto protocol_error;
399                 }
400                 break;
401
402         case STATE_DATA:
403                 if (ws->length)
404                         return 0;
405                 ws->state = STATE_INIT;
406                 break;
407         }
408         goto loop;
409
410  too_long_error:
411         websock_error(ws, WEBSOCKET_CODE_MESSAGE_TOO_LARGE, NULL, 0);
412         return 0;
413
414  protocol_error:
415         websock_error(ws, WEBSOCKET_CODE_PROTOCOL_ERROR, NULL, 0);
416         return 0;
417 }
418
419 ssize_t websock_read(struct websock * ws, void *buffer, size_t size)
420 {
421         uint32_t mask, *b32;
422         uint8_t m, *b8;
423         ssize_t rc;
424
425         if (ws->state != STATE_DATA)
426                 return 0;
427
428         if (size > ws->length)
429                 size = (size_t) ws->length;
430
431         rc = ws_read(ws, buffer, size);
432         if (rc > 0) {
433                 size = (size_t) rc;
434                 ws->length -= size;
435
436                 if (ws->mask) {
437                         mask = ws->mask;
438                         b8 = buffer;
439                         while (size && ((sizeof(uint32_t) - 1) & (uintptr_t) b8)) {
440                                 m = ((uint8_t *) & mask)[0];
441                                 ((uint8_t *) & mask)[0] = ((uint8_t *) & mask)[1];
442                                 ((uint8_t *) & mask)[1] = ((uint8_t *) & mask)[2];
443                                 ((uint8_t *) & mask)[2] = ((uint8_t *) & mask)[3];
444                                 ((uint8_t *) & mask)[3] = m;
445                                 *b8++ ^= m;
446                                 size--;
447                         }
448                         b32 = (uint32_t *) b8;
449                         while (size >= sizeof(uint32_t)) {
450                                 *b32++ ^= mask;
451                                 size -= sizeof(uint32_t);
452                         }
453                         b8 = (uint8_t *) b32;
454                         while (size) {
455                                 m = ((uint8_t *) & mask)[0];
456                                 ((uint8_t *) & mask)[0] = ((uint8_t *) & mask)[1];
457                                 ((uint8_t *) & mask)[1] = ((uint8_t *) & mask)[2];
458                                 ((uint8_t *) & mask)[2] = ((uint8_t *) & mask)[3];
459                                 ((uint8_t *) & mask)[3] = m;
460                                 *b8++ ^= m;
461                                 size--;
462                         }
463                         ws->mask = mask;
464                 }
465         }
466         return rc;
467 }
468
469 int websock_drop(struct websock *ws)
470 {
471         char buffer[8000];
472
473         while (ws->length)
474                 if (ws_read(ws, buffer, sizeof buffer) < 0)
475                         return -1;
476         return 0;
477 }
478
479 struct websock *websock_create_v13(const struct websock_itf *itf, void *closure)
480 {
481         struct websock *result = calloc(1, sizeof *result);
482         if (result) {
483                 result->itf = itf;
484                 result->closure = closure;
485                 result->maxlength = 65000;
486         }
487         return result;
488 }
489
490 void websock_destroy(struct websock *ws)
491 {
492         free(ws);
493 }