diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 58949157..6fb63d83 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -10,6 +10,7 @@ from .computer import AsyncComputer, Button, Computer, Environment from .exceptions import ( AgentsException, + ErrorRunData, InputGuardrailTripwireTriggered, MaxTurnsExceeded, ModelBehaviorError, @@ -173,6 +174,7 @@ def enable_verbose_stdout_logging(): "Environment", "Button", "AgentsException", + "ErrorRunData", "InputGuardrailTripwireTriggered", "OutputGuardrailTripwireTriggered", "MaxTurnsExceeded", diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 78898f01..066216b3 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -1,12 +1,35 @@ -from typing import TYPE_CHECKING +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from .agent import Agent from .guardrail import InputGuardrailResult, OutputGuardrailResult + from .items import ModelResponse, RunItem, TResponseInputItem + from .run_context import RunContextWrapper + + +@dataclass +class ErrorRunData: + """Data collected from an agent run when an exception occurs.""" + + input: str | list[TResponseInputItem] + new_items: list[RunItem] + raw_responses: list[ModelResponse] + last_agent: "Agent[Any]" + context_wrapper: "RunContextWrapper[Any]" + input_guardrail_results: list[InputGuardrailResult] + output_guardrail_results: list[OutputGuardrailResult] class AgentsException(Exception): """Base class for all exceptions in the Agents SDK.""" + run_data: ErrorRunData | None + + def __init__(self, *args: object) -> None: + super().__init__(*args) + self.run_data = None + class MaxTurnsExceeded(AgentsException): """Exception raised when the maximum number of turns is exceeded.""" @@ -15,6 +38,7 @@ class MaxTurnsExceeded(AgentsException): def __init__(self, message: str): self.message = message + super().__init__(message) class ModelBehaviorError(AgentsException): @@ -26,6 +50,7 @@ class ModelBehaviorError(AgentsException): def __init__(self, message: str): self.message = message + super().__init__(message) class UserError(AgentsException): @@ -35,6 +60,7 @@ class UserError(AgentsException): def __init__(self, message: str): self.message = message + super().__init__(message) class InputGuardrailTripwireTriggered(AgentsException): diff --git a/src/agents/result.py b/src/agents/result.py index 243db155..de6593b2 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -11,7 +11,12 @@ from ._run_impl import QueueCompleteSentinel from .agent import Agent from .agent_output import AgentOutputSchemaBase -from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded +from .exceptions import ( + AgentsException, + ErrorRunData, + InputGuardrailTripwireTriggered, + MaxTurnsExceeded, +) from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .logger import logger @@ -208,28 +213,78 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: def _check_errors(self): if self.current_turn > self.max_turns: - self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") + exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") + exc.run_data = ErrorRunData( + input=self.input, + new_items=self.new_items, + raw_responses=self.raw_responses, + last_agent=self.current_agent, + context_wrapper=self.context_wrapper, + input_guardrail_results=self.input_guardrail_results, + output_guardrail_results=self.output_guardrail_results, + ) + self._stored_exception = exc # Fetch all the completed guardrail results from the queue and raise if needed while not self._input_guardrail_queue.empty(): guardrail_result = self._input_guardrail_queue.get_nowait() if guardrail_result.output.tripwire_triggered: - self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result) + exc = InputGuardrailTripwireTriggered(guardrail_result) + exc.run_data = ErrorRunData( + input=self.input, + new_items=self.new_items, + raw_responses=self.raw_responses, + last_agent=self.current_agent, + context_wrapper=self.context_wrapper, + input_guardrail_results=self.input_guardrail_results, + output_guardrail_results=self.output_guardrail_results, + ) + self._stored_exception = exc # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): exc = self._run_impl_task.exception() if exc and isinstance(exc, Exception): + if isinstance(exc, AgentsException) and exc.run_data is None: + exc.run_data = ErrorRunData( + input=self.input, + new_items=self.new_items, + raw_responses=self.raw_responses, + last_agent=self.current_agent, + context_wrapper=self.context_wrapper, + input_guardrail_results=self.input_guardrail_results, + output_guardrail_results=self.output_guardrail_results, + ) self._stored_exception = exc if self._input_guardrails_task and self._input_guardrails_task.done(): exc = self._input_guardrails_task.exception() if exc and isinstance(exc, Exception): + if isinstance(exc, AgentsException) and exc.run_data is None: + exc.run_data = ErrorRunData( + input=self.input, + new_items=self.new_items, + raw_responses=self.raw_responses, + last_agent=self.current_agent, + context_wrapper=self.context_wrapper, + input_guardrail_results=self.input_guardrail_results, + output_guardrail_results=self.output_guardrail_results, + ) self._stored_exception = exc if self._output_guardrails_task and self._output_guardrails_task.done(): exc = self._output_guardrails_task.exception() if exc and isinstance(exc, Exception): + if isinstance(exc, AgentsException) and exc.run_data is None: + exc.run_data = ErrorRunData( + input=self.input, + new_items=self.new_items, + raw_responses=self.raw_responses, + last_agent=self.current_agent, + context_wrapper=self.context_wrapper, + input_guardrail_results=self.input_guardrail_results, + output_guardrail_results=self.output_guardrail_results, + ) self._stored_exception = exc def _cleanup_tasks(self): diff --git a/src/agents/run.py b/src/agents/run.py index b196c3bf..3b169abf 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -22,6 +22,7 @@ from .agent_output import AgentOutputSchema, AgentOutputSchemaBase from .exceptions import ( AgentsException, + ErrorRunData, InputGuardrailTripwireTriggered, MaxTurnsExceeded, ModelBehaviorError, @@ -283,6 +284,17 @@ async def run( raise AgentsException( f"Unknown next step type: {type(turn_result.next_step)}" ) + except AgentsException as exc: + exc.run_data = ErrorRunData( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + last_agent=current_agent, + context_wrapper=context_wrapper, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=[], + ) + raise finally: if current_span: current_span.finish(reset_current=True) @@ -609,6 +621,19 @@ async def _run_streamed_impl( streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) elif isinstance(turn_result.next_step, NextStepRunAgain): pass + except AgentsException as exc: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + exc.run_data = ErrorRunData( + input=streamed_result.input, + new_items=streamed_result.new_items, + raw_responses=streamed_result.raw_responses, + last_agent=current_agent, + context_wrapper=context_wrapper, + input_guardrail_results=streamed_result.input_guardrail_results, + output_guardrail_results=streamed_result.output_guardrail_results, + ) + raise except Exception as e: if current_span: _error_tracing.attach_error_to_span( diff --git a/tests/test_error_run_data.py b/tests/test_error_run_data.py new file mode 100644 index 00000000..63143fbe --- /dev/null +++ b/tests/test_error_run_data.py @@ -0,0 +1,42 @@ +import json +import pytest + +from agents import Agent, Runner, MaxTurnsExceeded, ErrorRunData +from .fake_model import FakeModel +from .test_responses import get_text_message, get_function_tool, get_function_tool_call + + +@pytest.mark.asyncio +async def test_run_error_includes_data(): + model = FakeModel() + agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")]) + model.add_multiple_turn_outputs([ + [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], + [get_text_message("done")], + ]) + with pytest.raises(MaxTurnsExceeded) as exc: + await Runner.run(agent, input="hello", max_turns=1) + data = exc.value.run_data + assert isinstance(data, ErrorRunData) + assert data.last_agent == agent + assert len(data.raw_responses) == 1 + assert len(data.new_items) > 0 + + +@pytest.mark.asyncio +async def test_streamed_run_error_includes_data(): + model = FakeModel() + agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")]) + model.add_multiple_turn_outputs([ + [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], + [get_text_message("done")], + ]) + result = Runner.run_streamed(agent, input="hello", max_turns=1) + with pytest.raises(MaxTurnsExceeded) as exc: + async for _ in result.stream_events(): + pass + data = exc.value.run_data + assert isinstance(data, ErrorRunData) + assert data.last_agent == agent + assert len(data.raw_responses) == 1 + assert len(data.new_items) > 0