From 60ba2a11c1d86cd6f296a227178ccc0b32f9e03e Mon Sep 17 00:00:00 2001 From: bracesproul Date: Thu, 6 Mar 2025 11:42:39 -0800 Subject: [PATCH 1/3] fix: Tool calls for trip planner --- agent/find-tool-call.ts | 2 +- agent/stockbroker/nodes/tools.tsx | 3 +- agent/trip-planner/nodes/tools.tsx | 117 ++++++++---------- .../accommodations-list/index.tsx | 2 +- 4 files changed, 57 insertions(+), 67 deletions(-) diff --git a/agent/find-tool-call.ts b/agent/find-tool-call.ts index 0372551..7c9cc40 100644 --- a/agent/find-tool-call.ts +++ b/agent/find-tool-call.ts @@ -10,5 +10,5 @@ interface ToolCall { export function findToolCall(name: Name) { return ( x: ToolCall, - ): x is { name: Name; args: z.infer } => x.name === name; + ): x is { name: Name; args: z.infer; id?: string } => x.name === name; } diff --git a/agent/stockbroker/nodes/tools.tsx b/agent/stockbroker/nodes/tools.tsx index 03bc0e5..1b1f252 100644 --- a/agent/stockbroker/nodes/tools.tsx +++ b/agent/stockbroker/nodes/tools.tsx @@ -158,8 +158,7 @@ export async function callTools( buyStockToolCall.args.ticker, ); ui.write("buy-stock", { - toolCallId: - message.tool_calls?.find((tc) => tc.name === "buy-stock")?.id ?? "", + toolCallId: buyStockToolCall.id ?? "", snapshot, quantity: buyStockToolCall.args.quantity, }); diff --git a/agent/trip-planner/nodes/tools.tsx b/agent/trip-planner/nodes/tools.tsx index 3a60fda..8966bbc 100644 --- a/agent/trip-planner/nodes/tools.tsx +++ b/agent/trip-planner/nodes/tools.tsx @@ -5,46 +5,39 @@ import type ComponentMap from "../../uis/index"; import { z } from "zod"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { getAccommodationsListProps } from "../utils/get-accommodations"; +import { findToolCall } from "../../find-tool-call"; -const schema = z.object({ - listAccommodations: z - .boolean() - .optional() - .describe( - "Whether or not the user has requested a list of accommodations for their trip.", - ), - bookAccommodation: z - .boolean() - .optional() - .describe( - "Whether or not the user has requested to book a reservation for an accommodation. If true, you MUST also set the 'accommodationName' field", - ), - accommodationName: z - .string() - .optional() - .describe( - "The name of the accommodation to book a reservation for. Only required if the 'bookAccommodation' field is true.", - ), +const listAccommodationsSchema = z.object({}).describe("A tool to list accommodations for the user") +const bookAccommodationSchema = z.object({ + accommodationName: z.string().describe("The name of the accommodation to book a reservation for"), +}).describe("A tool to book a reservation for an accommodation"); +const listRestaurantsSchema = z.object({}).describe("A tool to list restaurants for the user"); +const bookRestaurantSchema = z.object({ + restaurantName: z.string().describe("The name of the restaurant to book a reservation for"), +}).describe("A tool to book a reservation for a restaurant"); - listRestaurants: z - .boolean() - .optional() - .describe( - "Whether or not the user has requested a list of restaurants for their trip.", - ), - bookRestaurant: z - .boolean() - .optional() - .describe( - "Whether or not the user has requested to book a reservation for a restaurant. If true, you MUST also set the 'restaurantName' field", - ), - restaurantName: z - .string() - .optional() - .describe( - "The name of the restaurant to book a reservation for. Only required if the 'bookRestaurant' field is true.", - ), -}); +const ACCOMMODATIONS_TOOLS = [ + { + name: "list-accommodations", + description: "A tool to list accommodations for the user", + schema: listAccommodationsSchema, + }, + { + name: "book-accommodation", + description: "A tool to book a reservation for an accommodation", + schema: bookAccommodationSchema, + }, + { + name: "list-restaurants", + description: "A tool to list restaurants for the user", + schema: listRestaurantsSchema, + }, + { + name: "book-restaurant", + description: "A tool to book a reservation for a restaurant", + schema: bookRestaurantSchema, + }, +]; export async function callTools( state: TripPlannerState, @@ -56,18 +49,7 @@ export async function callTools( const ui = typedUi(config); - const llm = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools( - [ - { - name: "trip-planner", - description: "A series of actions to take for planning a trip", - schema, - }, - ], - { - tool_choice: "trip-planner", - }, - ); + const llm = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools(ACCOMMODATIONS_TOOLS); const response = await llm.invoke([ { @@ -78,35 +60,44 @@ export async function callTools( ...state.messages, ]); - const tripPlan = response.tool_calls?.[0]?.args as - | z.infer - | undefined; - const toolCallId = response.tool_calls?.[0]?.id; - if (!tripPlan || !toolCallId) { - throw new Error("No trip plan found"); + const listAccommodationsToolCall = response.tool_calls?.find( + findToolCall("list-accommodations"), + ); + const bookAccommodationToolCall = response.tool_calls?.find( + findToolCall("book-accommodation"), + ); + const listRestaurantsToolCall = response.tool_calls?.find( + findToolCall("list-restaurants"), + ); + const bookRestaurantToolCall = response.tool_calls?.find( + findToolCall("book-restaurant"), + ); + + if (!listAccommodationsToolCall && !bookAccommodationToolCall && !listRestaurantsToolCall && !bookRestaurantToolCall) { + throw new Error("No tool calls found"); } - if (tripPlan.listAccommodations) { + if (listAccommodationsToolCall) { ui.write("accommodations-list", { - toolCallId, + toolCallId: listAccommodationsToolCall.id ?? "", ...getAccommodationsListProps(state.tripDetails), }); } - if (tripPlan.bookAccommodation && tripPlan.accommodationName) { + if (bookAccommodationToolCall && bookAccommodationToolCall.args.accommodationName) { ui.write("book-accommodation", { tripDetails: state.tripDetails, - accommodationName: tripPlan.accommodationName, + accommodationName: bookAccommodationToolCall.args.accommodationName, }); } - if (tripPlan.listRestaurants) { + if (listRestaurantsToolCall) { ui.write("restaurants-list", { tripDetails: state.tripDetails }); } - if (tripPlan.bookRestaurant && tripPlan.restaurantName) { + if (bookRestaurantToolCall && bookRestaurantToolCall.args.restaurantName) { ui.write("book-restaurant", { tripDetails: state.tripDetails, - restaurantName: tripPlan.restaurantName, + restaurantName: bookRestaurantToolCall.args.restaurantName, }); } diff --git a/agent/uis/trip-planner/accommodations-list/index.tsx b/agent/uis/trip-planner/accommodations-list/index.tsx index 87e28f2..0a46ede 100644 --- a/agent/uis/trip-planner/accommodations-list/index.tsx +++ b/agent/uis/trip-planner/accommodations-list/index.tsx @@ -275,7 +275,7 @@ export default function AccommodationsList({ type: "tool", tool_call_id: toolCallId, id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`, - name: "trip-planner", + name: "book-accommodation", content: JSON.stringify(orderDetails), }, { From 5eb600f60d63c2b7ac391be7694ef4c59a83da11 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Thu, 6 Mar 2025 11:51:11 -0800 Subject: [PATCH 2/3] nit --- agent/trip-planner/nodes/extraction.tsx | 27 +++++++++++++++---------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/agent/trip-planner/nodes/extraction.tsx b/agent/trip-planner/nodes/extraction.tsx index 719997a..f578a54 100644 --- a/agent/trip-planner/nodes/extraction.tsx +++ b/agent/trip-planner/nodes/extraction.tsx @@ -1,7 +1,10 @@ +import { v4 as uuidv4 } from "uuid"; import { ChatOpenAI } from "@langchain/openai"; import { TripDetails, TripPlannerState, TripPlannerUpdate } from "../types"; import { z } from "zod"; import { formatMessages } from "agent/utils/format-messages"; +import { ToolMessage } from "@langchain/langgraph-sdk"; +import { DO_NOT_RENDER_ID_PREFIX } from "@/lib/ensure-tool-responses"; function calculateDates( startDate: string | undefined, @@ -60,9 +63,7 @@ export async function extraction( .describe("The end date of the trip. Should be in YYYY-MM-DD format"), numberOfGuests: z .number() - .optional() - .default(2) - .describe("The number of guests for the trip"), + .describe("The number of guests for the trip. Should default to 2 if not specified"), }); const model = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools([ @@ -96,15 +97,13 @@ Extract only what is specified by the user. It is okay to leave fields blank if { role: "human", content: humanMessage }, ]); - const extractedDetails = response.tool_calls?.[0]?.args as - | z.infer - | undefined; - - if (!extractedDetails) { + const toolCall = response.tool_calls?.[0]; + if (!toolCall) { return { messages: [response], }; } + const extractedDetails = toolCall.args as z.infer; const { startDate, endDate } = calculateDates( extractedDetails.startDate, @@ -114,13 +113,19 @@ Extract only what is specified by the user. It is okay to leave fields blank if const extractionDetailsWithDefaults: TripDetails = { startDate, endDate, - numberOfGuests: extractedDetails.numberOfGuests - ? extractedDetails.numberOfGuests - : 2, + numberOfGuests: extractedDetails.numberOfGuests ?? 2, location: extractedDetails.location, }; + const extractToolResponse: ToolMessage = { + type: "tool", + id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`, + tool_call_id: toolCall.id ?? "", + content: "Successfully extracted trip details", + }; + return { tripDetails: extractionDetailsWithDefaults, + messages: [response, extractToolResponse] }; } From 16beab0ebc78e61dcb4757601ff996d61b7e0d8e Mon Sep 17 00:00:00 2001 From: bracesproul Date: Mon, 10 Mar 2025 10:42:09 -0700 Subject: [PATCH 3/3] cr --- agent/uis/index.tsx | 4 ---- 1 file changed, 4 deletions(-) diff --git a/agent/uis/index.tsx b/agent/uis/index.tsx index f932995..11b2163 100644 --- a/agent/uis/index.tsx +++ b/agent/uis/index.tsx @@ -1,9 +1,7 @@ import StockPrice from "./stockbroker/stock-price"; import PortfolioView from "./stockbroker/portfolio-view"; import AccommodationsList from "./trip-planner/accommodations-list"; -import BookAccommodation from "./trip-planner/book-accommodation"; import RestaurantsList from "./trip-planner/restaurants-list"; -import BookRestaurant from "./trip-planner/book-restaurant"; import BuyStock from "./stockbroker/buy-stock"; import Plan from "./open-code/plan"; import ProposedChange from "./open-code/proposed-change"; @@ -12,9 +10,7 @@ const ComponentMap = { "stock-price": StockPrice, portfolio: PortfolioView, "accommodations-list": AccommodationsList, - "book-accommodation": BookAccommodation, "restaurants-list": RestaurantsList, - "book-restaurant": BookRestaurant, "buy-stock": BuyStock, "code-plan": Plan, "proposed-change": ProposedChange,