use anyhow::{bail, Result}; use atoi::FromRadix16; use tokio::io::{AsyncRead, AsyncReadExt as _}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum RequestKind { Signal, Command, } impl RequestKind { fn name(self) -> &'static str { match self { Self::Signal => "Signal", Self::Command => "Command", } } } pub struct Request<'a> { pub kind: RequestKind, buffer: &'a [u8], } impl std::fmt::Debug for Request<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut ds = f.debug_tuple(self.kind.name()); match self.kind { RequestKind::Signal => ds.field(&self.buffer), RequestKind::Command => match std::str::from_utf8(self.buffer) { Ok(str) => ds.field(&str), Err(_) => ds.field(&self.buffer), }, }; ds.finish() } } impl Request<'_> { pub fn match_str(&mut self, prefix: &str) -> bool { if let Some(new_buffer) = self.buffer.strip_prefix(prefix.as_bytes()) { self.buffer = new_buffer; return true; } false } pub fn match_some_str<'a, I: IntoIterator>( &mut self, prefixes: I, ) -> Option<&'a str> { prefixes.into_iter().find(|&prefix| self.match_str(prefix)) } pub fn match_hex(&mut self) -> Option { match I::from_radix_16(self.buffer) { (_, 0) => None, (val, used) => { self.buffer = self.buffer.split_at(used).1; Some(val) } } } pub fn match_hex_bytes(&mut self, buffer: &mut [u8]) -> bool { if self.buffer.len() < buffer.len() * 2 { return false; } for (i, item) in buffer.iter_mut().enumerate() { match u8::from_radix_16(&self.buffer[(i * 2)..(i * 2) + 2]) { (byte, 2) => *item = byte, _ => return false, }; } self.buffer = self.buffer.split_at(buffer.len()).1; true } pub fn match_bytes(&mut self, buffer: &mut [u8]) -> bool { if self.buffer.len() < buffer.len() { return false; } buffer.copy_from_slice(&self.buffer[0..buffer.len()]); self.buffer = self.buffer.split_at(buffer.len()).1; true } } pub struct RequestSource { reader: R, buffer: Vec, state: RequestReadState, } impl RequestSource { pub fn new(reader: R) -> Self { Self { reader, buffer: vec![], state: RequestReadState::Header, } } pub async fn recv(&mut self) -> Result> { let mut char = self.reader.read_u8().await?; if matches!(self.state, RequestReadState::Start) { self.buffer.clear(); self.state = RequestReadState::Header; } if matches!(self.state, RequestReadState::Header) { // Just ignore positive acks while char == b'+' { char = self.reader.read_u8().await?; } if char == b'-' { bail!("no support for negative acks"); } if char == 0x03 { // This is how the client "cancels an in-flight request" self.buffer.push(char); self.state = RequestReadState::Start; return Ok(Request { kind: RequestKind::Signal, buffer: &self.buffer, }); } if char != b'$' { // Messages are supposed to start with a dollar sign bail!("malformed message"); } self.state = RequestReadState::Body { checksum: 0, escaping: false, }; char = self.reader.read_u8().await?; } while let RequestReadState::Body { checksum, escaping } = &mut self.state { if char == b'#' && !*escaping { self.state = RequestReadState::Checksum { expected: *checksum, actual: 0, digits: 0, }; char = self.reader.read_u8().await?; break; } *checksum = checksum.wrapping_add(char); if *escaping { // escaped character self.buffer.push(char ^ 0x20); *escaping = false; } else if char == b'}' { // next character will be escaped *escaping = true; } else { self.buffer.push(char); } char = self.reader.read_u8().await?; } while let RequestReadState::Checksum { expected, actual, digits, } = &mut self.state { let digit = match char { b'0'..=b'9' => char - b'0', b'a'..=b'f' => char - b'a' + 10, b'A'..=b'F' => char - b'A' + 10, _ => bail!("invalid checksum"), }; *actual = (*actual << 4) + digit; *digits += 1; if *digits == 2 { if *expected != *actual { bail!("mismatched checksum"); } self.state = RequestReadState::Start; return Ok(Request { kind: RequestKind::Command, buffer: &self.buffer, }); } char = self.reader.read_u8().await?; } unreachable!(); } } enum RequestReadState { Start, Header, Body { checksum: u8, escaping: bool, }, Checksum { expected: u8, actual: u8, digits: u8, }, }