improves websockets
[src/app-framework-binder.git] / src / websock.c
1 /*
2  * Copyright 2016 iot.bzh
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * This work is a far adaptation of apache-websocket:
19  *   origin:  https://github.com/disconnect/apache-websocket
20  *   commit:  cfaef071223f11ba016bff7e1e4b7c9e5df45b50
21  *   Copyright 2010-2012 self.disconnect (APACHE-2)
22  */
23
24 #include <stdlib.h>
25 #include <stdint.h>
26 #include <errno.h>
27 #include <string.h>
28 #include <sys/uio.h>
29
30 #include "websock.h"
31
32 #define BLOCK_DATA_SIZE              4096
33
34 #define FRAME_GET_FIN(BYTE)         (((BYTE) >> 7) & 0x01)
35 #define FRAME_GET_RSV1(BYTE)        (((BYTE) >> 6) & 0x01)
36 #define FRAME_GET_RSV2(BYTE)        (((BYTE) >> 5) & 0x01)
37 #define FRAME_GET_RSV3(BYTE)        (((BYTE) >> 4) & 0x01)
38 #define FRAME_GET_OPCODE(BYTE)      ( (BYTE)       & 0x0F)
39 #define FRAME_GET_MASK(BYTE)        (((BYTE) >> 7) & 0x01)
40 #define FRAME_GET_PAYLOAD_LEN(BYTE) ( (BYTE)       & 0x7F)
41
42 #define FRAME_SET_FIN(BYTE)         (((BYTE) & 0x01) << 7)
43 #define FRAME_SET_OPCODE(BYTE)       ((BYTE) & 0x0F)
44 #define FRAME_SET_MASK(BYTE)        (((BYTE) & 0x01) << 7)
45 #define FRAME_SET_LENGTH(X64, IDX)  (unsigned char)(((X64) >> ((IDX)*8)) & 0xFF)
46
47 #define OPCODE_CONTINUATION 0x0
48 #define OPCODE_TEXT         0x1
49 #define OPCODE_BINARY       0x2
50 #define OPCODE_CLOSE        0x8
51 #define OPCODE_PING         0x9
52 #define OPCODE_PONG         0xA
53
54 #define STATE_INIT    0
55 #define STATE_START   1
56 #define STATE_LENGTH  2
57 #define STATE_DATA    3
58 #define STATE_CLOSED  4
59
60 struct websock {
61         int state;
62         uint64_t maxlength;
63         int lenhead, szhead;
64         uint64_t length;
65         uint32_t mask;
66         unsigned char header[14];       /* 2 + 8 + 4 */
67         const struct websock_itf *itf;
68         void *closure;
69 };
70
71 static ssize_t ws_writev(struct websock *ws, const struct iovec *iov, int iovcnt)
72 {
73         return ws->itf->writev(ws->closure, iov, iovcnt);
74 }
75
76 static ssize_t ws_readv(struct websock *ws, const struct iovec *iov, int iovcnt)
77 {
78         return ws->itf->readv(ws->closure, iov, iovcnt);
79 }
80
81 #if 0
82 static ssize_t ws_write(struct websock *ws, const void *buffer, size_t buffer_size)
83 {
84         struct iovec iov;
85         iov.iov_base = (void *)buffer;  /* const cast */
86         iov.iov_len = buffer_size;
87         return ws_writev(ws, &iov, 1);
88 }
89 #endif
90
91 static ssize_t ws_read(struct websock *ws, void *buffer, size_t buffer_size)
92 {
93         struct iovec iov;
94         iov.iov_base = buffer;
95         iov.iov_len = buffer_size;
96         return ws_readv(ws, &iov, 1);
97 }
98
99 static ssize_t websock_send(struct websock *ws, unsigned char opcode,
100                             const void *buffer, size_t buffer_size)
101 {
102         struct iovec iov[2];
103         size_t pos;
104         ssize_t rc;
105         unsigned char header[32];
106
107         if (ws->state == STATE_CLOSED)
108                 return 0;
109
110         pos = 0;
111         header[pos++] = (unsigned char)(FRAME_SET_FIN(1) | FRAME_SET_OPCODE(opcode));
112         buffer_size = (uint64_t) buffer_size;
113         if (buffer_size < 126) {
114                 header[pos++] =
115                     FRAME_SET_MASK(0) | FRAME_SET_LENGTH(buffer_size, 0);
116         } else {
117                 if (buffer_size < 65536) {
118                         header[pos++] = FRAME_SET_MASK(0) | 126;
119                 } else {
120                         header[pos++] = FRAME_SET_MASK(0) | 127;
121                         header[pos++] = FRAME_SET_LENGTH(buffer_size, 7);
122                         header[pos++] = FRAME_SET_LENGTH(buffer_size, 6);
123                         header[pos++] = FRAME_SET_LENGTH(buffer_size, 5);
124                         header[pos++] = FRAME_SET_LENGTH(buffer_size, 4);
125                         header[pos++] = FRAME_SET_LENGTH(buffer_size, 3);
126                         header[pos++] = FRAME_SET_LENGTH(buffer_size, 2);
127                 }
128                 header[pos++] = FRAME_SET_LENGTH(buffer_size, 1);
129                 header[pos++] = FRAME_SET_LENGTH(buffer_size, 0);
130         }
131
132         iov[0].iov_base = header;
133         iov[0].iov_len = pos;
134         iov[1].iov_base = (void *)buffer;       /* const cast */
135         iov[1].iov_len = buffer_size;
136
137         rc = ws_writev(ws, iov, 1 + !!buffer_size);
138
139         if (opcode == OPCODE_CLOSE) {
140                 ws->length = 0;
141                 ws->state = STATE_CLOSED;
142                 ws->itf->disconnect(ws->closure);
143         }
144         return rc;
145 }
146
147 void websock_close(struct websock *ws)
148 {
149         websock_send(ws, OPCODE_CLOSE, NULL, 0);
150 }
151
152 void websock_close_code(struct websock *ws, uint16_t code)
153 {
154         unsigned char buffer[2];
155         /* Send server-side closing handshake */
156         buffer[0] = (unsigned char)((code >> 8) & 0xFF);
157         buffer[1] = (unsigned char)(code & 0xFF);
158         websock_send(ws, OPCODE_CLOSE, buffer, 2);
159 }
160
161 void websock_ping(struct websock *ws)
162 {
163         websock_send(ws, OPCODE_PING, NULL, 0);
164 }
165
166 void websock_pong(struct websock *ws)
167 {
168         websock_send(ws, OPCODE_PONG, NULL, 0);
169 }
170
171 void websock_text(struct websock *ws, const char *text, size_t length)
172 {
173         websock_send(ws, OPCODE_TEXT, text, length);
174 }
175
176 void websock_binary(struct websock *ws, const void *data, size_t length)
177 {
178         websock_send(ws, OPCODE_BINARY, data, length);
179 }
180
181 static int read_header(struct websock *ws)
182 {
183         if (ws->lenhead < ws->szhead) {
184                 ssize_t rbc =
185                     ws_read(ws, &ws->header[ws->lenhead], (size_t)(ws->szhead - ws->lenhead));
186                 if (rbc < 0)
187                         return -1;
188                 ws->lenhead += (int)rbc;
189         }
190         return 0;
191 }
192
193 static int check_control_header(struct websock *ws)
194 {
195         /* sanity checks */
196         if (FRAME_GET_RSV1(ws->header[0]) != 0)
197                 return 0;
198         if (FRAME_GET_RSV2(ws->header[0]) != 0)
199                 return 0;
200         if (FRAME_GET_RSV3(ws->header[0]) != 0)
201                 return 0;
202         if (FRAME_GET_MASK(ws->header[1]))
203                 return 0;
204         if (FRAME_GET_OPCODE(ws->header[0]) == OPCODE_CLOSE)
205                 return FRAME_GET_PAYLOAD_LEN(ws->header[1]) != 1;
206         return FRAME_GET_PAYLOAD_LEN(ws->header[1]) == 0;
207 }
208
209 int websock_dispatch(struct websock *ws)
210 {
211 loop:
212         switch (ws->state) {
213         case STATE_INIT:
214                 ws->lenhead = 0;
215                 ws->szhead = 2;
216                 ws->state = STATE_START;
217
218         case STATE_START:
219                 /* read the header */
220                 if (read_header(ws))
221                         return -1;
222                 else if (ws->lenhead < ws->szhead)
223                         return 0;
224                 /* fast track */
225                 switch (FRAME_GET_OPCODE(ws->header[0])) {
226                 case OPCODE_CONTINUATION:
227                 case OPCODE_TEXT:
228                 case OPCODE_BINARY:
229                         break;
230                 case OPCODE_CLOSE:
231                         if (!check_control_header(ws))
232                                 goto protocol_error;
233                         if (FRAME_GET_PAYLOAD_LEN(ws->header[1]))
234                                 ws->szhead += 2;
235                         break;
236                 case OPCODE_PING:
237                         if (!check_control_header(ws))
238                                 goto protocol_error;
239                         if (ws->itf->on_ping)
240                                 ws->itf->on_ping(ws->closure);
241                         else
242                                 websock_pong(ws);
243                         ws->state = STATE_INIT;
244                         goto loop;
245                 case OPCODE_PONG:
246                         if (!check_control_header(ws))
247                                 goto protocol_error;
248                         if (ws->itf->on_pong)
249                                 ws->itf->on_pong(ws->closure);
250                         ws->state = STATE_INIT;
251                         goto loop;
252                 default:
253                         break;
254                 }
255                 /* update heading size */
256                 switch (FRAME_GET_PAYLOAD_LEN(ws->header[1])) {
257                 case 127:
258                         ws->szhead += 6;
259                 case 126:
260                         ws->szhead += 2;
261                 default:
262                         ws->szhead += 4 * FRAME_GET_MASK(ws->header[1]);
263                 }
264                 ws->state = STATE_LENGTH;
265
266         case STATE_LENGTH:
267                 /* continue to read the header */
268                 if (read_header(ws))
269                         return -1;
270                 else if (ws->lenhead < ws->szhead)
271                         return 0;
272                 /* compute header values */
273                 switch (FRAME_GET_PAYLOAD_LEN(ws->header[1])) {
274                 case 127:
275                         ws->length = (((uint64_t) ws->header[2]) << 56)
276                             | (((uint64_t) ws->header[3]) << 48)
277                             | (((uint64_t) ws->header[4]) << 40)
278                             | (((uint64_t) ws->header[5]) << 32)
279                             | (((uint64_t) ws->header[6]) << 24)
280                             | (((uint64_t) ws->header[7]) << 16)
281                             | (((uint64_t) ws->header[8]) << 8)
282                             | (uint64_t) ws->header[9];
283                         break;
284                 case 126:
285                         ws->length = (((uint64_t) ws->header[2]) << 8)
286                             | (uint64_t) ws->header[3];
287                         break;
288                 default:
289                         ws->length = FRAME_GET_PAYLOAD_LEN(ws->header[1]);
290                         break;
291                 }
292                 if (ws->length > ws->maxlength)
293                         goto too_long_error;
294                 if (FRAME_GET_MASK(ws->header[1])) {
295                         ((unsigned char *)&ws->mask)[0] = ws->header[ws->szhead - 4];
296                         ((unsigned char *)&ws->mask)[1] = ws->header[ws->szhead - 3];
297                         ((unsigned char *)&ws->mask)[2] = ws->header[ws->szhead - 2];
298                         ((unsigned char *)&ws->mask)[3] = ws->header[ws->szhead - 1];
299                 } else
300                         ws->mask = 0;
301
302                 /* all heading fields are known, process */
303                 ws->state = STATE_DATA;
304                 if (ws->itf->on_extension != NULL) {
305                         if (ws->itf->on_extension(ws->closure,
306                                         FRAME_GET_FIN(ws->header[0]),
307                                         FRAME_GET_RSV1(ws->header[0]),
308                                         FRAME_GET_RSV2(ws->header[0]),
309                                         FRAME_GET_RSV3(ws->header[0]),
310                                         FRAME_GET_OPCODE(ws->header[0]),
311                                         (size_t) ws->length)) {
312                                 return 0;
313                         }
314                 }
315
316                 /* not an extension case */
317                 if (FRAME_GET_RSV1(ws->header[0]) != 0)
318                         goto protocol_error;
319                 if (FRAME_GET_RSV2(ws->header[0]) != 0)
320                         goto protocol_error;
321                 if (FRAME_GET_RSV3(ws->header[0]) != 0)
322                         goto protocol_error;
323
324                 /* handle */
325                 switch (FRAME_GET_OPCODE(ws->header[0])) {
326                 case OPCODE_CONTINUATION:
327                         ws->itf->on_continue(ws->closure,
328                                              FRAME_GET_FIN(ws->header[0]),
329                                              (size_t) ws->length);
330                         break;
331                 case OPCODE_TEXT:
332                         ws->itf->on_text(ws->closure,
333                                          FRAME_GET_FIN(ws->header[0]),
334                                          (size_t) ws->length);
335                         break;
336                 case OPCODE_BINARY:
337                         ws->itf->on_binary(ws->closure,
338                                            FRAME_GET_FIN(ws->header[0]),
339                                            (size_t) ws->length);
340                         break;
341                 case OPCODE_CLOSE:
342                         ws->state = STATE_CLOSED;
343                         if (ws->length)
344                                 ws->itf->on_close(ws->closure,
345                                                   (uint16_t)((((uint16_t) ws-> header[2]) << 8) | ((uint16_t) ws->header[3])),
346                                                   (size_t) ws->length);
347                         else
348                                 ws->itf->on_close(ws->closure,
349                                                   WEBSOCKET_CODE_UNSET, 0);
350                         ws->itf->disconnect(ws->closure);
351                         return 0;
352                 default:
353                         goto protocol_error;
354                 }
355                 break;
356
357         case STATE_DATA:
358                 if (ws->length)
359                         return 0;
360                 ws->state = STATE_INIT;
361                 break;
362
363         case STATE_CLOSED:
364                 return 0;
365         }
366         goto loop;
367
368  too_long_error:
369         websock_close_code(ws, WEBSOCKET_CODE_MESSAGE_TOO_LARGE);
370         return 0;
371
372  protocol_error:
373         websock_close_code(ws, WEBSOCKET_CODE_PROTOCOL_ERROR);
374         return 0;
375 }
376
377 ssize_t websock_read(struct websock * ws, void *buffer, size_t size)
378 {
379         uint32_t mask, *b32;
380         uint8_t m, *b8;
381         ssize_t rc;
382
383         if (ws->state != STATE_DATA && ws->state != STATE_CLOSED)
384                 return 0;
385
386         if (size > ws->length)
387                 size = (size_t) ws->length;
388
389         rc = ws_read(ws, buffer, size);
390         if (rc > 0) {
391                 size = (size_t) rc;
392                 ws->length -= size;
393
394                 if (ws->mask) {
395                         mask = ws->mask;
396                         b8 = buffer;
397                         while (size && ((sizeof(uint32_t) - 1) & (uintptr_t) b8)) {
398                                 m = ((uint8_t *) & mask)[0];
399                                 ((uint8_t *) & mask)[0] = ((uint8_t *) & mask)[1];
400                                 ((uint8_t *) & mask)[1] = ((uint8_t *) & mask)[2];
401                                 ((uint8_t *) & mask)[2] = ((uint8_t *) & mask)[3];
402                                 ((uint8_t *) & mask)[3] = m;
403                                 *b8++ ^= m;
404                                 size--;
405                         }
406                         b32 = (uint32_t *) b8;
407                         while (size >= sizeof(uint32_t)) {
408                                 *b32++ ^= mask;
409                                 size -= sizeof(uint32_t);
410                         }
411                         b8 = (uint8_t *) b32;
412                         while (size) {
413                                 m = ((uint8_t *) & mask)[0];
414                                 ((uint8_t *) & mask)[0] = ((uint8_t *) & mask)[1];
415                                 ((uint8_t *) & mask)[1] = ((uint8_t *) & mask)[2];
416                                 ((uint8_t *) & mask)[2] = ((uint8_t *) & mask)[3];
417                                 ((uint8_t *) & mask)[3] = m;
418                                 *b8++ ^= m;
419                                 size--;
420                         }
421                         ws->mask = mask;
422                 }
423         }
424         return rc;
425 }
426
427 void websock_drop(struct websock *ws)
428 {
429         char buffer[8000];
430
431         while (ws->length && ws_read(ws, buffer, sizeof buffer) >= 0) ;
432 }
433
434 struct websock *websock_create_v13(const struct websock_itf *itf, void *closure)
435 {
436         struct websock *result = calloc(1, sizeof *result);
437         if (result) {
438                 result->itf = itf;
439                 result->closure = closure;
440                 result->maxlength = 65000;
441         }
442         return result;
443 }
444
445 void websock_destroy(struct websock *ws)
446 {
447         free(ws);
448 }