dfir_lang/graph/ops/
mod.rs

1//! DFIR's operators
2
3use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use proc_macro2::{Ident, Literal, Span, TokenStream};
9use quote::quote_spanned;
10use serde::{Deserialize, Serialize};
11use slotmap::Key;
12use syn::punctuated::Punctuated;
13use syn::{Expr, Token, parse_quote_spanned};
14
15use super::{
16    GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
17    PortIndexValue,
18};
19use crate::diagnostic::Diagnostic;
20use crate::parse::{Operator, PortIndex};
21
22/// The delay (soft barrier) type, for each input to an operator if needed.
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
24pub enum DelayType {
25    /// Input must be collected over the preceeding stratum.
26    Stratum,
27    /// Monotone accumulation: can delay to reduce flow rate, but also correct to emit "early"
28    MonotoneAccum,
29    /// Input must be collected over the previous tick.
30    Tick,
31    /// Input must be collected over the previous tick but also not cause a new tick to occur.
32    TickLazy,
33}
34
35/// Specification of the named (or unnamed) ports for an operator's inputs or outputs.
36pub enum PortListSpec {
37    /// Any number of unnamed (or optionally named) ports.
38    Variadic,
39    /// A specific number of named ports.
40    Fixed(Punctuated<PortIndex, Token![,]>),
41}
42
43/// An instance of this struct represents a single dfir operator.
44pub struct OperatorConstraints {
45    /// Operator's name.
46    pub name: &'static str,
47    /// Operator categories, for docs.
48    pub categories: &'static [OperatorCategory],
49
50    // TODO: generic argument ranges.
51    /// Input argument range required to not show an error.
52    pub hard_range_inn: &'static dyn RangeTrait<usize>,
53    /// Input argument range required to not show a warning.
54    pub soft_range_inn: &'static dyn RangeTrait<usize>,
55    /// Output argument range required to not show an error.
56    pub hard_range_out: &'static dyn RangeTrait<usize>,
57    /// Output argument range required to not show an warning.
58    pub soft_range_out: &'static dyn RangeTrait<usize>,
59    /// Number of arguments i.e. `operator(a, b, c)` has `num_args = 3`.
60    pub num_args: usize,
61    /// How many persistence lifetime arguments can be provided.
62    pub persistence_args: &'static dyn RangeTrait<usize>,
63    // /// How many (non-persistence) lifetime arguments can be provided.
64    // pub lifetime_args: &'static dyn RangeTrait<usize>,
65    /// How many generic type arguments can be provided.
66    pub type_args: &'static dyn RangeTrait<usize>,
67    /// If this operator receives external inputs and therefore must be in
68    /// stratum 0.
69    pub is_external_input: bool,
70    /// If this operator has a singleton reference output. For stateful operators.
71    /// If true, [`WriteContextArgs::singleton_output_ident`] will be set to a meaningful value in
72    /// the [`Self::write_fn`] invocation.
73    pub has_singleton_output: bool,
74    /// Flo semantics type.
75    pub flo_type: Option<FloType>,
76
77    /// What named or numbered input ports to expect?
78    pub ports_inn: Option<fn() -> PortListSpec>,
79    /// What named or numbered output ports to expect?
80    pub ports_out: Option<fn() -> PortListSpec>,
81
82    /// Determines if this input must be preceeded by a stratum barrier.
83    pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
84    /// The operator's codegen. Returns code that is emited is several different locations. See [`OperatorWriteOutput`].
85    pub write_fn: WriteFn,
86}
87
88/// Type alias for [`OperatorConstraints::write_fn`]'s type.
89pub type WriteFn =
90    fn(&WriteContextArgs<'_>, &mut Vec<Diagnostic>) -> Result<OperatorWriteOutput, ()>;
91
92impl Debug for OperatorConstraints {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        f.debug_struct("OperatorConstraints")
95            .field("name", &self.name)
96            .field("hard_range_inn", &self.hard_range_inn)
97            .field("soft_range_inn", &self.soft_range_inn)
98            .field("hard_range_out", &self.hard_range_out)
99            .field("soft_range_out", &self.soft_range_out)
100            .field("num_args", &self.num_args)
101            .field("persistence_args", &self.persistence_args)
102            .field("type_args", &self.type_args)
103            .field("is_external_input", &self.is_external_input)
104            .field("ports_inn", &self.ports_inn)
105            .field("ports_out", &self.ports_out)
106            // .field("input_delaytype_fn", &self.input_delaytype_fn)
107            // .field("flow_prop_fn", &self.flow_prop_fn)
108            // .field("write_fn", &self.write_fn)
109            .finish()
110    }
111}
112
113/// The code generated and returned by a [`OperatorConstraints::write_fn`].
114#[derive(Default)]
115#[non_exhaustive]
116pub struct OperatorWriteOutput {
117    /// Code which runs once outside the subgraph to set up any external stuff
118    /// like state API stuff, external chanels, network connections, etc.
119    pub write_prologue: TokenStream,
120    /// Iterator (or pusherator) code inside the subgraphs. The code for each
121    /// operator is emitted in order.
122    ///
123    /// Emitted code should assign to [`WriteContextArgs.ident`] and use
124    /// [`WriteIteratorArgs.inputs`] (pull iterators) or
125    /// [`WriteIteratorArgs.outputs`] (pusherators).
126    pub write_iterator: TokenStream,
127    /// Code which runs after iterators have been run. Mainly for flushing IO.
128    pub write_iterator_after: TokenStream,
129}
130
131/// Convenience range: zero or more (any number).
132pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
133/// Convenience range: exactly zero.
134pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
135/// Convenience range: exactly one.
136pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
137
138/// Helper to write the `write_iterator` portion of [`OperatorConstraints::write_fn`] output for
139/// unary identity operators.
140pub fn identity_write_iterator_fn(
141    &WriteContextArgs {
142        root,
143        op_span,
144        ident,
145        inputs,
146        outputs,
147        is_pull,
148        op_inst:
149            OperatorInstance {
150                generics: OpInstGenerics { type_args, .. },
151                ..
152            },
153        ..
154    }: &WriteContextArgs,
155) -> TokenStream {
156    let generic_type = type_args
157        .first()
158        .map(quote::ToTokens::to_token_stream)
159        .unwrap_or(quote_spanned!(op_span=> _));
160
161    if is_pull {
162        let input = &inputs[0];
163        quote_spanned! {op_span=>
164            let #ident = {
165                fn check_input<Iter: ::std::iter::Iterator<Item = Item>, Item>(iter: Iter) -> impl ::std::iter::Iterator<Item = Item> { iter }
166                check_input::<_, #generic_type>(#input)
167            };
168        }
169    } else {
170        let output = &outputs[0];
171        quote_spanned! {op_span=>
172            let #ident = {
173                fn check_output<Push: #root::pusherator::Pusherator<Item = Item>, Item>(push: Push) -> impl #root::pusherator::Pusherator<Item = Item> { push }
174                check_output::<_, #generic_type>(#output)
175            };
176        }
177    }
178}
179
180/// [`OperatorConstraints::write_fn`] for unary identity operators.
181pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
182    let write_iterator = identity_write_iterator_fn(write_context_args);
183    Ok(OperatorWriteOutput {
184        write_iterator,
185        ..Default::default()
186    })
187};
188
189/// Helper to write the `write_iterator` portion of [`OperatorConstraints::write_fn`] output for
190/// the null operator - an operator that ignores all inputs and produces no output.
191pub fn null_write_iterator_fn(
192    &WriteContextArgs {
193        root,
194        op_span,
195        ident,
196        inputs,
197        outputs,
198        is_pull,
199        op_inst:
200            OperatorInstance {
201                generics: OpInstGenerics { type_args, .. },
202                ..
203            },
204        ..
205    }: &WriteContextArgs,
206) -> TokenStream {
207    let default_type = parse_quote_spanned! {op_span=> _};
208    let iter_type = type_args.first().unwrap_or(&default_type);
209
210    if is_pull {
211        quote_spanned! {op_span=>
212            #(
213                #inputs.for_each(std::mem::drop);
214            )*
215            let #ident = std::iter::empty::<#iter_type>();
216        }
217    } else {
218        quote_spanned! {op_span=>
219            #[allow(clippy::let_unit_value)]
220            let _ = (#(#outputs),*);
221            let #ident = #root::pusherator::null::Null::<#iter_type>::new();
222        }
223    }
224}
225
226/// [`OperatorConstraints::write_fn`] for the null operator - an operator that ignores all inputs
227/// and produces no output.
228pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
229    let write_iterator = null_write_iterator_fn(write_context_args);
230    Ok(OperatorWriteOutput {
231        write_iterator,
232        ..Default::default()
233    })
234};
235
236macro_rules! declare_ops {
237    ( $( $mod:ident :: $op:ident, )* ) => {
238        $( pub(crate) mod $mod; )*
239        /// All DFIR operators.
240        pub const OPERATORS: &[OperatorConstraints] = &[
241            $( $mod :: $op, )*
242        ];
243    };
244}
245declare_ops![
246    all_iterations::ALL_ITERATIONS,
247    all_once::ALL_ONCE,
248    anti_join::ANTI_JOIN,
249    anti_join_multiset::ANTI_JOIN_MULTISET,
250    assert::ASSERT,
251    assert_eq::ASSERT_EQ,
252    batch::BATCH,
253    chain::CHAIN,
254    _counter::_COUNTER,
255    cross_join::CROSS_JOIN,
256    cross_join_multiset::CROSS_JOIN_MULTISET,
257    cross_singleton::CROSS_SINGLETON,
258    demux::DEMUX,
259    demux_enum::DEMUX_ENUM,
260    dest_file::DEST_FILE,
261    dest_sink::DEST_SINK,
262    dest_sink_serde::DEST_SINK_SERDE,
263    difference::DIFFERENCE,
264    difference_multiset::DIFFERENCE_MULTISET,
265    enumerate::ENUMERATE,
266    filter::FILTER,
267    filter_map::FILTER_MAP,
268    flat_map::FLAT_MAP,
269    flatten::FLATTEN,
270    fold::FOLD,
271    for_each::FOR_EACH,
272    identity::IDENTITY,
273    initialize::INITIALIZE,
274    inspect::INSPECT,
275    join::JOIN,
276    join_fused::JOIN_FUSED,
277    join_fused_lhs::JOIN_FUSED_LHS,
278    join_fused_rhs::JOIN_FUSED_RHS,
279    join_multiset::JOIN_MULTISET,
280    fold_keyed::FOLD_KEYED,
281    reduce_keyed::REDUCE_KEYED,
282    repeat_n::REPEAT_N,
283    // last_iteration::LAST_ITERATION,
284    lattice_bimorphism::LATTICE_BIMORPHISM,
285    _lattice_fold_batch::_LATTICE_FOLD_BATCH,
286    lattice_fold::LATTICE_FOLD,
287    _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
288    lattice_reduce::LATTICE_REDUCE,
289    map::MAP,
290    union::UNION,
291    multiset_delta::MULTISET_DELTA,
292    next_iteration::NEXT_ITERATION,
293    next_stratum::NEXT_STRATUM,
294    defer_signal::DEFER_SIGNAL,
295    defer_tick::DEFER_TICK,
296    defer_tick_lazy::DEFER_TICK_LAZY,
297    null::NULL,
298    partition::PARTITION,
299    persist::PERSIST,
300    persist_mut::PERSIST_MUT,
301    persist_mut_keyed::PERSIST_MUT_KEYED,
302    prefix::PREFIX,
303    py_udf::PY_UDF,
304    reduce::REDUCE,
305    spin::SPIN,
306    sort::SORT,
307    sort_by_key::SORT_BY_KEY,
308    source_file::SOURCE_FILE,
309    source_interval::SOURCE_INTERVAL,
310    source_iter::SOURCE_ITER,
311    source_json::SOURCE_JSON,
312    source_stdin::SOURCE_STDIN,
313    source_stream::SOURCE_STREAM,
314    source_stream_serde::SOURCE_STREAM_SERDE,
315    state::STATE,
316    state_by::STATE_BY,
317    tee::TEE,
318    unique::UNIQUE,
319    unzip::UNZIP,
320    zip::ZIP,
321    zip_longest::ZIP_LONGEST,
322];
323
324/// Get the operator lookup table, generating it if needed.
325pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
326    pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
327        OnceLock::new();
328    OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
329}
330/// Find an operator by [`GraphNode`].
331pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
332    if let GraphNode::Operator(operator) = node {
333        find_op_op_constraints(operator)
334    } else {
335        None
336    }
337}
338/// Find an operator by an AST [`Operator`].
339pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
340    let name = &*operator.name_string();
341    operator_lookup().get(name).copied()
342}
343
344/// Context arguments provided to [`OperatorConstraints::write_fn`].
345#[derive(Clone)]
346pub struct WriteContextArgs<'a> {
347    /// `dfir` crate name for `use #root::something`.
348    pub root: &'a TokenStream,
349    /// `context` ident, the name of the provided
350    /// [`dfir_rs::scheduled::Context`](https://hydro.run/rustdoc/dfir_rs/scheduled/context/struct.Context.html).
351    pub context: &'a Ident,
352    /// `df` ident, the name of the
353    /// [`dfir_rs::scheduled::graph::Dfir`](https://hydro.run/rustdoc/dfir_rs/scheduled/graph/struct.Dfir.html)
354    /// instance.
355    pub df_ident: &'a Ident,
356    /// Subgraph ID in which this operator is contained.
357    pub subgraph_id: GraphSubgraphId,
358    /// Node ID identifying this operator in the flat or partitioned graph meta-datastructure.
359    pub node_id: GraphNodeId,
360    /// Loop ID in which this operator is contained, or `None` if not in a loop.
361    pub loop_id: Option<GraphLoopId>,
362    /// The source span of this operator.
363    pub op_span: Span,
364    /// Tag for this operator appended to the generated identifier.
365    pub op_tag: Option<String>,
366    /// Identifier for a function to call when doing work outside the iterator.
367    pub work_fn: &'a Ident,
368
369    /// Ident the iterator or pullerator should be assigned to.
370    pub ident: &'a Ident,
371    /// If a pull iterator (true) or pusherator (false) should be used.
372    pub is_pull: bool,
373    /// Input operator idents (or ref idents; used for pull).
374    pub inputs: &'a [Ident],
375    /// Output operator idents (or ref idents; used for push).
376    pub outputs: &'a [Ident],
377    /// Ident for the singleton output of this operator, if any.
378    pub singleton_output_ident: &'a Ident,
379
380    /// Operator name.
381    pub op_name: &'static str,
382    /// Operator instance arguments object.
383    pub op_inst: &'a OperatorInstance,
384    /// Arguments provided by the user into the operator as arguments.
385    /// I.e. the `a, b, c` in `-> my_op(a, b, c) -> `.
386    ///
387    /// These arguments include singleton postprocessing codegen, with
388    /// [`std::cell::RefCell::borrow_mut`] code pre-generated.
389    pub arguments: &'a Punctuated<Expr, Token![,]>,
390    /// Same as [`Self::arguments`] but with only `StateHandle`s, no borrowing code.
391    pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
392}
393impl WriteContextArgs<'_> {
394    /// Generate a (almost certainly) unique identifier with the given suffix.
395    ///
396    /// Includes the subgraph and node IDs in the generated identifier.
397    ///
398    /// This will always return the same identifier for a given `suffix`.
399    pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
400        Ident::new(
401            &format!(
402                "sg_{:?}_node_{:?}_{}",
403                self.subgraph_id.data(),
404                self.node_id.data(),
405                suffix.as_ref(),
406            ),
407            self.op_span,
408        )
409    }
410}
411
412/// An object-safe version of [`RangeBounds`].
413pub trait RangeTrait<T>: Send + Sync + Debug
414where
415    T: ?Sized,
416{
417    /// Start (lower) bound.
418    fn start_bound(&self) -> Bound<&T>;
419    /// End (upper) bound.
420    fn end_bound(&self) -> Bound<&T>;
421    /// Returns if `item` is contained in this range.
422    fn contains(&self, item: &T) -> bool
423    where
424        T: PartialOrd<T>;
425
426    /// Turn this range into a human-readable string.
427    fn human_string(&self) -> String
428    where
429        T: Display + PartialEq,
430    {
431        match (self.start_bound(), self.end_bound()) {
432            (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
433
434            (Bound::Included(n), Bound::Included(x)) if n == x => {
435                format!("exactly {}", n)
436            }
437            (Bound::Included(n), Bound::Included(x)) => {
438                format!("at least {} and at most {}", n, x)
439            }
440            (Bound::Included(n), Bound::Excluded(x)) => {
441                format!("at least {} and less than {}", n, x)
442            }
443            (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
444            (Bound::Excluded(n), Bound::Included(x)) => {
445                format!("more than {} and at most {}", n, x)
446            }
447            (Bound::Excluded(n), Bound::Excluded(x)) => {
448                format!("more than {} and less than {}", n, x)
449            }
450            (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
451            (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
452            (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
453        }
454    }
455}
456
457impl<R, T> RangeTrait<T> for R
458where
459    R: RangeBounds<T> + Send + Sync + Debug,
460{
461    fn start_bound(&self) -> Bound<&T> {
462        self.start_bound()
463    }
464
465    fn end_bound(&self) -> Bound<&T> {
466        self.end_bound()
467    }
468
469    fn contains(&self, item: &T) -> bool
470    where
471        T: PartialOrd<T>,
472    {
473        self.contains(item)
474    }
475}
476
477/// Persistence lifetimes: `'tick`, `'static`, or `'mutable`.
478#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
479pub enum Persistence {
480    /// No persistence, for within a loop iteration.
481    None,
482    /// Persistence for one tick at-a-time only.
483    Tick,
484    /// Persistene across all ticks.
485    Static,
486    /// Mutability.
487    Mutable,
488}
489
490/// Helper which creates a error message string literal for when the Tokio runtime is not found.
491fn make_missing_runtime_msg(op_name: &str) -> Literal {
492    Literal::string(&format!(
493        "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
494        op_name
495    ))
496}
497
498/// Operator categories, for docs.
499///
500/// See source of [`Self::description`] for description of variants.
501#[allow(
502    clippy::allow_attributes,
503    missing_docs,
504    reason = "see `Self::description`"
505)]
506#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
507pub enum OperatorCategory {
508    Map,
509    Filter,
510    Flatten,
511    Fold,
512    KeyedFold,
513    LatticeFold,
514    Persistence,
515    MultiIn,
516    MultiOut,
517    Source,
518    Sink,
519    Control,
520    CompilerFusionOperator,
521    Windowing,
522    Unwindowing,
523}
524impl OperatorCategory {
525    /// Human-readible heading name, for docs.
526    pub fn name(self) -> &'static str {
527        match self {
528            OperatorCategory::Map => "Maps",
529            OperatorCategory::Filter => "Filters",
530            OperatorCategory::Flatten => "Flattens",
531            OperatorCategory::Fold => "Folds",
532            OperatorCategory::KeyedFold => "Keyed Folds",
533            OperatorCategory::LatticeFold => "Lattice Folds",
534            OperatorCategory::Persistence => "Persistent Operators",
535            OperatorCategory::MultiIn => "Multi-Input Operators",
536            OperatorCategory::MultiOut => "Multi-Output Operators",
537            OperatorCategory::Source => "Sources",
538            OperatorCategory::Sink => "Sinks",
539            OperatorCategory::Control => "Control Flow Operators",
540            OperatorCategory::CompilerFusionOperator => "Compiler Fusion Operators",
541            OperatorCategory::Windowing => "Windowing Operator",
542            OperatorCategory::Unwindowing => "Un-Windowing Operator",
543        }
544    }
545    /// Human description, for docs.
546    pub fn description(self) -> &'static str {
547        match self {
548            OperatorCategory::Map => "Simple one-in-one-out operators.",
549            OperatorCategory::Filter => "One-in zero-or-one-out operators.",
550            OperatorCategory::Flatten => "One-in multiple-out operators.",
551            OperatorCategory::Fold => "Operators which accumulate elements together.",
552            OperatorCategory::KeyedFold => "Keyed fold operators.",
553            OperatorCategory::LatticeFold => "Folds based on lattice-merge.",
554            OperatorCategory::Persistence => "Persistent (stateful) operators.",
555            OperatorCategory::MultiIn => "Operators with multiple inputs.",
556            OperatorCategory::MultiOut => "Operators with multiple outputs.",
557            OperatorCategory::Source => {
558                "Operators which produce output elements (and consume no inputs)."
559            }
560            OperatorCategory::Sink => {
561                "Operators which consume input elements (and produce no outputs)."
562            }
563            OperatorCategory::Control => "Operators which affect control flow/scheduling.",
564            OperatorCategory::CompilerFusionOperator => {
565                "Operators which are necessary to implement certain optimizations and rewrite rules"
566            }
567            OperatorCategory::Windowing => "Operators for windowing `loop` inputs.",
568            OperatorCategory::Unwindowing => "Operators for collecting `loop` outputs.",
569        }
570    }
571}
572
573/// Operator type for Flo semantics.
574#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
575pub enum FloType {
576    /// A source operator, which must be at the top level.
577    Source,
578    /// A windowing operator, for moving data into a loop context.
579    Windowing,
580    /// An un-windowing operator, for moving data out of a loop context.
581    Unwindowing,
582    /// Moves data into the next loop iteration within a loop context.
583    NextIteration,
584}