Skip to main content

robonix_pilot/
service.rs

1// SPDX-License-Identifier: MulanPSL-2.0
2// Author: wheatfox <wheatfox17@icloud.com>
3//
4// `RobonixSystemPilot` gRPC handler (contract `robonix/system/pilot`).
5
6use crate::pb::contracts::{
7    robonix_system_executor_client::RobonixSystemExecutorClient,
8    robonix_system_pilot_server::RobonixSystemPilot,
9};
10use crate::pb::pilot::{BatchResult, PilotEvent, Plan, SessionStatusEvent, Task};
11use crate::planner::{self, ExecutorConn};
12use crate::vlm::{Message, VlmClient};
13use anyhow::Context;
14use robonix_atlas::client::{self as atlas_client, AtlasClient};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tokio::sync::{Mutex, watch};
18use tokio_stream::wrappers::ReceiverStream;
19use tonic::{Request, Response, Status};
20use uuid::Uuid;
21
22#[derive(Clone, Copy)]
23#[repr(u32)]
24#[allow(dead_code)]
25pub enum SessionState {
26    Active = 0,    // when pilot recieved a task
27    Completed = 1, // the task is completed
28    Failed = 2,    // the task failed
29}
30
31// `PilotEvent` carries one of N payloads tagged by `event_kind`. proto3 lacks
32// a oneof here so we keep the discriminator explicit; planner + service both
33// build events through `pack`.
34pub const EVT_TEXT_CHUNK: u32 = 0;
35pub const EVT_PLAN: u32 = 1;
36pub const EVT_BATCH_RESULT: u32 = 2;
37pub const EVT_STATUS: u32 = 3;
38pub const EVT_FINAL_TEXT: u32 = 4;
39
40pub enum PilotStreamBody {
41    TextChunk(String),
42    FinalText(String),
43    Plan(Plan),
44    BatchResult(BatchResult),
45    Status(SessionStatusEvent),
46}
47
48pub fn pack(session_id: &str, body: PilotStreamBody) -> PilotEvent {
49    let mut e = PilotEvent {
50        session_id: session_id.to_string(),
51        ..Default::default()
52    };
53    match body {
54        PilotStreamBody::TextChunk(s) => {
55            e.event_kind = EVT_TEXT_CHUNK;
56            e.text_chunk = s;
57        }
58        PilotStreamBody::Plan(g) => {
59            e.event_kind = EVT_PLAN;
60            e.plan = Some(g);
61        }
62        PilotStreamBody::BatchResult(b) => {
63            e.event_kind = EVT_BATCH_RESULT;
64            e.batch_result = Some(b);
65        }
66        PilotStreamBody::Status(s) => {
67            e.event_kind = EVT_STATUS;
68            e.status = Some(s);
69        }
70        PilotStreamBody::FinalText(s) => {
71            e.event_kind = EVT_FINAL_TEXT;
72            e.final_text = s;
73        }
74    }
75    e
76}
77
78/// LLM conversation history per `session_id`. Grows across turns; never
79/// expired (turns trim themselves at MAX_HISTORY in planner).
80type Histories = Arc<Mutex<HashMap<String, Arc<Mutex<Vec<Message>>>>>>;
81
82pub struct PilotServiceImpl {
83    /// `AtlasClient` is cheap to clone (its inner channel is just a handle);
84    /// each Stream RPC clones it to discover executor concurrently without
85    /// serialising on a single mutex.
86    atlas: AtlasClient,
87    /// Pilot's own provider_id; passed to atlas as `consumer_id` on every
88    /// `ConnectCapability` so the channel record reflects who is using
89    /// the executor.
90    provider_id: String,
91    vlm: VlmClient,
92    histories: Histories,
93    /// Per-session cancellation senders. `abort_turn` Task signals this
94    /// without holding the history lock.
95    cancels: Arc<Mutex<HashMap<String, watch::Sender<bool>>>>,
96}
97
98impl PilotServiceImpl {
99    pub fn new(atlas: AtlasClient, provider_id: String, vlm: VlmClient) -> Self {
100        Self {
101            atlas,
102            provider_id,
103            vlm,
104            histories: Arc::new(Mutex::new(HashMap::new())),
105            cancels: Arc::new(Mutex::new(HashMap::new())),
106        }
107    }
108
109    async fn get_or_create_history(&self, session_id: &str) -> Arc<Mutex<Vec<Message>>> {
110        let mut map = self.histories.lock().await;
111        map.entry(session_id.to_string())
112            .or_insert_with(|| Arc::new(Mutex::new(Vec::new())))
113            .clone()
114    }
115}
116
117fn task_is_abort_turn(task: &Task) -> bool {
118    let j = task.context_json.trim();
119    if j.is_empty() {
120        return false;
121    }
122    serde_json::from_str::<serde_json::Value>(j)
123        .ok()
124        .and_then(|v| v.get("abort_turn").and_then(|x| x.as_bool()))
125        .unwrap_or(false)
126}
127
128#[tonic::async_trait]
129impl RobonixSystemPilot for PilotServiceImpl {
130    type SubmitTaskStream = ReceiverStream<Result<PilotEvent, Status>>;
131
132    async fn submit_task(
133        &self,
134        request: Request<Task>,
135    ) -> Result<Response<Self::SubmitTaskStream>, Status> {
136        let mut task = request.into_inner();
137
138        if task_is_abort_turn(&task) {
139            let id = task.session_id.clone();
140            let ok = if let Some(tx) = self.cancels.lock().await.get(&id) {
141                let _ = tx.send(true);
142                true
143            } else {
144                false
145            };
146            log::debug!("[pilot] abort_turn task for session {id} (signaled={ok})");
147            let (_tx, rx) = tokio::sync::mpsc::channel::<Result<PilotEvent, Status>>(1);
148            return Ok(Response::new(ReceiverStream::new(rx)));
149        }
150
151        if task.session_id.is_empty() {
152            task.session_id = Uuid::new_v4().to_string();
153        }
154
155        let history_arc = self.get_or_create_history(&task.session_id).await;
156        // what is tokio's tx and rx:
157        // https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.Sender.html
158        // https://tokio.rs/tokio/tutorial/channels
159        // MPSC: Multiple Producer Single Consumer
160        let (tx, rx) = tokio::sync::mpsc::channel::<Result<PilotEvent, Status>>(64);
161        let atlas = self.atlas.clone();
162        let provider_id = self.provider_id.clone();
163        let vlm = self.vlm.clone();
164        let session_id = task.session_id.clone();
165        let cancels = Arc::clone(&self.cancels);
166
167        let (cancel_tx, cancel_rx) = watch::channel(false);
168        cancels.lock().await.insert(session_id.clone(), cancel_tx);
169
170        tokio::spawn(async move {
171            let _ = tx
172                .send(Ok(pack(
173                    &session_id,
174                    PilotStreamBody::Status(SessionStatusEvent {
175                        session_id: session_id.clone(),
176                        state: SessionState::Active as u32,
177                        message: String::new(),
178                    }),
179                )))
180                .await;
181
182            let mut atlas_for_turn = atlas.clone();
183            let mut executor = match build_executor_conn(atlas, &provider_id).await {
184                Ok(e) => e,
185                Err(e) => {
186                    let _ = tx
187                        .send(Err(Status::unavailable(format!(
188                            "cannot reach Executor via atlas: {e:#}"
189                        ))))
190                        .await;
191                    cancels.lock().await.remove(&session_id);
192                    return;
193                }
194            };
195
196            let mut history = history_arc.lock().await;
197            if let Err(e) = planner::run_turn(
198                &task,
199                &mut history,
200                &vlm,
201                &mut executor,
202                &mut atlas_for_turn,
203                &provider_id,
204                &tx,
205                cancel_rx,
206            )
207            .await
208            {
209                log::error!("[pilot] turn error for session '{session_id}': {e:#}");
210                let _ = tx.send(Err(Status::internal(e.to_string()))).await;
211            }
212
213            cancels.lock().await.remove(&session_id);
214        });
215
216        Ok(Response::new(ReceiverStream::new(rx)))
217    }
218}
219
220/// Connect to executor's Execute RPC. Capability discovery (what's available
221/// for the LLM to call) is done directly against atlas, not through executor.
222async fn build_executor_conn(
223    mut atlas: AtlasClient,
224    consumer_id: &str,
225) -> anyhow::Result<ExecutorConn> {
226    let (_, _, exec_ch) =
227        atlas_client::connect_to_capability(&mut atlas, consumer_id, "robonix/system/executor")
228            .await
229            .context("connect_to_capability robonix/system/executor")?;
230    Ok(ExecutorConn {
231        graph: RobonixSystemExecutorClient::new(exec_ch),
232    })
233}
234
235#[cfg(test)]
236mod tests {
237    use super::task_is_abort_turn;
238    use crate::pb::pilot::Task;
239
240    fn task(ctx: &str) -> Task {
241        Task {
242            task_id: "t".into(),
243            session_id: "s".into(),
244            source: 0,
245            text: String::new(),
246            audio_data: Vec::new(),
247            context_json: ctx.into(),
248            timestamp_ms: 0,
249        }
250    }
251
252    #[test]
253    fn abort_turn_detected() {
254        assert!(task_is_abort_turn(&task(r#"{"abort_turn":true}"#)));
255        assert!(!task_is_abort_turn(&task(r#"{"abort_turn":false}"#)));
256        assert!(!task_is_abort_turn(&task(r#"{"foo":1}"#)));
257        assert!(!task_is_abort_turn(&task("")));
258        assert!(!task_is_abort_turn(&task("not json")));
259    }
260}