Skip to main content

robonix_atlas/
client.rs

1// SPDX-License-Identifier: MulanPSL-2.0
2// Author: wheatfox <wheatfox17@icloud.com>
3//
4// Atlas client-side helpers shared by every Robonix component that talks
5// to Atlas (pilot, executor, cli, system services, …).
6//
7// All helpers return `anyhow::Error` wrapping the underlying
8// `tonic::Status` so callers can attach context with `.with_context(...)`.
9
10use crate::pb;
11use anyhow::{Context, Result};
12use std::time::Duration;
13use tonic::transport::{Channel, Endpoint};
14
15/// Wrapped `pb::atlas_client::AtlasClient` with helpers.
16///
17/// Cheap to clone — the inner generated client wraps a `Channel`, which
18/// is itself just a handle to a connection pool. Share clones across
19/// tasks rather than wrapping in a `Mutex`.
20#[derive(Clone)]
21pub struct AtlasClient {
22    inner: pb::atlas_client::AtlasClient<Channel>,
23}
24
25impl AtlasClient {
26    /// Connect once. Accepts bare `host:port` or full `http://host:port`.
27    pub async fn connect(endpoint: impl AsRef<str>) -> Result<Self> {
28        let normalized = normalize_grpc_endpoint(endpoint.as_ref());
29        let channel = Endpoint::new(normalized.clone())
30            .with_context(|| format!("invalid Atlas endpoint '{}'", normalized))?
31            .connect()
32            .await
33            .with_context(|| format!("connect to Atlas at '{}'", normalized))?;
34        Ok(Self {
35            inner: pb::atlas_client::AtlasClient::new(channel),
36        })
37    }
38
39    /// `connect`, retrying up to `attempts` times with `delay` between tries.
40    pub async fn connect_with_retry(
41        endpoint: impl AsRef<str>,
42        attempts: u32,
43        delay: Duration,
44    ) -> Result<Self> {
45        let endpoint = endpoint.as_ref();
46        let mut last_err: Option<anyhow::Error> = None;
47        for i in 0..attempts.max(1) {
48            match Self::connect(endpoint).await {
49                Ok(c) => return Ok(c),
50                Err(e) => {
51                    log::debug!(
52                        "[atlas-client] connect attempt {}/{} failed: {e:#}",
53                        i + 1,
54                        attempts
55                    );
56                    last_err = Some(e);
57                    if i + 1 < attempts {
58                        tokio::time::sleep(delay).await;
59                    }
60                }
61            }
62        }
63        Err(last_err.unwrap_or_else(|| anyhow::anyhow!("connect_with_retry: 0 attempts")))
64    }
65
66    pub fn inner(&self) -> pb::atlas_client::AtlasClient<Channel> {
67        self.inner.clone()
68    }
69
70    // ── Registration (one RPC per kind, shared request/response type) ──────
71
72    fn build_register_req(
73        id: &str,
74        namespace: &str,
75        capability_md_path: &str,
76    ) -> pb::RegisterRequest {
77        pb::RegisterRequest {
78            id: id.to_string(),
79            namespace: namespace.to_string(),
80            capability_md_path: capability_md_path.to_string(),
81        }
82    }
83
84    /// Register a Primitive. Returns the (possibly Atlas-assigned) id.
85    pub async fn register_primitive(
86        &mut self,
87        id: &str,
88        namespace: &str,
89        capability_md_path: &str,
90    ) -> Result<String> {
91        let resp = self
92            .inner
93            .register_primitive(Self::build_register_req(id, namespace, capability_md_path))
94            .await
95            .with_context(|| format!("RegisterPrimitive '{id}'"))?;
96        Ok(resp.into_inner().id)
97    }
98
99    /// Register a Service.
100    pub async fn register_service(
101        &mut self,
102        id: &str,
103        namespace: &str,
104        capability_md_path: &str,
105    ) -> Result<String> {
106        let resp = self
107            .inner
108            .register_service(Self::build_register_req(id, namespace, capability_md_path))
109            .await
110            .with_context(|| format!("RegisterService '{id}'"))?;
111        Ok(resp.into_inner().id)
112    }
113
114    /// Register a Skill.
115    pub async fn register_skill(
116        &mut self,
117        id: &str,
118        namespace: &str,
119        capability_md_path: &str,
120    ) -> Result<String> {
121        let resp = self
122            .inner
123            .register_skill(Self::build_register_req(id, namespace, capability_md_path))
124            .await
125            .with_context(|| format!("RegisterSkill '{id}'"))?;
126        Ok(resp.into_inner().id)
127    }
128
129    /// Unregister any registered entity. Returns `true` if a record was
130    /// removed, `false` if the id was unknown (idempotent).
131    pub async fn unregister(&mut self, id: &str) -> Result<bool> {
132        let resp = self
133            .inner
134            .unregister(pb::UnregisterRequest { id: id.to_string() })
135            .await
136            .with_context(|| format!("Unregister '{id}'"))?;
137        Ok(resp.into_inner().was_present)
138    }
139
140    // ── Liveness + lifecycle ───────────────────────────────────────────────
141
142    pub async fn heartbeat(&mut self, id: &str) -> Result<()> {
143        self.inner
144            .heartbeat(pb::HeartbeatRequest { id: id.to_string() })
145            .await
146            .with_context(|| format!("Heartbeat '{id}'"))?;
147        Ok(())
148    }
149
150    /// Push a lifecycle state transition. `detail` is a free-form
151    /// human-readable note (e.g. "missing /opt/models/...") that
152    /// `rbnx caps` surfaces verbatim; pass empty when there's nothing.
153    pub async fn set_lifecycle_state(
154        &mut self,
155        id: &str,
156        new_state: pb::LifecycleState,
157        detail: &str,
158    ) -> Result<()> {
159        self.inner
160            .set_lifecycle_state(pb::SetLifecycleStateRequest {
161                id: id.to_string(),
162                state: new_state as i32,
163                detail: detail.to_string(),
164            })
165            .await
166            .with_context(|| format!("SetLifecycleState '{id}' -> {new_state:?}"))?;
167        Ok(())
168    }
169
170    // ── Capability binding ─────────────────────────────────────────────────
171
172    /// Declare one Capability (transport-bound endpoint) on an entity.
173    /// Returns the *authoritative* endpoint (may differ from the request
174    /// when Atlas rewrote to disambiguate). `description` is the optional
175    /// natural-language description for this Capability.
176    pub async fn declare_capability(
177        &mut self,
178        provider_id: &str,
179        contract_id: &str,
180        transport: pb::Transport,
181        endpoint: &str,
182        params: pb::TransportParams,
183    ) -> Result<String> {
184        self.declare_capability_with_description(
185            provider_id,
186            contract_id,
187            transport,
188            endpoint,
189            params,
190            "",
191        )
192        .await
193    }
194
195    /// Same as `declare_capability` but with the instance-specific
196    /// description string (see DeclareCapabilityRequest.description).
197    pub async fn declare_capability_with_description(
198        &mut self,
199        provider_id: &str,
200        contract_id: &str,
201        transport: pb::Transport,
202        endpoint: &str,
203        params: pb::TransportParams,
204        description: &str,
205    ) -> Result<String> {
206        let resp = self
207            .inner
208            .declare_capability(pb::DeclareCapabilityRequest {
209                provider_id: provider_id.to_string(),
210                contract_id: contract_id.to_string(),
211                transport: transport as i32,
212                endpoint: endpoint.to_string(),
213                params: Some(params),
214                description: description.to_string(),
215            })
216            .await
217            .with_context(|| {
218                format!(
219                    "DeclareCapability provider='{provider_id}' contract='{contract_id}' \
220                     transport={transport:?}"
221                )
222            })?;
223        Ok(resp.into_inner().endpoint)
224    }
225
226    // ── Discovery ──────────────────────────────────────────────────────────
227
228    /// Generic Query. `kind == Kind::Unspecified` = no kind filter (all
229    /// kinds returned; each `CapabilityProvider.kind` carries its kind).
230    /// Empty strings / `Transport::Unspecified` = no filter on that field.
231    pub async fn query(
232        &mut self,
233        kind: pb::Kind,
234        id: &str,
235        contract_id: &str,
236        namespace_prefix: &str,
237        transport: pb::Transport,
238    ) -> Result<Vec<pb::CapabilityProvider>> {
239        let resp = self
240            .inner
241            .query(pb::QueryRequest {
242                kind: kind as i32,
243                id: id.to_string(),
244                contract_id: contract_id.to_string(),
245                namespace_prefix: namespace_prefix.to_string(),
246                transport: transport as i32,
247            })
248            .await
249            .with_context(|| format!("Query kind={kind:?}"))?;
250        Ok(resp.into_inner().providers)
251    }
252
253    /// Convenience — find Primitives (kind filter applied).
254    pub async fn query_primitives(
255        &mut self,
256        id: &str,
257        contract_id: &str,
258        namespace_prefix: &str,
259        transport: pb::Transport,
260    ) -> Result<Vec<pb::CapabilityProvider>> {
261        self.query(
262            pb::Kind::Primitive,
263            id,
264            contract_id,
265            namespace_prefix,
266            transport,
267        )
268        .await
269    }
270
271    /// Convenience — find Services.
272    pub async fn query_services(
273        &mut self,
274        id: &str,
275        contract_id: &str,
276        namespace_prefix: &str,
277        transport: pb::Transport,
278    ) -> Result<Vec<pb::CapabilityProvider>> {
279        self.query(
280            pb::Kind::Service,
281            id,
282            contract_id,
283            namespace_prefix,
284            transport,
285        )
286        .await
287    }
288
289    /// Convenience — find Skills.
290    pub async fn query_skills(
291        &mut self,
292        id: &str,
293        contract_id: &str,
294        namespace_prefix: &str,
295        transport: pb::Transport,
296    ) -> Result<Vec<pb::CapabilityProvider>> {
297        self.query(
298            pb::Kind::Skill,
299            id,
300            contract_id,
301            namespace_prefix,
302            transport,
303        )
304        .await
305    }
306
307    /// Consumer-facing discovery: flat list of Capabilities across all
308    /// kinds. Walks `Query(kind=Unspecified)` and flattens each
309    /// CapabilityProvider's nested capabilities. Each returned
310    /// `Capability` already carries `provider_id` + `provider_kind`.
311    pub async fn flatten_capabilities(
312        &mut self,
313        contract_id: &str,
314        namespace_prefix: &str,
315        transport: pb::Transport,
316    ) -> Result<Vec<pb::Capability>> {
317        let providers = self
318            .query(
319                pb::Kind::Unspecified,
320                "",
321                contract_id,
322                namespace_prefix,
323                transport,
324            )
325            .await?;
326        Ok(providers
327            .into_iter()
328            .flat_map(|p| p.capabilities.into_iter())
329            .collect())
330    }
331
332    /// Back-compat alias for the legacy 3-arg signature returning a list
333    /// of CapabilityProviders. Equivalent to
334    /// `query(Kind::Unspecified, id, contract_id, "", transport)`.
335    pub async fn query_capabilities(
336        &mut self,
337        id: &str,
338        contract_id: &str,
339        transport: pb::Transport,
340    ) -> Result<Vec<pb::CapabilityProvider>> {
341        self.query(pb::Kind::Unspecified, id, contract_id, "", transport)
342            .await
343    }
344
345    // ── Channels ───────────────────────────────────────────────────────────
346
347    /// Open a channel to one (provider, contract, transport). Atlas only
348    /// providers the edge — the consumer dials the returned endpoint
349    /// itself using whatever transport-appropriate mechanism (tonic for
350    /// grpc, rclrs for ros2, fastmcp for mcp, …).
351    /// Returns `(channel_id, endpoint, params)`.
352    pub async fn connect_capability(
353        &mut self,
354        consumer_id: &str,
355        provider_id: &str,
356        contract_id: &str,
357        transport: pb::Transport,
358    ) -> Result<(String, String, pb::TransportParams)> {
359        let resp = self
360            .inner
361            .connect_capability(pb::ConnectCapabilityRequest {
362                consumer_id: consumer_id.to_string(),
363                provider_id: provider_id.to_string(),
364                contract_id: contract_id.to_string(),
365                transport: transport as i32,
366            })
367            .await
368            .with_context(|| {
369                format!(
370                    "ConnectCapability consumer='{consumer_id}' provider='{provider_id}' \
371                     contract='{contract_id}' transport={transport:?}"
372                )
373            })?;
374        let r = resp.into_inner();
375        Ok((r.channel_id, r.endpoint, r.params.unwrap_or_default()))
376    }
377
378    /// Release a previously-opened channel. Idempotent: returns `false`
379    /// when the channel_id was unknown.
380    pub async fn disconnect_capability(&mut self, channel_id: &str) -> Result<bool> {
381        let resp = self
382            .inner
383            .disconnect_capability(pb::DisconnectCapabilityRequest {
384                channel_id: channel_id.to_string(),
385            })
386            .await
387            .with_context(|| format!("DisconnectCapability '{channel_id}'"))?;
388        Ok(resp.into_inner().was_open)
389    }
390
391    // ── Contract registry ──────────────────────────────────────────────────
392
393    /// Look up one contract by id.
394    pub async fn query_contract(
395        &mut self,
396        contract_id: &str,
397    ) -> Result<Option<pb::ContractDescriptor>> {
398        let resp = self
399            .inner
400            .query_contract(pb::QueryContractRequest {
401                contract_id: contract_id.to_string(),
402            })
403            .await
404            .with_context(|| format!("QueryContract '{contract_id}'"))?;
405        let inner = resp.into_inner();
406        Ok(if inner.found { inner.contract } else { None })
407    }
408
409    pub async fn list_contracts(
410        &mut self,
411        namespace_prefix: &str,
412    ) -> Result<Vec<pb::ContractDescriptor>> {
413        let resp = self
414            .inner
415            .list_contracts(pb::ListContractsRequest {
416                namespace_prefix: namespace_prefix.to_string(),
417            })
418            .await
419            .with_context(|| format!("ListContracts prefix='{namespace_prefix}'"))?;
420        Ok(resp.into_inner().contracts)
421    }
422}
423
424/// gRPC-only convenience: pick the first Capability matching `contract_id`
425/// over gRPC, call `ConnectCapability` to register the edge, dial the
426/// returned host:port, and hand back a tonic Channel + the channel_id
427/// (so the caller can DisconnectCapability on shutdown).
428///
429/// Returns `(channel_id, provider_id, Channel)`.
430pub async fn connect_to_capability(
431    atlas: &mut AtlasClient,
432    consumer_id: &str,
433    contract_id: &str,
434) -> Result<(String, String, Channel)> {
435    let rows = atlas
436        .flatten_capabilities(contract_id, "", pb::Transport::Grpc)
437        .await?;
438    if rows.is_empty() {
439        anyhow::bail!(
440            "no Capability offering contract_id='{contract_id}' over gRPC; \
441             registered entities may not have declared this Capability yet"
442        );
443    }
444    if rows.len() > 1 {
445        log::warn!(
446            "[atlas-client] {} entities offer '{contract_id}' over gRPC; \
447             picking first ('{}'). Use query_capabilities + connect_capability \
448             for deterministic selection.",
449            rows.len(),
450            rows[0].provider_id,
451        );
452    }
453    let provider_id = rows
454        .into_iter()
455        .next()
456        .expect("non-empty checked above")
457        .provider_id;
458    let (channel_id, endpoint_str, _params) = atlas
459        .connect_capability(consumer_id, &provider_id, contract_id, pb::Transport::Grpc)
460        .await?;
461    let normalized = normalize_grpc_endpoint(&endpoint_str);
462    let channel = Endpoint::new(normalized.clone())
463        .with_context(|| {
464            format!(
465                "invalid endpoint '{}' for provider '{}'",
466                normalized, provider_id
467            )
468        })?
469        .connect()
470        .await
471        .with_context(|| {
472            format!(
473                "connect to provider '{provider_id}' at '{normalized}' for contract '{contract_id}'"
474            )
475        })?;
476    Ok((channel_id, provider_id, channel))
477}
478
479// ── TransportParams constructors ───────────────────────────────────────────
480
481pub fn grpc_params(
482    proto_file: impl Into<String>,
483    service_name: impl Into<String>,
484    method: impl Into<String>,
485) -> pb::TransportParams {
486    pb::TransportParams {
487        kind: Some(pb::transport_params::Kind::Grpc(pb::GrpcParams {
488            proto_file: proto_file.into(),
489            service_name: service_name.into(),
490            method: method.into(),
491        })),
492    }
493}
494
495pub fn ros2_params(qos_profile: impl Into<String>) -> pb::TransportParams {
496    pb::TransportParams {
497        kind: Some(pb::transport_params::Kind::Ros2(pb::Ros2Params {
498            qos_profile: qos_profile.into(),
499        })),
500    }
501}
502
503/// Build `TransportParams` for an MCP tool Capability. The natural-
504/// language description now lives on `DeclareCapabilityRequest.description`,
505/// not inside `McpParams`.
506pub fn mcp_params(input_schema_json: impl Into<String>) -> pb::TransportParams {
507    pb::TransportParams {
508        kind: Some(pb::transport_params::Kind::Mcp(pb::McpParams {
509            input_schema_json: input_schema_json.into(),
510        })),
511    }
512}
513
514// ── Helpers ────────────────────────────────────────────────────────────────
515
516fn normalize_grpc_endpoint(s: &str) -> String {
517    let s = s.trim();
518    if s.starts_with("http://") || s.starts_with("https://") {
519        s.to_string()
520    } else {
521        format!("http://{s}")
522    }
523}