Skip to main content

robonix_executor/
service.rs

1// SPDX-License-Identifier: MulanPSL-2.0
2// Author: wheatfox <wheatfox17@icloud.com>
3//
4// gRPC contract handler: RobonixSystemExecutor.Execute(Plan) → stream CapabilityCallEvent.
5
6use crate::dispatch;
7use crate::exec_wire;
8use crate::pb::contracts::robonix_system_executor_server::RobonixSystemExecutor;
9use crate::pb::executor::CapabilityCallEvent;
10use crate::pb::pilot::{CapabilityCall, Plan};
11use robonix_atlas::client::AtlasClient;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15use tokio::sync::mpsc::Sender;
16use tokio_stream::wrappers::ReceiverStream;
17use tonic::{Request, Response, Status};
18
19const RTDL_SEQUENCE: u32 = 0;
20const RTDL_PARALLEL: u32 = 1;
21const RTDL_DO: u32 = 2;
22
23/// `AtlasClient` is cheap to clone — each Execute RPC clones it so per-plan
24/// dispatch runs without serialising on a single mutex.
25#[derive(Clone)]
26pub struct ExecutorServiceImpl {
27    atlas: AtlasClient,
28    /// Executor's own provider_id. Two roles:
29    ///   1. consumer_id passed to atlas on every ConnectCapability so the
30    ///      channel record reflects who is using each downstream provider.
31    ///   2. self-detection: when a CapabilityCall in the plan targets this
32    ///      provider_id, dispatch short-circuits to the in-process builtin
33    ///      handlers instead of going through MCP loopback.
34    provider_id: String,
35}
36
37impl ExecutorServiceImpl {
38    pub fn new(atlas: AtlasClient, provider_id: String) -> Self {
39        Self { atlas, provider_id }
40    }
41}
42
43#[tonic::async_trait]
44impl RobonixSystemExecutor for ExecutorServiceImpl {
45    type ExecuteStream = ReceiverStream<Result<CapabilityCallEvent, Status>>;
46
47    async fn execute(
48        &self,
49        request: Request<Plan>,
50    ) -> Result<Response<Self::ExecuteStream>, Status> {
51        let plan = request.into_inner();
52        validate_plan(&plan).map_err(Status::invalid_argument)?;
53        let (tx, rx) = tokio::sync::mpsc::channel(64);
54        let atlas = self.atlas.clone();
55        let provider_id = self.provider_id.clone();
56
57        tokio::spawn(async move {
58            let plan_id = plan.plan_id.clone();
59            let plan = Arc::new(plan);
60            let any_failed = execute_node(
61                Arc::clone(&plan),
62                plan.root_index as usize,
63                tx.clone(),
64                atlas,
65                provider_id,
66            )
67            .await;
68
69            let _ = tx.send(Ok(exec_wire::complete(plan_id, any_failed))).await;
70        });
71
72        Ok(Response::new(ReceiverStream::new(rx)))
73    }
74}
75
76type ExecuteNodeFuture = Pin<Box<dyn Future<Output = bool> + Send + 'static>>;
77
78fn execute_node(
79    plan: Arc<Plan>,
80    node_index: usize,
81    tx: Sender<Result<CapabilityCallEvent, Status>>,
82    atlas: AtlasClient,
83    provider_id: String,
84) -> ExecuteNodeFuture {
85    Box::pin(async move {
86        let node = &plan.nodes[node_index];
87        match node.node_kind {
88            RTDL_SEQUENCE => {
89                let mut any_failed = false;
90                for child in &node.children {
91                    any_failed |= execute_node(
92                        Arc::clone(&plan),
93                        *child as usize,
94                        tx.clone(),
95                        atlas.clone(),
96                        provider_id.clone(),
97                    )
98                    .await;
99                }
100                any_failed
101            }
102            RTDL_PARALLEL => {
103                let mut handles = Vec::with_capacity(node.children.len());
104                for child in &node.children {
105                    let child_plan = Arc::clone(&plan);
106                    let child_tx = tx.clone();
107                    let child_atlas = atlas.clone();
108                    let child_provider_id = provider_id.clone();
109                    let child_index = *child as usize;
110                    handles.push(tokio::spawn(async move {
111                        execute_node(
112                            child_plan,
113                            child_index,
114                            child_tx,
115                            child_atlas,
116                            child_provider_id,
117                        )
118                        .await
119                    }));
120                }
121                let mut any_failed = false;
122                for handle in handles {
123                    match handle.await {
124                        Ok(child_failed) => any_failed |= child_failed,
125                        Err(e) => {
126                            any_failed = true;
127                            log::warn!("[executor] parallel branch task failed: {e}");
128                        }
129                    }
130                }
131                any_failed
132            }
133            RTDL_DO => {
134                let call = node
135                    .call
136                    .as_ref()
137                    .expect("validated do node must contain call");
138                execute_call(call, tx, atlas, provider_id).await
139            }
140            _ => {
141                log::warn!(
142                    "[executor] invalid node_kind={} reached after validation",
143                    node.node_kind
144                );
145                true
146            }
147        }
148    })
149}
150
151/// Dispatch one RTDL `do` node and stream its started/result events.
152async fn execute_call(
153    call: &CapabilityCall,
154    tx: Sender<Result<CapabilityCallEvent, Status>>,
155    atlas: AtlasClient,
156    provider_id: String,
157) -> bool {
158    // `tokio::sync::mpsc::Sender` is concurrency-safe when cloned (parallel branches).
159    // Outbound events ride the Execute server-stream to the gRPC client (Pilot).
160    let _ = tx
161        .send(Ok(exec_wire::started(
162            call.call_id.clone(),
163            call.provider_id.clone(),
164            call.contract_id.clone(),
165        )))
166        .await;
167
168    log::info!(
169        "[executor] dispatching call_id={} provider='{}' contract='{}'",
170        call.call_id,
171        call.provider_id,
172        call.contract_id,
173    );
174    let mut atlas_for_call = atlas.clone();
175    let result = dispatch::dispatch(call, &provider_id, &mut atlas_for_call).await;
176    let failed = !result.success;
177
178    if result.success {
179        let preview: String = result.output.chars().take(120).collect();
180        let ellipsis = if result.output.len() > 120 { "..." } else { "" };
181        log::info!(
182            "[executor] '{}' ok: {}{}",
183            call.contract_id,
184            preview,
185            ellipsis
186        );
187    } else {
188        log::warn!("[executor] '{}' failed: {}", call.contract_id, result.error);
189    }
190
191    let _ = tx.send(Ok(exec_wire::result(result))).await;
192    failed
193}
194
195/// Validate Plan arena shape before spawning execution work.
196fn validate_plan(plan: &Plan) -> Result<(), String> {
197    if plan.nodes.is_empty() {
198        return Err("Plan.nodes must not be empty".to_string());
199    }
200    let root = plan.root_index as usize;
201    if root >= plan.nodes.len() {
202        return Err(format!(
203            "Plan.root_index {} is out of bounds for {} nodes",
204            plan.root_index,
205            plan.nodes.len()
206        ));
207    }
208
209    for (idx, node) in plan.nodes.iter().enumerate() {
210        match node.node_kind {
211            RTDL_SEQUENCE | RTDL_PARALLEL => {
212                for child in &node.children {
213                    if *child as usize >= plan.nodes.len() {
214                        return Err(format!("node {idx} child index {child} is out of bounds"));
215                    }
216                }
217            }
218            RTDL_DO => {
219                if !node.children.is_empty() {
220                    return Err(format!("do node {idx} must not have children"));
221                }
222                let Some(call) = node.call.as_ref() else {
223                    return Err(format!("do node {idx} must contain a call"));
224                };
225                validate_call(idx, call)?;
226            }
227            other => return Err(format!("node {idx} has invalid node_kind {other}")),
228        }
229    }
230
231    let mut colors = vec![VisitColor::White; plan.nodes.len()];
232    visit_for_cycles(root, plan, &mut colors)
233}
234
235fn validate_call(node_index: usize, call: &CapabilityCall) -> Result<(), String> {
236    if call.call_id.is_empty() {
237        return Err(format!("do node {node_index} call_id must not be empty"));
238    }
239    if call.provider_id.is_empty() {
240        return Err(format!(
241            "do node {node_index} provider_id must not be empty"
242        ));
243    }
244    if call.contract_id.is_empty() {
245        return Err(format!(
246            "do node {node_index} contract_id must not be empty"
247        ));
248    }
249    Ok(())
250}
251
252#[derive(Clone, Copy, PartialEq, Eq)]
253enum VisitColor {
254    White,
255    Gray,
256    Black,
257}
258
259/// DFS cycle check on the plan arena following only sequence/parallel child edges.
260///
261/// Uses White/Gray/Black marks: entering a Gray node means a back-edge to an ancestor.
262/// `RTDL_DO` nodes have no children in this graph. Returns `Ok` when the subgraph from
263/// `index` is acyclic; otherwise an error naming the node where the cycle was found.
264fn visit_for_cycles(index: usize, plan: &Plan, colors: &mut [VisitColor]) -> Result<(), String> {
265    match colors[index] {
266        VisitColor::Gray => return Err(format!("cycle detected at node {index}")),
267        VisitColor::Black => return Ok(()),
268        VisitColor::White => {}
269    }
270    colors[index] = VisitColor::Gray;
271    let node = &plan.nodes[index];
272    if matches!(node.node_kind, RTDL_SEQUENCE | RTDL_PARALLEL) {
273        for child in &node.children {
274            visit_for_cycles(*child as usize, plan, colors)?;
275        }
276    }
277    colors[index] = VisitColor::Black;
278    Ok(())
279}
280
281#[cfg(test)]
282mod tests {
283    use super::{RTDL_DO, RTDL_PARALLEL, RTDL_SEQUENCE, validate_plan};
284    use crate::pb::pilot::{CapabilityCall, Plan, RtdlNode};
285
286    fn call(id: &str) -> CapabilityCall {
287        CapabilityCall {
288            call_id: id.to_string(),
289            provider_id: "provider".to_string(),
290            contract_id: "robonix/test/cap".to_string(),
291            args_json: "{}".to_string(),
292        }
293    }
294
295    fn node(kind: u32, children: Vec<u32>, call: Option<CapabilityCall>) -> RtdlNode {
296        RtdlNode {
297            node_kind: kind,
298            children,
299            call,
300        }
301    }
302
303    fn plan(nodes: Vec<RtdlNode>, root_index: u32) -> Plan {
304        Plan {
305            plan_id: "p".to_string(),
306            session_id: "s".to_string(),
307            round: 0,
308            nodes,
309            root_index,
310        }
311    }
312
313    #[test]
314    fn validates_sequence_and_parallel_nodes() {
315        let p = plan(
316            vec![
317                node(RTDL_SEQUENCE, vec![1, 2], None),
318                node(RTDL_DO, vec![], Some(call("p:0"))),
319                node(RTDL_PARALLEL, vec![3, 4], None),
320                node(RTDL_DO, vec![], Some(call("p:1"))),
321                node(RTDL_DO, vec![], Some(call("p:2"))),
322            ],
323            0,
324        );
325        validate_plan(&p).unwrap();
326    }
327
328    #[test]
329    fn rejects_invalid_root() {
330        let p = plan(vec![node(RTDL_SEQUENCE, vec![], None)], 3);
331        assert!(validate_plan(&p).unwrap_err().contains("root_index"));
332    }
333
334    #[test]
335    fn rejects_out_of_bounds_child() {
336        let p = plan(vec![node(RTDL_SEQUENCE, vec![9], None)], 0);
337        assert!(validate_plan(&p).unwrap_err().contains("out of bounds"));
338    }
339
340    #[test]
341    fn rejects_cycle() {
342        let p = plan(
343            vec![
344                node(RTDL_SEQUENCE, vec![1], None),
345                node(RTDL_PARALLEL, vec![0], None),
346            ],
347            0,
348        );
349        assert!(validate_plan(&p).unwrap_err().contains("cycle"));
350    }
351
352    #[test]
353    fn rejects_do_without_call() {
354        let p = plan(vec![node(RTDL_DO, vec![], None)], 0);
355        assert!(
356            validate_plan(&p)
357                .unwrap_err()
358                .contains("must contain a call")
359        );
360    }
361}