1use crate::config::VlmConfig;
7use anyhow::{Context, Result};
8use async_openai::Client;
9use async_openai::config::OpenAIConfig;
10use async_openai::types::chat::{
11 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
12 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
13 ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText,
14 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15 ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent,
16 ChatCompletionRequestUserMessageContentPart, ChatCompletionTool, ChatCompletionTools,
17 CreateChatCompletionRequestArgs, FunctionCall, FunctionObject, FunctionObjectArgs, ImageDetail,
18 ImageUrl, ResponseFormat,
19};
20use futures_util::stream::{Stream, StreamExt};
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use std::collections::BTreeMap;
24use std::pin::Pin;
25
26#[derive(Serialize, Deserialize, Clone)]
53pub struct Message {
54 pub role: String,
57
58 #[serde(skip_serializing_if = "Option::is_none")]
61 pub name: Option<String>,
62
63 #[serde(skip_serializing_if = "Option::is_none")]
66 pub content: Option<String>,
67
68 #[serde(skip_serializing_if = "Option::is_none")]
72 pub tool_calls: Option<Vec<ToolCall>>,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
77 pub tool_call_id: Option<String>,
78
79 #[serde(skip_serializing_if = "Option::is_none")]
83 pub image_base64: Option<String>,
84}
85
86#[derive(Serialize, Deserialize, Clone)]
87pub struct ToolCall {
88 pub id: String,
89 #[serde(rename = "type")]
90 pub kind: String,
91 pub function: FnCall,
92}
93
94#[derive(Serialize, Deserialize, Clone)]
95pub struct FnCall {
96 pub name: String,
97 pub arguments: String,
98}
99
100#[derive(Serialize, Clone)]
101pub struct ToolDef {
102 #[serde(rename = "type")]
103 kind: String,
104 function: FnDef,
105}
106
107#[derive(Serialize, Clone)]
108struct FnDef {
109 name: String,
110 description: String,
111 parameters: Value,
112}
113
114impl Message {
115 pub fn system(content: &str) -> Self {
116 Self {
117 role: "system".into(),
118 name: None,
119 content: Some(content.into()),
120 tool_calls: None,
121 tool_call_id: None,
122 image_base64: None,
123 }
124 }
125 pub fn user(content: &str) -> Self {
126 Self {
127 role: "user".into(),
128 name: None,
129 content: Some(content.into()),
130 tool_calls: None,
131 tool_call_id: None,
132 image_base64: None,
133 }
134 }
135 pub fn user_with_image(content: &str, image_base64: String) -> Self {
136 Self {
137 role: "user".into(),
138 name: None,
139 content: Some(content.into()),
140 tool_calls: None,
141 tool_call_id: None,
142 image_base64: Some(image_base64),
143 }
144 }
145 pub fn assistant(content: &str) -> Self {
146 Self {
147 role: "assistant".into(),
148 name: None,
149 content: Some(content.into()),
150 tool_calls: None,
151 tool_call_id: None,
152 image_base64: None,
153 }
154 }
155 pub fn tool_result(id: &str, content: &str) -> Self {
156 Self {
157 role: "tool".into(),
158 name: None,
159 content: Some(content.into()),
160 tool_calls: None,
161 tool_call_id: Some(id.into()),
162 image_base64: None,
163 }
164 }
165}
166
167pub enum VlmStreamItem {
170 TextDelta(String),
171 ToolCall(ToolCall),
172 Finish,
176}
177
178#[derive(Clone)]
182pub struct VlmClient {
183 inner: Client<OpenAIConfig>,
184 model: String,
185}
186
187impl VlmClient {
188 pub fn new(cfg: &VlmConfig) -> Self {
189 let oa = OpenAIConfig::new()
190 .with_api_base(cfg.upstream.trim_end_matches('/'))
191 .with_api_key(&cfg.api_key);
192 Self {
193 inner: Client::with_config(oa),
194 model: cfg.model.clone(),
195 }
196 }
197
198 pub async fn chat_stream(
204 &self,
205 messages: &[Message],
206 tools: &[ToolDef],
207 ) -> Result<Pin<Box<dyn Stream<Item = Result<VlmStreamItem>> + Send>>> {
208 let oai_messages = build_openai_messages(messages)?;
209 let oai_tools = build_openai_tools(tools)?;
210
211 let mut req_builder = CreateChatCompletionRequestArgs::default();
212 req_builder
213 .model(&self.model)
214 .messages(oai_messages)
215 .stream(true)
216 .response_format(ResponseFormat::JsonObject);
217 if !oai_tools.is_empty() {
218 req_builder.tools(oai_tools);
219 }
220 let request = req_builder
221 .build()
222 .context("build chat completion request")?;
223
224 let mut upstream = self
225 .inner
226 .chat()
227 .create_stream(request)
228 .await
229 .context("open VLM chat stream")?;
230
231 let (tx, rx) = tokio::sync::mpsc::channel::<Result<VlmStreamItem>>(64);
236 tokio::spawn(async move {
237 let mut tc_acc: BTreeMap<u32, AccumulatedToolCall> = BTreeMap::new();
238 let mut finish = "stop".to_string();
239 while let Some(chunk) = upstream.next().await {
240 match chunk {
241 Ok(resp) => {
242 let Some(choice) = resp.choices.into_iter().next() else {
243 continue;
244 };
245 let delta = choice.delta;
246 if let Some(content) = delta.content
247 && !content.is_empty()
248 && tx
249 .send(Ok(VlmStreamItem::TextDelta(content)))
250 .await
251 .is_err()
252 {
253 return;
254 }
255 if let Some(tc_chunks) = delta.tool_calls {
256 for tc in tc_chunks {
257 let entry = tc_acc.entry(tc.index).or_default();
258 if let Some(id) = tc.id {
259 entry.id = id;
260 }
261 if let Some(func) = tc.function {
262 if let Some(name) = func.name {
263 entry.name.push_str(&name);
264 }
265 if let Some(args) = func.arguments {
266 entry.arguments.push_str(&args);
267 }
268 }
269 }
270 }
271 if let Some(fr) = choice.finish_reason {
272 finish = format!("{fr:?}").to_lowercase();
273 }
274 }
275 Err(e) => {
276 let _ = tx
277 .send(Err(anyhow::anyhow!("VLM stream chunk error: {e}")))
278 .await;
279 return;
280 }
281 }
282 }
283
284 for (_, tc) in tc_acc {
285 if tc.id.is_empty() && tc.name.is_empty() {
286 continue;
287 }
288 let item = VlmStreamItem::ToolCall(ToolCall {
289 id: tc.id,
290 kind: "function".to_string(),
291 function: FnCall {
292 name: tc.name,
293 arguments: tc.arguments,
294 },
295 });
296 if tx.send(Ok(item)).await.is_err() {
297 return;
298 }
299 }
300 let _ = finish; let _ = tx.send(Ok(VlmStreamItem::Finish)).await;
302 });
303
304 Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
305 }
306}
307
308#[derive(Default)]
309struct AccumulatedToolCall {
310 id: String,
311 name: String,
312 arguments: String,
313}
314
315fn build_openai_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
316 let mut out = Vec::with_capacity(messages.len());
317 for m in messages {
318 let msg = match m.role.as_str() {
319 "system" => ChatCompletionRequestSystemMessageArgs::default()
320 .content(m.content.clone().unwrap_or_default())
321 .build()?
322 .into(),
323 "user" => {
324 if let Some(image) = &m.image_base64 {
325 let text = m.content.clone().unwrap_or_default();
326 let mut parts: Vec<ChatCompletionRequestUserMessageContentPart> = Vec::new();
327 if !text.is_empty() {
328 parts.push(ChatCompletionRequestUserMessageContentPart::Text(
329 ChatCompletionRequestMessageContentPartText { text },
330 ));
331 }
332 let url = format!("data:image/jpeg;base64,{image}");
333 parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
334 ChatCompletionRequestMessageContentPartImage {
335 image_url: ImageUrl {
336 url,
337 detail: Some(ImageDetail::Auto),
338 },
339 },
340 ));
341 ChatCompletionRequestUserMessageArgs::default()
342 .content(ChatCompletionRequestUserMessageContent::Array(parts))
343 .build()?
344 .into()
345 } else {
346 ChatCompletionRequestUserMessageArgs::default()
347 .content(m.content.clone().unwrap_or_default())
348 .build()?
349 .into()
350 }
351 }
352 "assistant" => {
353 let mut b = ChatCompletionRequestAssistantMessageArgs::default();
354 if let Some(c) = &m.content
355 && !c.is_empty()
356 {
357 b.content(c.clone());
358 }
359 if let Some(tcs) = &m.tool_calls {
360 let oai_tcs: Vec<ChatCompletionMessageToolCalls> = tcs
361 .iter()
362 .map(|tc| {
363 ChatCompletionMessageToolCalls::Function(
364 ChatCompletionMessageToolCall {
365 id: tc.id.clone(),
366 function: FunctionCall {
367 name: tc.function.name.clone(),
368 arguments: tc.function.arguments.clone(),
369 },
370 },
371 )
372 })
373 .collect();
374 b.tool_calls(oai_tcs);
375 }
376 b.build()?.into()
377 }
378 "tool" => {
379 let id = m.tool_call_id.clone().unwrap_or_default();
380 ChatCompletionRequestToolMessageArgs::default()
381 .tool_call_id(id)
382 .content(m.content.clone().unwrap_or_default())
383 .build()?
384 .into()
385 }
386 other => anyhow::bail!("unknown message role '{other}'"),
387 };
388 out.push(msg);
389 }
390 Ok(out)
391}
392
393fn build_openai_tools(tools: &[ToolDef]) -> Result<Vec<ChatCompletionTools>> {
394 tools
395 .iter()
396 .map(|t| -> Result<ChatCompletionTools> {
397 let func: FunctionObject = FunctionObjectArgs::default()
398 .name(&t.function.name)
399 .description(&t.function.description)
400 .parameters(t.function.parameters.clone())
401 .build()?;
402 Ok(ChatCompletionTools::Function(ChatCompletionTool {
403 function: func,
404 }))
405 })
406 .collect()
407}