Source code for robonix_api.tool

# SPDX-License-Identifier: MulanPSL-2.0
"""Contract-enforced MCP tools on top of FastMCP + robonix-codegen ROS dataclasses.

The decorator stashes the codegen IO classes + contract id on the original
handler so Capability.run() can derive atlas DeclareCapability metadata
(JSON Schema, description) without re-reflecting annotations.
"""
from __future__ import annotations

import inspect
import logging
from typing import TYPE_CHECKING, Any, Callable, get_type_hints

if TYPE_CHECKING:
    from mcp.server.fastmcp import FastMCP  # noqa: F401

__all__ = ["mcp_contract"]
log = logging.getLogger(__name__)


def io_types_from_handler(
    user_fn: Callable[..., Any], *, contract_id: str,
) -> tuple[type, type | None]:
    globalns = getattr(user_fn, "__globals__", None) or {}
    hints = get_type_hints(user_fn, globalns=globalns, localns=globalns)
    sig = inspect.signature(user_fn)
    params = [
        p for p in sig.parameters.values()
        if p.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
    ]
    if len(params) != 1:
        raise TypeError(
            f"[mcp_contract:{contract_id}] {user_fn.__name__} must have exactly one parameter "
            f"(input message), got {sig}"
        )
    p0 = params[0]
    ann_in = hints.get(p0.name)
    if ann_in is None or not isinstance(ann_in, type):
        raise TypeError(
            f"[mcp_contract:{contract_id}] parameter {p0.name!r} must be annotated with the "
            "codegen input message type"
        )
    ann_out = hints.get("return")
    if ann_out is None:
        raise TypeError(
            f"[mcp_contract:{contract_id}] {user_fn.__name__} must annotate return type"
        )
    if ann_out is type(None):
        return ann_in, None
    if not isinstance(ann_out, type):
        raise TypeError(
            f"[mcp_contract:{contract_id}] return must be a codegen message class"
        )
    return ann_in, ann_out


def json_prop_to_py_type(prop: dict[str, Any]) -> type:
    t = prop.get("type", "string")
    return {"string": str, "integer": int, "number": float, "boolean": bool,
            "array": list, "object": dict}.get(t, Any)  # type: ignore[return-value]


def wire_output_for_json(value: Any) -> Any:
    if value is None or isinstance(value, (dict, list, str, int, float, bool)):
        return value
    td = getattr(value, "to_dict", None)
    return td() if callable(td) else value


def make_shim(user_fn: Callable[..., Any], input_cls: type, out_cls: type | None) -> Callable[..., Any]:
    schema = input_cls.json_schema()
    props: dict[str, Any] = schema.get("properties") or {}
    if not props:
        async def shim() -> Any:
            r = user_fn(input_cls.from_dict({}))
            if inspect.isawaitable(r):
                r = await r
            return wire_output_for_json(r)
    else:
        keys = list(props.keys())
        param_decls: list[str] = []
        ns: dict[str, Any] = {
            "input_cls": input_cls, "user_fn": user_fn,
            "wire_output_for_json": wire_output_for_json, "inspect": inspect,
            "str": str, "int": int, "float": float, "bool": bool,
            "list": list, "dict": dict, "Any": Any,
        }
        for k in keys:
            prop = props[k]
            if prop.get("type") == "object" and prop.get("properties"):
                py_t = dict  # nested ROS sub-messages (Pose, Header, …)
            else:
                py_t = json_prop_to_py_type(prop)
            ann = "Any" if py_t is Any else getattr(py_t, "__name__", "Any")
            param_decls.append(f"{k}: {ann}")
        dict_literal = ", ".join(f"{k!r}: {k}" for k in keys)
        code = (
            f"async def shim({', '.join(param_decls)}) -> Any:\n"
            f"    msg = input_cls.from_dict({{{dict_literal}}})\n"
            f"    r = user_fn(msg)\n"
            f"    if inspect.isawaitable(r):\n"
            f"        r = await r\n"
            f"    return wire_output_for_json(r)\n"
        )
        exec(code, ns, ns)  # noqa: S102 — schema-driven, fixed template
        shim = ns["shim"]
    shim.__name__ = user_fn.__name__
    shim.__doc__ = user_fn.__doc__ or ""
    shim.__module__ = user_fn.__module__
    shim.__qualname__ = user_fn.__qualname__
    if out_cls is not None:
        ann = dict(getattr(shim, "__annotations__", {}))
        # Actual return is JSON wire (dict / primitives) after wire_output_for_json — not
        # output_cls, or Pydantic will validate dict fields against codegen types (Image.data, …).
        ann["return"] = Any
        shim.__annotations__ = ann
    return shim


[docs] def mcp_contract( mcp: "FastMCP", *, contract_id: str, structured_output: bool | None = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Register an MCP tool bound to a contract. Use directly with your own FastMCP app, OR use the provider's `@<provider>.mcp(...)` decorator (Primitive / Service / Skill) for a one-stop registration that also auto-declares to atlas + manages uvicorn. The MCP-server-side tool name is **always** the contract_id's leaf segment — executor's dispatch derives the same value from contract_id (see executor/dispatch/mcp.rs `mcp_tool_name`). They must match; we enforce that by deriving from a single source. Stashes the codegen IO classes + contract id on the original handler: fn._robonix_input_cls fn._robonix_output_cls fn._robonix_contract_id The provider framework picks these up via attribute reflection during run(). """ def decorator(user_fn: Callable[..., Any]) -> Callable[..., Any]: input_cls, output_cls = io_types_from_handler(user_fn, contract_id=contract_id) shim = make_shim(user_fn, input_cls, output_cls) tool_name = contract_id.rsplit("/", 1)[-1] mcp.add_tool( shim, name=tool_name, description=(user_fn.__doc__ or "").strip(), structured_output=structured_output, ) user_fn._robonix_input_cls = input_cls # type: ignore[attr-defined] user_fn._robonix_output_cls = output_cls # type: ignore[attr-defined] user_fn._robonix_contract_id = contract_id # type: ignore[attr-defined] return user_fn return decorator