hydro_lang/rewrites/
properties.rs1use std::collections::HashSet;
2
3use stageleft::*;
4
5use crate::ir::{HydroLeaf, HydroNode, transform_bottom_up};
6
7#[derive(Default)]
19pub struct PropertyDatabase {
20 commutative: HashSet<syn::Expr>,
21}
22
23#[allow(clippy::allow_attributes, dead_code, reason = "staged programming")]
26fn convert_hf_to_binary<I, A: Default, F: Fn(&mut A, I)>(f: F) -> impl Fn(I, I) -> A {
27 move |a, b| {
28 let mut acc = Default::default();
29 f(&mut acc, a);
30 f(&mut acc, b);
31 acc
32 }
33}
34
35impl PropertyDatabase {
36 pub fn add_commutative_tag<
38 'a,
39 I,
40 A,
41 F: Fn(&mut A, I),
42 Ctx,
43 Q: QuotedWithContext<'a, F, Ctx> + Clone,
44 >(
45 &mut self,
46 expr: Q,
47 ctx: &Ctx,
48 ) -> Q {
49 let expr_clone = expr.clone();
50 self.commutative.insert(expr_clone.splice_untyped_ctx(ctx));
51 expr
52 }
53
54 pub fn is_tagged_commutative(&self, expr: &syn::Expr) -> bool {
55 self.commutative.contains(expr)
56 }
57}
58
59fn properties_optimize_node(node: &mut HydroNode, db: &mut PropertyDatabase) {
63 match node {
64 HydroNode::ReduceKeyed { f, .. } if db.is_tagged_commutative(&f.0) => {
65 dbg!("IDENTIFIED COMMUTATIVE OPTIMIZATION for {:?}", &f);
66 }
67 _ => {}
68 }
69}
70
71pub fn properties_optimize(ir: &mut [HydroLeaf], db: &mut PropertyDatabase) {
72 transform_bottom_up(ir, &mut |_| (), &mut |node| {
73 properties_optimize_node(node, db)
74 });
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use crate::FlowBuilder;
81 use crate::deploy::SingleProcessGraph;
82 use crate::location::Location;
83
84 #[test]
85 fn test_property_database() {
86 let mut db = PropertyDatabase::default();
87
88 assert!(
89 !db.is_tagged_commutative(&(q!(|a: &mut i32, b: i32| *a += b).splice_untyped_ctx(&())))
90 );
91
92 let _ = db.add_commutative_tag(q!(|a: &mut i32, b: i32| *a += b), &());
93
94 assert!(
95 db.is_tagged_commutative(&(q!(|a: &mut i32, b: i32| *a += b).splice_untyped_ctx(&())))
96 );
97 }
98
99 #[test]
100 fn test_property_optimized() {
101 let flow = FlowBuilder::new();
102 let mut database = PropertyDatabase::default();
103
104 let process = flow.process::<()>();
105 let tick = process.tick();
106
107 let counter_func = q!(|count: &mut i32, _| *count += 1);
108 let _ = database.add_commutative_tag(counter_func, &tick);
109
110 unsafe {
111 process
112 .source_iter(q!(vec![]))
113 .map(q!(|string: String| (string, ())))
114 .tick_batch(&tick)
115 }
116 .fold_keyed(q!(|| 0), counter_func)
117 .all_ticks()
118 .for_each(q!(|(string, count)| println!("{}: {}", string, count)));
119
120 let built = flow
121 .optimize_with(|ir| properties_optimize(ir, &mut database))
122 .with_default_optimize::<SingleProcessGraph>();
123
124 insta::assert_debug_snapshot!(built.ir());
125
126 let _ = built.compile_no_network();
127 }
128}