2025-01-01 22:02:14 +00:00
|
|
|
use anyhow::{bail, Result};
|
2025-01-01 23:15:41 +00:00
|
|
|
use atoi::FromRadix16;
|
2025-01-01 22:02:14 +00:00
|
|
|
use tokio::io::{AsyncRead, AsyncReadExt as _};
|
|
|
|
|
|
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
|
|
|
pub enum RequestKind {
|
|
|
|
Signal,
|
|
|
|
Command,
|
|
|
|
}
|
2025-01-01 23:15:41 +00:00
|
|
|
impl RequestKind {
|
|
|
|
fn name(self) -> &'static str {
|
|
|
|
match self {
|
|
|
|
Self::Signal => "Signal",
|
|
|
|
Self::Command => "Command",
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2025-01-01 22:02:14 +00:00
|
|
|
|
|
|
|
pub struct Request<'a> {
|
|
|
|
pub kind: RequestKind,
|
2025-01-01 23:15:41 +00:00
|
|
|
buffer: &'a [u8],
|
|
|
|
}
|
|
|
|
|
|
|
|
impl std::fmt::Debug for Request<'_> {
|
|
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
|
|
let mut ds = f.debug_tuple(self.kind.name());
|
|
|
|
match self.kind {
|
|
|
|
RequestKind::Signal => ds.field(&self.buffer),
|
|
|
|
RequestKind::Command => match std::str::from_utf8(self.buffer) {
|
|
|
|
Ok(str) => ds.field(&str),
|
|
|
|
Err(_) => ds.field(&self.buffer),
|
|
|
|
},
|
|
|
|
};
|
|
|
|
ds.finish()
|
|
|
|
}
|
2025-01-01 22:02:14 +00:00
|
|
|
}
|
|
|
|
|
2025-01-04 06:04:21 +00:00
|
|
|
impl Request<'_> {
|
2025-01-01 22:02:14 +00:00
|
|
|
pub fn match_str(&mut self, prefix: &str) -> bool {
|
2025-01-01 23:15:41 +00:00
|
|
|
if let Some(new_buffer) = self.buffer.strip_prefix(prefix.as_bytes()) {
|
|
|
|
self.buffer = new_buffer;
|
2025-01-01 22:02:14 +00:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
false
|
|
|
|
}
|
2025-01-01 23:15:41 +00:00
|
|
|
|
2025-01-04 06:04:21 +00:00
|
|
|
pub fn match_some_str<'a, I: IntoIterator<Item = &'a str>>(
|
|
|
|
&mut self,
|
|
|
|
prefixes: I,
|
|
|
|
) -> Option<&'a str> {
|
|
|
|
prefixes.into_iter().find(|&prefix| self.match_str(prefix))
|
|
|
|
}
|
|
|
|
|
2025-01-01 23:15:41 +00:00
|
|
|
pub fn match_hex<I: FromRadix16>(&mut self) -> Option<I> {
|
|
|
|
match I::from_radix_16(self.buffer) {
|
|
|
|
(_, 0) => None,
|
|
|
|
(val, used) => {
|
|
|
|
self.buffer = self.buffer.split_at(used).1;
|
|
|
|
Some(val)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2025-01-18 05:20:57 +00:00
|
|
|
|
|
|
|
pub fn match_hex_bytes(&mut self, buffer: &mut [u8]) -> bool {
|
|
|
|
if self.buffer.len() < buffer.len() * 2 {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
for (i, item) in buffer.iter_mut().enumerate() {
|
|
|
|
match u8::from_radix_16(&self.buffer[(i * 2)..(i * 2) + 2]) {
|
|
|
|
(byte, 2) => *item = byte,
|
|
|
|
_ => return false,
|
|
|
|
};
|
|
|
|
}
|
|
|
|
self.buffer = self.buffer.split_at(buffer.len()).1;
|
|
|
|
true
|
|
|
|
}
|
2025-01-18 06:03:22 +00:00
|
|
|
|
|
|
|
pub fn match_bytes(&mut self, buffer: &mut [u8]) -> bool {
|
|
|
|
if self.buffer.len() < buffer.len() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
buffer.copy_from_slice(&self.buffer[0..buffer.len()]);
|
|
|
|
self.buffer = self.buffer.split_at(buffer.len()).1;
|
|
|
|
true
|
|
|
|
}
|
2025-01-01 22:02:14 +00:00
|
|
|
}
|
2025-01-04 06:04:21 +00:00
|
|
|
|
|
|
|
pub struct RequestSource<R> {
|
|
|
|
reader: R,
|
|
|
|
buffer: Vec<u8>,
|
|
|
|
state: RequestReadState,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<R: AsyncRead + Unpin> RequestSource<R> {
|
|
|
|
pub fn new(reader: R) -> Self {
|
|
|
|
Self {
|
|
|
|
reader,
|
|
|
|
buffer: vec![],
|
|
|
|
state: RequestReadState::Header,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn recv(&mut self) -> Result<Request<'_>> {
|
|
|
|
let mut char = self.reader.read_u8().await?;
|
|
|
|
if matches!(self.state, RequestReadState::Start) {
|
|
|
|
self.buffer.clear();
|
|
|
|
self.state = RequestReadState::Header;
|
|
|
|
}
|
|
|
|
if matches!(self.state, RequestReadState::Header) {
|
|
|
|
// Just ignore positive acks
|
|
|
|
while char == b'+' {
|
|
|
|
char = self.reader.read_u8().await?;
|
|
|
|
}
|
|
|
|
if char == b'-' {
|
|
|
|
bail!("no support for negative acks");
|
|
|
|
}
|
|
|
|
if char == 0x03 {
|
|
|
|
// This is how the client "cancels an in-flight request"
|
|
|
|
self.buffer.push(char);
|
|
|
|
self.state = RequestReadState::Start;
|
|
|
|
return Ok(Request {
|
|
|
|
kind: RequestKind::Signal,
|
|
|
|
buffer: &self.buffer,
|
|
|
|
});
|
|
|
|
}
|
|
|
|
if char != b'$' {
|
|
|
|
// Messages are supposed to start with a dollar sign
|
|
|
|
bail!("malformed message");
|
|
|
|
}
|
|
|
|
self.state = RequestReadState::Body {
|
|
|
|
checksum: 0,
|
|
|
|
escaping: false,
|
|
|
|
};
|
|
|
|
char = self.reader.read_u8().await?;
|
|
|
|
}
|
|
|
|
while let RequestReadState::Body { checksum, escaping } = &mut self.state {
|
|
|
|
if char == b'#' && !*escaping {
|
|
|
|
self.state = RequestReadState::Checksum {
|
|
|
|
expected: *checksum,
|
|
|
|
actual: 0,
|
|
|
|
digits: 0,
|
|
|
|
};
|
|
|
|
char = self.reader.read_u8().await?;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
*checksum = checksum.wrapping_add(char);
|
|
|
|
|
|
|
|
if *escaping {
|
|
|
|
// escaped character
|
|
|
|
self.buffer.push(char ^ 0x20);
|
|
|
|
*escaping = false;
|
|
|
|
} else if char == b'}' {
|
|
|
|
// next character will be escaped
|
|
|
|
*escaping = true;
|
|
|
|
} else {
|
|
|
|
self.buffer.push(char);
|
|
|
|
}
|
|
|
|
char = self.reader.read_u8().await?;
|
|
|
|
}
|
|
|
|
while let RequestReadState::Checksum {
|
|
|
|
expected,
|
|
|
|
actual,
|
|
|
|
digits,
|
|
|
|
} = &mut self.state
|
|
|
|
{
|
|
|
|
let digit = match char {
|
|
|
|
b'0'..=b'9' => char - b'0',
|
|
|
|
b'a'..=b'f' => char - b'a' + 10,
|
|
|
|
b'A'..=b'F' => char - b'A' + 10,
|
|
|
|
_ => bail!("invalid checksum"),
|
|
|
|
};
|
|
|
|
*actual = (*actual << 4) + digit;
|
|
|
|
*digits += 1;
|
|
|
|
if *digits == 2 {
|
|
|
|
if *expected != *actual {
|
|
|
|
bail!("mismatched checksum");
|
|
|
|
}
|
|
|
|
self.state = RequestReadState::Start;
|
|
|
|
return Ok(Request {
|
|
|
|
kind: RequestKind::Command,
|
|
|
|
buffer: &self.buffer,
|
|
|
|
});
|
|
|
|
}
|
|
|
|
char = self.reader.read_u8().await?;
|
|
|
|
}
|
|
|
|
unreachable!();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
enum RequestReadState {
|
|
|
|
Start,
|
|
|
|
Header,
|
|
|
|
Body {
|
|
|
|
checksum: u8,
|
|
|
|
escaping: bool,
|
|
|
|
},
|
|
|
|
Checksum {
|
|
|
|
expected: u8,
|
|
|
|
actual: u8,
|
|
|
|
digits: u8,
|
|
|
|
},
|
|
|
|
}
|