1use std::any::Any;
4use std::borrow::Cow;
5use std::cell::Cell;
6use std::cmp::Ordering;
7use std::future::Future;
8use std::marker::PhantomData;
9
10#[cfg(feature = "meta")]
11use dfir_lang::diagnostic::{Diagnostic, SerdeSpan};
12#[cfg(feature = "meta")]
13use dfir_lang::graph::DfirGraph;
14use ref_cast::RefCast;
15use smallvec::SmallVec;
16use web_time::SystemTime;
17
18use super::context::Context;
19use super::handoff::handoff_list::PortList;
20use super::handoff::{Handoff, HandoffMeta, TeeingHandoff};
21use super::port::{RECV, RecvCtx, RecvPort, SEND, SendCtx, SendPort};
22use super::reactor::Reactor;
23use super::state::StateHandle;
24use super::subgraph::Subgraph;
25use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
26use crate::Never;
27use crate::scheduled::ticks::{TickDuration, TickInstant};
28use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
29
30#[derive(Default)]
32pub struct Dfir<'a> {
33 pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
34
35 pub(super) loop_data: SecondarySlotVec<LoopTag, LoopData>,
36
37 pub(super) context: Context,
38
39 handoffs: SlotVec<HandoffTag, HandoffData>,
40
41 #[cfg(feature = "meta")]
42 meta_graph: Option<DfirGraph>,
44
45 #[cfg(feature = "meta")]
46 diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
48}
49
50impl Dfir<'_> {
52 pub fn teeing_handoff_tee<T>(
54 &mut self,
55 tee_parent_port: &RecvPort<TeeingHandoff<T>>,
56 ) -> RecvPort<TeeingHandoff<T>>
57 where
58 T: Clone,
59 {
60 let tee_root = self.handoffs[tee_parent_port.handoff_id].pred_handoffs[0];
62
63 let tee_root_data = &mut self.handoffs[tee_root];
65 let tee_root_data_name = tee_root_data.name.clone();
66
67 let teeing_handoff = tee_root_data
69 .handoff
70 .any_ref()
71 .downcast_ref::<TeeingHandoff<T>>()
72 .unwrap();
73 let new_handoff = teeing_handoff.tee();
74
75 let new_hoff_id = self.handoffs.insert_with_key(|new_hoff_id| {
77 let new_name = Cow::Owned(format!("{} tee {:?}", tee_root_data_name, new_hoff_id));
78 let mut new_handoff_data = HandoffData::new(new_name, new_handoff, new_hoff_id);
79 new_handoff_data.pred_handoffs = vec![tee_root];
81 new_handoff_data
82 });
83
84 let tee_root_data = &mut self.handoffs[tee_root];
86 tee_root_data.succ_handoffs.push(new_hoff_id);
87
88 assert!(
91 tee_root_data.preds.len() <= 1,
92 "Tee send side should only have one sender (or none set yet)."
93 );
94 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
95 self.subgraphs[pred_sg_id].succs.push(new_hoff_id);
96 }
97
98 let output_port = RecvPort {
99 handoff_id: new_hoff_id,
100 _marker: PhantomData,
101 };
102 output_port
103 }
104
105 pub fn teeing_handoff_drop<T>(&mut self, tee_port: RecvPort<TeeingHandoff<T>>)
110 where
111 T: Clone,
112 {
113 let data = &self.handoffs[tee_port.handoff_id];
114 let teeing_handoff = data
115 .handoff
116 .any_ref()
117 .downcast_ref::<TeeingHandoff<T>>()
118 .unwrap();
119 teeing_handoff.drop();
120
121 let tee_root = data.pred_handoffs[0];
122 let tee_root_data = &mut self.handoffs[tee_root];
123 tee_root_data
125 .succ_handoffs
126 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
127 assert!(
129 tee_root_data.preds.len() <= 1,
130 "Tee send side should only have one sender (or none set yet)."
131 );
132 if let Some(&pred_sg_id) = tee_root_data.preds.first() {
133 self.subgraphs[pred_sg_id]
134 .succs
135 .retain(|&succ_hoff| succ_hoff != tee_port.handoff_id);
136 }
137 }
138}
139
140impl<'a> Dfir<'a> {
141 pub fn new() -> Self {
143 Default::default()
144 }
145
146 #[doc(hidden)]
148 pub fn __assign_meta_graph(&mut self, _meta_graph_json: &str) {
149 #[cfg(feature = "meta")]
150 {
151 let mut meta_graph: DfirGraph =
152 serde_json::from_str(_meta_graph_json).expect("Failed to deserialize graph.");
153
154 let mut op_inst_diagnostics = Vec::new();
155 meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
156 assert!(
157 op_inst_diagnostics.is_empty(),
158 "Expected no diagnostics, got: {:#?}",
159 op_inst_diagnostics
160 );
161
162 assert!(self.meta_graph.replace(meta_graph).is_none());
163 }
164 }
165 #[doc(hidden)]
167 pub fn __assign_diagnostics(&mut self, _diagnostics_json: &'static str) {
168 #[cfg(feature = "meta")]
169 {
170 let diagnostics: Vec<Diagnostic<SerdeSpan>> = serde_json::from_str(_diagnostics_json)
171 .expect("Failed to deserialize diagnostics.");
172
173 assert!(self.diagnostics.replace(diagnostics).is_none());
174 }
175 }
176
177 #[cfg(feature = "meta")]
181 pub fn meta_graph(&self) -> Option<&DfirGraph> {
182 self.meta_graph.as_ref()
183 }
184
185 #[cfg(feature = "meta")]
190 pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
191 self.diagnostics.as_deref()
192 }
193
194 pub fn reactor(&self) -> Reactor {
197 Reactor::new(self.context.event_queue_send.clone())
198 }
199
200 pub fn current_tick(&self) -> TickInstant {
202 self.context.current_tick
203 }
204
205 pub fn current_stratum(&self) -> usize {
207 self.context.current_stratum
208 }
209
210 #[tracing::instrument(level = "trace", skip(self), ret)]
213 pub fn run_tick(&mut self) -> bool {
214 let mut work_done = false;
215 while self.next_stratum(true) {
217 work_done = true;
218 self.run_stratum();
220 }
221 work_done
222 }
223
224 #[tracing::instrument(level = "trace", skip(self), ret)]
229 pub fn run_available(&mut self) -> bool {
230 let mut work_done = false;
231 while self.next_stratum(false) {
233 work_done = true;
234 self.run_stratum();
236 }
237 work_done
238 }
239
240 #[tracing::instrument(level = "trace", skip(self), ret)]
246 pub async fn run_available_async(&mut self) -> bool {
247 let mut work_done = false;
248 while self.next_stratum(false) {
250 work_done = true;
251 self.run_stratum();
253
254 tokio::task::yield_now().await;
257 }
258 work_done
259 }
260
261 #[tracing::instrument(level = "trace", skip(self), fields(tick = u64::from(self.context.current_tick), stratum = self.context.current_stratum), ret)]
264 pub fn run_stratum(&mut self) -> bool {
265 self.context.spawn_tasks();
268
269 let mut work_done = false;
270
271 'pop: while let Some(sg_id) =
272 self.context.stratum_queues[self.context.current_stratum].pop_front()
273 {
274 {
275 let sg_data = &mut self.subgraphs[sg_id];
276 assert!(sg_data.is_scheduled.take());
278
279 match sg_data.loop_depth.cmp(&self.context.loop_nonce_stack.len()) {
280 Ordering::Greater => {
281 self.context.loop_nonce += 1;
283 self.context.loop_nonce_stack.push(self.context.loop_nonce);
284 tracing::trace!(loop_nonce = self.context.loop_nonce, "Entered loop.");
285 }
286 Ordering::Less => {
287 self.context.loop_nonce_stack.pop();
289 tracing::trace!("Exited loop.");
290 }
291 Ordering::Equal => {}
292 }
293
294 self.context.subgraph_id = sg_id;
295 self.context.is_first_run_this_tick = sg_data
296 .last_tick_run_in
297 .is_none_or(|last_tick| last_tick < self.context.current_tick);
298
299 if let Some(loop_id) = sg_data.loop_id {
300 let curr_loop_nonce = self.context.loop_nonce_stack.last().copied();
308
309 let LoopData {
310 iter_count: loop_iter_count,
311 allow_another_iteration,
312 } = &mut self.loop_data[loop_id];
313
314 let (prev_loop_nonce, prev_iter_count) = sg_data.last_loop_nonce;
315
316 let curr_iter_count =
321 if curr_loop_nonce.is_none_or(|nonce| nonce == prev_loop_nonce) {
322 if loop_iter_count.is_none_or(|n| n == prev_iter_count) {
325 if !std::mem::take(allow_another_iteration) {
327 tracing::trace!(
328 "Loop will not continue to next iteration, skipping."
329 );
330 continue 'pop;
331 }
332 loop_iter_count.map_or(0, |n| n + 1)
334 } else {
335 debug_assert!(loop_iter_count.is_some_and(|n| prev_iter_count < n));
337 loop_iter_count.unwrap()
338 }
339 } else {
340 0
342 };
343 *loop_iter_count = Some(curr_iter_count);
344 self.context.loop_iter_count = curr_iter_count;
345 sg_data.last_loop_nonce =
346 (curr_loop_nonce.unwrap_or_default(), curr_iter_count);
347 }
348
349 tracing::info!(
350 sg_id = sg_id.to_string(),
351 sg_name = &*sg_data.name,
352 sg_depth = sg_data.loop_depth,
353 sg_loop_nonce = sg_data.last_loop_nonce.0,
354 sg_iter_count = sg_data.last_loop_nonce.1,
355 "Running subgraph."
356 );
357 sg_data.subgraph.run(&mut self.context, &mut self.handoffs);
358
359 sg_data.last_tick_run_in = Some(self.context.current_tick);
360 }
361
362 let sg_data = &self.subgraphs[sg_id];
363 for &handoff_id in sg_data.succs.iter() {
364 let handoff = &self.handoffs[handoff_id];
365 if !handoff.handoff.is_bottom() {
366 for &succ_id in handoff.succs.iter() {
367 let succ_sg_data = &self.subgraphs[succ_id];
368 if succ_sg_data.stratum < self.context.current_stratum && !sg_data.is_lazy {
370 self.context.can_start_tick = true;
371 }
372 if !succ_sg_data.is_scheduled.replace(true) {
374 self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
375 }
376 if 0 < succ_sg_data.loop_depth {
378 self.context
380 .stratum_stack
381 .push(succ_sg_data.loop_depth, succ_sg_data.stratum);
382 }
383 }
384 }
385 }
386
387 let reschedule = self.context.reschedule_loop_block.take();
388 let allow_another = self.context.allow_another_iteration.take();
389
390 if reschedule {
391 self.context.schedule_deferred.push(sg_id);
393 self.context
394 .stratum_stack
395 .push(sg_data.loop_depth, sg_data.stratum);
396 }
397 if reschedule || allow_another {
398 if let Some(loop_id) = sg_data.loop_id {
399 self.loop_data
400 .get_mut(loop_id)
401 .unwrap()
402 .allow_another_iteration = true;
403 }
404 }
405
406 work_done = true;
407 }
408 work_done
409 }
410
411 #[tracing::instrument(level = "trace", skip(self), ret)]
423 pub fn next_stratum(&mut self, current_tick_only: bool) -> bool {
424 tracing::trace!(
425 events_received_tick = self.context.events_received_tick,
426 can_start_tick = self.context.can_start_tick,
427 "Starting `next_stratum` call.",
428 );
429
430 let mut end_stratum = self.context.current_stratum;
432 let mut new_tick_started = false;
433
434 if 0 == self.context.current_stratum {
435 new_tick_started = true;
436
437 tracing::trace!("Starting tick, setting `can_start_tick = false`.");
439 self.context.can_start_tick = false;
440 self.context.current_tick_start = SystemTime::now();
441
442 if !self.context.events_received_tick {
444 self.try_recv_events();
446 }
447 }
448
449 loop {
450 tracing::trace!(
451 tick = u64::from(self.context.current_tick),
452 stratum = self.context.current_stratum,
453 "Looking for work on stratum."
454 );
455 if !self.context.stratum_queues[self.context.current_stratum].is_empty() {
457 tracing::trace!(
458 tick = u64::from(self.context.current_tick),
459 stratum = self.context.current_stratum,
460 "Work found on stratum."
461 );
462 return true;
463 }
464
465 if let Some(next_stratum) = self.context.stratum_stack.pop() {
466 self.context.current_stratum = next_stratum;
467
468 {
470 for sg_id in self.context.schedule_deferred.drain(..) {
471 let sg_data = &self.subgraphs[sg_id];
472 tracing::info!(
473 tick = u64::from(self.context.current_tick),
474 stratum = self.context.current_stratum,
475 sg_id = sg_id.to_string(),
476 sg_name = &*sg_data.name,
477 is_scheduled = sg_data.is_scheduled.get(),
478 "Rescheduling deferred subgraph."
479 );
480 if !sg_data.is_scheduled.replace(true) {
481 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
482 }
483 }
484 }
485 } else {
486 self.context.current_stratum += 1;
488
489 if self.context.current_stratum >= self.context.stratum_queues.len() {
490 new_tick_started = true;
491
492 tracing::trace!(
493 can_start_tick = self.context.can_start_tick,
494 "End of tick {}, starting tick {}.",
495 self.context.current_tick,
496 self.context.current_tick + TickDuration::SINGLE_TICK,
497 );
498 self.context.reset_state_at_end_of_tick();
499
500 self.context.current_stratum = 0;
501 self.context.current_tick += TickDuration::SINGLE_TICK;
502 self.context.events_received_tick = false;
503
504 if current_tick_only {
505 tracing::trace!(
506 "`current_tick_only` is `true`, returning `false` before receiving events."
507 );
508 return false;
509 } else {
510 self.try_recv_events();
511 if std::mem::replace(&mut self.context.can_start_tick, false) {
512 tracing::trace!(
513 tick = u64::from(self.context.current_tick),
514 "`can_start_tick` is `true`, continuing."
515 );
516 end_stratum = 0;
518 continue;
519 } else {
520 tracing::trace!(
521 "`can_start_tick` is `false`, re-setting `events_received_tick = false`, returning `false`."
522 );
523 self.context.events_received_tick = false;
524 return false;
525 }
526 }
527 }
528 }
529
530 if new_tick_started && end_stratum == self.context.current_stratum {
532 tracing::trace!(
533 "Made full loop around stratum, re-setting `current_stratum = 0`, returning `false`."
534 );
535 self.context.events_received_tick = false;
540 self.context.current_stratum = 0;
541 return false;
542 }
543 }
544 }
545
546 #[tracing::instrument(level = "trace", skip(self), ret)]
550 pub fn run(&mut self) -> Option<Never> {
551 loop {
552 self.run_tick();
553 }
554 }
555
556 #[tracing::instrument(level = "trace", skip(self), ret)]
560 pub async fn run_async(&mut self) -> Option<Never> {
561 loop {
562 self.run_available_async().await;
564 self.recv_events_async().await;
566 }
567 }
568
569 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
573 pub fn try_recv_events(&mut self) -> usize {
574 let mut enqueued_count = 0;
575 while let Ok((sg_id, is_external)) = self.context.event_queue_recv.try_recv() {
576 let sg_data = &self.subgraphs[sg_id];
577 tracing::trace!(
578 sg_id = sg_id.to_string(),
579 is_external = is_external,
580 sg_stratum = sg_data.stratum,
581 "Event received."
582 );
583 if !sg_data.is_scheduled.replace(true) {
584 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
585 enqueued_count += 1;
586 }
587 if is_external {
588 if !self.context.events_received_tick
591 || sg_data.stratum < self.context.current_stratum
592 {
593 tracing::trace!(
594 current_stratum = self.context.current_stratum,
595 sg_stratum = sg_data.stratum,
596 "External event, setting `can_start_tick = true`."
597 );
598 self.context.can_start_tick = true;
599 }
600 }
601 }
602 self.context.events_received_tick = true;
603
604 enqueued_count
605 }
606
607 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
610 pub fn recv_events(&mut self) -> Option<usize> {
611 let mut count = 0;
612 loop {
613 let (sg_id, is_external) = self.context.event_queue_recv.blocking_recv()?;
614 let sg_data = &self.subgraphs[sg_id];
615 tracing::trace!(
616 sg_id = sg_id.to_string(),
617 is_external = is_external,
618 sg_stratum = sg_data.stratum,
619 "Event received."
620 );
621 if !sg_data.is_scheduled.replace(true) {
622 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
623 count += 1;
624 }
625 if is_external {
626 if !self.context.events_received_tick
629 || sg_data.stratum < self.context.current_stratum
630 {
631 tracing::trace!(
632 current_stratum = self.context.current_stratum,
633 sg_stratum = sg_data.stratum,
634 "External event, setting `can_start_tick = true`."
635 );
636 self.context.can_start_tick = true;
637 }
638 break;
639 }
640 }
641 self.context.events_received_tick = true;
642
643 let extra_count = self.try_recv_events();
645 Some(count + extra_count)
646 }
647
648 #[tracing::instrument(level = "trace", skip(self), fields(events_received_tick = self.context.events_received_tick), ret)]
654 pub async fn recv_events_async(&mut self) -> Option<usize> {
655 let mut count = 0;
656 loop {
657 tracing::trace!("Awaiting events (`event_queue_recv`).");
658 let (sg_id, is_external) = self.context.event_queue_recv.recv().await?;
659 let sg_data = &self.subgraphs[sg_id];
660 tracing::trace!(
661 sg_id = sg_id.to_string(),
662 is_external = is_external,
663 sg_stratum = sg_data.stratum,
664 "Event received."
665 );
666 if !sg_data.is_scheduled.replace(true) {
667 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
668 count += 1;
669 }
670 if is_external {
671 if !self.context.events_received_tick
674 || sg_data.stratum < self.context.current_stratum
675 {
676 tracing::trace!(
677 current_stratum = self.context.current_stratum,
678 sg_stratum = sg_data.stratum,
679 "External event, setting `can_start_tick = true`."
680 );
681 self.context.can_start_tick = true;
682 }
683 break;
684 }
685 }
686 self.context.events_received_tick = true;
687
688 let extra_count = self.try_recv_events();
690 Some(count + extra_count)
691 }
692
693 pub fn schedule_subgraph(&mut self, sg_id: SubgraphId) -> bool {
695 let sg_data = &self.subgraphs[sg_id];
696 let already_scheduled = sg_data.is_scheduled.replace(true);
697 if !already_scheduled {
698 self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
699 true
700 } else {
701 false
702 }
703 }
704
705 pub fn add_subgraph<Name, R, W, F>(
707 &mut self,
708 name: Name,
709 recv_ports: R,
710 send_ports: W,
711 subgraph: F,
712 ) -> SubgraphId
713 where
714 Name: Into<Cow<'static, str>>,
715 R: 'static + PortList<RECV>,
716 W: 'static + PortList<SEND>,
717 F: 'static + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
718 {
719 self.add_subgraph_stratified(name, 0, recv_ports, send_ports, false, subgraph)
720 }
721
722 pub fn add_subgraph_stratified<Name, R, W, F>(
726 &mut self,
727 name: Name,
728 stratum: usize,
729 recv_ports: R,
730 send_ports: W,
731 laziness: bool,
732 subgraph: F,
733 ) -> SubgraphId
734 where
735 Name: Into<Cow<'static, str>>,
736 R: 'static + PortList<RECV>,
737 W: 'static + PortList<SEND>,
738 F: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
739 {
740 self.add_subgraph_full(
741 name, stratum, recv_ports, send_ports, laziness, None, subgraph,
742 )
743 }
744
745 #[expect(clippy::too_many_arguments, reason = "Mainly for internal use.")]
747 pub fn add_subgraph_full<Name, R, W, F>(
748 &mut self,
749 name: Name,
750 stratum: usize,
751 recv_ports: R,
752 send_ports: W,
753 laziness: bool,
754 loop_id: Option<LoopId>,
755 mut subgraph: F,
756 ) -> SubgraphId
757 where
758 Name: Into<Cow<'static, str>>,
759 R: 'static + PortList<RECV>,
760 W: 'static + PortList<SEND>,
761 F: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
762 {
763 recv_ports.assert_is_from(&self.handoffs);
765 send_ports.assert_is_from(&self.handoffs);
766
767 let loop_depth = loop_id
768 .and_then(|loop_id| self.context.loop_depth.get(loop_id))
769 .copied()
770 .unwrap_or(0);
771
772 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
773 let (mut subgraph_preds, mut subgraph_succs) = Default::default();
774 recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
775 send_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_succs, sg_id, false);
776
777 let subgraph =
778 move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
779 let (recv, send) = unsafe {
780 (
784 recv_ports.make_ctx(&*handoffs),
785 send_ports.make_ctx(&*handoffs),
786 )
787 };
788 (subgraph)(context, recv, send);
789 };
790 SubgraphData::new(
791 name.into(),
792 stratum,
793 subgraph,
794 subgraph_preds,
795 subgraph_succs,
796 true,
797 laziness,
798 loop_id,
799 loop_depth,
800 )
801 });
802 self.context.init_stratum(stratum);
803 self.context.stratum_queues[stratum].push_back(sg_id);
804
805 sg_id
806 }
807
808 pub fn add_subgraph_n_m<Name, R, W, F>(
810 &mut self,
811 name: Name,
812 recv_ports: Vec<RecvPort<R>>,
813 send_ports: Vec<SendPort<W>>,
814 subgraph: F,
815 ) -> SubgraphId
816 where
817 Name: Into<Cow<'static, str>>,
818 R: 'static + Handoff,
819 W: 'static + Handoff,
820 F: 'static
821 + for<'ctx> FnMut(&'ctx mut Context, &'ctx [&'ctx RecvCtx<R>], &'ctx [&'ctx SendCtx<W>]),
822 {
823 self.add_subgraph_stratified_n_m(name, 0, recv_ports, send_ports, subgraph)
824 }
825
826 pub fn add_subgraph_stratified_n_m<Name, R, W, F>(
828 &mut self,
829 name: Name,
830 stratum: usize,
831 recv_ports: Vec<RecvPort<R>>,
832 send_ports: Vec<SendPort<W>>,
833 mut subgraph: F,
834 ) -> SubgraphId
835 where
836 Name: Into<Cow<'static, str>>,
837 R: 'static + Handoff,
838 W: 'static + Handoff,
839 F: 'static
840 + for<'ctx> FnMut(&'ctx mut Context, &'ctx [&'ctx RecvCtx<R>], &'ctx [&'ctx SendCtx<W>]),
841 {
842 let sg_id = self.subgraphs.insert_with_key(|sg_id| {
843 let subgraph_preds = recv_ports.iter().map(|port| port.handoff_id).collect();
844 let subgraph_succs = send_ports.iter().map(|port| port.handoff_id).collect();
845
846 for recv_port in recv_ports.iter() {
847 self.handoffs[recv_port.handoff_id].succs.push(sg_id);
848 }
849 for send_port in send_ports.iter() {
850 self.handoffs[send_port.handoff_id].preds.push(sg_id);
851 }
852
853 let subgraph =
854 move |context: &mut Context, handoffs: &mut SlotVec<HandoffTag, HandoffData>| {
855 let recvs: Vec<&RecvCtx<R>> = recv_ports
856 .iter()
857 .map(|hid| hid.handoff_id)
858 .map(|hid| handoffs.get(hid).unwrap())
859 .map(|h_data| {
860 h_data
861 .handoff
862 .any_ref()
863 .downcast_ref()
864 .expect("Attempted to cast handoff to wrong type.")
865 })
866 .map(RefCast::ref_cast)
867 .collect();
868
869 let sends: Vec<&SendCtx<W>> = send_ports
870 .iter()
871 .map(|hid| hid.handoff_id)
872 .map(|hid| handoffs.get(hid).unwrap())
873 .map(|h_data| {
874 h_data
875 .handoff
876 .any_ref()
877 .downcast_ref()
878 .expect("Attempted to cast handoff to wrong type.")
879 })
880 .map(RefCast::ref_cast)
881 .collect();
882
883 (subgraph)(context, &recvs, &sends)
884 };
885 SubgraphData::new(
886 name.into(),
887 stratum,
888 subgraph,
889 subgraph_preds,
890 subgraph_succs,
891 true,
892 false,
893 None,
894 0,
895 )
896 });
897
898 self.context.init_stratum(stratum);
899 self.context.stratum_queues[stratum].push_back(sg_id);
900
901 sg_id
902 }
903
904 pub fn make_edge<Name, H>(&mut self, name: Name) -> (SendPort<H>, RecvPort<H>)
906 where
907 Name: Into<Cow<'static, str>>,
908 H: 'static + Handoff,
909 {
910 let handoff = H::default();
912 let handoff_id = self
913 .handoffs
914 .insert_with_key(|hoff_id| HandoffData::new(name.into(), handoff, hoff_id));
915
916 let input_port = SendPort {
918 handoff_id,
919 _marker: PhantomData,
920 };
921 let output_port = RecvPort {
922 handoff_id,
923 _marker: PhantomData,
924 };
925 (input_port, output_port)
926 }
927
928 pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
933 where
934 T: Any,
935 {
936 self.context.add_state(state)
937 }
938
939 pub fn set_state_tick_hook<T>(
943 &mut self,
944 handle: StateHandle<T>,
945 tick_hook_fn: impl 'static + FnMut(&mut T),
946 ) where
947 T: Any,
948 {
949 self.context.set_state_tick_hook(handle, tick_hook_fn)
950 }
951
952 pub fn context_mut(&mut self, sg_id: SubgraphId) -> &mut Context {
954 self.context.subgraph_id = sg_id;
955 &mut self.context
956 }
957
958 pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
963 let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
964 let loop_id = self.context.loop_depth.insert(depth);
965 self.loop_data.insert(
966 loop_id,
967 LoopData {
968 iter_count: None,
969 allow_another_iteration: true,
970 },
971 );
972 loop_id
973 }
974}
975
976impl Dfir<'_> {
977 pub fn request_task<Fut>(&mut self, future: Fut)
979 where
980 Fut: Future<Output = ()> + 'static,
981 {
982 self.context.request_task(future);
983 }
984
985 pub fn abort_tasks(&mut self) {
987 self.context.abort_tasks()
988 }
989
990 pub fn join_tasks(&mut self) -> impl use<'_> + Future {
992 self.context.join_tasks()
993 }
994}
995
996impl Drop for Dfir<'_> {
997 fn drop(&mut self) {
998 self.abort_tasks();
999 }
1000}
1001
1002#[doc(hidden)]
1008pub struct HandoffData {
1009 pub(super) name: Cow<'static, str>,
1011 pub(super) handoff: Box<dyn HandoffMeta>,
1013 pub(super) preds: SmallVec<[SubgraphId; 1]>,
1015 pub(super) succs: SmallVec<[SubgraphId; 1]>,
1017
1018 pub(super) pred_handoffs: Vec<HandoffId>,
1024 pub(super) succ_handoffs: Vec<HandoffId>,
1030}
1031impl std::fmt::Debug for HandoffData {
1032 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
1033 f.debug_struct("HandoffData")
1034 .field("preds", &self.preds)
1035 .field("succs", &self.succs)
1036 .finish_non_exhaustive()
1037 }
1038}
1039impl HandoffData {
1040 pub fn new(
1042 name: Cow<'static, str>,
1043 handoff: impl 'static + HandoffMeta,
1044 hoff_id: HandoffId,
1045 ) -> Self {
1046 let (preds, succs) = Default::default();
1047 Self {
1048 name,
1049 handoff: Box::new(handoff),
1050 preds,
1051 succs,
1052 pred_handoffs: vec![hoff_id],
1053 succ_handoffs: vec![hoff_id],
1054 }
1055 }
1056}
1057
1058pub(super) struct SubgraphData<'a> {
1063 pub(super) name: Cow<'static, str>,
1065 pub(super) stratum: usize,
1069 subgraph: Box<dyn Subgraph + 'a>,
1071
1072 #[expect(dead_code, reason = "may be useful in the future")]
1073 preds: Vec<HandoffId>,
1074 succs: Vec<HandoffId>,
1075
1076 is_scheduled: Cell<bool>,
1081
1082 last_tick_run_in: Option<TickInstant>,
1084 last_loop_nonce: (usize, usize),
1087
1088 is_lazy: bool,
1090
1091 loop_id: Option<LoopId>,
1093 loop_depth: usize,
1095}
1096impl<'a> SubgraphData<'a> {
1097 #[expect(clippy::too_many_arguments, reason = "internal use")]
1098 pub(crate) fn new(
1099 name: Cow<'static, str>,
1100 stratum: usize,
1101 subgraph: impl Subgraph + 'a,
1102 preds: Vec<HandoffId>,
1103 succs: Vec<HandoffId>,
1104 is_scheduled: bool,
1105 is_lazy: bool,
1106 loop_id: Option<LoopId>,
1107 loop_depth: usize,
1108 ) -> Self {
1109 Self {
1110 name,
1111 stratum,
1112 subgraph: Box::new(subgraph),
1113 preds,
1114 succs,
1115 is_scheduled: Cell::new(is_scheduled),
1116 last_tick_run_in: None,
1117 last_loop_nonce: (0, 0),
1118 is_lazy,
1119 loop_id,
1120 loop_depth,
1121 }
1122 }
1123}
1124
1125pub(crate) struct LoopData {
1126 iter_count: Option<usize>,
1128 allow_another_iteration: bool,
1130}