lemur/src/gdbserver/request.rs

203 lines
5.8 KiB
Rust
Raw Normal View History

use anyhow::{bail, Result};
2025-01-01 23:15:41 +00:00
use atoi::FromRadix16;
use tokio::io::{AsyncRead, AsyncReadExt as _};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RequestKind {
Signal,
Command,
}
2025-01-01 23:15:41 +00:00
impl RequestKind {
fn name(self) -> &'static str {
match self {
Self::Signal => "Signal",
Self::Command => "Command",
}
}
}
pub struct Request<'a> {
pub kind: RequestKind,
2025-01-01 23:15:41 +00:00
buffer: &'a [u8],
}
impl std::fmt::Debug for Request<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut ds = f.debug_tuple(self.kind.name());
match self.kind {
RequestKind::Signal => ds.field(&self.buffer),
RequestKind::Command => match std::str::from_utf8(self.buffer) {
Ok(str) => ds.field(&str),
Err(_) => ds.field(&self.buffer),
},
};
ds.finish()
}
}
2025-01-04 06:04:21 +00:00
impl Request<'_> {
pub fn match_str(&mut self, prefix: &str) -> bool {
2025-01-01 23:15:41 +00:00
if let Some(new_buffer) = self.buffer.strip_prefix(prefix.as_bytes()) {
self.buffer = new_buffer;
return true;
}
false
}
2025-01-01 23:15:41 +00:00
2025-01-04 06:04:21 +00:00
pub fn match_some_str<'a, I: IntoIterator<Item = &'a str>>(
&mut self,
prefixes: I,
) -> Option<&'a str> {
prefixes.into_iter().find(|&prefix| self.match_str(prefix))
}
2025-01-01 23:15:41 +00:00
pub fn match_hex<I: FromRadix16>(&mut self) -> Option<I> {
match I::from_radix_16(self.buffer) {
(_, 0) => None,
(val, used) => {
self.buffer = self.buffer.split_at(used).1;
Some(val)
}
}
}
2025-01-18 05:20:57 +00:00
pub fn match_hex_bytes(&mut self, buffer: &mut [u8]) -> bool {
if self.buffer.len() < buffer.len() * 2 {
return false;
}
for (i, item) in buffer.iter_mut().enumerate() {
match u8::from_radix_16(&self.buffer[(i * 2)..(i * 2) + 2]) {
(byte, 2) => *item = byte,
_ => return false,
};
}
self.buffer = self.buffer.split_at(buffer.len()).1;
true
}
2025-01-18 06:03:22 +00:00
pub fn match_bytes(&mut self, buffer: &mut [u8]) -> bool {
if self.buffer.len() < buffer.len() {
return false;
}
buffer.copy_from_slice(&self.buffer[0..buffer.len()]);
self.buffer = self.buffer.split_at(buffer.len()).1;
true
}
}
2025-01-04 06:04:21 +00:00
pub struct RequestSource<R> {
reader: R,
buffer: Vec<u8>,
state: RequestReadState,
}
impl<R: AsyncRead + Unpin> RequestSource<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
buffer: vec![],
state: RequestReadState::Header,
}
}
pub async fn recv(&mut self) -> Result<Request<'_>> {
let mut char = self.reader.read_u8().await?;
if matches!(self.state, RequestReadState::Start) {
self.buffer.clear();
self.state = RequestReadState::Header;
}
if matches!(self.state, RequestReadState::Header) {
// Just ignore positive acks
while char == b'+' {
char = self.reader.read_u8().await?;
}
if char == b'-' {
bail!("no support for negative acks");
}
if char == 0x03 {
// This is how the client "cancels an in-flight request"
self.buffer.push(char);
self.state = RequestReadState::Start;
return Ok(Request {
kind: RequestKind::Signal,
buffer: &self.buffer,
});
}
if char != b'$' {
// Messages are supposed to start with a dollar sign
bail!("malformed message");
}
self.state = RequestReadState::Body {
checksum: 0,
escaping: false,
};
char = self.reader.read_u8().await?;
}
while let RequestReadState::Body { checksum, escaping } = &mut self.state {
if char == b'#' && !*escaping {
self.state = RequestReadState::Checksum {
expected: *checksum,
actual: 0,
digits: 0,
};
char = self.reader.read_u8().await?;
break;
}
*checksum = checksum.wrapping_add(char);
if *escaping {
// escaped character
self.buffer.push(char ^ 0x20);
*escaping = false;
} else if char == b'}' {
// next character will be escaped
*escaping = true;
} else {
self.buffer.push(char);
}
char = self.reader.read_u8().await?;
}
while let RequestReadState::Checksum {
expected,
actual,
digits,
} = &mut self.state
{
let digit = match char {
b'0'..=b'9' => char - b'0',
b'a'..=b'f' => char - b'a' + 10,
b'A'..=b'F' => char - b'A' + 10,
_ => bail!("invalid checksum"),
};
*actual = (*actual << 4) + digit;
*digits += 1;
if *digits == 2 {
if *expected != *actual {
bail!("mismatched checksum");
}
self.state = RequestReadState::Start;
return Ok(Request {
kind: RequestKind::Command,
buffer: &self.buffer,
});
}
char = self.reader.read_u8().await?;
}
unreachable!();
}
}
enum RequestReadState {
Start,
Header,
Body {
checksum: u8,
escaping: bool,
},
Checksum {
expected: u8,
actual: u8,
digits: u8,
},
}