diff --git a/pyproject.toml b/pyproject.toml index 54f92dd67..5cb82af57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" dependencies = [ "uipath>=2.10.79, <2.12.0", - "uipath-core>=0.5.20, <0.6.0", + "uipath-core==0.5.23.dev1017616946", "uipath-platform>=0.1.71, <0.2.0", "uipath-runtime>=0.11.0, <0.12.0", "langgraph>=1.1.8, <2.0.0", @@ -169,3 +169,6 @@ name = "testpypi" url = "https://test.pypi.org/simple/" publish-url = "https://test.pypi.org/legacy/" explicit = true + +[tool.uv.sources] +uipath-core = { index = "testpypi" } diff --git a/src/uipath_langchain/governance/__init__.py b/src/uipath_langchain/governance/__init__.py new file mode 100644 index 000000000..c7bfd29bc --- /dev/null +++ b/src/uipath_langchain/governance/__init__.py @@ -0,0 +1,16 @@ +"""Governance integration for ``uipath-langchain``. + +Exposes :class:`GovernanceCallbackHandler` — a LangChain callback +handler that calls an :class:`~uipath.core.adapters.EvaluatorProtocol` +on the model and tool lifecycle. Wired into a run by passing an +``evaluator`` to :class:`UiPathLangGraphRuntimeFactory`; the factory +builds the handler and hands it to the runtime through the existing +``callbacks`` channel. + +Importing this module has no side effects: no adapter is registered, +no global state is mutated. +""" + +from .callbacks import GovernanceCallbackHandler + +__all__ = ["GovernanceCallbackHandler"] diff --git a/src/uipath_langchain/governance/callbacks.py b/src/uipath_langchain/governance/callbacks.py new file mode 100644 index 000000000..5d796eac7 --- /dev/null +++ b/src/uipath_langchain/governance/callbacks.py @@ -0,0 +1,390 @@ +"""LangChain governance callback handler. + +A :class:`langchain_core.callbacks.BaseCallbackHandler` that calls a +framework-agnostic :class:`~uipath.core.adapters.EvaluatorProtocol` +on the model and tool lifecycle. + +Wiring lives in :class:`UiPathLangGraphRuntimeFactory`: passing an +``evaluator`` to ``new_runtime`` causes the factory to build this +handler and hand it to :class:`UiPathLangGraphRuntime` through the +existing ``callbacks`` constructor arg. No adapter registry, no global +state, no import-time mutation. + +Intercepts: + +- ``on_llm_start`` / ``on_chat_model_start`` / ``on_llm_end`` → BEFORE_MODEL / AFTER_MODEL +- ``on_tool_start`` / ``on_tool_end`` → TOOL_CALL / AFTER_TOOL + +Chain-level boundaries (BEFORE_AGENT / AFTER_AGENT) are intentionally +*not* fired from here — they are owned by the governance host that +drives the agent. ``ignore_chain = True`` makes LangChain skip chain +notifications entirely, avoiding duplicate boundary evaluations. + +Audit emission and enforcement (raising +:class:`GovernanceBlockException` on DENY) are owned by the evaluator +itself. This module just hooks the framework callbacks, extracts the +data, and calls ``evaluator.evaluate_*``; block exceptions propagate, +everything else is logged and swallowed so a governance bug never +breaks an agent run. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, Iterable +from uuid import uuid4 + +from langchain_core.callbacks import BaseCallbackHandler +from uipath.core.adapters import EvaluatorProtocol +from uipath.core.governance.exceptions import GovernanceBlockException + +logger = logging.getLogger(__name__) + +# Cap on the text scanned per model hook, so a long history / runaway +# response can't blow scan-time budgets. +_BEFORE_MODEL_TEXT_CAP = 64000 + + +class GovernanceCallbackHandler(BaseCallbackHandler): + """LangChain callback handler that fires governance evaluation. + + The evaluator owns audit emission and DENY-raising. Each ``on_*`` + callback only extracts the relevant payload and calls the matching + ``evaluate_*`` method; :class:`GovernanceBlockException` is allowed + to propagate, anything else is logged and swallowed. + """ + + run_inline: bool = True + raise_error: bool = False + ignore_llm: bool = False + # Chain-level events (BEFORE_AGENT / AFTER_AGENT) are owned by the + # governance host, so this handler skips them to avoid duplicate + # boundary evaluations. + ignore_chain: bool = True + ignore_agent: bool = False + ignore_retriever: bool = True + ignore_retry: bool = True + ignore_chat_model: bool = False + ignore_custom_event: bool = True + + def __init__( + self, + evaluator: EvaluatorProtocol, + agent_name: str, + session_id: str, + ) -> None: + self._evaluator = evaluator + self._agent_name = agent_name + self._session_id = session_id + self._trace_id = str(uuid4()) + self._session_state: Dict[str, Any] = {"tool_calls": 0, "llm_calls": 0} + # Tool name lookup keyed by LangChain ``run_id`` so ``on_tool_end`` + # can report the actual tool name to AFTER_TOOL evaluation. + self._tool_runs: Dict[str, str] = {} + + # ----- LLM callbacks --------------------------------------------------- + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: list[str], + **kwargs: Any, + ) -> None: + """Evaluate BEFORE_MODEL rules at LLM start (non-chat completion).""" + try: + self._session_state["llm_calls"] = ( + self._session_state.get("llm_calls", 0) + 1 + ) + # Take only the latest prompt. Re-scanning every prompt in a + # batched call would re-fire rules on prior turns' content + # that's still in the prompt for context. + model_input = (prompts[-1] if prompts else "")[:_BEFORE_MODEL_TEXT_CAP] + self._evaluator.evaluate_before_model( + model_input=model_input, + agent_name=self._agent_name, + runtime_id=self._session_id, + trace_id=self._trace_id, + ) + except GovernanceBlockException: + raise + except Exception as e: + logger.warning("on_llm_start governance check failed (continuing): %s", e) + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: list[list[Any]], + **kwargs: Any, + ) -> None: + """Evaluate BEFORE_MODEL rules for chat models. + + Scans only the **latest message** in the prompt — not the full + chat history. The LLM still receives the entire history (this + callback doesn't mutate ``messages``), but the governance + evaluator focuses on the new content the agent is about to + respond to. Without this scoping, a violation in turn 3's user + message would keep re-firing on turns 4, 5, 6 ... because that + text stays in the prompt for context. + + List-of-blocks content (multimodal, function-call, tool_use, + extended thinking) is walked via :meth:`_extract_block_text` so + dict-syntax noise from ``str(list)`` doesn't leak into the + regex-scanned blob. + """ + try: + self._session_state["llm_calls"] = ( + self._session_state.get("llm_calls", 0) + 1 + ) + model_input = self._latest_message_input(messages) + self._evaluator.evaluate_before_model( + model_input=model_input, + agent_name=self._agent_name, + runtime_id=self._session_id, + trace_id=self._trace_id, + ) + except GovernanceBlockException: + raise + except Exception as e: + logger.warning( + "on_chat_model_start governance check failed (continuing): %s", e + ) + + @staticmethod + def _latest_message_input(messages: list[list[Any]]) -> str: + """Extract content from the most-recent message in the prompt. + + ``messages`` is LangChain's nested shape ``list[list[BaseMessage]]`` + — the outer list is for batched calls (rare); the inner list is + the full message stack for one call. We take the last entry of + the last inner list. For string content, that's used directly; + for list-of-blocks content, :meth:`_extract_block_text` pulls + the text / arguments / input / thinking fields cleanly. + + Returns ``""`` (empty) when the message stack is empty or the + last message carries no extractable content. + """ + if not messages: + return "" + last_batch = messages[-1] + if not last_batch: + return "" + last_msg = last_batch[-1] + # BaseMessage exposes ``.content``; dict-shaped messages + # (LangGraph state, raw OpenAI format) carry it under the same + # key. + content = getattr(last_msg, "content", None) + if content is None and isinstance(last_msg, dict): + content = last_msg.get("content") + if isinstance(content, str): + return content[:_BEFORE_MODEL_TEXT_CAP] + if isinstance(content, list): + return GovernanceCallbackHandler._blocks_to_text(content) + return "" + + @staticmethod + def _blocks_to_text(content: list[Any]) -> str: + """Concatenate governance-relevant text from a list of content blocks. + + Walks list-of-blocks message content (multimodal, function-call, + tool_use, extended thinking) via :meth:`_extract_block_text`, + capping the joined result at ``_BEFORE_MODEL_TEXT_CAP``. + """ + pieces = ( + GovernanceCallbackHandler._extract_block_text(block) + for block in content + if isinstance(block, dict) + ) + return GovernanceCallbackHandler._join_within_cap(pieces, "\n") + + @staticmethod + def _join_within_cap(pieces: Iterable[str], sep: str) -> str: + """Join non-empty ``pieces`` with ``sep``, stopping at the text cap. + + Shared accumulator for the model-input/output scan blobs: appends + pieces until ``_BEFORE_MODEL_TEXT_CAP`` characters are reached + (counting the separator), then caps the joined result. + """ + out: list[str] = [] + remaining = _BEFORE_MODEL_TEXT_CAP + for piece in pieces: + if remaining <= 0: + break + if piece: + out.append(piece) + remaining -= len(piece) + len(sep) + return sep.join(out)[:_BEFORE_MODEL_TEXT_CAP] + + def on_llm_end(self, response: Any, **kwargs: Any) -> None: + """Evaluate AFTER_MODEL rules at LLM end. + + Concatenates text from every generation. The result is capped at + ``_BEFORE_MODEL_TEXT_CAP`` to match the BEFORE_MODEL budget, so + batched calls or a runaway single response can't blow scan budgets. + """ + try: + model_output = self._collect_generations_text(response) + self._evaluator.evaluate_after_model( + model_output=model_output, + agent_name=self._agent_name, + runtime_id=self._session_id, + trace_id=self._trace_id, + ) + except GovernanceBlockException: + raise + except Exception as e: + logger.warning("on_llm_end governance check failed (continuing): %s", e) + + def _collect_generations_text(self, response: Any) -> str: + """Concatenate text across all generations, capped at the text budget. + + Returns ``""`` when the response carries no ``generations``. + """ + if not hasattr(response, "generations"): + return "" + pieces = ( + self._extract_generation_text(gen) + for gen_list in response.generations + for gen in gen_list + ) + return self._join_within_cap(pieces, "") + + @staticmethod + def _extract_generation_text(gen: Any) -> str: + """Return the text payload of a LangChain ``Generation`` / ``ChatGeneration``. + + ``Generation.text`` is set from ``message.content`` only when content + is a plain ``str``. For chat models whose content is a list of + content blocks (multimodal, tool calls, "submit final answer" + function calls, extended thinking) ``.text`` is ``""``. Fall back + to walking ``gen.message.content`` so the governance evaluator + sees the actual assistant text. + """ + text = getattr(gen, "text", "") or "" + if text: + return text + message = getattr(gen, "message", None) + if message is None: + return "" + content = getattr(message, "content", None) + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [ + GovernanceCallbackHandler._extract_block_text(block) + for block in content + if isinstance(block, dict) + ] + return "\n".join(p for p in parts if p) + return "" + + @staticmethod + def _extract_block_text(block: Dict[str, Any]) -> str: + """Return any governance-relevant text from a content block. + + Covers the common block shapes across providers: + + - ``{"type": "text", "text": "..."}`` — plain text block. + - ``{"type": "function_call", "arguments": ""}`` — OpenAI + function call; ``arguments`` is JSON-encoded and routinely + carries the user-visible reply (e.g. ``end_execution(content=...)`` + tools used as a "submit final answer" pattern). + - ``{"type": "tool_use", "input": {...}}`` — Anthropic tool use; + string values in ``input`` are the assistant's outgoing payload. + - ``{"type": "thinking", "thinking": "..."}`` — Claude extended + thinking (governance-relevant: hidden reasoning can also leak + commitments and PII). + + Metadata-only keys (``id``, ``call_id``, ``name``, ``status``, + ``type``, ...) are excluded so the scanned text isn't padded with + opaque identifiers that could false-positive a rule. + """ + parts: list[str] = [] + text_value = block.get("text") + if isinstance(text_value, str): + parts.append(text_value) + arguments_value = block.get("arguments") + if isinstance(arguments_value, str): + parts.append(arguments_value) + thinking_value = block.get("thinking") + if isinstance(thinking_value, str): + parts.append(thinking_value) + input_value = block.get("input") + if isinstance(input_value, dict): + parts.extend(v for v in input_value.values() if isinstance(v, str)) + return "\n".join(p for p in parts if p) + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + logger.warning("LLM error in governed session %s: %s", self._session_id, error) + + # ----- Tool callbacks -------------------------------------------------- + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + inputs: Dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Evaluate TOOL_CALL rules at tool start. + + ``run_id → tool_name`` is recorded so ``on_tool_end`` / + ``on_tool_error`` can report the actual tool. If the evaluator + BLOCKS, the tool is aborted, ``on_tool_end`` will not fire, and + the mapping is dropped to keep ``_tool_runs`` from growing + unbounded across blocked turns. + """ + run_id = kwargs.get("run_id") + run_id_str = str(run_id) if run_id is not None else None + try: + self._session_state["tool_calls"] = ( + self._session_state.get("tool_calls", 0) + 1 + ) + tool_name = (serialized or {}).get("name", "unknown") + if run_id_str is not None: + self._tool_runs[run_id_str] = tool_name + tool_args = inputs or {"input": input_str} + self._evaluator.evaluate_tool_call( + tool_name=tool_name, + tool_args=tool_args, + agent_name=self._agent_name, + runtime_id=self._session_id, + trace_id=self._trace_id, + session_state=self._session_state, + ) + except GovernanceBlockException: + # Tool will not run → no on_tool_end is coming. Drop the + # mapping so it does not accumulate across blocked turns. + if run_id_str is not None: + self._tool_runs.pop(run_id_str, None) + raise + except Exception as e: + logger.warning("on_tool_start governance check failed (continuing): %s", e) + + def on_tool_end(self, output: Any, **kwargs: Any) -> None: + """Evaluate AFTER_TOOL rules at tool end.""" + try: + run_id = kwargs.get("run_id") + tool_name = "unknown" + if run_id is not None: + tool_name = self._tool_runs.pop(str(run_id), "unknown") + tool_result = str(output) if output is not None else "" + self._evaluator.evaluate_after_tool( + tool_name=tool_name, + tool_result=tool_result, + agent_name=self._agent_name, + runtime_id=self._session_id, + trace_id=self._trace_id, + ) + except GovernanceBlockException: + raise + except Exception as e: + logger.warning("on_tool_end governance check failed (continuing): %s", e) + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + # Tool errored out — on_tool_end will not fire. Pop the mapping + # so a session with many failing tool calls does not leak. + run_id = kwargs.get("run_id") + if run_id is not None: + self._tool_runs.pop(str(run_id), None) + logger.warning("Tool error in governed session %s: %s", self._session_id, error) diff --git a/src/uipath_langchain/runtime/factory.py b/src/uipath_langchain/runtime/factory.py index b8f6565f8..fead85d87 100644 --- a/src/uipath_langchain/runtime/factory.py +++ b/src/uipath_langchain/runtime/factory.py @@ -2,6 +2,7 @@ import os from typing import Any, AsyncContextManager +from langchain_core.callbacks import BaseCallbackHandler from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from langgraph.graph.state import CompiledStateGraph, StateGraph from openinference.instrumentation.langchain import ( @@ -9,6 +10,7 @@ get_ancestor_spans, get_current_span, ) +from uipath.core.adapters import EvaluatorProtocol from uipath.core.tracing import UiPathSpanUtils, UiPathTraceManager from uipath.platform.resume_triggers import ( UiPathResumeTriggerHandler, @@ -23,6 +25,7 @@ from uipath.runtime.errors import UiPathErrorCategory from uipath_langchain._tracing import _instrument_traceable_attributes +from uipath_langchain.governance import GovernanceCallbackHandler from uipath_langchain.runtime.config import LangGraphConfig from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError from uipath_langchain.runtime.graph import LangGraphLoader @@ -254,6 +257,7 @@ async def _create_runtime_instance( compiled_graph: CompiledStateGraph[Any, Any, Any, Any], runtime_id: str, entrypoint: str, + evaluator: EvaluatorProtocol | None = None, **kwargs, ) -> UiPathRuntimeProtocol: """ @@ -263,6 +267,9 @@ async def _create_runtime_instance( compiled_graph: The compiled graph runtime_id: Unique identifier for the runtime instance entrypoint: Graph entrypoint name + evaluator: Optional governance evaluator. When supplied, the + factory builds a :class:`GovernanceCallbackHandler` and + hands it to the runtime via its ``callbacks`` arg. Returns: Configured runtime instance @@ -271,10 +278,23 @@ async def _create_runtime_instance( storage = SqliteResumableStorage(memory) trigger_manager = UiPathResumeTriggerHandler() + callbacks: list[BaseCallbackHandler] | None = ( + [ + GovernanceCallbackHandler( + evaluator=evaluator, + agent_name=entrypoint, + session_id=runtime_id, + ) + ] + if evaluator is not None + else None + ) + base_runtime = UiPathLangGraphRuntime( graph=compiled_graph, runtime_id=runtime_id, entrypoint=entrypoint, + callbacks=callbacks, storage=storage, ) @@ -286,7 +306,11 @@ async def _create_runtime_instance( ) async def new_runtime( - self, entrypoint: str, runtime_id: str, **kwargs + self, + entrypoint: str, + runtime_id: str, + evaluator: EvaluatorProtocol | None = None, + **kwargs, ) -> UiPathRuntimeProtocol: """ Create a new LangGraph runtime instance. @@ -294,6 +318,9 @@ async def new_runtime( Args: entrypoint: Graph name from langgraph.json runtime_id: Unique identifier for the runtime instance + evaluator: Optional governance evaluator. When supplied, the + factory wires a :class:`GovernanceCallbackHandler` into + the runtime's callback list. Returns: Configured runtime instance with compiled graph @@ -309,6 +336,7 @@ async def new_runtime( compiled_graph=compiled_graph, runtime_id=runtime_id, entrypoint=entrypoint, + evaluator=evaluator, **kwargs, ) diff --git a/tests/governance/__init__.py b/tests/governance/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/governance/test_callbacks.py b/tests/governance/test_callbacks.py new file mode 100644 index 000000000..d634548a0 --- /dev/null +++ b/tests/governance/test_callbacks.py @@ -0,0 +1,572 @@ +"""Tests for the LangChain governance callback handler.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from langchain_core.callbacks import BaseCallbackHandler +from uipath.core.governance.exceptions import GovernanceBlockException + +from uipath_langchain.governance import GovernanceCallbackHandler +from uipath_langchain.governance.callbacks import _BEFORE_MODEL_TEXT_CAP + +LOGGER_PATH = "uipath_langchain.governance.callbacks.logger" + + +@pytest.fixture +def evaluator() -> MagicMock: + return MagicMock() + + +@pytest.fixture +def handler(evaluator: MagicMock) -> GovernanceCallbackHandler: + return GovernanceCallbackHandler( + evaluator=evaluator, + agent_name="test-agent", + session_id="test-session", + ) + + +class TestSubclassesBaseCallbackHandler: + def test_is_base_callback_handler(self, handler: GovernanceCallbackHandler) -> None: + # Closes governance-architecture-review §3.2: the handler must be + # a real LangChain BaseCallbackHandler so LangChain's dispatch / + # tracer wiring treats it natively. + assert isinstance(handler, BaseCallbackHandler) + + def test_ignore_flags_override_parent_properties( + self, handler: GovernanceCallbackHandler + ) -> None: + # Chain notifications skipped — the governance host owns + # BEFORE_AGENT / AFTER_AGENT and would otherwise double-fire. + assert handler.ignore_chain is True + assert handler.ignore_retriever is True + assert handler.ignore_retry is True + assert handler.ignore_custom_event is True + # LLM / chat model / tool / agent events stay on. + assert handler.ignore_llm is False + assert handler.ignore_chat_model is False + assert handler.ignore_agent is False + + +class TestCallbackHandlerLLM: + def test_on_llm_start_invokes_evaluator_with_latest_prompt( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + """Only the latest prompt feeds BEFORE_MODEL — prior prompts in a + batched call would re-fire rules on content the LLM has + already responded to in earlier batches.""" + handler.on_llm_start({"name": "m"}, ["a", "b"]) + evaluator.evaluate_before_model.assert_called_once() + kwargs = evaluator.evaluate_before_model.call_args.kwargs + assert kwargs["model_input"] == "b" + assert kwargs["agent_name"] == "test-agent" + assert kwargs["runtime_id"] == "test-session" + assert kwargs["trace_id"] == handler._trace_id + + def test_on_llm_start_increments_counter( + self, handler: GovernanceCallbackHandler + ) -> None: + handler.on_llm_start({}, ["p"]) + handler.on_llm_start({}, ["p"]) + assert handler._session_state["llm_calls"] == 2 + + def test_on_llm_start_empty_prompts( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_llm_start({}, []) + assert evaluator.evaluate_before_model.call_args.kwargs["model_input"] == "" + + def test_on_llm_start_propagates_block( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + evaluator.evaluate_before_model.side_effect = GovernanceBlockException( + "blocked" + ) + with pytest.raises(GovernanceBlockException): + handler.on_llm_start({}, ["p"]) + + def test_on_llm_start_swallows_other_exceptions( + self, + handler: GovernanceCallbackHandler, + evaluator: MagicMock, + ) -> None: + evaluator.evaluate_before_model.side_effect = RuntimeError("nope") + with patch(LOGGER_PATH) as mock_logger: + handler.on_llm_start({}, ["p"]) # must not raise + mock_logger.warning.assert_called_once() + assert "on_llm_start" in mock_logger.warning.call_args.args[0] + + def test_on_chat_model_start_latest_message_only( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + """Only the LAST message in the prompt is scanned. + + Without this scoping, a violation in turn 3's user message + would keep re-firing on every subsequent LLM call because + that text stays in the prompt for context. + """ + handler.on_chat_model_start( + {}, + [[SimpleNamespace(content="hello"), SimpleNamespace(content="world")]], + ) + model_input = evaluator.evaluate_before_model.call_args.kwargs["model_input"] + assert model_input == "world" + assert "hello" not in model_input + + def test_on_chat_model_start_dict_messages_latest_only( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + """Dict-shaped (LangGraph state) messages: latest is extracted.""" + handler.on_chat_model_start( + {}, + [[{"content": "from dict"}, {"role": "user", "content": "another"}]], + ) + model_input = evaluator.evaluate_before_model.call_args.kwargs["model_input"] + assert model_input == "another" + assert "from dict" not in model_input + + def test_on_chat_model_start_dict_message_missing_content( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_chat_model_start({}, [[{"role": "user"}]]) + assert evaluator.evaluate_before_model.call_args.kwargs["model_input"] == "" + + def test_on_chat_model_start_list_of_blocks_content( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + """Multi-block content (text + function_call) is extracted cleanly. + + Regression for the prior ``str(msg.content)`` path which produced + ``[{'type': ..., 'text': ...}]`` dict-repr noise instead of + clean text. Field-precise rules can't navigate that shape. + """ + latest = SimpleNamespace( + content=[ + {"type": "text", "text": "Here's the answer:"}, + { + "type": "function_call", + "name": "end_execution", + "arguments": '{"content":"Cost: $1,200"}', + "id": "fc_abc", + }, + ] + ) + handler.on_chat_model_start({}, [[SimpleNamespace(content="old"), latest]]) + model_input = evaluator.evaluate_before_model.call_args.kwargs["model_input"] + assert "Here's the answer:" in model_input + assert "Cost: $1,200" in model_input + # No dict-syntax noise from str(list). + assert "{'type'" not in model_input + + def test_on_chat_model_start_empty_messages( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_chat_model_start({}, []) + assert evaluator.evaluate_before_model.call_args.kwargs["model_input"] == "" + + def test_on_chat_model_start_empty_inner_batch( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_chat_model_start({}, [[]]) + assert evaluator.evaluate_before_model.call_args.kwargs["model_input"] == "" + + def test_on_chat_model_start_caps_model_input( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + """``model_input`` is bounded so a runaway prompt can't dominate scan time.""" + huge = SimpleNamespace(content="x" * (_BEFORE_MODEL_TEXT_CAP + 1000)) + handler.on_chat_model_start({}, [[huge]]) + model_input = evaluator.evaluate_before_model.call_args.kwargs["model_input"] + assert len(model_input) == _BEFORE_MODEL_TEXT_CAP + + def test_on_chat_model_start_block_list_stops_at_remaining_budget( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + """The block walk exits early once the per-call cap is exhausted.""" + first = "a" * _BEFORE_MODEL_TEXT_CAP # consumes the entire budget + latest = SimpleNamespace( + content=[ + {"type": "text", "text": first}, + {"type": "text", "text": "MUST_NOT_APPEAR"}, + ] + ) + handler.on_chat_model_start({}, [[latest]]) + model_input = evaluator.evaluate_before_model.call_args.kwargs["model_input"] + assert "MUST_NOT_APPEAR" not in model_input + assert len(model_input) == _BEFORE_MODEL_TEXT_CAP + + def test_on_chat_model_start_block_list_skips_non_dict_entries( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + """Non-dict entries inside a content list are silently skipped.""" + latest = SimpleNamespace( + content=[ + "ignored-string-block", + {"type": "text", "text": "kept"}, + 42, + None, + ] + ) + handler.on_chat_model_start({}, [[latest]]) + model_input = evaluator.evaluate_before_model.call_args.kwargs["model_input"] + assert model_input == "kept" + + def test_on_chat_model_start_propagates_block( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + evaluator.evaluate_before_model.side_effect = GovernanceBlockException("x") + with pytest.raises(GovernanceBlockException): + handler.on_chat_model_start({}, [[SimpleNamespace(content="x")]]) + + def test_on_chat_model_start_swallows_other_exceptions( + self, + handler: GovernanceCallbackHandler, + evaluator: MagicMock, + ) -> None: + evaluator.evaluate_before_model.side_effect = RuntimeError("oops") + with patch(LOGGER_PATH) as mock_logger: + handler.on_chat_model_start({}, [[SimpleNamespace(content="x")]]) + mock_logger.warning.assert_called_once() + assert "on_chat_model_start" in mock_logger.warning.call_args.args[0] + + def test_on_llm_end_extracts_plain_text( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + gen = SimpleNamespace(text="output", message=None) + response = SimpleNamespace(generations=[[gen]]) + handler.on_llm_end(response) + kwargs = evaluator.evaluate_after_model.call_args.kwargs + assert kwargs["model_output"] == "output" + + def test_on_llm_end_response_without_generations( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_llm_end(SimpleNamespace()) + assert evaluator.evaluate_after_model.call_args.kwargs["model_output"] == "" + + def test_on_llm_end_propagates_block( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + evaluator.evaluate_after_model.side_effect = GovernanceBlockException("x") + with pytest.raises(GovernanceBlockException): + handler.on_llm_end(SimpleNamespace(generations=[])) + + def test_on_llm_end_caps_model_output( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + """A runaway / batched response is capped so the AFTER_MODEL + scan budget matches BEFORE_MODEL and the runtime side's cap. + """ + # Many large generations across batched gen_lists. + gen = SimpleNamespace(text="y" * 50_000, message=None) + response = SimpleNamespace(generations=[[gen], [gen, gen]]) + handler.on_llm_end(response) + model_output = evaluator.evaluate_after_model.call_args.kwargs["model_output"] + assert len(model_output) == _BEFORE_MODEL_TEXT_CAP + + def test_on_llm_end_skips_empty_generation_text( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + """Generations with no extractable text don't bloat the output.""" + empty = SimpleNamespace(text="", message=None) + keep = SimpleNamespace(text="kept", message=None) + response = SimpleNamespace(generations=[[empty, keep]]) + handler.on_llm_end(response) + assert evaluator.evaluate_after_model.call_args.kwargs["model_output"] == "kept" + + def test_on_llm_end_swallows_other_exceptions( + self, + handler: GovernanceCallbackHandler, + evaluator: MagicMock, + ) -> None: + evaluator.evaluate_after_model.side_effect = RuntimeError("nope") + with patch(LOGGER_PATH) as mock_logger: + handler.on_llm_end(SimpleNamespace()) + mock_logger.warning.assert_called_once() + assert "on_llm_end" in mock_logger.warning.call_args.args[0] + + def test_on_llm_error_logs( + self, + handler: GovernanceCallbackHandler, + ) -> None: + with patch(LOGGER_PATH) as mock_logger: + handler.on_llm_error(RuntimeError("boom")) + mock_logger.warning.assert_called_once() + assert "LLM error" in mock_logger.warning.call_args.args[0] + + +class TestExtractGenerationText: + def test_returns_text_when_present(self) -> None: + gen = SimpleNamespace(text="hello", message=None) + assert GovernanceCallbackHandler._extract_generation_text(gen) == "hello" + + def test_falls_back_to_message_string_content(self) -> None: + gen = SimpleNamespace(text="", message=SimpleNamespace(content="rich")) + assert GovernanceCallbackHandler._extract_generation_text(gen) == "rich" + + def test_returns_empty_when_message_missing(self) -> None: + class G: + text = "" + + assert GovernanceCallbackHandler._extract_generation_text(G()) == "" + + def test_returns_empty_when_message_is_none(self) -> None: + gen = SimpleNamespace(text="", message=None) + assert GovernanceCallbackHandler._extract_generation_text(gen) == "" + + def test_extracts_from_block_list_content(self) -> None: + gen = SimpleNamespace( + text="", + message=SimpleNamespace( + content=[ + {"type": "text", "text": "alpha"}, + {"type": "tool_use", "input": {"q": "beta"}}, + ] + ), + ) + out = GovernanceCallbackHandler._extract_generation_text(gen) + assert "alpha" in out + assert "beta" in out + + def test_block_list_skips_non_dict_entries(self) -> None: + gen = SimpleNamespace( + text="", + message=SimpleNamespace( + content=["string-entry", {"type": "text", "text": "kept"}] + ), + ) + assert GovernanceCallbackHandler._extract_generation_text(gen) == "kept" + + def test_unknown_content_shape_returns_empty(self) -> None: + gen = SimpleNamespace(text="", message=SimpleNamespace(content=123)) + assert GovernanceCallbackHandler._extract_generation_text(gen) == "" + + +class TestExtractBlockText: + def test_plain_text_block(self) -> None: + assert ( + GovernanceCallbackHandler._extract_block_text( + {"type": "text", "text": "hello"} + ) + == "hello" + ) + + def test_function_call_arguments_block(self) -> None: + assert ( + GovernanceCallbackHandler._extract_block_text( + {"type": "function_call", "arguments": '{"a":1}'} + ) + == '{"a":1}' + ) + + def test_thinking_block(self) -> None: + assert ( + GovernanceCallbackHandler._extract_block_text( + {"type": "thinking", "thinking": "step by step"} + ) + == "step by step" + ) + + def test_tool_use_input_extracts_string_values(self) -> None: + result = GovernanceCallbackHandler._extract_block_text( + {"type": "tool_use", "input": {"query": "search", "id": "ignored"}} + ) + assert "search" in result + assert "ignored" in result # both are strings; metadata filtering is by key + + def test_input_ignores_non_string_values(self) -> None: + result = GovernanceCallbackHandler._extract_block_text( + {"input": {"a": 123, "b": ["nested"], "c": "kept"}} + ) + assert result == "kept" + + def test_metadata_only_block_returns_empty(self) -> None: + assert ( + GovernanceCallbackHandler._extract_block_text( + {"type": "tool_use", "id": "abc", "name": "search", "status": "ok"} + ) + == "" + ) + + def test_combined_fields_all_collected(self) -> None: + result = GovernanceCallbackHandler._extract_block_text( + { + "type": "tool_use", + "text": "T", + "arguments": "A", + "thinking": "Th", + "input": {"k": "I"}, + } + ) + for token in ("T", "A", "Th", "I"): + assert token in result + + def test_empty_block(self) -> None: + assert GovernanceCallbackHandler._extract_block_text({}) == "" + + +class TestCallbackHandlerTools: + def test_on_tool_start_with_inputs( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_tool_start({"name": "search"}, "fallback", inputs={"q": "v"}) + kwargs = evaluator.evaluate_tool_call.call_args.kwargs + assert kwargs["tool_name"] == "search" + assert kwargs["tool_args"] == {"q": "v"} + assert kwargs["session_state"] is handler._session_state + + def test_on_tool_start_without_inputs_uses_input_str( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_tool_start({"name": "calc"}, "1+2") + kwargs = evaluator.evaluate_tool_call.call_args.kwargs + assert kwargs["tool_args"] == {"input": "1+2"} + + def test_on_tool_start_unknown_name_when_missing( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_tool_start({}, "x") + assert evaluator.evaluate_tool_call.call_args.kwargs["tool_name"] == "unknown" + + def test_on_tool_start_increments_counter( + self, handler: GovernanceCallbackHandler + ) -> None: + handler.on_tool_start({}, "x") + handler.on_tool_start({}, "y") + assert handler._session_state["tool_calls"] == 2 + + def test_on_tool_start_propagates_block( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + evaluator.evaluate_tool_call.side_effect = GovernanceBlockException("no") + with pytest.raises(GovernanceBlockException): + handler.on_tool_start({}, "x") + + def test_on_tool_start_swallows_other_exceptions( + self, + handler: GovernanceCallbackHandler, + evaluator: MagicMock, + ) -> None: + evaluator.evaluate_tool_call.side_effect = RuntimeError("nope") + with patch(LOGGER_PATH) as mock_logger: + handler.on_tool_start({}, "x") + mock_logger.warning.assert_called_once() + assert "on_tool_start" in mock_logger.warning.call_args.args[0] + + def test_on_tool_end_with_output( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_tool_end({"answer": 42}) + kwargs = evaluator.evaluate_after_tool.call_args.kwargs + assert "42" in kwargs["tool_result"] + assert kwargs["tool_name"] == "unknown" + + def test_on_tool_end_uses_tool_name_from_run_id( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_tool_start({"name": "search"}, "q", run_id="run-1") + handler.on_tool_end("result", run_id="run-1") + assert evaluator.evaluate_after_tool.call_args.kwargs["tool_name"] == "search" + # The run_id mapping is cleaned up so a stale entry isn't reused. + assert "run-1" not in handler._tool_runs + + def test_on_tool_end_unknown_when_run_id_not_recorded( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_tool_end("r", run_id="never-started") + assert evaluator.evaluate_after_tool.call_args.kwargs["tool_name"] == "unknown" + + def test_on_tool_start_handles_none_serialized( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_tool_start(None, "x") # type: ignore[arg-type] + assert evaluator.evaluate_tool_call.call_args.kwargs["tool_name"] == "unknown" + + def test_on_tool_end_with_none_output( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + handler.on_tool_end(None) + assert evaluator.evaluate_after_tool.call_args.kwargs["tool_result"] == "" + + def test_on_tool_end_propagates_block( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + evaluator.evaluate_after_tool.side_effect = GovernanceBlockException("x") + with pytest.raises(GovernanceBlockException): + handler.on_tool_end("out") + + def test_on_tool_end_swallows_other_exceptions( + self, + handler: GovernanceCallbackHandler, + evaluator: MagicMock, + ) -> None: + evaluator.evaluate_after_tool.side_effect = RuntimeError("err") + with patch(LOGGER_PATH) as mock_logger: + handler.on_tool_end("out") + mock_logger.warning.assert_called_once() + assert "on_tool_end" in mock_logger.warning.call_args.args[0] + + def test_on_tool_error_logs( + self, + handler: GovernanceCallbackHandler, + ) -> None: + with patch(LOGGER_PATH) as mock_logger: + handler.on_tool_error(RuntimeError("broke")) + mock_logger.warning.assert_called_once() + assert "Tool error" in mock_logger.warning.call_args.args[0] + + def test_on_tool_error_pops_run_id_mapping( + self, handler: GovernanceCallbackHandler + ) -> None: + """``on_tool_error`` cleans up ``_tool_runs`` so failed tool calls + don't accumulate over the lifetime of a governed session. + """ + handler.on_tool_start({"name": "search"}, "q", run_id="run-err") + assert handler._tool_runs.get("run-err") == "search" + handler.on_tool_error(RuntimeError("boom"), run_id="run-err") + assert "run-err" not in handler._tool_runs + + def test_on_tool_error_without_run_id_does_not_crash( + self, handler: GovernanceCallbackHandler + ) -> None: + # No run_id kwargs — should still log and not raise. + handler.on_tool_error(RuntimeError("boom")) + assert handler._tool_runs == {} + + def test_on_tool_start_block_pops_run_id_mapping( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + """If BEFORE_TOOL evaluation BLOCKS, the recorded mapping is + dropped — the tool never runs and ``on_tool_end`` will not fire. + Leaving the entry would leak across blocked turns. + """ + evaluator.evaluate_tool_call.side_effect = GovernanceBlockException("nope") + with pytest.raises(GovernanceBlockException): + handler.on_tool_start({"name": "search"}, "q", run_id="run-blocked") + assert "run-blocked" not in handler._tool_runs + + def test_on_tool_start_swallowed_error_preserves_mapping( + self, handler: GovernanceCallbackHandler, evaluator: MagicMock + ) -> None: + """When the evaluator raises a non-block exception, we swallow + and the tool still runs — the mapping must survive so + ``on_tool_end`` can resolve the tool name. + """ + evaluator.evaluate_tool_call.side_effect = RuntimeError("flaky") + with patch(LOGGER_PATH): + handler.on_tool_start({"name": "search"}, "q", run_id="run-flaky") + assert handler._tool_runs.get("run-flaky") == "search" + + +class TestCallbackHandlerInit: + def test_session_state_initialized(self, evaluator: MagicMock) -> None: + h = GovernanceCallbackHandler( + evaluator=evaluator, agent_name="a", session_id="s" + ) + assert h._session_state == {"tool_calls": 0, "llm_calls": 0} + assert h._agent_name == "a" + assert h._session_id == "s" + assert h._trace_id # uuid4 string diff --git a/tests/runtime/test_factory_governance.py b/tests/runtime/test_factory_governance.py new file mode 100644 index 000000000..5058a10da --- /dev/null +++ b/tests/runtime/test_factory_governance.py @@ -0,0 +1,116 @@ +"""Factory-level governance wiring: evaluator -> callbacks plumbing.""" + +from __future__ import annotations + +import os +import tempfile +from typing import Any, TypedDict +from unittest.mock import MagicMock + +import pytest +from langgraph.graph import END, START, StateGraph +from uipath.core.adapters import EvaluatorProtocol +from uipath.runtime import UiPathRuntimeContext + +from uipath_langchain.governance import GovernanceCallbackHandler +from uipath_langchain.runtime.factory import UiPathLangGraphRuntimeFactory + + +class _State(TypedDict): + v: int + + +def _build_graph() -> StateGraph[Any, Any, Any]: + g = StateGraph(_State) + g.add_node("noop", lambda s: s) + g.add_edge(START, "noop") + g.add_edge("noop", END) + return g + + +@pytest.fixture +def context() -> UiPathRuntimeContext: + tmpdir = tempfile.mkdtemp() + ctx = UiPathRuntimeContext( + runtime_dir=tmpdir, + state_file=os.path.join(tmpdir, "state.db"), + ) + return ctx + + +@pytest.fixture +def factory(context: UiPathRuntimeContext) -> UiPathLangGraphRuntimeFactory: + return UiPathLangGraphRuntimeFactory(context) + + +class TestEvaluatorWiring: + """Passing ``evaluator`` to ``new_runtime`` should attach a + :class:`GovernanceCallbackHandler` to the underlying LangGraph + runtime's callback list. This is the entire surface change — the + previous adapter / register-on-import path is gone. + """ + + async def test_no_evaluator_means_no_callbacks( + self, factory: UiPathLangGraphRuntimeFactory + ) -> None: + compiled = _build_graph().compile() + await factory._get_memory() + runtime = await factory._create_runtime_instance( + compiled_graph=compiled, + runtime_id="rt-1", + entrypoint="ep", + ) + # The resumable runtime wraps the langgraph runtime as ``delegate``. + assert runtime.delegate.callbacks == [] # type: ignore[attr-defined] + await factory.dispose() + + async def test_evaluator_attaches_governance_handler( + self, factory: UiPathLangGraphRuntimeFactory + ) -> None: + evaluator: EvaluatorProtocol = MagicMock(spec=EvaluatorProtocol) + compiled = _build_graph().compile() + await factory._get_memory() # ensure memory is initialized + runtime = await factory._create_runtime_instance( + compiled_graph=compiled, + runtime_id="rt-1", + entrypoint="ep", + evaluator=evaluator, + ) + callbacks = runtime.delegate.callbacks # type: ignore[attr-defined] + assert len(callbacks) == 1 + handler = callbacks[0] + assert isinstance(handler, GovernanceCallbackHandler) + # Identity / session_id / agent_name come from the factory args. + assert handler._evaluator is evaluator + assert handler._agent_name == "ep" + assert handler._session_id == "rt-1" + await factory.dispose() + + async def test_handler_built_per_runtime_instance( + self, factory: UiPathLangGraphRuntimeFactory + ) -> None: + """Two factory calls with the same evaluator yield two distinct + handler instances — each runtime gets its own trace_id and + session_state, so concurrent sessions don't share counters.""" + evaluator: EvaluatorProtocol = MagicMock(spec=EvaluatorProtocol) + compiled = _build_graph().compile() + await factory._get_memory() + first = await factory._create_runtime_instance( + compiled_graph=compiled, + runtime_id="rt-a", + entrypoint="ep", + evaluator=evaluator, + ) + second = await factory._create_runtime_instance( + compiled_graph=compiled, + runtime_id="rt-b", + entrypoint="ep", + evaluator=evaluator, + ) + h1 = first.delegate.callbacks[0] # type: ignore[attr-defined] + h2 = second.delegate.callbacks[0] # type: ignore[attr-defined] + assert h1 is not h2 + assert h1._trace_id != h2._trace_id + assert h1._session_id == "rt-a" + assert h2._session_id == "rt-b" + await factory.dispose() diff --git a/uv.lock b/uv.lock index 51ceb50ae..ff51689ac 100644 --- a/uv.lock +++ b/uv.lock @@ -4374,16 +4374,16 @@ wheels = [ [[package]] name = "uipath-core" -version = "0.5.20" -source = { registry = "https://pypi.org/simple" } +version = "0.5.23.dev1017616946" +source = { registry = "https://test.pypi.org/simple/" } dependencies = [ { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-sdk" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/93/da/011ced5af57363caf7ca6d263261fc4b64f19bf7f7b2a5e54132906a36a6/uipath_core-0.5.20.tar.gz", hash = "sha256:2a2430185522869b10c05273128c23a81fbb8c53ee5dd8686c8b5089ea270fa7", size = 132363, upload-time = "2026-06-19T12:01:37.545Z" } +sdist = { url = "https://test-files.pythonhosted.org/packages/49/a7/5683de37c2e12168e1b9d0da80c17411e77ca6e600134de0e1584b36eab6/uipath_core-0.5.23.dev1017616946.tar.gz", hash = "sha256:46b36d45ea984d6ebce134e5ee171b615a4d6f53bff7a0aa2e9c26f0d87da117", size = 129588, upload-time = "2026-06-25T03:15:30.649Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/02/b518bb5569c9c35f4ed19a7f23c744c471352a1fa5233071dd75a4cc9324/uipath_core-0.5.20-py3-none-any.whl", hash = "sha256:2be116ec68d034348ea58fd541675863d73e96649f528648d533206b8c8853cc", size = 55005, upload-time = "2026-06-19T12:01:36.08Z" }, + { url = "https://test-files.pythonhosted.org/packages/85/f9/f3e7bea29d70951167907171117fa4494abf64f9ccfd9b66d22b8604e7d6/uipath_core-0.5.23.dev1017616946-py3-none-any.whl", hash = "sha256:c988ee6b2ca2d7093f2cd7d46b7c9bb3bb504a09a0d76aa3c61d9ce1003dac16", size = 54265, upload-time = "2026-06-25T03:15:29.663Z" }, ] [[package]] @@ -4464,7 +4464,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.6.0" }, { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "uipath", specifier = ">=2.10.79,<2.12.0" }, - { name = "uipath-core", specifier = ">=0.5.20,<0.6.0" }, + { name = "uipath-core", specifier = "==0.5.23.dev1017616946", index = "https://test.pypi.org/simple/" }, { name = "uipath-langchain-client", extras = ["all"], marker = "extra == 'all'", specifier = ">=1.14.0,<1.15.0" }, { name = "uipath-langchain-client", extras = ["anthropic"], marker = "extra == 'anthropic'", specifier = ">=1.14.0,<1.15.0" }, { name = "uipath-langchain-client", extras = ["bedrock"], marker = "extra == 'bedrock'", specifier = ">=1.14.0,<1.15.0" },