#include #include #include #include #include #include void rdbRequestInit(RdbRequest *req, int connfd, char *buf, size_t buflen) { req->connfd = connfd; req->outbuf = buf; req->outbuflen = buflen; req->blocking = true; rdbRequestReset(req); } void rdbRequestReset(RdbRequest *req) { req->state = read_state_header; req->inbuf.len = 0; req->inbuf.index = 0; req->chk = 0; } void rdbRequestSetBlocking(RdbRequest *req, bool blocking) { req->blocking = blocking; } static rdb_read_result_t readChar(RdbRequest *req, char *in) { if (req->inbuf.index >= req->inbuf.len) { int flags = req->blocking ? 0 : MSG_DONTWAIT; ssize_t inlen = recv(req->connfd, req->inbuf.buf, INBUF_LEN, flags); if (inlen == 0) { return read_result_disconnected; } if (inlen < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { return read_result_pending; } else { perror("could not read incoming packet"); return read_result_error; } } req->inbuf.index = 0; req->inbuf.len = inlen; } *in = req->inbuf.buf[req->inbuf.index++]; return read_result_success; } rdb_read_result_t rdbRequestRead(RdbRequest *req, CommandBuf *cmd) { rdb_read_result_t res; char in, hi, lo; switch (req->state) { case read_state_header: cmd->buf = req->outbuf; cmd->len = 0; /* read any acknowledgements and continue */ do { res = readChar(req, &in); if (res != read_result_success) return res; } while (in == '+'); if (in == '-') { fprintf(stderr, "negative ack not supported\n"); return read_result_error; } if (in == '\x03') { /* interrupt from the server */ cmd->buf[0] = in; cmd->len = 1; return read_result_success; } /* now, we should be at the start of a packet */ if (in != '$') { fprintf(stderr, "unexpected packet start \"%c\"\n", in); return read_result_error; } req->state = read_state_body; __attribute__ ((fallthrough)); case read_state_body: case read_state_body_escape: while (1) { res = readChar(req, &in); if (res != read_result_success) return res; if (req->state == read_state_body && in == '#') { /* end of packet body */ break; } req->chk += in; if (req->state == read_state_body && in == '}') { /* escape sequence */ req->state = read_state_body_escape; continue; } if (cmd->len >= req->outbuflen) { /* ran out of room in the buffer */ fprintf(stderr, "packet too big for buffer\n"); return read_result_error; } if (req->state == read_state_body_escape) { cmd->buf[cmd->len++] = in ^ 0x20; req->state = read_state_body; } else { cmd->buf[cmd->len++] = in; } } req->state = read_state_checksum_1; __attribute__ ((fallthrough)); case read_state_checksum_1: res = readChar(req, &in); if (res != read_result_success) return res; /* check the high digit of the checksum */ if (!parseHexDigit(in, &hi)) { fprintf(stderr, "invalid checksum1\n"); return read_result_error; } if (((req->chk >> 4) & 0x000f) != hi) { fprintf(stderr, "invalid checksum2\n"); return read_result_error; } req->state = read_state_checksum_2; __attribute__ ((fallthrough)); case read_state_checksum_2: res = readChar(req, &in); if (res != read_result_success) return res; /* check the high digit of the checksum */ if (!parseHexDigit(in, &lo)) { fprintf(stderr, "invalid checksum3 %c\n", in); return read_result_error; } if ((req->chk & 0x000f) != lo) { fprintf(stderr, "invalid checksum4\n"); return read_result_error; } return read_result_success; default: fprintf(stderr, "invalid state\n"); return read_result_error; } }