Make request parsing reentrant

This commit is contained in:
Simon Gellis 2024-10-07 00:55:44 -04:00
parent e31ac94a8b
commit e5932da1e5
9 changed files with 229 additions and 142 deletions

116
client.c
View File

@ -34,22 +34,6 @@ bool char_to_hex_digit(char in, char *out) {
return true; return true;
} }
bool hex_digit_to_char(char in, char *out) {
if (in >= '0' && in <= '9') {
*out = in - '0';
return true;
}
if (in >= 'a' && in <= 'f') {
*out = in - 'a' + 10;
return true;
}
if (in >= 'A' && in <= 'F') {
*out = in - 'A' + 10;
return true;
}
return false;
}
bool char_to_hex_digits(char in, char *hi, char *lo) { bool char_to_hex_digits(char in, char *hi, char *lo) {
return char_to_hex_digit((in & 0xf0) >> 4, hi) return char_to_hex_digit((in & 0xf0) >> 4, hi)
&& char_to_hex_digit(in & 0x0f, lo); && char_to_hex_digit(in & 0x0f, lo);
@ -62,106 +46,6 @@ void rdb_client_init(RdbClient *self, int connfd) {
self->should_ack = true; self->should_ack = true;
} }
#define BUFFER_LEN 8096
typedef struct Buffer {
char buf[BUFFER_LEN];
size_t len;
size_t index;
} Buffer;
static Buffer INBUF = {
.len = 0,
.index = 0,
};
static ssize_t read_next_char(int connfd, char *in) {
if (INBUF.index >= INBUF.len) {
ssize_t inlen = read(connfd, INBUF.buf, BUFFER_LEN);
if (inlen < 1) {
// either we got an error (-1) or the connection closed (0)
return inlen;
}
INBUF.index = 0;
INBUF.len = inlen;
}
*in = INBUF.buf[INBUF.index++];
return 1;
}
ssize_t rdb_client_read(RdbClient *self, char *buf, size_t len) {
// read any acknowledgements and continue
char in;
do {
ssize_t res = read_next_char(self->connfd, &in);
if (res < 1) return res;
} while (in == '+');
if (in == '\x03') {
// interrupt from the server
*buf = in;
return 1;
}
if (in == '-') {
// we don't handle resending right now
return -1;
}
// now, expect to be at the start of a packet
if (in != '$') {
fprintf(stderr, "unexpected packet start \"%c\"", in);
return -1;
}
size_t outlen = 0;
char chk = 0;
while (1) {
ssize_t res = read_next_char(self->connfd, &in);
if (res < 1) return res;
if (in == '#') {
// end of packet, checksum next
break;
}
if (outlen >= len) {
// ran out of room in the buffer
fprintf(stderr, "packet too big for buffer\n");
return -1;
}
if (in == '}') {
// escape sequence
chk += in;
res = read_next_char(self->connfd, &in);
if (res < 1) return res;
chk += in;
buf[outlen++] = in ^ 0x20;
} else {
chk += in;
buf[outlen++] = in;
}
};
// validate the checksum
char hi, lo;
ssize_t res = read_next_char(self->connfd, &in);
if (res < 1) return res;
if (!hex_digit_to_char(in, &hi)) return -1;
res = read_next_char(self->connfd, &in);
if (res < 1) return res;
if (!hex_digit_to_char(in, &lo)) return -1;
char real_chk = (hi << 4) | lo;
if (real_chk != chk) {
fprintf(stderr, "invalid checksum\n");
return -1;
}
return outlen;
}
void rdb_client_begin_packet(RdbClient *self) { void rdb_client_begin_packet(RdbClient *self) {
self->len = 0; self->len = 0;
self->chk = 0; self->chk = 0;

View File

@ -1,4 +1,5 @@
#include <cmdbuf.h> #include <cmdbuf.h>
#include <hex.h>
#include <string.h> #include <string.h>
bool cmd_match_str(CommandBuf *cmd, const char *str) { bool cmd_match_str(CommandBuf *cmd, const char *str) {
@ -23,22 +24,6 @@ bool cmd_match_only_str(CommandBuf *cmd, const char *str) {
return false; return false;
} }
static bool parse_hex_digit(char digit, char *out) {
if (digit >= '0' && digit <= '9') {
*out = digit - '0';
return true;
}
if (digit >= 'a' && digit <= 'f') {
*out = digit - 'a' + 10;
return true;
}
if (digit >= 'A' && digit <= 'F') {
*out = digit - 'A' + 10;
return true;
}
return false;
}
bool cmd_match_hex_number(CommandBuf *cmd, uint32_t *value) { bool cmd_match_hex_number(CommandBuf *cmd, uint32_t *value) {
size_t read = 0; size_t read = 0;
size_t max_len = cmd->len; size_t max_len = cmd->len;

17
hex.c Normal file
View File

@ -0,0 +1,17 @@
#include <hex.h>
bool parse_hex_digit(char digit, char *out) {
if (digit >= '0' && digit <= '9') {
*out = digit - '0';
return true;
}
if (digit >= 'a' && digit <= 'f') {
*out = digit - 'a' + 10;
return true;
}
if (digit >= 'A' && digit <= 'F') {
*out = digit - 'A' + 10;
return true;
}
return false;
}

View File

@ -17,7 +17,6 @@ typedef struct RdbClient {
} RdbClient; } RdbClient;
void rdb_client_init(RdbClient *self, int connfd); void rdb_client_init(RdbClient *self, int connfd);
ssize_t rdb_client_read(RdbClient *self, char *buf, size_t len);
void rdb_client_begin_packet(RdbClient *self); void rdb_client_begin_packet(RdbClient *self);
bool rdb_client_write_str(RdbClient *self, const char *str); bool rdb_client_write_str(RdbClient *self, const char *str);
bool rdb_client_write_str_hex(RdbClient *self, const char *str); bool rdb_client_write_str_hex(RdbClient *self, const char *str);

8
include/hex.h Normal file
View File

@ -0,0 +1,8 @@
#ifndef V810_HEX_H_
#define V810_HEX_H_
#include <stdbool.h>
bool parse_hex_digit(char digit, char *out);
#endif

41
include/request.h Normal file
View File

@ -0,0 +1,41 @@
#ifndef RDBSERVER_REQUEST_H
#define RDBSERVER_REQUEST_H
#include <stddef.h>
#define INBUF_LEN 256
typedef enum rdb_read_result_t {
read_result_success,
read_result_error,
read_result_pending,
read_result_disconnected,
} rdb_read_result_t;
typedef enum rdb_read_state_t {
read_state_header,
read_state_body,
read_state_body_escape,
read_state_checksum_1,
read_state_checksum_2,
} rdb_read_state_t;
typedef struct RdbRequest {
int connfd;
struct Buffer {
char buf[INBUF_LEN];
size_t len;
size_t index;
} inbuf;
rdb_read_state_t state;
char *cmd;
size_t cmdlen;
size_t outlen;
char chk;
} RdbRequest;
void rdb_request_init(RdbRequest *req, int connfd, char *cmd, size_t cmdlen);
void rdb_request_reset(RdbRequest *req);
rdb_read_result_t rdb_request_read(RdbRequest *req, size_t *len);
#endif

23
main.c
View File

@ -1,5 +1,6 @@
#include <client.h> #include <client.h>
#include <cmdbuf.h> #include <cmdbuf.h>
#include <request.h>
#include <stdbool.h> #include <stdbool.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
@ -209,24 +210,30 @@ int handle_command(RdbClient *client, CommandBuf *cmd, VB *sim) {
} }
return rdb_client_send_packet(client); return rdb_client_send_packet(client);
} }
fprintf(stderr, "Unrecognized command."); fprintf(stderr, "Unrecognized command.\n");
return rdb_client_send_packet(client); return rdb_client_send_packet(client);
} }
int server(int connfd, VB *sim) { int server(int connfd, VB *sim) {
RdbRequest req;
RdbClient client; RdbClient client;
char buf[BUFLEN];
size_t len;
rdb_request_init(&req, connfd, buf, BUFLEN);
rdb_client_init(&client, connfd); rdb_client_init(&client, connfd);
char buf[BUFLEN];
while (1) { while (1) {
ssize_t len = rdb_client_read(&client, buf, BUFLEN); rdb_read_result_t result = rdb_request_read(&req, &len);
if (len < 0) { if (result == read_result_error) {
perror("could not read data"); return -1;
return -len; } else if (result == read_result_disconnected) {
} else if (len == 0) {
printf("client has disconnected\n"); printf("client has disconnected\n");
return 0; return 0;
} else if (result == read_result_pending) {
// TODO: should run the emulator while we wait
continue;
} else { } else {
printf("received command \"%.*s\"\n", (int) len, buf); printf("received command \"%.*s\"\n", (int) len, buf);
fflush(stdout); fflush(stdout);
@ -237,7 +244,7 @@ int server(int connfd, VB *sim) {
if (res != 0) { if (res != 0) {
return res; return res;
} }
// +$QStartNoAckMode#b0 rdb_request_reset(&req);
} }
} }

View File

@ -1,6 +1,6 @@
build: build:
@mkdir -p build @mkdir -p build
@gcc main.c client.c cmdbuf.c ../vbtest/vb.c -I include -I ../vbtest \ @gcc main.c client.c cmdbuf.c hex.c request.c ../vbtest/vb.c -I include -I ../vbtest \
-Werror -Wall -Wextra -Wpedantic \ -Werror -Wall -Wextra -Wpedantic \
-Wno-unused-parameter -Wno-unused-function \ -Wno-unused-parameter -Wno-unused-function \
-o ./build/rdb -o ./build/rdb

146
request.c Normal file
View File

@ -0,0 +1,146 @@
#include <errno.h>
#include <hex.h>
#include <request.h>
#include <stdio.h>
#include <unistd.h>
void rdb_request_init(RdbRequest *req, int connfd, char *cmd, size_t cmdlen) {
req->connfd = connfd;
req->cmd = cmd;
req->cmdlen = cmdlen;
rdb_request_reset(req);
}
void rdb_request_reset(RdbRequest *req) {
req->state = read_state_header;
req->inbuf.len = 0;
req->inbuf.index = 0;
req->outlen = 0;
req->chk = 0;
}
static rdb_read_result_t read_char(RdbRequest *req, char *in) {
if (req->inbuf.index >= req->inbuf.len) {
ssize_t inlen = read(req->connfd, req->inbuf.buf, INBUF_LEN);
if (inlen == 0) {
return read_result_disconnected;
}
if (inlen < 0) {
if (errno == EAGAIN) {
return read_result_pending;
} else {
perror("could not read incoming packet");
return read_result_error;
}
}
req->inbuf.index = 0;
req->inbuf.len = inlen;
}
*in = req->inbuf.buf[req->inbuf.index++];
return read_result_success;
}
rdb_read_result_t rdb_request_read(RdbRequest *req, size_t *len) {
rdb_read_result_t res;
char in;
switch (req->state) {
case read_state_header:
// read any acknowledgements and continue
do {
res = read_char(req, &in);
if (res != read_result_success) return res;
} while (in == '+');
if (in == '-') {
fprintf(stderr, "negative ack not supported\n");
return read_result_error;
}
if (in == '\x03') {
// interrupt from the server
req->cmd[0] = in;
req->outlen = 1;
return read_result_success;
}
// now, we should be at the start of a packet
if (in != '$') {
fprintf(stderr, "unexpected packet start \"%c\"\n", in);
return read_result_error;
}
req->state = read_state_body;
__attribute__ ((fallthrough));
case read_state_body:
case read_state_body_escape:
while (1) {
res = read_char(req, &in);
if (res != read_result_success) return res;
if (req->state == read_state_body && in == '#') {
// end of packet body
break;
}
req->chk += in;
if (req->state == read_state_body && in == '}') {
// escape sequence
req->state = read_state_body_escape;
continue;
}
if (req->outlen >= req->cmdlen) {
// ran out of room in the buffer
fprintf(stderr, "packet too big for buffer\n");
return read_result_error;
}
if (req->state == read_state_body_escape) {
req->cmd[req->outlen++] = in ^ 0x20;
req->state = read_state_body;
} else {
req->cmd[req->outlen++] = in;
}
}
req->state = read_state_checksum_1;
__attribute__ ((fallthrough));
case read_state_checksum_1:
res = read_char(req, &in);
if (res != read_result_success) return res;
// check the high digit of the checksum
char hi;
if (!parse_hex_digit(in, &hi)) {
fprintf(stderr, "invalid checksum1\n");
return read_result_error;
}
if (((req->chk >> 4) & 0x000f) != hi) {
fprintf(stderr, "invalid checksum2\n");
return read_result_error;
}
req->state = read_state_checksum_2;
__attribute__ ((fallthrough));
case read_state_checksum_2:
res = read_char(req, &in);
if (res != read_result_success) return res;
// check the high digit of the checksum
char lo;
if (!parse_hex_digit(in, &lo)) {
fprintf(stderr, "invalid checksum3 %c\n", in);
return read_result_error;
}
if ((req->chk & 0x000f) != lo) {
fprintf(stderr, "invalid checksum4\n");
return read_result_error;
}
*len = req->outlen;
return read_result_success;
default:
fprintf(stderr, "invalid state\n");
return read_result_error;
}
}