Skip to main content

robonix_codegen/codegen/
proto_gen.rs

1// SPDX-License-Identifier: MulanPSL-2.0
2// Protobuf (.proto) code generator — ROS IDL -> proto3
3//
4// Emits **one `{package}.proto` per ROS package** that has content:
5//   - Every `.msg` -> a `message` (always).
6//   - `*_Request` / `*_Response` from `.srv` **only** when `--contracts` is passed and that srv is listed in a contract `[io.srv].srv`.
7// Does **not** emit per-package `service FooService { rpc ... }`; gRPC facades are **only** in `robonix_contracts.proto`.
8
9use anyhow::{Context, Result};
10use std::collections::BTreeSet;
11use std::fmt::Write as FmtWrite;
12use std::fs;
13use std::path::Path;
14
15use super::msg_parser::{MsgField, MsgResolver, MsgSpec, MsgTypeRef, RosPrimitive, SrvSpec};
16
17fn proto_primitive_type(p: &str) -> &'static str {
18    // Single source of truth for primitive → proto3 mapping is
19    // RosPrimitive::proto_type. Panic on a primitive the parser
20    // wouldn't have accepted — that's a codegen bug, not a runtime
21    // condition we want to silently swallow as `bytes` (which used to
22    // be the catch-all and produced wrong wire formats).
23    let prim = RosPrimitive::parse(p).unwrap_or_else(|| {
24        panic!(
25            "[robonix-codegen] proto: primitive '{p}' is not in RosPrimitive::parse — \
26             every primitive accepted by the parser must have a proto mapping; \
27             extend RosPrimitive if this is a new ROS primitive."
28        )
29    });
30    prim.proto_type()
31}
32
33fn proto_field_type(field: &MsgField, current_package: &str) -> String {
34    let base = match &field.type_ref {
35        MsgTypeRef::Primitive(p) => {
36            // `uint8[]` (unsigned 8-bit, unbounded) is the canonical
37            // raw-byte buffer and becomes proto `bytes`. `uint8[N]`
38            // (fixed-size) does NOT collapse to `bytes` because proto3
39            // bytes has no length constraint, so a fixed-size array
40            // would lose its size invariant on the wire — keep it as
41            // `repeated uint32` and let the consumer enforce length.
42            // `byte[]` is signed 8-bit and stays `repeated int32`.
43            let prim = RosPrimitive::parse(p).unwrap_or_else(|| {
44                panic!("[robonix-codegen] proto: primitive '{p}' is not in RosPrimitive::parse")
45            });
46            if field.is_array && field.array_size.is_none() && prim.is_blob_element() {
47                return "bytes".to_string();
48            }
49            proto_primitive_type(p).to_string()
50        }
51        MsgTypeRef::Named { package, name } => {
52            if package == current_package {
53                name.clone()
54            } else {
55                format!("{}.{}", proto_package_name(package), name)
56            }
57        }
58    };
59    if field.is_array {
60        format!("repeated {}", base)
61    } else {
62        base
63    }
64}
65
66/// ROS package name (`prm_base`, `sensor_msgs`) → protobuf package (`robonix.prm_base`, …).
67pub fn proto_package_name(ros_package: &str) -> String {
68    format!("robonix.{}", ros_package)
69}
70
71fn emit_message(out: &mut String, spec: &MsgSpec) {
72    let _ = writeln!(out, "message {} {{", spec.name);
73    for (i, field) in spec.fields.iter().enumerate() {
74        let proto_type = proto_field_type(field, &spec.package);
75        let _ = writeln!(out, "  {} {} = {};", proto_type, field.name, i + 1);
76    }
77    let _ = writeln!(out, "}}");
78}
79
80fn emit_srv_messages(out: &mut String, srv: &SrvSpec) {
81    emit_message(out, &srv.request);
82    let _ = writeln!(out);
83    emit_message(out, &srv.response);
84}
85
86fn import_named_type(imports: &mut BTreeSet<String>, current_package: &str, tr: &MsgTypeRef) {
87    if let MsgTypeRef::Named { package, .. } = tr
88        && package != current_package
89    {
90        imports.insert(package.clone());
91    }
92}
93
94fn collect_imports(
95    specs: &[&MsgSpec],
96    srvs: &[&SrvSpec],
97    current_package: &str,
98) -> BTreeSet<String> {
99    let mut imports = BTreeSet::new();
100    for spec in specs {
101        for field in &spec.fields {
102            import_named_type(&mut imports, current_package, &field.type_ref);
103        }
104    }
105    for srv in srvs {
106        for field in srv.request.fields.iter().chain(srv.response.fields.iter()) {
107            import_named_type(&mut imports, current_package, &field.type_ref);
108        }
109    }
110    imports
111}
112
113pub fn generate(
114    resolver: &MsgResolver,
115    out_dir: &Path,
116    contract_srvs: Option<&BTreeSet<(String, String)>>,
117    verbose: bool,
118) -> Result<()> {
119    fs::create_dir_all(out_dir)?;
120
121    let mut all_packages = BTreeSet::new();
122    for spec in resolver.ordered_specs() {
123        all_packages.insert(spec.package.clone());
124    }
125    if let Some(set) = contract_srvs {
126        for (pkg, _) in set.iter() {
127            all_packages.insert(pkg.clone());
128        }
129    }
130
131    for package in &all_packages {
132        let specs: Vec<_> = resolver
133            .ordered_specs()
134            .into_iter()
135            .filter(|s| &s.package == package)
136            .collect();
137        let srvs: Vec<_> = match contract_srvs {
138            Some(set) => resolver
139                .ordered_srvs()
140                .into_iter()
141                .filter(|s| {
142                    &s.package == package && set.contains(&(s.package.clone(), s.name.clone()))
143                })
144                .collect(),
145            None => Vec::new(),
146        };
147
148        if specs.is_empty() && srvs.is_empty() {
149            continue;
150        }
151
152        let mut out = String::new();
153        let _ = writeln!(out, "// @generated by robonix-codegen --lang proto");
154        let _ = writeln!(out, "// source: ROS IDL package '{}'", package);
155        let _ = writeln!(out, "syntax = \"proto3\";");
156        let _ = writeln!(out);
157        let _ = writeln!(out, "package {};", proto_package_name(package));
158        let _ = writeln!(out);
159
160        let imports = collect_imports(&specs, &srvs, package);
161        for imp in &imports {
162            let _ = writeln!(out, "import \"{}.proto\";", imp);
163        }
164        if !imports.is_empty() {
165            let _ = writeln!(out);
166        }
167
168        for spec in &specs {
169            emit_message(&mut out, spec);
170            let _ = writeln!(out);
171        }
172
173        for srv in &srvs {
174            emit_srv_messages(&mut out, srv);
175            let _ = writeln!(out);
176        }
177
178        let filename = format!("{}.proto", package);
179        let filepath = out_dir.join(&filename);
180        fs::write(&filepath, &out)
181            .with_context(|| format!("failed to write proto file to '{}'", filepath.display()))?;
182        if verbose {
183            eprintln!(
184                "[robonix-codegen] generated proto for '{}' ({} msgs, {} contract srvs) -> {}",
185                package,
186                specs.len(),
187                srvs.len(),
188                filepath.display()
189            );
190        }
191    }
192
193    Ok(())
194}