Skip to main content

robonix_codegen/codegen/
mcp_python_gen.rs

1// SPDX-License-Identifier: MulanPSL-2.0
2// mcp_python_gen.rs — Python @dataclass generator for MCP tool use (--lang mcp)
3//
4// Generates one `{package}_mcp.py` per ROS package, containing:
5//   - @dataclass for each .msg type
6//   - to_dict() / from_dict() for JSON serialization over MCP wire
7//   - json_schema() classmethod for MCP tool registration (input_schema / output_schema)
8//
9// Nested types appear as nested objects in both Python and JSON Schema
10// (no flattening). `uint8[]` binary blobs are represented as `bytes` and
11// base64-encoded in the JSON form.
12
13use anyhow::{Context, Result};
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt::Write as FmtWrite;
16use std::fs;
17use std::path::Path;
18
19use super::msg_parser::{MsgField, MsgResolver, MsgSpec, MsgTypeRef, RosPrimitive};
20
21// ── Primitive type mappings ──────────────────────────────────────────────────
22//
23// Every primitive flows through `RosPrimitive`, the codegen's single
24// source of truth. Unknown primitives panic at codegen time rather
25// than silently turning into `int` (the previous catch-all behaviour
26// would coerce e.g. a typo'd `uint128` into `int` on the wire).
27
28fn parse_prim_or_panic(p: &str) -> RosPrimitive {
29    RosPrimitive::parse(p).unwrap_or_else(|| {
30        panic!(
31            "[robonix-codegen] mcp: primitive '{p}' is not in RosPrimitive::parse — \
32             every primitive accepted by the parser must have an MCP mapping."
33        )
34    })
35}
36
37fn python_field_type(p: &str, is_array: bool, is_blob: bool) -> String {
38    if is_blob {
39        // Both `uint8[]` and `uint8[N]` are bytes; the fixed-size
40        // case carries its length in JSON Schema (handled later) but
41        // the in-memory type is the same.
42        return "bytes".to_string();
43    }
44    let base = parse_prim_or_panic(p).python_type();
45    if is_array {
46        format!("List[{}]", base)
47    } else {
48        base.to_string()
49    }
50}
51
52fn python_default(p: &str, is_array: bool, is_blob: bool) -> String {
53    if is_blob {
54        return "b\"\"".to_string();
55    }
56    if is_array {
57        return "field(default_factory=list)".to_string();
58    }
59    parse_prim_or_panic(p).python_default().to_string()
60}
61
62fn python_cast(p: &str) -> &'static str {
63    parse_prim_or_panic(p).python_cast()
64}
65
66// ── Cross-package dependency collection ──────────────────────────────────────
67
68fn cross_package_deps(spec: &MsgSpec) -> BTreeSet<String> {
69    let mut pkgs = BTreeSet::new();
70    for f in &spec.fields {
71        if let MsgTypeRef::Named { package, .. } = &f.type_ref
72            && package != &spec.package
73        {
74            pkgs.insert(package.clone());
75        }
76    }
77    pkgs
78}
79
80// ── Field code emitters ───────────────────────────────────────────────────────
81
82/// Returns true when this field's wire shape is "raw byte buffer".
83/// Only `uint8[]` and `uint8[N]` qualify — `byte` is signed 8-bit and
84/// must NOT be conflated with `uint8` (unsigned) on the wire. Both
85/// the unbounded and fixed-size cases hit this path; the JSON Schema
86/// emitter encodes the length distinction.
87fn is_byte_blob(field: &MsgField) -> bool {
88    if !field.is_array {
89        return false;
90    }
91    match &field.type_ref {
92        MsgTypeRef::Primitive(p) => RosPrimitive::parse(p)
93            .map(|prim| prim.is_blob_element())
94            .unwrap_or(false),
95        _ => false,
96    }
97}
98
99fn emit_field_decl(out: &mut String, field: &MsgField, pkg: &str) {
100    let type_str = match &field.type_ref {
101        MsgTypeRef::Primitive(p) => {
102            let blob = is_byte_blob(field);
103            python_field_type(p, field.is_array && !blob, blob)
104        }
105        MsgTypeRef::Named { package, name } => {
106            let qualified = if package == pkg {
107                name.clone()
108            } else {
109                format!("{}_mcp.{}", package, name)
110            };
111            if field.is_array {
112                format!("List[{}]", qualified)
113            } else {
114                qualified
115            }
116        }
117    };
118
119    let default_str = match &field.type_ref {
120        MsgTypeRef::Primitive(p) => {
121            let blob = is_byte_blob(field);
122            python_default(p, field.is_array && !blob, blob)
123        }
124        MsgTypeRef::Named { package, name } => {
125            if field.is_array {
126                "field(default_factory=list)".to_string()
127            } else {
128                let qualified = if package == pkg {
129                    name.clone()
130                } else {
131                    format!("{}_mcp.{}", package, name)
132                };
133                // Defer constructor to runtime: same-package msgs are emitted in
134                // arbitrary order; `default_factory=Foo` evaluates `Foo` at class body time.
135                format!("field(default_factory=lambda: {}())", qualified)
136            }
137        }
138    };
139
140    let _ = writeln!(out, "    {}: {} = {}", field.name, type_str, default_str);
141}
142
143fn emit_to_dict_entry(out: &mut String, field: &MsgField) {
144    let blob = is_byte_blob(field);
145    match &field.type_ref {
146        MsgTypeRef::Primitive(_) if blob => {
147            let _ = writeln!(
148                out,
149                "            \"{name}\": _b64enc(self.{name}),",
150                name = field.name
151            );
152        }
153        MsgTypeRef::Primitive(_) => {
154            let _ = writeln!(
155                out,
156                "            \"{name}\": self.{name},",
157                name = field.name
158            );
159        }
160        MsgTypeRef::Named { .. } => {
161            if field.is_array {
162                let _ = writeln!(
163                    out,
164                    "            \"{name}\": [_i.to_dict() for _i in self.{name}],",
165                    name = field.name
166                );
167            } else {
168                let _ = writeln!(
169                    out,
170                    "            \"{name}\": self.{name}.to_dict(),",
171                    name = field.name
172                );
173            }
174        }
175    }
176}
177
178fn emit_from_dict_entry(out: &mut String, field: &MsgField, pkg: &str) {
179    let blob = is_byte_blob(field);
180    match &field.type_ref {
181        MsgTypeRef::Primitive(_) if blob => {
182            let _ = writeln!(
183                out,
184                "            {name}=_b64dec(d.get(\"{name}\", \"\")),",
185                name = field.name
186            );
187        }
188        MsgTypeRef::Primitive(p) => {
189            let cast = python_cast(p);
190            let default = python_default(p, field.is_array, false);
191            if field.is_array {
192                let _ = writeln!(
193                    out,
194                    "            {name}=[{cast}(_x) for _x in d.get(\"{name}\", [])],",
195                    name = field.name,
196                    cast = cast
197                );
198            } else {
199                let _ = writeln!(
200                    out,
201                    "            {name}={cast}(d.get(\"{name}\", {default})),",
202                    name = field.name,
203                    cast = cast,
204                    default = default
205                );
206            }
207        }
208        MsgTypeRef::Named {
209            package,
210            name: type_name,
211        } => {
212            let qualified = if package == pkg {
213                type_name.clone()
214            } else {
215                format!("{}_mcp.{}", package, type_name)
216            };
217            if field.is_array {
218                let _ = writeln!(
219                    out,
220                    "            {name}=[{qual}.from_dict(_x) for _x in d.get(\"{name}\", [])],",
221                    name = field.name,
222                    qual = qualified
223                );
224            } else {
225                let _ = writeln!(
226                    out,
227                    "            {name}={qual}.from_dict(d.get(\"{name}\", {{}})),",
228                    name = field.name,
229                    qual = qualified
230                );
231            }
232        }
233    }
234}
235
236/// Build the JSON Schema fragment for one .msg field. The fragment is
237/// emitted as Python source — for primitive scalars it's a literal
238/// dict, for nested types it's a function call to the dependency's
239/// `json_schema()`. The embedded constraints come from the parser:
240///
241///   - integer primitives → `minimum` / `maximum` from
242///     `RosPrimitive::integer_range`
243///   - fixed-size arrays → `minItems` = `maxItems` = N
244///   - `string<=N` / `wstring<=N` → `maxLength` N
245///   - `uint8[N]` blob → base64 string with `minLength`/`maxLength`
246///     reflecting the encoded length
247///   - trailing `# comment` → `description`
248fn emit_json_schema_prop(out: &mut String, field: &MsgField, pkg: &str, depth: usize) {
249    let indent = "    ".repeat(depth + 3);
250    let blob = is_byte_blob(field);
251
252    let mut entries: Vec<String> = Vec::new();
253    match &field.type_ref {
254        MsgTypeRef::Primitive(p) if blob => {
255            entries.push(r#""type": "string""#.to_string());
256            entries.push(r#""contentEncoding": "base64""#.to_string());
257            // Fixed-size `uint8[N]` carries N bytes → base64-encoded
258            // length is exactly ceil(N/3)*4 with up to two `=` pads.
259            // We emit min == max so the LLM can't produce a shorter
260            // or longer payload.
261            if let Some(n) = field.array_size {
262                let _ = p; // primitive identity already validated by is_byte_blob
263                let encoded = n.div_ceil(3) * 4;
264                entries.push(format!(r#""minLength": {encoded}"#));
265                entries.push(format!(r#""maxLength": {encoded}"#));
266            }
267        }
268        MsgTypeRef::Primitive(p) => {
269            let prim = parse_prim_or_panic(p);
270            if field.is_array {
271                let mut item_entries: Vec<String> =
272                    vec![format!(r#""type": "{}""#, prim.json_schema_type())];
273                push_integer_range(&mut item_entries, prim);
274                push_string_max_len(&mut item_entries, prim, field.string_max_len);
275                let item_json = format!("{{{}}}", item_entries.join(", "));
276                entries.push(r#""type": "array""#.to_string());
277                entries.push(format!(r#""items": {item_json}"#));
278                push_array_bounds(&mut entries, field.array_size);
279            } else {
280                entries.push(format!(r#""type": "{}""#, prim.json_schema_type()));
281                push_integer_range(&mut entries, prim);
282                push_string_max_len(&mut entries, prim, field.string_max_len);
283            }
284        }
285        MsgTypeRef::Named {
286            package,
287            name: type_name,
288        } => {
289            let qualified = if package == pkg {
290                format!("{}.json_schema()", type_name)
291            } else {
292                format!("{}_mcp.{}.json_schema()", package, type_name)
293            };
294            if field.is_array {
295                // Note: we can't push items + size + description in
296                // one literal dict because `items` is a runtime call.
297                // We assemble the dict using a Python helper that
298                // shallow-merges; keeps generated code uniform.
299                let mut frag = format!(r#"{{"type": "array", "items": {qualified}"#);
300                if let Some(n) = field.array_size {
301                    frag.push_str(&format!(r#", "minItems": {n}, "maxItems": {n}"#));
302                }
303                if !field.description.is_empty() {
304                    frag.push_str(&format!(
305                        r#", "description": {}"#,
306                        py_string_literal(&field.description)
307                    ));
308                }
309                frag.push('}');
310                let _ = writeln!(
311                    out,
312                    "{indent}\"{name}\": {frag},",
313                    indent = indent,
314                    name = field.name,
315                );
316                return;
317            } else {
318                // Same shape: nested object schema is a runtime call.
319                let mut frag = qualified;
320                if !field.description.is_empty() {
321                    frag = format!(
322                        "{{**{frag}, \"description\": {desc}}}",
323                        frag = frag,
324                        desc = py_string_literal(&field.description)
325                    );
326                }
327                let _ = writeln!(
328                    out,
329                    "{indent}\"{name}\": {frag},",
330                    indent = indent,
331                    name = field.name,
332                );
333                return;
334            }
335        }
336    }
337    if !field.description.is_empty() {
338        entries.push(format!(
339            r#""description": {}"#,
340            py_string_literal(&field.description)
341        ));
342    }
343    let _ = writeln!(
344        out,
345        "{indent}\"{name}\": {{{body}}},",
346        indent = indent,
347        name = field.name,
348        body = entries.join(", ")
349    );
350}
351
352fn push_integer_range(entries: &mut Vec<String>, prim: RosPrimitive) {
353    if let Some((lo, hi)) = prim.integer_range() {
354        entries.push(format!(r#""minimum": {lo}"#));
355        entries.push(format!(r#""maximum": {hi}"#));
356    }
357}
358
359fn push_string_max_len(entries: &mut Vec<String>, prim: RosPrimitive, bound: Option<usize>) {
360    if matches!(prim, RosPrimitive::String | RosPrimitive::Wstring)
361        && let Some(n) = bound
362    {
363        entries.push(format!(r#""maxLength": {n}"#));
364    }
365}
366
367fn push_array_bounds(entries: &mut Vec<String>, array_size: Option<usize>) {
368    if let Some(n) = array_size {
369        entries.push(format!(r#""minItems": {n}"#));
370        entries.push(format!(r#""maxItems": {n}"#));
371    }
372}
373
374/// Emit a Python string literal with the value of `s`, escaping
375/// backslashes / double-quotes / control chars so the generated file
376/// remains valid Python regardless of what's in the .msg comment.
377fn py_string_literal(s: &str) -> String {
378    let mut out = String::with_capacity(s.len() + 2);
379    out.push('"');
380    for c in s.chars() {
381        match c {
382            '\\' => out.push_str(r"\\"),
383            '"' => out.push_str("\\\""),
384            '\n' => out.push_str("\\n"),
385            '\r' => out.push_str("\\r"),
386            '\t' => out.push_str("\\t"),
387            c if (c as u32) < 0x20 => {
388                let _ = write!(&mut out, "\\u{:04x}", c as u32);
389            }
390            c => out.push(c),
391        }
392    }
393    out.push('"');
394    out
395}
396
397// ── Class emitter ─────────────────────────────────────────────────────────────
398
399fn emit_class(out: &mut String, spec: &MsgSpec) {
400    let _ = writeln!(out, "@dataclass");
401    let _ = writeln!(out, "class {}:", spec.name);
402    let _ = writeln!(
403        out,
404        "    \"\"\"ROS IDL: {}/msg/{}\"\"\"",
405        spec.package, spec.name
406    );
407
408    if spec.fields.is_empty() {
409        // Empty message → empty dataclass body (`pass`). Previously
410        // we synthesized a `_empty: bool` field which leaked into
411        // `__eq__` / `__repr__` and shadowed any user-declared field
412        // named `_empty`. `pass` is the right "no fields" body for a
413        // @dataclass.
414        let _ = writeln!(out, "    pass");
415        let _ = writeln!(out);
416        let _ = writeln!(out, "    def to_dict(self) -> dict:");
417        let _ = writeln!(out, "        return {{}}");
418        let _ = writeln!(out);
419        let _ = writeln!(out, "    @classmethod");
420        let _ = writeln!(
421            out,
422            "    def from_dict(cls, _d: dict) -> \"{}\":",
423            spec.name
424        );
425        let _ = writeln!(out, "        return cls()");
426        let _ = writeln!(out);
427        let _ = writeln!(out, "    @classmethod");
428        let _ = writeln!(out, "    def json_schema(cls) -> dict:");
429        // additionalProperties=false even on empty messages so an
430        // LLM can't sneak extra keys into a no-arg call.
431        let _ = writeln!(
432            out,
433            "        return {{\"type\": \"object\", \"properties\": {{}}, \
434             \"additionalProperties\": False}}"
435        );
436        return;
437    }
438
439    // ── Field declarations ──────────────────────────────────────────────────
440    for f in &spec.fields {
441        emit_field_decl(out, f, &spec.package);
442    }
443
444    // ── to_dict() ───────────────────────────────────────────────────────────
445    let _ = writeln!(out);
446    let _ = writeln!(out, "    def to_dict(self) -> dict:");
447    let _ = writeln!(out, "        return {{");
448    for f in &spec.fields {
449        emit_to_dict_entry(out, f);
450    }
451    let _ = writeln!(out, "        }}");
452
453    // ── from_dict() ─────────────────────────────────────────────────────────
454    let _ = writeln!(out);
455    let _ = writeln!(out, "    @classmethod");
456    let _ = writeln!(out, "    def from_dict(cls, d: dict) -> \"{}\":", spec.name);
457    let _ = writeln!(out, "        return cls(");
458    for f in &spec.fields {
459        emit_from_dict_entry(out, f, &spec.package);
460    }
461    let _ = writeln!(out, "        )");
462
463    // ── json_schema() ────────────────────────────────────────────────────────
464    let _ = writeln!(out);
465    let _ = writeln!(out, "    @classmethod");
466    let _ = writeln!(out, "    def json_schema(cls) -> dict:");
467    let _ = writeln!(out, "        return {{");
468    let _ = writeln!(out, "            \"type\": \"object\",");
469    let _ = writeln!(out, "            \"properties\": {{");
470    for f in &spec.fields {
471        emit_json_schema_prop(out, f, &spec.package, 0);
472    }
473    let _ = writeln!(out, "            }},");
474    // Every field is required: ROS messages have no concept of
475    // "optional" (every declared field has a default and must round-
476    // trip), so an LLM should always provide every field. Listing
477    // them as `required` makes that explicit to schema validators.
478    let _ = write!(out, "            \"required\": [");
479    let mut first = true;
480    for f in &spec.fields {
481        if !first {
482            let _ = write!(out, ", ");
483        }
484        first = false;
485        let _ = write!(out, "\"{}\"", f.name);
486    }
487    let _ = writeln!(out, "],");
488    // Reject extra properties — an LLM passing a typo'd field name
489    // is a bug, not silently-ignored noise.
490    let _ = writeln!(out, "            \"additionalProperties\": False,");
491    let _ = writeln!(out, "        }}");
492}
493
494// ── Public entry point ────────────────────────────────────────────────────────
495
496pub fn generate(resolver: &MsgResolver, out_dir: &Path, verbose: bool) -> Result<()> {
497    fs::create_dir_all(out_dir)?;
498
499    // Cover every package that has at least one msg OR srv. srv files
500    // emit two classes each: <Name>_Request and <Name>_Response, derived
501    // directly from the SrvSpec's request / response MsgSpec halves.
502    // Packages with srv-only (no msg) still need a *_mcp.py generated.
503    let mut all_packages: BTreeSet<String> = BTreeSet::new();
504    for spec in resolver.ordered_specs() {
505        all_packages.insert(spec.package.clone());
506    }
507    for srv in resolver.ordered_srvs() {
508        all_packages.insert(srv.package.clone());
509    }
510
511    for package in &all_packages {
512        let mut raw_specs: Vec<_> = resolver
513            .ordered_specs()
514            .into_iter()
515            .filter(|s| &s.package == package)
516            .collect();
517
518        // Append the synthetic Request/Response specs for every srv in
519        // this package. They look like ordinary msg dataclasses to the
520        // emitter — `request.name` / `response.name` are already
521        // "<Srv>_Request" / "<Srv>_Response" courtesy of parse_srv_file.
522        for srv in resolver.ordered_srvs() {
523            if &srv.package != package {
524                continue;
525            }
526            raw_specs.push(&srv.request);
527            raw_specs.push(&srv.response);
528        }
529
530        if raw_specs.is_empty() {
531            continue;
532        }
533
534        // Topo-sort same-package types so that a class whose default
535        // factory references a sibling (`field(default_factory=lambda:
536        // Vector3())`) is emitted AFTER its dependency. `from
537        // __future__ import annotations` only defers type-annotation
538        // resolution — the lambda body runs at call time and would
539        // hit `NameError: Vector3 is not defined` if Vector3's class
540        // statement hadn't executed yet.
541        let specs = topo_sort_same_package(&raw_specs);
542
543        // Collect cross-package imports needed by any spec in this package
544        let mut ext_pkgs: BTreeSet<String> = BTreeSet::new();
545        for spec in &specs {
546            for dep in cross_package_deps(spec) {
547                if &dep != package {
548                    ext_pkgs.insert(dep);
549                }
550            }
551        }
552
553        let mut out = String::new();
554        let _ = writeln!(
555            out,
556            "# @generated by robonix-codegen --lang mcp — DO NOT EDIT"
557        );
558        let _ = writeln!(out, "# source: ROS IDL package '{}'", package);
559        let _ = writeln!(
560            out,
561            "# Python dataclasses for MCP tool use (to_dict / from_dict / json_schema)."
562        );
563        let _ = writeln!(
564            out,
565            "# Re-generate: cargo run -p robonix-codegen -- --lang mcp -I <lib> -o <out>"
566        );
567        let _ = writeln!(out);
568        let _ = writeln!(out, "from __future__ import annotations");
569        let _ = writeln!(out, "import base64");
570        let _ = writeln!(out, "from dataclasses import dataclass, field");
571        let _ = writeln!(out, "from typing import List");
572        for dep in &ext_pkgs {
573            let _ = writeln!(out, "import {}_mcp", dep);
574        }
575        let _ = writeln!(out);
576        let _ = writeln!(out, "def _b64enc(b: bytes) -> str:");
577        let _ = writeln!(out, "    return base64.b64encode(b).decode(\"ascii\")");
578        let _ = writeln!(out);
579        let _ = writeln!(out, "def _b64dec(s: str) -> bytes:");
580        // validate=True makes b64decode raise on garbage characters
581        // instead of silently returning truncated/junk bytes. A
582        // typo'd payload over MCP should fail loudly.
583        let _ = writeln!(
584            out,
585            "    return base64.b64decode(s, validate=True) if s else b\"\""
586        );
587        let _ = writeln!(out);
588        let _ = writeln!(out);
589
590        for spec in &specs {
591            emit_class(&mut out, spec);
592            let _ = writeln!(out);
593            let _ = writeln!(out);
594        }
595
596        let filename = format!("{}_mcp.py", package);
597        let filepath = out_dir.join(&filename);
598        fs::write(&filepath, &out)
599            .with_context(|| format!("failed to write '{}'", filepath.display()))?;
600        if verbose {
601            eprintln!(
602                "[robonix-codegen] mcp: '{}' ({} msgs) -> {}",
603                package,
604                specs.len(),
605                filepath.display()
606            );
607        }
608    }
609
610    // Package __init__.py
611    let init = out_dir.join("__init__.py");
612    if !init.exists() {
613        fs::write(&init, "# @generated by robonix-codegen --lang mcp\n")?;
614    }
615
616    Ok(())
617}
618
619/// Order the specs of one package so dependencies come before
620/// dependents. Cross-package references don't matter (they go through
621/// `import other_mcp` and resolve at runtime via attribute lookup);
622/// only same-package siblings need ordering.
623///
624/// On unresolvable cycle (A→B→A within one package — extremely rare in
625/// real ROS IDL), falls back to the input order and lets Python raise
626/// at instantiation time. We don't try to break cycles silently.
627fn topo_sort_same_package<'a>(specs: &[&'a MsgSpec]) -> Vec<&'a MsgSpec> {
628    if specs.is_empty() {
629        return Vec::new();
630    }
631    let pkg = specs[0].package.clone();
632    let by_name: BTreeMap<String, &MsgSpec> = specs.iter().map(|s| (s.name.clone(), *s)).collect();
633
634    enum Mark {
635        Temp,
636        Done,
637    }
638    let mut marks: BTreeMap<String, Mark> = BTreeMap::new();
639    let mut order: Vec<&MsgSpec> = Vec::new();
640
641    fn visit<'a>(
642        name: &str,
643        pkg: &str,
644        by_name: &BTreeMap<String, &'a MsgSpec>,
645        marks: &mut BTreeMap<String, Mark>,
646        order: &mut Vec<&'a MsgSpec>,
647    ) {
648        match marks.get(name) {
649            Some(Mark::Done) => return,
650            Some(Mark::Temp) => return, // cycle — give up on ordering
651            None => {}
652        }
653        let Some(spec) = by_name.get(name) else {
654            return;
655        };
656        marks.insert(name.to_string(), Mark::Temp);
657        for f in &spec.fields {
658            if let MsgTypeRef::Named {
659                package,
660                name: dep_name,
661            } = &f.type_ref
662                && package == pkg
663            {
664                visit(dep_name, pkg, by_name, marks, order);
665            }
666        }
667        marks.insert(name.to_string(), Mark::Done);
668        order.push(*spec);
669    }
670
671    for spec in specs {
672        visit(&spec.name, &pkg, &by_name, &mut marks, &mut order);
673    }
674    order
675}