1use anyhow::{Context, Result};
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::Write as FmtWrite;
16use std::fs;
17use std::path::Path;
18
19use super::msg_parser::{MsgField, MsgResolver, MsgSpec, MsgTypeRef, RosPrimitive};
20
21fn parse_prim_or_panic(p: &str) -> RosPrimitive {
29 RosPrimitive::parse(p).unwrap_or_else(|| {
30 panic!(
31 "[robonix-codegen] mcp: primitive '{p}' is not in RosPrimitive::parse — \
32 every primitive accepted by the parser must have an MCP mapping."
33 )
34 })
35}
36
37fn python_field_type(p: &str, is_array: bool, is_blob: bool) -> String {
38 if is_blob {
39 return "bytes".to_string();
43 }
44 let base = parse_prim_or_panic(p).python_type();
45 if is_array {
46 format!("List[{}]", base)
47 } else {
48 base.to_string()
49 }
50}
51
52fn python_default(p: &str, is_array: bool, is_blob: bool) -> String {
53 if is_blob {
54 return "b\"\"".to_string();
55 }
56 if is_array {
57 return "field(default_factory=list)".to_string();
58 }
59 parse_prim_or_panic(p).python_default().to_string()
60}
61
62fn python_cast(p: &str) -> &'static str {
63 parse_prim_or_panic(p).python_cast()
64}
65
66fn cross_package_deps(spec: &MsgSpec) -> BTreeSet<String> {
69 let mut pkgs = BTreeSet::new();
70 for f in &spec.fields {
71 if let MsgTypeRef::Named { package, .. } = &f.type_ref
72 && package != &spec.package
73 {
74 pkgs.insert(package.clone());
75 }
76 }
77 pkgs
78}
79
80fn is_byte_blob(field: &MsgField) -> bool {
88 if !field.is_array {
89 return false;
90 }
91 match &field.type_ref {
92 MsgTypeRef::Primitive(p) => RosPrimitive::parse(p)
93 .map(|prim| prim.is_blob_element())
94 .unwrap_or(false),
95 _ => false,
96 }
97}
98
99fn emit_field_decl(out: &mut String, field: &MsgField, pkg: &str) {
100 let type_str = match &field.type_ref {
101 MsgTypeRef::Primitive(p) => {
102 let blob = is_byte_blob(field);
103 python_field_type(p, field.is_array && !blob, blob)
104 }
105 MsgTypeRef::Named { package, name } => {
106 let qualified = if package == pkg {
107 name.clone()
108 } else {
109 format!("{}_mcp.{}", package, name)
110 };
111 if field.is_array {
112 format!("List[{}]", qualified)
113 } else {
114 qualified
115 }
116 }
117 };
118
119 let default_str = match &field.type_ref {
120 MsgTypeRef::Primitive(p) => {
121 let blob = is_byte_blob(field);
122 python_default(p, field.is_array && !blob, blob)
123 }
124 MsgTypeRef::Named { package, name } => {
125 if field.is_array {
126 "field(default_factory=list)".to_string()
127 } else {
128 let qualified = if package == pkg {
129 name.clone()
130 } else {
131 format!("{}_mcp.{}", package, name)
132 };
133 format!("field(default_factory=lambda: {}())", qualified)
136 }
137 }
138 };
139
140 let _ = writeln!(out, " {}: {} = {}", field.name, type_str, default_str);
141}
142
143fn emit_to_dict_entry(out: &mut String, field: &MsgField) {
144 let blob = is_byte_blob(field);
145 match &field.type_ref {
146 MsgTypeRef::Primitive(_) if blob => {
147 let _ = writeln!(
148 out,
149 " \"{name}\": _b64enc(self.{name}),",
150 name = field.name
151 );
152 }
153 MsgTypeRef::Primitive(_) => {
154 let _ = writeln!(
155 out,
156 " \"{name}\": self.{name},",
157 name = field.name
158 );
159 }
160 MsgTypeRef::Named { .. } => {
161 if field.is_array {
162 let _ = writeln!(
163 out,
164 " \"{name}\": [_i.to_dict() for _i in self.{name}],",
165 name = field.name
166 );
167 } else {
168 let _ = writeln!(
169 out,
170 " \"{name}\": self.{name}.to_dict(),",
171 name = field.name
172 );
173 }
174 }
175 }
176}
177
178fn emit_from_dict_entry(out: &mut String, field: &MsgField, pkg: &str) {
179 let blob = is_byte_blob(field);
180 match &field.type_ref {
181 MsgTypeRef::Primitive(_) if blob => {
182 let _ = writeln!(
183 out,
184 " {name}=_b64dec(d.get(\"{name}\", \"\")),",
185 name = field.name
186 );
187 }
188 MsgTypeRef::Primitive(p) => {
189 let cast = python_cast(p);
190 let default = python_default(p, field.is_array, false);
191 if field.is_array {
192 let _ = writeln!(
193 out,
194 " {name}=[{cast}(_x) for _x in d.get(\"{name}\", [])],",
195 name = field.name,
196 cast = cast
197 );
198 } else {
199 let _ = writeln!(
200 out,
201 " {name}={cast}(d.get(\"{name}\", {default})),",
202 name = field.name,
203 cast = cast,
204 default = default
205 );
206 }
207 }
208 MsgTypeRef::Named {
209 package,
210 name: type_name,
211 } => {
212 let qualified = if package == pkg {
213 type_name.clone()
214 } else {
215 format!("{}_mcp.{}", package, type_name)
216 };
217 if field.is_array {
218 let _ = writeln!(
219 out,
220 " {name}=[{qual}.from_dict(_x) for _x in d.get(\"{name}\", [])],",
221 name = field.name,
222 qual = qualified
223 );
224 } else {
225 let _ = writeln!(
226 out,
227 " {name}={qual}.from_dict(d.get(\"{name}\", {{}})),",
228 name = field.name,
229 qual = qualified
230 );
231 }
232 }
233 }
234}
235
236fn emit_json_schema_prop(out: &mut String, field: &MsgField, pkg: &str, depth: usize) {
249 let indent = " ".repeat(depth + 3);
250 let blob = is_byte_blob(field);
251
252 let mut entries: Vec<String> = Vec::new();
253 match &field.type_ref {
254 MsgTypeRef::Primitive(p) if blob => {
255 entries.push(r#""type": "string""#.to_string());
256 entries.push(r#""contentEncoding": "base64""#.to_string());
257 if let Some(n) = field.array_size {
262 let _ = p; let encoded = n.div_ceil(3) * 4;
264 entries.push(format!(r#""minLength": {encoded}"#));
265 entries.push(format!(r#""maxLength": {encoded}"#));
266 }
267 }
268 MsgTypeRef::Primitive(p) => {
269 let prim = parse_prim_or_panic(p);
270 if field.is_array {
271 let mut item_entries: Vec<String> =
272 vec![format!(r#""type": "{}""#, prim.json_schema_type())];
273 push_integer_range(&mut item_entries, prim);
274 push_string_max_len(&mut item_entries, prim, field.string_max_len);
275 let item_json = format!("{{{}}}", item_entries.join(", "));
276 entries.push(r#""type": "array""#.to_string());
277 entries.push(format!(r#""items": {item_json}"#));
278 push_array_bounds(&mut entries, field.array_size);
279 } else {
280 entries.push(format!(r#""type": "{}""#, prim.json_schema_type()));
281 push_integer_range(&mut entries, prim);
282 push_string_max_len(&mut entries, prim, field.string_max_len);
283 }
284 }
285 MsgTypeRef::Named {
286 package,
287 name: type_name,
288 } => {
289 let qualified = if package == pkg {
290 format!("{}.json_schema()", type_name)
291 } else {
292 format!("{}_mcp.{}.json_schema()", package, type_name)
293 };
294 if field.is_array {
295 let mut frag = format!(r#"{{"type": "array", "items": {qualified}"#);
300 if let Some(n) = field.array_size {
301 frag.push_str(&format!(r#", "minItems": {n}, "maxItems": {n}"#));
302 }
303 if !field.description.is_empty() {
304 frag.push_str(&format!(
305 r#", "description": {}"#,
306 py_string_literal(&field.description)
307 ));
308 }
309 frag.push('}');
310 let _ = writeln!(
311 out,
312 "{indent}\"{name}\": {frag},",
313 indent = indent,
314 name = field.name,
315 );
316 return;
317 } else {
318 let mut frag = qualified;
320 if !field.description.is_empty() {
321 frag = format!(
322 "{{**{frag}, \"description\": {desc}}}",
323 frag = frag,
324 desc = py_string_literal(&field.description)
325 );
326 }
327 let _ = writeln!(
328 out,
329 "{indent}\"{name}\": {frag},",
330 indent = indent,
331 name = field.name,
332 );
333 return;
334 }
335 }
336 }
337 if !field.description.is_empty() {
338 entries.push(format!(
339 r#""description": {}"#,
340 py_string_literal(&field.description)
341 ));
342 }
343 let _ = writeln!(
344 out,
345 "{indent}\"{name}\": {{{body}}},",
346 indent = indent,
347 name = field.name,
348 body = entries.join(", ")
349 );
350}
351
352fn push_integer_range(entries: &mut Vec<String>, prim: RosPrimitive) {
353 if let Some((lo, hi)) = prim.integer_range() {
354 entries.push(format!(r#""minimum": {lo}"#));
355 entries.push(format!(r#""maximum": {hi}"#));
356 }
357}
358
359fn push_string_max_len(entries: &mut Vec<String>, prim: RosPrimitive, bound: Option<usize>) {
360 if matches!(prim, RosPrimitive::String | RosPrimitive::Wstring)
361 && let Some(n) = bound
362 {
363 entries.push(format!(r#""maxLength": {n}"#));
364 }
365}
366
367fn push_array_bounds(entries: &mut Vec<String>, array_size: Option<usize>) {
368 if let Some(n) = array_size {
369 entries.push(format!(r#""minItems": {n}"#));
370 entries.push(format!(r#""maxItems": {n}"#));
371 }
372}
373
374fn py_string_literal(s: &str) -> String {
378 let mut out = String::with_capacity(s.len() + 2);
379 out.push('"');
380 for c in s.chars() {
381 match c {
382 '\\' => out.push_str(r"\\"),
383 '"' => out.push_str("\\\""),
384 '\n' => out.push_str("\\n"),
385 '\r' => out.push_str("\\r"),
386 '\t' => out.push_str("\\t"),
387 c if (c as u32) < 0x20 => {
388 let _ = write!(&mut out, "\\u{:04x}", c as u32);
389 }
390 c => out.push(c),
391 }
392 }
393 out.push('"');
394 out
395}
396
397fn emit_class(out: &mut String, spec: &MsgSpec) {
400 let _ = writeln!(out, "@dataclass");
401 let _ = writeln!(out, "class {}:", spec.name);
402 let _ = writeln!(
403 out,
404 " \"\"\"ROS IDL: {}/msg/{}\"\"\"",
405 spec.package, spec.name
406 );
407
408 if spec.fields.is_empty() {
409 let _ = writeln!(out, " pass");
415 let _ = writeln!(out);
416 let _ = writeln!(out, " def to_dict(self) -> dict:");
417 let _ = writeln!(out, " return {{}}");
418 let _ = writeln!(out);
419 let _ = writeln!(out, " @classmethod");
420 let _ = writeln!(
421 out,
422 " def from_dict(cls, _d: dict) -> \"{}\":",
423 spec.name
424 );
425 let _ = writeln!(out, " return cls()");
426 let _ = writeln!(out);
427 let _ = writeln!(out, " @classmethod");
428 let _ = writeln!(out, " def json_schema(cls) -> dict:");
429 let _ = writeln!(
432 out,
433 " return {{\"type\": \"object\", \"properties\": {{}}, \
434 \"additionalProperties\": False}}"
435 );
436 return;
437 }
438
439 for f in &spec.fields {
441 emit_field_decl(out, f, &spec.package);
442 }
443
444 let _ = writeln!(out);
446 let _ = writeln!(out, " def to_dict(self) -> dict:");
447 let _ = writeln!(out, " return {{");
448 for f in &spec.fields {
449 emit_to_dict_entry(out, f);
450 }
451 let _ = writeln!(out, " }}");
452
453 let _ = writeln!(out);
455 let _ = writeln!(out, " @classmethod");
456 let _ = writeln!(out, " def from_dict(cls, d: dict) -> \"{}\":", spec.name);
457 let _ = writeln!(out, " return cls(");
458 for f in &spec.fields {
459 emit_from_dict_entry(out, f, &spec.package);
460 }
461 let _ = writeln!(out, " )");
462
463 let _ = writeln!(out);
465 let _ = writeln!(out, " @classmethod");
466 let _ = writeln!(out, " def json_schema(cls) -> dict:");
467 let _ = writeln!(out, " return {{");
468 let _ = writeln!(out, " \"type\": \"object\",");
469 let _ = writeln!(out, " \"properties\": {{");
470 for f in &spec.fields {
471 emit_json_schema_prop(out, f, &spec.package, 0);
472 }
473 let _ = writeln!(out, " }},");
474 let _ = write!(out, " \"required\": [");
479 let mut first = true;
480 for f in &spec.fields {
481 if !first {
482 let _ = write!(out, ", ");
483 }
484 first = false;
485 let _ = write!(out, "\"{}\"", f.name);
486 }
487 let _ = writeln!(out, "],");
488 let _ = writeln!(out, " \"additionalProperties\": False,");
491 let _ = writeln!(out, " }}");
492}
493
494pub fn generate(resolver: &MsgResolver, out_dir: &Path, verbose: bool) -> Result<()> {
497 fs::create_dir_all(out_dir)?;
498
499 let mut all_packages: BTreeSet<String> = BTreeSet::new();
504 for spec in resolver.ordered_specs() {
505 all_packages.insert(spec.package.clone());
506 }
507 for srv in resolver.ordered_srvs() {
508 all_packages.insert(srv.package.clone());
509 }
510
511 for package in &all_packages {
512 let mut raw_specs: Vec<_> = resolver
513 .ordered_specs()
514 .into_iter()
515 .filter(|s| &s.package == package)
516 .collect();
517
518 for srv in resolver.ordered_srvs() {
523 if &srv.package != package {
524 continue;
525 }
526 raw_specs.push(&srv.request);
527 raw_specs.push(&srv.response);
528 }
529
530 if raw_specs.is_empty() {
531 continue;
532 }
533
534 let specs = topo_sort_same_package(&raw_specs);
542
543 let mut ext_pkgs: BTreeSet<String> = BTreeSet::new();
545 for spec in &specs {
546 for dep in cross_package_deps(spec) {
547 if &dep != package {
548 ext_pkgs.insert(dep);
549 }
550 }
551 }
552
553 let mut out = String::new();
554 let _ = writeln!(
555 out,
556 "# @generated by robonix-codegen --lang mcp — DO NOT EDIT"
557 );
558 let _ = writeln!(out, "# source: ROS IDL package '{}'", package);
559 let _ = writeln!(
560 out,
561 "# Python dataclasses for MCP tool use (to_dict / from_dict / json_schema)."
562 );
563 let _ = writeln!(
564 out,
565 "# Re-generate: cargo run -p robonix-codegen -- --lang mcp -I <lib> -o <out>"
566 );
567 let _ = writeln!(out);
568 let _ = writeln!(out, "from __future__ import annotations");
569 let _ = writeln!(out, "import base64");
570 let _ = writeln!(out, "from dataclasses import dataclass, field");
571 let _ = writeln!(out, "from typing import List");
572 for dep in &ext_pkgs {
573 let _ = writeln!(out, "import {}_mcp", dep);
574 }
575 let _ = writeln!(out);
576 let _ = writeln!(out, "def _b64enc(b: bytes) -> str:");
577 let _ = writeln!(out, " return base64.b64encode(b).decode(\"ascii\")");
578 let _ = writeln!(out);
579 let _ = writeln!(out, "def _b64dec(s: str) -> bytes:");
580 let _ = writeln!(
584 out,
585 " return base64.b64decode(s, validate=True) if s else b\"\""
586 );
587 let _ = writeln!(out);
588 let _ = writeln!(out);
589
590 for spec in &specs {
591 emit_class(&mut out, spec);
592 let _ = writeln!(out);
593 let _ = writeln!(out);
594 }
595
596 let filename = format!("{}_mcp.py", package);
597 let filepath = out_dir.join(&filename);
598 fs::write(&filepath, &out)
599 .with_context(|| format!("failed to write '{}'", filepath.display()))?;
600 if verbose {
601 eprintln!(
602 "[robonix-codegen] mcp: '{}' ({} msgs) -> {}",
603 package,
604 specs.len(),
605 filepath.display()
606 );
607 }
608 }
609
610 let init = out_dir.join("__init__.py");
612 if !init.exists() {
613 fs::write(&init, "# @generated by robonix-codegen --lang mcp\n")?;
614 }
615
616 Ok(())
617}
618
619fn topo_sort_same_package<'a>(specs: &[&'a MsgSpec]) -> Vec<&'a MsgSpec> {
628 if specs.is_empty() {
629 return Vec::new();
630 }
631 let pkg = specs[0].package.clone();
632 let by_name: BTreeMap<String, &MsgSpec> = specs.iter().map(|s| (s.name.clone(), *s)).collect();
633
634 enum Mark {
635 Temp,
636 Done,
637 }
638 let mut marks: BTreeMap<String, Mark> = BTreeMap::new();
639 let mut order: Vec<&MsgSpec> = Vec::new();
640
641 fn visit<'a>(
642 name: &str,
643 pkg: &str,
644 by_name: &BTreeMap<String, &'a MsgSpec>,
645 marks: &mut BTreeMap<String, Mark>,
646 order: &mut Vec<&'a MsgSpec>,
647 ) {
648 match marks.get(name) {
649 Some(Mark::Done) => return,
650 Some(Mark::Temp) => return, None => {}
652 }
653 let Some(spec) = by_name.get(name) else {
654 return;
655 };
656 marks.insert(name.to_string(), Mark::Temp);
657 for f in &spec.fields {
658 if let MsgTypeRef::Named {
659 package,
660 name: dep_name,
661 } = &f.type_ref
662 && package == pkg
663 {
664 visit(dep_name, pkg, by_name, marks, order);
665 }
666 }
667 marks.insert(name.to_string(), Mark::Done);
668 order.push(*spec);
669 }
670
671 for spec in specs {
672 visit(&spec.name, &pkg, &by_name, &mut marks, &mut order);
673 }
674 order
675}