1use crate::pb;
11use anyhow::{Context, Result};
12use std::time::Duration;
13use tonic::transport::{Channel, Endpoint};
14
15#[derive(Clone)]
21pub struct AtlasClient {
22 inner: pb::atlas_client::AtlasClient<Channel>,
23}
24
25impl AtlasClient {
26 pub async fn connect(endpoint: impl AsRef<str>) -> Result<Self> {
28 let normalized = normalize_grpc_endpoint(endpoint.as_ref());
29 let channel = Endpoint::new(normalized.clone())
30 .with_context(|| format!("invalid Atlas endpoint '{}'", normalized))?
31 .connect()
32 .await
33 .with_context(|| format!("connect to Atlas at '{}'", normalized))?;
34 Ok(Self {
35 inner: pb::atlas_client::AtlasClient::new(channel),
36 })
37 }
38
39 pub async fn connect_with_retry(
41 endpoint: impl AsRef<str>,
42 attempts: u32,
43 delay: Duration,
44 ) -> Result<Self> {
45 let endpoint = endpoint.as_ref();
46 let mut last_err: Option<anyhow::Error> = None;
47 for i in 0..attempts.max(1) {
48 match Self::connect(endpoint).await {
49 Ok(c) => return Ok(c),
50 Err(e) => {
51 log::debug!(
52 "[atlas-client] connect attempt {}/{} failed: {e:#}",
53 i + 1,
54 attempts
55 );
56 last_err = Some(e);
57 if i + 1 < attempts {
58 tokio::time::sleep(delay).await;
59 }
60 }
61 }
62 }
63 Err(last_err.unwrap_or_else(|| anyhow::anyhow!("connect_with_retry: 0 attempts")))
64 }
65
66 pub fn inner(&self) -> pb::atlas_client::AtlasClient<Channel> {
67 self.inner.clone()
68 }
69
70 fn build_register_req(
73 id: &str,
74 namespace: &str,
75 capability_md_path: &str,
76 ) -> pb::RegisterRequest {
77 pb::RegisterRequest {
78 id: id.to_string(),
79 namespace: namespace.to_string(),
80 capability_md_path: capability_md_path.to_string(),
81 }
82 }
83
84 pub async fn register_primitive(
86 &mut self,
87 id: &str,
88 namespace: &str,
89 capability_md_path: &str,
90 ) -> Result<String> {
91 let resp = self
92 .inner
93 .register_primitive(Self::build_register_req(id, namespace, capability_md_path))
94 .await
95 .with_context(|| format!("RegisterPrimitive '{id}'"))?;
96 Ok(resp.into_inner().id)
97 }
98
99 pub async fn register_service(
101 &mut self,
102 id: &str,
103 namespace: &str,
104 capability_md_path: &str,
105 ) -> Result<String> {
106 let resp = self
107 .inner
108 .register_service(Self::build_register_req(id, namespace, capability_md_path))
109 .await
110 .with_context(|| format!("RegisterService '{id}'"))?;
111 Ok(resp.into_inner().id)
112 }
113
114 pub async fn register_skill(
116 &mut self,
117 id: &str,
118 namespace: &str,
119 capability_md_path: &str,
120 ) -> Result<String> {
121 let resp = self
122 .inner
123 .register_skill(Self::build_register_req(id, namespace, capability_md_path))
124 .await
125 .with_context(|| format!("RegisterSkill '{id}'"))?;
126 Ok(resp.into_inner().id)
127 }
128
129 pub async fn unregister(&mut self, id: &str) -> Result<bool> {
132 let resp = self
133 .inner
134 .unregister(pb::UnregisterRequest { id: id.to_string() })
135 .await
136 .with_context(|| format!("Unregister '{id}'"))?;
137 Ok(resp.into_inner().was_present)
138 }
139
140 pub async fn heartbeat(&mut self, id: &str) -> Result<()> {
143 self.inner
144 .heartbeat(pb::HeartbeatRequest { id: id.to_string() })
145 .await
146 .with_context(|| format!("Heartbeat '{id}'"))?;
147 Ok(())
148 }
149
150 pub async fn set_lifecycle_state(
154 &mut self,
155 id: &str,
156 new_state: pb::LifecycleState,
157 detail: &str,
158 ) -> Result<()> {
159 self.inner
160 .set_lifecycle_state(pb::SetLifecycleStateRequest {
161 id: id.to_string(),
162 state: new_state as i32,
163 detail: detail.to_string(),
164 })
165 .await
166 .with_context(|| format!("SetLifecycleState '{id}' -> {new_state:?}"))?;
167 Ok(())
168 }
169
170 pub async fn declare_capability(
177 &mut self,
178 provider_id: &str,
179 contract_id: &str,
180 transport: pb::Transport,
181 endpoint: &str,
182 params: pb::TransportParams,
183 ) -> Result<String> {
184 self.declare_capability_with_description(
185 provider_id,
186 contract_id,
187 transport,
188 endpoint,
189 params,
190 "",
191 )
192 .await
193 }
194
195 pub async fn declare_capability_with_description(
198 &mut self,
199 provider_id: &str,
200 contract_id: &str,
201 transport: pb::Transport,
202 endpoint: &str,
203 params: pb::TransportParams,
204 description: &str,
205 ) -> Result<String> {
206 let resp = self
207 .inner
208 .declare_capability(pb::DeclareCapabilityRequest {
209 provider_id: provider_id.to_string(),
210 contract_id: contract_id.to_string(),
211 transport: transport as i32,
212 endpoint: endpoint.to_string(),
213 params: Some(params),
214 description: description.to_string(),
215 })
216 .await
217 .with_context(|| {
218 format!(
219 "DeclareCapability provider='{provider_id}' contract='{contract_id}' \
220 transport={transport:?}"
221 )
222 })?;
223 Ok(resp.into_inner().endpoint)
224 }
225
226 pub async fn query(
232 &mut self,
233 kind: pb::Kind,
234 id: &str,
235 contract_id: &str,
236 namespace_prefix: &str,
237 transport: pb::Transport,
238 ) -> Result<Vec<pb::CapabilityProvider>> {
239 let resp = self
240 .inner
241 .query(pb::QueryRequest {
242 kind: kind as i32,
243 id: id.to_string(),
244 contract_id: contract_id.to_string(),
245 namespace_prefix: namespace_prefix.to_string(),
246 transport: transport as i32,
247 })
248 .await
249 .with_context(|| format!("Query kind={kind:?}"))?;
250 Ok(resp.into_inner().providers)
251 }
252
253 pub async fn query_primitives(
255 &mut self,
256 id: &str,
257 contract_id: &str,
258 namespace_prefix: &str,
259 transport: pb::Transport,
260 ) -> Result<Vec<pb::CapabilityProvider>> {
261 self.query(
262 pb::Kind::Primitive,
263 id,
264 contract_id,
265 namespace_prefix,
266 transport,
267 )
268 .await
269 }
270
271 pub async fn query_services(
273 &mut self,
274 id: &str,
275 contract_id: &str,
276 namespace_prefix: &str,
277 transport: pb::Transport,
278 ) -> Result<Vec<pb::CapabilityProvider>> {
279 self.query(
280 pb::Kind::Service,
281 id,
282 contract_id,
283 namespace_prefix,
284 transport,
285 )
286 .await
287 }
288
289 pub async fn query_skills(
291 &mut self,
292 id: &str,
293 contract_id: &str,
294 namespace_prefix: &str,
295 transport: pb::Transport,
296 ) -> Result<Vec<pb::CapabilityProvider>> {
297 self.query(
298 pb::Kind::Skill,
299 id,
300 contract_id,
301 namespace_prefix,
302 transport,
303 )
304 .await
305 }
306
307 pub async fn flatten_capabilities(
312 &mut self,
313 contract_id: &str,
314 namespace_prefix: &str,
315 transport: pb::Transport,
316 ) -> Result<Vec<pb::Capability>> {
317 let providers = self
318 .query(
319 pb::Kind::Unspecified,
320 "",
321 contract_id,
322 namespace_prefix,
323 transport,
324 )
325 .await?;
326 Ok(providers
327 .into_iter()
328 .flat_map(|p| p.capabilities.into_iter())
329 .collect())
330 }
331
332 pub async fn query_capabilities(
336 &mut self,
337 id: &str,
338 contract_id: &str,
339 transport: pb::Transport,
340 ) -> Result<Vec<pb::CapabilityProvider>> {
341 self.query(pb::Kind::Unspecified, id, contract_id, "", transport)
342 .await
343 }
344
345 pub async fn connect_capability(
353 &mut self,
354 consumer_id: &str,
355 provider_id: &str,
356 contract_id: &str,
357 transport: pb::Transport,
358 ) -> Result<(String, String, pb::TransportParams)> {
359 let resp = self
360 .inner
361 .connect_capability(pb::ConnectCapabilityRequest {
362 consumer_id: consumer_id.to_string(),
363 provider_id: provider_id.to_string(),
364 contract_id: contract_id.to_string(),
365 transport: transport as i32,
366 })
367 .await
368 .with_context(|| {
369 format!(
370 "ConnectCapability consumer='{consumer_id}' provider='{provider_id}' \
371 contract='{contract_id}' transport={transport:?}"
372 )
373 })?;
374 let r = resp.into_inner();
375 Ok((r.channel_id, r.endpoint, r.params.unwrap_or_default()))
376 }
377
378 pub async fn disconnect_capability(&mut self, channel_id: &str) -> Result<bool> {
381 let resp = self
382 .inner
383 .disconnect_capability(pb::DisconnectCapabilityRequest {
384 channel_id: channel_id.to_string(),
385 })
386 .await
387 .with_context(|| format!("DisconnectCapability '{channel_id}'"))?;
388 Ok(resp.into_inner().was_open)
389 }
390
391 pub async fn query_contract(
395 &mut self,
396 contract_id: &str,
397 ) -> Result<Option<pb::ContractDescriptor>> {
398 let resp = self
399 .inner
400 .query_contract(pb::QueryContractRequest {
401 contract_id: contract_id.to_string(),
402 })
403 .await
404 .with_context(|| format!("QueryContract '{contract_id}'"))?;
405 let inner = resp.into_inner();
406 Ok(if inner.found { inner.contract } else { None })
407 }
408
409 pub async fn list_contracts(
410 &mut self,
411 namespace_prefix: &str,
412 ) -> Result<Vec<pb::ContractDescriptor>> {
413 let resp = self
414 .inner
415 .list_contracts(pb::ListContractsRequest {
416 namespace_prefix: namespace_prefix.to_string(),
417 })
418 .await
419 .with_context(|| format!("ListContracts prefix='{namespace_prefix}'"))?;
420 Ok(resp.into_inner().contracts)
421 }
422}
423
424pub async fn connect_to_capability(
431 atlas: &mut AtlasClient,
432 consumer_id: &str,
433 contract_id: &str,
434) -> Result<(String, String, Channel)> {
435 let rows = atlas
436 .flatten_capabilities(contract_id, "", pb::Transport::Grpc)
437 .await?;
438 if rows.is_empty() {
439 anyhow::bail!(
440 "no Capability offering contract_id='{contract_id}' over gRPC; \
441 registered entities may not have declared this Capability yet"
442 );
443 }
444 if rows.len() > 1 {
445 log::warn!(
446 "[atlas-client] {} entities offer '{contract_id}' over gRPC; \
447 picking first ('{}'). Use query_capabilities + connect_capability \
448 for deterministic selection.",
449 rows.len(),
450 rows[0].provider_id,
451 );
452 }
453 let provider_id = rows
454 .into_iter()
455 .next()
456 .expect("non-empty checked above")
457 .provider_id;
458 let (channel_id, endpoint_str, _params) = atlas
459 .connect_capability(consumer_id, &provider_id, contract_id, pb::Transport::Grpc)
460 .await?;
461 let normalized = normalize_grpc_endpoint(&endpoint_str);
462 let channel = Endpoint::new(normalized.clone())
463 .with_context(|| {
464 format!(
465 "invalid endpoint '{}' for provider '{}'",
466 normalized, provider_id
467 )
468 })?
469 .connect()
470 .await
471 .with_context(|| {
472 format!(
473 "connect to provider '{provider_id}' at '{normalized}' for contract '{contract_id}'"
474 )
475 })?;
476 Ok((channel_id, provider_id, channel))
477}
478
479pub fn grpc_params(
482 proto_file: impl Into<String>,
483 service_name: impl Into<String>,
484 method: impl Into<String>,
485) -> pb::TransportParams {
486 pb::TransportParams {
487 kind: Some(pb::transport_params::Kind::Grpc(pb::GrpcParams {
488 proto_file: proto_file.into(),
489 service_name: service_name.into(),
490 method: method.into(),
491 })),
492 }
493}
494
495pub fn ros2_params(qos_profile: impl Into<String>) -> pb::TransportParams {
496 pb::TransportParams {
497 kind: Some(pb::transport_params::Kind::Ros2(pb::Ros2Params {
498 qos_profile: qos_profile.into(),
499 })),
500 }
501}
502
503pub fn mcp_params(input_schema_json: impl Into<String>) -> pb::TransportParams {
507 pb::TransportParams {
508 kind: Some(pb::transport_params::Kind::Mcp(pb::McpParams {
509 input_schema_json: input_schema_json.into(),
510 })),
511 }
512}
513
514fn normalize_grpc_endpoint(s: &str) -> String {
517 let s = s.trim();
518 if s.starts_with("http://") || s.starts_with("https://") {
519 s.to_string()
520 } else {
521 format!("http://{s}")
522 }
523}