1use crate::pb::contracts::{
7 robonix_system_executor_client::RobonixSystemExecutorClient,
8 robonix_system_pilot_server::RobonixSystemPilot,
9};
10use crate::pb::pilot::{BatchResult, PilotEvent, Plan, SessionStatusEvent, Task};
11use crate::planner::{self, ExecutorConn};
12use crate::vlm::{Message, VlmClient};
13use anyhow::Context;
14use robonix_atlas::client::{self as atlas_client, AtlasClient};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tokio::sync::{Mutex, watch};
18use tokio_stream::wrappers::ReceiverStream;
19use tonic::{Request, Response, Status};
20use uuid::Uuid;
21
22#[derive(Clone, Copy)]
23#[repr(u32)]
24#[allow(dead_code)]
25pub enum SessionState {
26 Active = 0, Completed = 1, Failed = 2, }
30
31pub const EVT_TEXT_CHUNK: u32 = 0;
35pub const EVT_PLAN: u32 = 1;
36pub const EVT_BATCH_RESULT: u32 = 2;
37pub const EVT_STATUS: u32 = 3;
38pub const EVT_FINAL_TEXT: u32 = 4;
39
40pub enum PilotStreamBody {
41 TextChunk(String),
42 FinalText(String),
43 Plan(Plan),
44 BatchResult(BatchResult),
45 Status(SessionStatusEvent),
46}
47
48pub fn pack(session_id: &str, body: PilotStreamBody) -> PilotEvent {
49 let mut e = PilotEvent {
50 session_id: session_id.to_string(),
51 ..Default::default()
52 };
53 match body {
54 PilotStreamBody::TextChunk(s) => {
55 e.event_kind = EVT_TEXT_CHUNK;
56 e.text_chunk = s;
57 }
58 PilotStreamBody::Plan(g) => {
59 e.event_kind = EVT_PLAN;
60 e.plan = Some(g);
61 }
62 PilotStreamBody::BatchResult(b) => {
63 e.event_kind = EVT_BATCH_RESULT;
64 e.batch_result = Some(b);
65 }
66 PilotStreamBody::Status(s) => {
67 e.event_kind = EVT_STATUS;
68 e.status = Some(s);
69 }
70 PilotStreamBody::FinalText(s) => {
71 e.event_kind = EVT_FINAL_TEXT;
72 e.final_text = s;
73 }
74 }
75 e
76}
77
78type Histories = Arc<Mutex<HashMap<String, Arc<Mutex<Vec<Message>>>>>>;
81
82pub struct PilotServiceImpl {
83 atlas: AtlasClient,
87 provider_id: String,
91 vlm: VlmClient,
92 histories: Histories,
93 cancels: Arc<Mutex<HashMap<String, watch::Sender<bool>>>>,
96}
97
98impl PilotServiceImpl {
99 pub fn new(atlas: AtlasClient, provider_id: String, vlm: VlmClient) -> Self {
100 Self {
101 atlas,
102 provider_id,
103 vlm,
104 histories: Arc::new(Mutex::new(HashMap::new())),
105 cancels: Arc::new(Mutex::new(HashMap::new())),
106 }
107 }
108
109 async fn get_or_create_history(&self, session_id: &str) -> Arc<Mutex<Vec<Message>>> {
110 let mut map = self.histories.lock().await;
111 map.entry(session_id.to_string())
112 .or_insert_with(|| Arc::new(Mutex::new(Vec::new())))
113 .clone()
114 }
115}
116
117fn task_is_abort_turn(task: &Task) -> bool {
118 let j = task.context_json.trim();
119 if j.is_empty() {
120 return false;
121 }
122 serde_json::from_str::<serde_json::Value>(j)
123 .ok()
124 .and_then(|v| v.get("abort_turn").and_then(|x| x.as_bool()))
125 .unwrap_or(false)
126}
127
128#[tonic::async_trait]
129impl RobonixSystemPilot for PilotServiceImpl {
130 type SubmitTaskStream = ReceiverStream<Result<PilotEvent, Status>>;
131
132 async fn submit_task(
133 &self,
134 request: Request<Task>,
135 ) -> Result<Response<Self::SubmitTaskStream>, Status> {
136 let mut task = request.into_inner();
137
138 if task_is_abort_turn(&task) {
139 let id = task.session_id.clone();
140 let ok = if let Some(tx) = self.cancels.lock().await.get(&id) {
141 let _ = tx.send(true);
142 true
143 } else {
144 false
145 };
146 log::debug!("[pilot] abort_turn task for session {id} (signaled={ok})");
147 let (_tx, rx) = tokio::sync::mpsc::channel::<Result<PilotEvent, Status>>(1);
148 return Ok(Response::new(ReceiverStream::new(rx)));
149 }
150
151 if task.session_id.is_empty() {
152 task.session_id = Uuid::new_v4().to_string();
153 }
154
155 let history_arc = self.get_or_create_history(&task.session_id).await;
156 let (tx, rx) = tokio::sync::mpsc::channel::<Result<PilotEvent, Status>>(64);
161 let atlas = self.atlas.clone();
162 let provider_id = self.provider_id.clone();
163 let vlm = self.vlm.clone();
164 let session_id = task.session_id.clone();
165 let cancels = Arc::clone(&self.cancels);
166
167 let (cancel_tx, cancel_rx) = watch::channel(false);
168 cancels.lock().await.insert(session_id.clone(), cancel_tx);
169
170 tokio::spawn(async move {
171 let _ = tx
172 .send(Ok(pack(
173 &session_id,
174 PilotStreamBody::Status(SessionStatusEvent {
175 session_id: session_id.clone(),
176 state: SessionState::Active as u32,
177 message: String::new(),
178 }),
179 )))
180 .await;
181
182 let mut atlas_for_turn = atlas.clone();
183 let mut executor = match build_executor_conn(atlas, &provider_id).await {
184 Ok(e) => e,
185 Err(e) => {
186 let _ = tx
187 .send(Err(Status::unavailable(format!(
188 "cannot reach Executor via atlas: {e:#}"
189 ))))
190 .await;
191 cancels.lock().await.remove(&session_id);
192 return;
193 }
194 };
195
196 let mut history = history_arc.lock().await;
197 if let Err(e) = planner::run_turn(
198 &task,
199 &mut history,
200 &vlm,
201 &mut executor,
202 &mut atlas_for_turn,
203 &provider_id,
204 &tx,
205 cancel_rx,
206 )
207 .await
208 {
209 log::error!("[pilot] turn error for session '{session_id}': {e:#}");
210 let _ = tx.send(Err(Status::internal(e.to_string()))).await;
211 }
212
213 cancels.lock().await.remove(&session_id);
214 });
215
216 Ok(Response::new(ReceiverStream::new(rx)))
217 }
218}
219
220async fn build_executor_conn(
223 mut atlas: AtlasClient,
224 consumer_id: &str,
225) -> anyhow::Result<ExecutorConn> {
226 let (_, _, exec_ch) =
227 atlas_client::connect_to_capability(&mut atlas, consumer_id, "robonix/system/executor")
228 .await
229 .context("connect_to_capability robonix/system/executor")?;
230 Ok(ExecutorConn {
231 graph: RobonixSystemExecutorClient::new(exec_ch),
232 })
233}
234
235#[cfg(test)]
236mod tests {
237 use super::task_is_abort_turn;
238 use crate::pb::pilot::Task;
239
240 fn task(ctx: &str) -> Task {
241 Task {
242 task_id: "t".into(),
243 session_id: "s".into(),
244 source: 0,
245 text: String::new(),
246 audio_data: Vec::new(),
247 context_json: ctx.into(),
248 timestamp_ms: 0,
249 }
250 }
251
252 #[test]
253 fn abort_turn_detected() {
254 assert!(task_is_abort_turn(&task(r#"{"abort_turn":true}"#)));
255 assert!(!task_is_abort_turn(&task(r#"{"abort_turn":false}"#)));
256 assert!(!task_is_abort_turn(&task(r#"{"foo":1}"#)));
257 assert!(!task_is_abort_turn(&task("")));
258 assert!(!task_is_abort_turn(&task("not json")));
259 }
260}