#include #include #include #include #include #include void rdb_request_init(RdbRequest *req, int connfd, char *cmd, size_t cmdlen) { req->connfd = connfd; req->cmd = cmd; req->cmdlen = cmdlen; req->blocking = true; rdb_request_reset(req); } void rdb_request_reset(RdbRequest *req) { req->state = read_state_header; req->inbuf.len = 0; req->inbuf.index = 0; req->outlen = 0; req->chk = 0; } void rdb_request_set_blocking(RdbRequest *req, bool blocking) { req->blocking = blocking; } static rdb_read_result_t read_char(RdbRequest *req, char *in) { if (req->inbuf.index >= req->inbuf.len) { int flags = req->blocking ? 0 : MSG_DONTWAIT; ssize_t inlen = recv(req->connfd, req->inbuf.buf, INBUF_LEN, flags); if (inlen == 0) { return read_result_disconnected; } if (inlen < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { return read_result_pending; } else { perror("could not read incoming packet"); return read_result_error; } } req->inbuf.index = 0; req->inbuf.len = inlen; } *in = req->inbuf.buf[req->inbuf.index++]; return read_result_success; } rdb_read_result_t rdb_request_read(RdbRequest *req, size_t *len) { rdb_read_result_t res; char in; switch (req->state) { case read_state_header: // read any acknowledgements and continue do { res = read_char(req, &in); if (res != read_result_success) return res; } while (in == '+'); if (in == '-') { fprintf(stderr, "negative ack not supported\n"); return read_result_error; } if (in == '\x03') { // interrupt from the server req->cmd[0] = in; req->outlen = 1; return read_result_success; } // now, we should be at the start of a packet if (in != '$') { fprintf(stderr, "unexpected packet start \"%c\"\n", in); return read_result_error; } req->state = read_state_body; __attribute__ ((fallthrough)); case read_state_body: case read_state_body_escape: while (1) { res = read_char(req, &in); if (res != read_result_success) return res; if (req->state == read_state_body && in == '#') { // end of packet body break; } req->chk += in; if (req->state == read_state_body && in == '}') { // escape sequence req->state = read_state_body_escape; continue; } if (req->outlen >= req->cmdlen) { // ran out of room in the buffer fprintf(stderr, "packet too big for buffer\n"); return read_result_error; } if (req->state == read_state_body_escape) { req->cmd[req->outlen++] = in ^ 0x20; req->state = read_state_body; } else { req->cmd[req->outlen++] = in; } } req->state = read_state_checksum_1; __attribute__ ((fallthrough)); case read_state_checksum_1: res = read_char(req, &in); if (res != read_result_success) return res; // check the high digit of the checksum char hi; if (!parse_hex_digit(in, &hi)) { fprintf(stderr, "invalid checksum1\n"); return read_result_error; } if (((req->chk >> 4) & 0x000f) != hi) { fprintf(stderr, "invalid checksum2\n"); return read_result_error; } req->state = read_state_checksum_2; __attribute__ ((fallthrough)); case read_state_checksum_2: res = read_char(req, &in); if (res != read_result_success) return res; // check the high digit of the checksum char lo; if (!parse_hex_digit(in, &lo)) { fprintf(stderr, "invalid checksum3 %c\n", in); return read_result_error; } if ((req->chk & 0x000f) != lo) { fprintf(stderr, "invalid checksum4\n"); return read_result_error; } *len = req->outlen; return read_result_success; default: fprintf(stderr, "invalid state\n"); return read_result_error; } }