Skip to main content

robonix_pilot/
vlm.rs

1// SPDX-License-Identifier: MulanPSL-2.0
2// Author: wheatfox <wheatfox17@icloud.com>
3//
4// Embedded OpenAI-compatible chat-completions client.
5// TODO: maybe we will support Google/Anthropic/etc. in the future :D
6use crate::config::VlmConfig;
7use anyhow::{Context, Result};
8use async_openai::Client;
9use async_openai::config::OpenAIConfig;
10use async_openai::types::chat::{
11    ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
12    ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
13    ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText,
14    ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15    ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent,
16    ChatCompletionRequestUserMessageContentPart, ChatCompletionTool, ChatCompletionTools,
17    CreateChatCompletionRequestArgs, FunctionCall, FunctionObject, FunctionObjectArgs, ImageDetail,
18    ImageUrl, ResponseFormat,
19};
20use futures_util::stream::{Stream, StreamExt};
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use std::collections::BTreeMap;
24use std::pin::Pin;
25
26/// One message in an OpenAI Chat Completions conversation.
27/// Spec: https://platform.openai.com/docs/api-reference/chat/create#chat/create-messages
28///
29/// One struct, four roles (`system` / `user` / `assistant` / `tool`); each
30/// role uses a different subset of the optional fields. `skip_serializing_if`
31/// on every Option prunes irrelevant fields at serialization, so the wire
32/// JSON for each role only carries what OpenAI expects:
33///
34///   system    → role + content
35///   user      → role + content (+ optional `name` for multi-user)
36///   assistant → role + content (may be null when only tool_calls are emitted)
37///                              + optional tool_calls[]
38///   tool      → role + content + tool_call_id (must match an id in the
39///                                              preceding assistant.tool_calls)
40///
41/// We use a flat struct with optional fields rather than a tagged enum because
42/// the planner does a lot of generic Vec<Message> manipulation (trim, sanitize,
43/// sliding-window slicing) that's awkward to express through `match` on every
44/// access. Type-safety for "tool messages must have tool_call_id" is delegated
45/// to runtime checks (`history::sanitize_for_vlm`) and the OpenAI server's
46/// own validation.
47///
48/// `image_base64` is a robonix-side simplification, NOT part of the OpenAI
49/// wire format. Callers set it on a `user` message; `build_openai_messages`
50/// in this file repackages content + image into OpenAI's multimodal `content`
51/// array (`[{type:"text",...}, {type:"image_url",...}]`) at request time.
52#[derive(Serialize, Deserialize, Clone)]
53pub struct Message {
54    /// "system" / "user" / "assistant" / "tool". Determines which other
55    /// fields are meaningful; OpenAI rejects mismatched combinations.
56    pub role: String,
57
58    /// Optional sender name. Used by `user`/`assistant` for multi-user
59    /// disambiguation; rare in practice. Robonix doesn't set it today.
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub name: Option<String>,
62
63    /// Message text. Always present except on `assistant` messages whose
64    /// only output is tool calls (then None / null on the wire).
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub content: Option<String>,
67
68    /// Tool calls the LLM decided to make. Present only on `assistant`
69    /// messages. Each entry carries id + function.{name, arguments};
70    /// the corresponding `tool` message links back via `tool_call_id`.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub tool_calls: Option<Vec<ToolCall>>,
73
74    /// Correlates a `tool` result back to the `assistant.tool_calls[].id`
75    /// that produced it. Required on `tool` messages; absent on others.
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub tool_call_id: Option<String>,
78
79    /// Inline image attached to a `user` message (base64-encoded JPEG
80    /// bytes). Robonix-only field; rewritten into OpenAI's multimodal
81    /// content array at serialize time by `build_openai_messages`.
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub image_base64: Option<String>,
84}
85
86#[derive(Serialize, Deserialize, Clone)]
87pub struct ToolCall {
88    pub id: String,
89    #[serde(rename = "type")]
90    pub kind: String,
91    pub function: FnCall,
92}
93
94#[derive(Serialize, Deserialize, Clone)]
95pub struct FnCall {
96    pub name: String,
97    pub arguments: String,
98}
99
100#[derive(Serialize, Clone)]
101pub struct ToolDef {
102    #[serde(rename = "type")]
103    kind: String,
104    function: FnDef,
105}
106
107#[derive(Serialize, Clone)]
108struct FnDef {
109    name: String,
110    description: String,
111    parameters: Value,
112}
113
114impl Message {
115    pub fn system(content: &str) -> Self {
116        Self {
117            role: "system".into(),
118            name: None,
119            content: Some(content.into()),
120            tool_calls: None,
121            tool_call_id: None,
122            image_base64: None,
123        }
124    }
125    pub fn user(content: &str) -> Self {
126        Self {
127            role: "user".into(),
128            name: None,
129            content: Some(content.into()),
130            tool_calls: None,
131            tool_call_id: None,
132            image_base64: None,
133        }
134    }
135    pub fn user_with_image(content: &str, image_base64: String) -> Self {
136        Self {
137            role: "user".into(),
138            name: None,
139            content: Some(content.into()),
140            tool_calls: None,
141            tool_call_id: None,
142            image_base64: Some(image_base64),
143        }
144    }
145    pub fn assistant(content: &str) -> Self {
146        Self {
147            role: "assistant".into(),
148            name: None,
149            content: Some(content.into()),
150            tool_calls: None,
151            tool_call_id: None,
152            image_base64: None,
153        }
154    }
155    pub fn tool_result(id: &str, content: &str) -> Self {
156        Self {
157            role: "tool".into(),
158            name: None,
159            content: Some(content.into()),
160            tool_calls: None,
161            tool_call_id: Some(id.into()),
162            image_base64: None,
163        }
164    }
165}
166
167/// Item yielded by the chat completion stream. `planner.rs` matches on this
168/// enum to drive token streaming, tool dispatch, and finish handling.
169pub enum VlmStreamItem {
170    TextDelta(String),
171    ToolCall(ToolCall),
172    /// Stream complete. Finish reason ("stop" / "tool_calls" / "error") is
173    /// not surfaced to consumers yet — add a field here when the planner or
174    /// downstream PilotEvent grows a use for it.
175    Finish,
176}
177
178/// Direct HTTP client for an OpenAI-compatible chat-completions endpoint.
179/// Cheap to clone — `async_openai::Client` wraps a `reqwest::Client` (an
180/// `Arc<...>` internally). No mutex needed when sharing across tasks.
181#[derive(Clone)]
182pub struct VlmClient {
183    inner: Client<OpenAIConfig>,
184    model: String,
185}
186
187impl VlmClient {
188    pub fn new(cfg: &VlmConfig) -> Self {
189        let oa = OpenAIConfig::new()
190            .with_api_base(cfg.upstream.trim_end_matches('/'))
191            .with_api_key(&cfg.api_key);
192        Self {
193            inner: Client::with_config(oa),
194            model: cfg.model.clone(),
195        }
196    }
197
198    /// Open a streaming chat completion. Yields:
199    ///   - `TextDelta` for every assistant content chunk
200    ///   - `ToolCall` once per accumulated function call (after the upstream
201    ///     finishes streaming all argument deltas)
202    ///   - one final `Finish`
203    pub async fn chat_stream(
204        &self,
205        messages: &[Message],
206        tools: &[ToolDef],
207    ) -> Result<Pin<Box<dyn Stream<Item = Result<VlmStreamItem>> + Send>>> {
208        let oai_messages = build_openai_messages(messages)?;
209        let oai_tools = build_openai_tools(tools)?;
210
211        let mut req_builder = CreateChatCompletionRequestArgs::default();
212        req_builder
213            .model(&self.model)
214            .messages(oai_messages)
215            .stream(true)
216            .response_format(ResponseFormat::JsonObject);
217        if !oai_tools.is_empty() {
218            req_builder.tools(oai_tools);
219        }
220        let request = req_builder
221            .build()
222            .context("build chat completion request")?;
223
224        let mut upstream = self
225            .inner
226            .chat()
227            .create_stream(request)
228            .await
229            .context("open VLM chat stream")?;
230
231        // Walk the upstream chunk-by-chunk, accumulating tool-call deltas by
232        // index until the upstream finishes; then emit one ToolCall per index
233        // and a final Finish event. Use mpsc + spawn so we can return the
234        // boxed Stream while the polling runs in the background.
235        let (tx, rx) = tokio::sync::mpsc::channel::<Result<VlmStreamItem>>(64);
236        tokio::spawn(async move {
237            let mut tc_acc: BTreeMap<u32, AccumulatedToolCall> = BTreeMap::new();
238            let mut finish = "stop".to_string();
239            while let Some(chunk) = upstream.next().await {
240                match chunk {
241                    Ok(resp) => {
242                        let Some(choice) = resp.choices.into_iter().next() else {
243                            continue;
244                        };
245                        let delta = choice.delta;
246                        if let Some(content) = delta.content
247                            && !content.is_empty()
248                            && tx
249                                .send(Ok(VlmStreamItem::TextDelta(content)))
250                                .await
251                                .is_err()
252                        {
253                            return;
254                        }
255                        if let Some(tc_chunks) = delta.tool_calls {
256                            for tc in tc_chunks {
257                                let entry = tc_acc.entry(tc.index).or_default();
258                                if let Some(id) = tc.id {
259                                    entry.id = id;
260                                }
261                                if let Some(func) = tc.function {
262                                    if let Some(name) = func.name {
263                                        entry.name.push_str(&name);
264                                    }
265                                    if let Some(args) = func.arguments {
266                                        entry.arguments.push_str(&args);
267                                    }
268                                }
269                            }
270                        }
271                        if let Some(fr) = choice.finish_reason {
272                            finish = format!("{fr:?}").to_lowercase();
273                        }
274                    }
275                    Err(e) => {
276                        let _ = tx
277                            .send(Err(anyhow::anyhow!("VLM stream chunk error: {e}")))
278                            .await;
279                        return;
280                    }
281                }
282            }
283
284            for (_, tc) in tc_acc {
285                if tc.id.is_empty() && tc.name.is_empty() {
286                    continue;
287                }
288                let item = VlmStreamItem::ToolCall(ToolCall {
289                    id: tc.id,
290                    kind: "function".to_string(),
291                    function: FnCall {
292                        name: tc.name,
293                        arguments: tc.arguments,
294                    },
295                });
296                if tx.send(Ok(item)).await.is_err() {
297                    return;
298                }
299            }
300            let _ = finish; // surface to PilotEvent later if needed
301            let _ = tx.send(Ok(VlmStreamItem::Finish)).await;
302        });
303
304        Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
305    }
306}
307
308#[derive(Default)]
309struct AccumulatedToolCall {
310    id: String,
311    name: String,
312    arguments: String,
313}
314
315fn build_openai_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
316    let mut out = Vec::with_capacity(messages.len());
317    for m in messages {
318        let msg = match m.role.as_str() {
319            "system" => ChatCompletionRequestSystemMessageArgs::default()
320                .content(m.content.clone().unwrap_or_default())
321                .build()?
322                .into(),
323            "user" => {
324                if let Some(image) = &m.image_base64 {
325                    let text = m.content.clone().unwrap_or_default();
326                    let mut parts: Vec<ChatCompletionRequestUserMessageContentPart> = Vec::new();
327                    if !text.is_empty() {
328                        parts.push(ChatCompletionRequestUserMessageContentPart::Text(
329                            ChatCompletionRequestMessageContentPartText { text },
330                        ));
331                    }
332                    let url = format!("data:image/jpeg;base64,{image}");
333                    parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
334                        ChatCompletionRequestMessageContentPartImage {
335                            image_url: ImageUrl {
336                                url,
337                                detail: Some(ImageDetail::Auto),
338                            },
339                        },
340                    ));
341                    ChatCompletionRequestUserMessageArgs::default()
342                        .content(ChatCompletionRequestUserMessageContent::Array(parts))
343                        .build()?
344                        .into()
345                } else {
346                    ChatCompletionRequestUserMessageArgs::default()
347                        .content(m.content.clone().unwrap_or_default())
348                        .build()?
349                        .into()
350                }
351            }
352            "assistant" => {
353                let mut b = ChatCompletionRequestAssistantMessageArgs::default();
354                if let Some(c) = &m.content
355                    && !c.is_empty()
356                {
357                    b.content(c.clone());
358                }
359                if let Some(tcs) = &m.tool_calls {
360                    let oai_tcs: Vec<ChatCompletionMessageToolCalls> = tcs
361                        .iter()
362                        .map(|tc| {
363                            ChatCompletionMessageToolCalls::Function(
364                                ChatCompletionMessageToolCall {
365                                    id: tc.id.clone(),
366                                    function: FunctionCall {
367                                        name: tc.function.name.clone(),
368                                        arguments: tc.function.arguments.clone(),
369                                    },
370                                },
371                            )
372                        })
373                        .collect();
374                    b.tool_calls(oai_tcs);
375                }
376                b.build()?.into()
377            }
378            "tool" => {
379                let id = m.tool_call_id.clone().unwrap_or_default();
380                ChatCompletionRequestToolMessageArgs::default()
381                    .tool_call_id(id)
382                    .content(m.content.clone().unwrap_or_default())
383                    .build()?
384                    .into()
385            }
386            other => anyhow::bail!("unknown message role '{other}'"),
387        };
388        out.push(msg);
389    }
390    Ok(out)
391}
392
393fn build_openai_tools(tools: &[ToolDef]) -> Result<Vec<ChatCompletionTools>> {
394    tools
395        .iter()
396        .map(|t| -> Result<ChatCompletionTools> {
397            let func: FunctionObject = FunctionObjectArgs::default()
398                .name(&t.function.name)
399                .description(&t.function.description)
400                .parameters(t.function.parameters.clone())
401                .build()?;
402            Ok(ChatCompletionTools::Function(ChatCompletionTool {
403                function: func,
404            }))
405        })
406        .collect()
407}