dfir_rs/util/
mod.rs

1#![warn(missing_docs)]
2//! Helper utilities for the DFIR syntax.
3
4pub mod clear;
5#[cfg(feature = "dfir_macro")]
6pub mod demux_enum;
7pub mod monotonic_map;
8pub mod multiset;
9pub mod priority_stack;
10pub mod slot_vec;
11pub mod sparse_vec;
12pub mod unsync;
13
14pub mod simulation;
15
16mod monotonic;
17pub use monotonic::*;
18
19mod udp;
20#[cfg(not(target_arch = "wasm32"))]
21pub use udp::*;
22
23mod tcp;
24#[cfg(not(target_arch = "wasm32"))]
25pub use tcp::*;
26
27#[cfg(unix)]
28mod socket;
29#[cfg(unix)]
30pub use socket::*;
31
32#[cfg(feature = "deploy_integration")]
33pub mod deploy;
34
35use std::io::Read;
36use std::net::SocketAddr;
37use std::num::NonZeroUsize;
38use std::process::{Child, ChildStdin, ChildStdout, Stdio};
39use std::task::{Context, Poll};
40
41use futures::Stream;
42use serde::de::DeserializeOwned;
43use serde::ser::Serialize;
44
45/// Persit or delete tuples
46pub enum Persistence<T> {
47    /// Persist T values
48    Persist(T),
49    /// Delete all values that exactly match
50    Delete(T),
51}
52
53/// Persit or delete key-value pairs
54pub enum PersistenceKeyed<K, V> {
55    /// Persist key-value pairs
56    Persist(K, V),
57    /// Delete all tuples that have the key K
58    Delete(K),
59}
60
61/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
62pub fn unbounded_channel<T>() -> (
63    tokio::sync::mpsc::UnboundedSender<T>,
64    tokio_stream::wrappers::UnboundedReceiverStream<T>,
65) {
66    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
67    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
68    (send, recv)
69}
70
71/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
72pub fn unsync_channel<T>(
73    capacity: Option<NonZeroUsize>,
74) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
75    unsync::mpsc::channel(capacity)
76}
77
78/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
79pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
80where
81    S: Stream,
82{
83    let mut stream = Box::pin(stream);
84    std::iter::from_fn(move || {
85        match stream
86            .as_mut()
87            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
88        {
89            Poll::Ready(opt) => opt,
90            Poll::Pending => None,
91        }
92    })
93}
94
95/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
96///
97/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
98/// to retain ownership of your stream.
99pub fn collect_ready<C, S>(stream: S) -> C
100where
101    C: FromIterator<S::Item>,
102    S: Stream,
103{
104    assert!(
105        tokio::runtime::Handle::try_current().is_err(),
106        "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
107    );
108    ready_iter(stream).collect()
109}
110
111/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
112///
113/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
114/// to retain ownership of your stream.
115pub async fn collect_ready_async<C, S>(stream: S) -> C
116where
117    C: Default + Extend<S::Item>,
118    S: Stream,
119{
120    use std::sync::atomic::Ordering;
121
122    // Yield to let any background async tasks send to the stream.
123    tokio::task::yield_now().await;
124
125    let got_any_items = std::sync::atomic::AtomicBool::new(true);
126    let mut unfused_iter =
127        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
128    let mut out = C::default();
129    while got_any_items.swap(false, Ordering::Relaxed) {
130        out.extend(unfused_iter.by_ref());
131        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
132        // that everything gets returned. That is why we yield here and loop.
133        tokio::task::yield_now().await;
134    }
135    out
136}
137
138/// Serialize a message to bytes using bincode.
139pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
140where
141    T: Serialize,
142{
143    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
144}
145
146/// Serialize a message from bytes using bincode.
147pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
148where
149    T: DeserializeOwned,
150{
151    bincode::deserialize(msg.as_ref())
152}
153
154/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
155pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
156    use std::net::ToSocketAddrs;
157    let mut addrs = addr.to_socket_addrs()?;
158    let result = addrs.find(|addr| addr.is_ipv4());
159    match result {
160        Some(addr) => Ok(addr),
161        None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
162    }
163}
164
165/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
166/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
167#[cfg(not(target_arch = "wasm32"))]
168pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
169    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
170    udp_bytes(socket)
171}
172
173/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
174/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
175#[cfg(not(target_arch = "wasm32"))]
176pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
177    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
178    udp_lines(socket)
179}
180
181/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
182///
183/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
184/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
185/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
186/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
187#[cfg(not(target_arch = "wasm32"))]
188pub async fn bind_tcp_bytes(
189    addr: SocketAddr,
190) -> (
191    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
192    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
193    SocketAddr,
194) {
195    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
196        .await
197        .unwrap()
198}
199
200/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
201#[cfg(not(target_arch = "wasm32"))]
202pub async fn bind_tcp_lines(
203    addr: SocketAddr,
204) -> (
205    unsync::mpsc::Sender<(String, SocketAddr)>,
206    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
207    SocketAddr,
208) {
209    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
210        .await
211        .unwrap()
212}
213
214/// The inverse of [`bind_tcp_bytes`].
215///
216/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
217/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
218#[cfg(not(target_arch = "wasm32"))]
219pub fn connect_tcp_bytes() -> (
220    TcpFramedSink<bytes::Bytes>,
221    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
222) {
223    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
224}
225
226/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
227#[cfg(not(target_arch = "wasm32"))]
228pub fn connect_tcp_lines() -> (
229    TcpFramedSink<String>,
230    TcpFramedStream<tokio_util::codec::LinesCodec>,
231) {
232    connect_tcp(tokio_util::codec::LinesCodec::new())
233}
234
235/// Sort a slice using a key fn which returns references.
236///
237/// From addendum in
238/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
239pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
240where
241    F: for<'a> Fn(&'a T) -> &'a K,
242    K: Ord,
243{
244    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
245}
246
247/// Waits for a specific process output before returning.
248///
249/// When a child process is spawned often you want to wait until the child process is ready before
250/// moving on. One way to do that synchronization is by waiting for the child process to output
251/// something and match regex against that output. For example, you could wait until the child
252/// process outputs "Client live!" which would indicate that it is ready to receive input now on
253/// stdin.
254pub fn wait_for_process_output(
255    output_so_far: &mut String,
256    output: &mut ChildStdout,
257    wait_for: &str,
258) {
259    let re = regex::Regex::new(wait_for).unwrap();
260
261    while !re.is_match(output_so_far) {
262        println!("waiting: {}", output_so_far);
263        let mut buffer = [0u8; 1024];
264        let bytes_read = output.read(&mut buffer).unwrap();
265
266        if bytes_read == 0 {
267            panic!();
268        }
269
270        output_so_far.push_str(&String::from_utf8_lossy(&buffer[0..bytes_read]));
271
272        println!("XXX {}", output_so_far);
273    }
274}
275
276/// Terminates the inner [`Child`] process when dropped.
277///
278/// When a `Child` is dropped normally nothing happens but in unit tests you usually want to
279/// terminate the child and wait for it to terminate. `DroppableChild` does that for us.
280pub struct DroppableChild(Child);
281
282impl Drop for DroppableChild {
283    fn drop(&mut self) {
284        #[cfg(target_family = "windows")]
285        let _ = self.0.kill(); // Windows throws `PermissionDenied` if the process has already exited.
286        #[cfg(not(target_family = "windows"))]
287        self.0.kill().unwrap();
288
289        self.0.wait().unwrap();
290    }
291}
292
293/// Run a rust example as a test.
294///
295/// Rust examples are meant to be run by people and have a natural interface for that. This makes
296/// unit testing them cumbersome. This function wraps calling cargo run and piping the stdin/stdout
297/// of the example to easy to handle returned objects. The function also returns a `DroppableChild`
298/// which will ensure that the child processes will be cleaned up appropriately.
299pub fn run_cargo_example(test_name: &str, args: &str) -> (DroppableChild, ChildStdin, ChildStdout) {
300    let mut server = if args.is_empty() {
301        std::process::Command::new("cargo")
302            .args(["run", "-p", "dfir_rs", "--example"])
303            .arg(test_name)
304            .stdin(Stdio::piped())
305            .stdout(Stdio::piped())
306            .spawn()
307            .unwrap()
308    } else {
309        std::process::Command::new("cargo")
310            .args(["run", "-p", "dfir_rs", "--example"])
311            .arg(test_name)
312            .arg("--")
313            .args(args.split(' '))
314            .stdin(Stdio::piped())
315            .stdout(Stdio::piped())
316            .spawn()
317            .unwrap()
318    };
319
320    let stdin = server.stdin.take().unwrap();
321    let stdout = server.stdout.take().unwrap();
322
323    (DroppableChild(server), stdin, stdout)
324}
325
326/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
327///
328/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
329/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
330/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
331/// are loops in the stratum).
332pub fn iter_batches_stream<I>(
333    iter: I,
334    n: usize,
335) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
336where
337    I: IntoIterator + Unpin,
338{
339    let mut count = 0;
340    let mut iter = iter.into_iter();
341    futures::stream::poll_fn(move |ctx| {
342        count += 1;
343        if n < count {
344            count = 0;
345            ctx.waker().wake_by_ref();
346            Poll::Pending
347        } else {
348            Poll::Ready(iter.next())
349        }
350    })
351}
352
353#[cfg(test)]
354mod test {
355    use super::*;
356
357    #[test]
358    pub fn test_collect_ready() {
359        let (send, mut recv) = unbounded_channel::<usize>();
360        for x in 0..1000 {
361            send.send(x).unwrap();
362        }
363        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
364    }
365
366    #[crate::test]
367    pub async fn test_collect_ready_async() {
368        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
369        let (send, mut recv) = unbounded_channel::<usize>();
370        for x in 0..1000 {
371            send.send(x).unwrap();
372        }
373        assert_eq!(
374            1000,
375            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
376        );
377    }
378}