diff --git a/src/seclab_taskflow_agent/env_utils.py b/src/seclab_taskflow_agent/env_utils.py index 3ccfd77..6d75696 100644 --- a/src/seclab_taskflow_agent/env_utils.py +++ b/src/seclab_taskflow_agent/env_utils.py @@ -12,49 +12,68 @@ -def swap_env(s: str) -> str: - """Replace {{ env('VAR') }} patterns in string with environment values. +def swap_env(s: str, context: dict[str, Any] | None = None) -> str: + """Render Jinja template expressions in a string. + + Supports expressions such as ``{{ env('VAR') }}``. Template variables + like ``{{ globals.X }}`` are only available when provided by the caller + via ``context`` (e.g. ``{'globals': {...}}``). Args: - s: String potentially containing env templates + s: String potentially containing templates. + context: Optional template context. Variables such as ``globals`` + must be supplied here to be available during rendering. Returns: - String with env templates replaced + String with templates replaced. Raises: - LookupError: If required env var not found + LookupError: If a required environment variable or template + variable is not found during rendering. """ - # Quick check if templating needed - if '{{' not in s: - return s - try: - # Import here to avoid circular dependency from .template_utils import create_jinja_environment from .available_tools import AvailableTools available_tools = AvailableTools() jinja_env = create_jinja_environment(available_tools) template = jinja_env.from_string(s) - return template.render() + # Filter out keys that collide with built-in template globals + # (e.g. the env() helper) to prevent callers from breaking them. + reserved_keys = set(jinja_env.globals) + render_context = { + key: value for key, value in (context or {}).items() + if key not in reserved_keys + } + return template.render(**render_context) except jinja2.UndefinedError as e: - # Convert Jinja undefined to LookupError for compatibility raise LookupError(str(e)) - except jinja2.TemplateError: - # Not a template or failed to render, return as-is - return s + except jinja2.TemplateError as e: + raise LookupError(f"Template rendering failed for: {s!r}: {e}") class TmpEnv: """Context manager that temporarily sets environment variables.""" - def __init__(self, env: dict[str, str]) -> None: + def __init__(self, env: dict[str, str], + context: dict[str, Any] | None = None) -> None: self.env = dict(env) + self.context = context self.restore_env = dict(os.environ) def __enter__(self) -> None: - for k, v in self.env.items(): - os.environ[k] = swap_env(v) + applied: list[str] = [] + try: + for k, v in self.env.items(): + os.environ[k] = swap_env(v, self.context) + applied.append(k) + except Exception: + for k in applied: + if k in self.restore_env: + os.environ[k] = self.restore_env[k] + else: + os.environ.pop(k, None) + raise def __exit__(self, exc_type: type | None, exc_val: BaseException | None, exc_tb: Any | None) -> None: for k, v in self.env.items(): diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 5869385..7c56bea 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -579,7 +579,7 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo logging.error(f"Template rendering error: {e}") raise ValueError(f"Failed to render prompt template: {e}") from e - with TmpEnv(env): + with TmpEnv(env, context={"globals": global_variables}): prompts_to_run: list[str] = await _build_prompts_to_run( task_prompt, repeat_prompt, last_mcp_tool_results, available_tools, global_variables, inputs, diff --git a/tests/test_env_utils.py b/tests/test_env_utils.py new file mode 100644 index 0000000..01b0e6a --- /dev/null +++ b/tests/test_env_utils.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Tests for env_utils: swap_env and TmpEnv with globals context.""" + +import os + +import pytest + +from seclab_taskflow_agent.env_utils import TmpEnv, swap_env + + +class TestSwapEnv: + """Tests for swap_env template rendering.""" + + def test_plain_string_unchanged(self): + assert swap_env("no templates here") == "no templates here" + + def test_env_function_works(self): + os.environ["TEST_SWAP_ENV_VAR"] = "hello" + try: + assert swap_env('{{ env("TEST_SWAP_ENV_VAR") }}') == "hello" + finally: + del os.environ["TEST_SWAP_ENV_VAR"] + + def test_globals_with_context(self): + result = swap_env( + "key-{{ globals.ghsa_id }}", + context={"globals": {"ghsa_id": "GHSA-1234"}}, + ) + assert result == "key-GHSA-1234" + + def test_globals_without_context_raises(self): + with pytest.raises(LookupError): + swap_env("{{ globals.missing }}") + + def test_context_cannot_override_env_helper(self): + """Passing an 'env' key in context must not shadow the env() function.""" + os.environ["TEST_SWAP_RESERVED"] = "works" + try: + result = swap_env( + '{{ env("TEST_SWAP_RESERVED") }}', + context={"env": "should be filtered"}, + ) + assert result == "works" + finally: + del os.environ["TEST_SWAP_RESERVED"] + + def test_no_context_backward_compat(self): + assert swap_env("plain") == "plain" + + +class TestTmpEnv: + """Tests for TmpEnv context manager with globals.""" + + def test_globals_rendered_in_env_block(self): + env = {"MY_KEY": "pvr-{{ globals.ghsa }}"} + ctx = {"globals": {"ghsa": "GHSA-5678"}} + with TmpEnv(env, context=ctx): + assert os.environ["MY_KEY"] == "pvr-GHSA-5678" + assert "MY_KEY" not in os.environ + + def test_env_function_still_works_in_tmpenv(self): + os.environ["SOURCE_VAR"] = "value" + try: + env = {"DEST_VAR": '{{ env("SOURCE_VAR") }}'} + with TmpEnv(env): + assert os.environ["DEST_VAR"] == "value" + finally: + del os.environ["SOURCE_VAR"] + + def test_tmpenv_restores_original(self): + os.environ["RESTORE_TEST"] = "original" + env = {"RESTORE_TEST": "overwritten"} + with TmpEnv(env): + assert os.environ["RESTORE_TEST"] == "overwritten" + assert os.environ["RESTORE_TEST"] == "original" + del os.environ["RESTORE_TEST"] + + def test_tmpenv_rollback_on_error(self): + """Partial env modification is rolled back if swap_env raises.""" + env = {"GOOD_KEY": "value", "BAD_KEY": "{{ globals.missing }}"} + with pytest.raises(LookupError), TmpEnv(env): + pass + assert "GOOD_KEY" not in os.environ + assert "BAD_KEY" not in os.environ