1#![allow(clippy::allow_attributes, missing_docs, reason = "// TODO(mingwei)")]
2
3use std::cell::RefCell;
4use std::collections::HashMap;
5
6pub use hydro_deploy_integration::*;
7use serde::de::DeserializeOwned;
8
9use crate::scheduled::graph::Dfir;
10
11#[macro_export]
12macro_rules! launch {
13 ($f:expr) => {
14 async {
15 let ports = $crate::util::deploy::init_no_ack_start().await;
16 let flow = $f(&ports);
17
18 println!("ack start");
19
20 $crate::util::deploy::launch_flow(flow).await
21 }
22 };
23}
24
25pub use crate::launch;
26
27pub async fn launch_flow(mut flow: Dfir<'_>) {
28 let stop = tokio::sync::oneshot::channel();
29 tokio::task::spawn_blocking(|| {
30 let mut line = String::new();
31 std::io::stdin().read_line(&mut line).unwrap();
32 if line.starts_with("stop") {
33 stop.0.send(()).unwrap();
34 } else {
35 eprintln!("Unexpected stdin input: {:?}", line);
36 }
37 });
38
39 let local_set = tokio::task::LocalSet::new();
40 let flow = local_set.run_until(flow.run_async());
41
42 tokio::select! {
43 _ = stop.1 => {},
44 _ = flow => {}
45 }
46}
47
48pub struct DeployPorts<T = Option<()>> {
51 ports: RefCell<HashMap<String, Connection>>,
52 pub meta: T,
53}
54
55impl<T> DeployPorts<T> {
56 pub fn port(&self, name: &str) -> Connection {
57 self.ports
58 .try_borrow_mut()
59 .unwrap()
60 .remove(name)
61 .unwrap_or_else(|| panic!("port {} not found", name))
62 }
63}
64
65pub async fn init_no_ack_start<T: DeserializeOwned + Default>() -> DeployPorts<T> {
66 let mut input = String::new();
67 std::io::stdin().read_line(&mut input).unwrap();
68 let trimmed = input.trim();
69
70 let bind_config = serde_json::from_str::<InitConfig>(trimmed).unwrap();
71
72 let mut bind_results: HashMap<String, ServerPort> = HashMap::new();
74 let mut binds = HashMap::new();
75 for (name, config) in bind_config.0 {
76 let bound = config.bind().await;
77 bind_results.insert(name.clone(), bound.server_port());
78 binds.insert(name.clone(), bound);
79 }
80
81 let bind_serialized = serde_json::to_string(&bind_results).unwrap();
82 println!("ready: {bind_serialized}");
83
84 let mut start_buf = String::new();
85 std::io::stdin().read_line(&mut start_buf).unwrap();
86 let connection_defns = if start_buf.starts_with("start: ") {
87 serde_json::from_str::<HashMap<String, ServerPort>>(
88 start_buf.trim_start_matches("start: ").trim(),
89 )
90 .unwrap()
91 } else {
92 panic!("expected start");
93 };
94
95 let mut all_connected = HashMap::new();
96 for (name, defn) in connection_defns {
97 all_connected.insert(name, Connection::AsClient(defn.connect()));
98 }
99
100 for (name, defn) in binds {
101 all_connected.insert(name, Connection::AsServer(defn));
102 }
103
104 DeployPorts {
105 ports: RefCell::new(all_connected),
106 meta: bind_config
107 .1
108 .map(|b| serde_json::from_str(&b).unwrap())
109 .unwrap_or_default(),
110 }
111}
112
113pub async fn init<T: DeserializeOwned + Default>() -> DeployPorts<T> {
114 let ret = init_no_ack_start::<T>().await;
115
116 println!("ack start");
117
118 ret
119}