hydro_lang/rewrites/
properties.rs

1use std::collections::HashSet;
2
3use stageleft::*;
4
5use crate::ir::{HydroLeaf, HydroNode, transform_bottom_up};
6
7/// Structure for tracking expressions known to have particular algebraic properties.
8///
9/// # Schema
10///
11/// Each field in this struct corresponds to an algebraic property, and contains the list of
12/// expressions that satisfy the property. Currently only `commutative`.
13///
14/// # Interface
15///
16/// "Tag" an expression with a property and it will add it to that table. For example, [`Self::add_commutative_tag`].
17/// Can also run a check to see if an expression satisfies a property.
18#[derive(Default)]
19pub struct PropertyDatabase {
20    commutative: HashSet<syn::Expr>,
21}
22
23/// Allows us to convert the dfir datatype for folds to a binary operation for the algebra
24/// property tests.
25#[allow(clippy::allow_attributes, dead_code, reason = "staged programming")]
26fn convert_hf_to_binary<I, A: Default, F: Fn(&mut A, I)>(f: F) -> impl Fn(I, I) -> A {
27    move |a, b| {
28        let mut acc = Default::default();
29        f(&mut acc, a);
30        f(&mut acc, b);
31        acc
32    }
33}
34
35impl PropertyDatabase {
36    /// Tags the expression as commutative.
37    pub fn add_commutative_tag<
38        'a,
39        I,
40        A,
41        F: Fn(&mut A, I),
42        Ctx,
43        Q: QuotedWithContext<'a, F, Ctx> + Clone,
44    >(
45        &mut self,
46        expr: Q,
47        ctx: &Ctx,
48    ) -> Q {
49        let expr_clone = expr.clone();
50        self.commutative.insert(expr_clone.splice_untyped_ctx(ctx));
51        expr
52    }
53
54    pub fn is_tagged_commutative(&self, expr: &syn::Expr) -> bool {
55        self.commutative.contains(expr)
56    }
57}
58
59// Dataflow graph optimization rewrite rules based on algebraic property tags
60// TODO add a test that verifies the space of possible graphs after rewrites is correct for each property
61
62fn properties_optimize_node(node: &mut HydroNode, db: &mut PropertyDatabase) {
63    match node {
64        HydroNode::ReduceKeyed { f, .. } if db.is_tagged_commutative(&f.0) => {
65            dbg!("IDENTIFIED COMMUTATIVE OPTIMIZATION for {:?}", &f);
66        }
67        _ => {}
68    }
69}
70
71pub fn properties_optimize(ir: &mut [HydroLeaf], db: &mut PropertyDatabase) {
72    transform_bottom_up(ir, &mut |_| (), &mut |node| {
73        properties_optimize_node(node, db)
74    });
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use crate::FlowBuilder;
81    use crate::deploy::SingleProcessGraph;
82    use crate::location::Location;
83
84    #[test]
85    fn test_property_database() {
86        let mut db = PropertyDatabase::default();
87
88        assert!(
89            !db.is_tagged_commutative(&(q!(|a: &mut i32, b: i32| *a += b).splice_untyped_ctx(&())))
90        );
91
92        let _ = db.add_commutative_tag(q!(|a: &mut i32, b: i32| *a += b), &());
93
94        assert!(
95            db.is_tagged_commutative(&(q!(|a: &mut i32, b: i32| *a += b).splice_untyped_ctx(&())))
96        );
97    }
98
99    #[test]
100    fn test_property_optimized() {
101        let flow = FlowBuilder::new();
102        let mut database = PropertyDatabase::default();
103
104        let process = flow.process::<()>();
105        let tick = process.tick();
106
107        let counter_func = q!(|count: &mut i32, _| *count += 1);
108        let _ = database.add_commutative_tag(counter_func, &tick);
109
110        unsafe {
111            process
112                .source_iter(q!(vec![]))
113                .map(q!(|string: String| (string, ())))
114                .tick_batch(&tick)
115        }
116        .fold_keyed(q!(|| 0), counter_func)
117        .all_ticks()
118        .for_each(q!(|(string, count)| println!("{}: {}", string, count)));
119
120        let built = flow
121            .optimize_with(|ir| properties_optimize(ir, &mut database))
122            .with_default_optimize::<SingleProcessGraph>();
123
124        insta::assert_debug_snapshot!(built.ir());
125
126        let _ = built.compile_no_network();
127    }
128}