diff --git a/.env.example b/.env.example index 97ee9a2..9c2cd11 100644 --- a/.env.example +++ b/.env.example @@ -1,11 +1,3 @@ -TAVILY_API_KEY=... - +ANTHROPIC_API_KEY=... # To separate your traces from other application -LANGSMITH_PROJECT=retrieval-agent - -# The following depend on your selected configuration - -## LLM choice: -ANTHROPIC_API_KEY=.... -FIREWORKS_API_KEY=... -OPENAI_API_KEY=... +LANGSMITH_PROJECT=new-agent diff --git a/pyproject.toml b/pyproject.toml index a7320f5..b9d4251 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,11 @@ authors = [ readme = "README.md" license = { text = "MIT" } requires-python = ">=3.9" -dependencies = ["langgraph>=0.2.6", "python-dotenv>=1.0.1"] +dependencies = [ + "anthropic>=0.34.2", + "langgraph>=0.2.6", + "python-dotenv>=1.0.1", +] [project.optional-dependencies] diff --git a/src/agent/configuration.py b/src/agent/configuration.py index 54df5eb..9629b66 100644 --- a/src/agent/configuration.py +++ b/src/agent/configuration.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass, field, fields from typing import Annotated, Optional -from langchain_core.runnables import RunnableConfig, ensure_config +from langchain_core.runnables import RunnableConfig from agent import prompts @@ -20,30 +20,21 @@ class Configuration: This prompt sets the context and behavior for the agent. """ - model_name: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = ( - "anthropic/claude-3-5-sonnet-20240620" - ) - """The name of the language model to use for the agent's main interactions. - - Should be in the form: provider/model-name. - """ - - scraper_tool_model_name: Annotated[ - str, {"__template_metadata__": {"kind": "llm"}} - ] = "accounts/fireworks/models/firefunction-v2" - """The name of the language model to use for the web scraping tool. - - This model is specifically used for summarizing and extracting information from web pages. - """ - max_search_results: int = 10 - """The maximum number of search results to return for each search query.""" + model_name: Annotated[ + str, + { + "__template_metadata__": { + "kind": "llm", + } + }, + ] = "claude-3-5-sonnet-20240620" + """The name of the language model to use for our chatbot.""" @classmethod def from_runnable_config( cls, config: Optional[RunnableConfig] = None ) -> Configuration: """Create a Configuration instance from a RunnableConfig object.""" - config = ensure_config(config) - configurable = config.get("configurable") or {} + configurable = (config.get("configurable") or {}) if config else {} _fields = {f.name for f in fields(cls) if f.init} return cls(**{k: v for k, v in configurable.items() if k in _fields}) diff --git a/src/agent/graph.py b/src/agent/graph.py index dd1e3be..c2b857c 100644 --- a/src/agent/graph.py +++ b/src/agent/graph.py @@ -4,25 +4,20 @@ Works with a chat model with tool calling support. """ from datetime import datetime, timezone -from typing import Dict, List, Literal, cast +from typing import Any, Dict, List -from langchain_core.messages import AIMessage -from langchain_core.prompts import ChatPromptTemplate +import anthropic +from agent.configuration import Configuration +from agent.state import State from langchain_core.runnables import RunnableConfig from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode - -from agent.configuration import Configuration -from agent.state import InputState, State -from agent.tools import TOOLS -from agent.utils import load_chat_model # Define the function that calls the model async def call_model( state: State, config: RunnableConfig -) -> Dict[str, List[AIMessage]]: +) -> Dict[str, List[Dict[str, Any]]]: """Call the LLM powering our "agent". This function prepares the prompt, initializes the model, and processes the response. @@ -35,93 +30,44 @@ async def call_model( dict: A dictionary containing the model's response message. """ configuration = Configuration.from_runnable_config(config) - - # Create a prompt template. Customize this to change the agent's behavior. - prompt = ChatPromptTemplate.from_messages( - [("system", configuration.system_prompt), ("placeholder", "{messages}")] + system_prompt = configuration.system_prompt.format( + system_time=datetime.now(tz=timezone.utc).isoformat() ) - - # Initialize the model with tool binding. Change the model or add more tools here. - model = load_chat_model(configuration.model_name).bind_tools(TOOLS) - - # Prepare the input for the model, including the current system time - message_value = await prompt.ainvoke( - { - "messages": state.messages, - "system_time": datetime.now(tz=timezone.utc).isoformat(), - }, - config, - ) - - # Get the model's response - response = cast(AIMessage, await model.ainvoke(message_value, config)) - - # Handle the case when it's the last step and the model still wants to use a tool - if state.is_last_step and response.tool_calls: - return { - "messages": [ - AIMessage( - id=response.id, - content="Sorry, I could not find an answer to your question in the specified number of steps.", - ) - ] - } + toks = [] + async with anthropic.AsyncAnthropic() as client: + async with client.messages.stream( + model=configuration.model_name, + max_tokens=1024, + system=system_prompt, + messages=state.messages, + ) as stream: + async for text in stream.text_stream: + toks.append(text) # Return the model's response as a list to be added to existing messages - return {"messages": [response]} + return { + "messages": [ + {"role": "assistant", "content": [{"type": "text", "text": "".join(toks)}]} + ] + } # Define a new graph -workflow = StateGraph(State, input=InputState, config_schema=Configuration) +workflow = StateGraph(State, config_schema=Configuration) # Define the two nodes we will cycle between workflow.add_node(call_model) -workflow.add_node("tools", ToolNode(TOOLS)) # Set the entrypoint as `call_model` # This means that this node is the first one called workflow.add_edge("__start__", "call_model") -def route_model_output(state: State) -> Literal["__end__", "tools"]: - """Determine the next node based on the model's output. - - This function checks if the model's last message contains tool calls. - - Args: - state (State): The current state of the conversation. - - Returns: - str: The name of the next node to call ("__end__" or "tools"). - """ - last_message = state.messages[-1] - if not isinstance(last_message, AIMessage): - raise ValueError( - f"Expected AIMessage in output edges, but got {type(last_message).__name__}" - ) - # If there is no tool call, then we finish - if not last_message.tool_calls: - return "__end__" - # Otherwise we execute the requested actions - return "tools" - - -# Add a conditional edge to determine the next step after `call_model` -workflow.add_conditional_edges( - "call_model", - # After call_model finishes running, the next node(s) are scheduled - # based on the output from route_model_output - route_model_output, -) - -# Add a normal edge from `tools` to `call_model` -# This creates a cycle: after using tools, we always return to the model -workflow.add_edge("tools", "call_model") - # Compile the workflow into an executable graph # You can customize this by adding interrupt points for state updates graph = workflow.compile( interrupt_before=[], # Add node names here to update state before they're called interrupt_after=[], # Add node names here to update state after they're called ) +graph.name = "My New Graph" # This defines the custom name in LangSmith diff --git a/src/agent/prompts.py b/src/agent/prompts.py index b7d8d46..c88bcf6 100644 --- a/src/agent/prompts.py +++ b/src/agent/prompts.py @@ -1,5 +1,5 @@ -"""Default prompts used by the agent.""" +"""Default prompts used by the chatbot.""" -SYSTEM_PROMPT = """You are a helpful AI assistant. +SYSTEM_PROMPT = """You are a helpful (if not sassy) personal assistant. System time: {system_time}""" diff --git a/src/agent/state.py b/src/agent/state.py index 703bcf9..7db40d4 100644 --- a/src/agent/state.py +++ b/src/agent/state.py @@ -2,59 +2,23 @@ from __future__ import annotations +import operator from dataclasses import dataclass, field from typing import Sequence -from langchain_core.messages import AnyMessage -from langgraph.graph import add_messages -from langgraph.managed import IsLastStep from typing_extensions import Annotated @dataclass -class InputState: +class State: """Defines the input state for the agent, representing a narrower interface to the outside world. This class is used to define the initial state and structure of incoming data. """ - messages: Annotated[Sequence[AnyMessage], add_messages] = field( - default_factory=list - ) + messages: Annotated[Sequence[dict], operator.add] = field(default_factory=list) """ Messages tracking the primary execution state of the agent. - Typically accumulates a pattern of: - 1. HumanMessage - user input - 2. AIMessage with .tool_calls - agent picking tool(s) to use to collect information - 3. ToolMessage(s) - the responses (or errors) from the executed tools - 4. AIMessage without .tool_calls - agent responding in unstructured format to the user - 5. HumanMessage - user responds with the next conversational turn - - Steps 2-5 may repeat as needed. - - The `add_messages` annotation ensures that new messages are merged with existing ones, - updating by ID to maintain an "append-only" state unless a message with the same ID is provided. + Typically accumulates a pattern of user, assistant, user, ... etc. messages. """ - - -@dataclass -class State(InputState): - """Represents the complete state of the agent, extending InputState with additional attributes. - - This class can be used to store any information needed throughout the agent's lifecycle. - """ - - is_last_step: IsLastStep = field(default=False) - """ - Indicates whether the current step is the last one before the graph raises an error. - - This is a 'managed' variable, controlled by the state machine rather than user code. - It is set to 'True' when the step count reaches recursion_limit - 1. - """ - - # Additional attributes can be added here as needed. - # Common examples include: - # retrieved_documents: List[Document] = field(default_factory=list) - # extracted_entities: Dict[str, Any] = field(default_factory=dict) - # api_connections: Dict[str, Any] = field(default_factory=dict) diff --git a/src/agent/utils.py b/src/agent/utils.py deleted file mode 100644 index d17b53f..0000000 --- a/src/agent/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Utility & helper functions.""" - -from langchain.chat_models import init_chat_model -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage - - -def get_message_text(msg: BaseMessage) -> str: - """Get the text content of a message.""" - content = msg.content - if isinstance(content, str): - return content - elif isinstance(content, dict): - return content.get("text", "") - else: - txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content] - return "".join(txts).strip() - - -def load_chat_model(fully_specified_name: str) -> BaseChatModel: - """Load a chat model from a fully specified name. - - Args: - fully_specified_name (str): String in the format 'provider/model'. - """ - provider, model = fully_specified_name.split("/", maxsplit=1) - return init_chat_model(model, model_provider=provider) diff --git a/tests/integration_tests/test_graph.py b/tests/integration_tests/test_graph.py index fe52f69..778fee2 100644 --- a/tests/integration_tests/test_graph.py +++ b/tests/integration_tests/test_graph.py @@ -1,14 +1,12 @@ import pytest -from langsmith import unit - from agent import graph +from langsmith import expect, unit @pytest.mark.asyncio @unit async def test_agent_simple_passthrough() -> None: res = await graph.ainvoke( - {"messages": [("user", "Who is the founder of LangChain?")]} + {"messages": [{"role": "user", "content": "What's 62 - 19?"}]} ) - - assert "harrison" in str(res["messages"][-1].content).lower() + expect(res["messages"][-1]["content"][0]["text"]).to_contain("43")