diff --git a/src/components/thread/agent-inbox/hooks/use-interrupted-actions.tsx b/src/components/thread/agent-inbox/hooks/use-interrupted-actions.tsx index 0e371c7..58b7775 100644 --- a/src/components/thread/agent-inbox/hooks/use-interrupted-actions.tsx +++ b/src/components/thread/agent-inbox/hooks/use-interrupted-actions.tsx @@ -11,8 +11,8 @@ import { import { createDefaultHumanResponse } from "../utils"; import { toast } from "sonner"; import { HumanInterrupt, HumanResponse } from "@langchain/langgraph/prebuilt"; -import { useThreads } from "@/providers/Thread"; -import { StringParam, useQueryParam } from "use-query-params"; +import { END } from "@langchain/langgraph/web"; +import { useStreamContext } from "@/providers/Stream"; interface UseInterruptedActionsInput { interrupt: HumanInterrupt; @@ -53,9 +53,7 @@ interface UseInterruptedActionsValue { export default function useInterruptedActions({ interrupt, }: UseInterruptedActionsInput): UseInterruptedActionsValue { - const { resumeRun, ignoreRun } = useThreads(); - const [threadId] = useQueryParam("threadId", StringParam); - + const thread = useStreamContext(); const [humanResponse, setHumanResponse] = useState( [], ); @@ -82,6 +80,28 @@ export default function useInterruptedActions({ } }, [interrupt]); + const resumeRun = ( + response: HumanResponse[], + ): boolean => { + try { + thread.submit({}, { + command: { + resume: response, + update: { + messages: [{ + type: "human", + content: `Sending type '${response[0].type}' to interrupt...` + }] + } + }, + }) + return true; + } catch (e: any) { + console.error("Error sending human response", e); + return false; + } + }; + const handleSubmit = async ( e: React.MouseEvent | KeyboardEvent, ) => { @@ -95,15 +115,6 @@ export default function useInterruptedActions({ }); return; } - if (!threadId) { - toast("Error", { - description: "Please select a thread.", - duration: 5000, - richColors: true, - closeButton: true, - }); - return; - } let errorOccurred = false; initialHumanInterruptEditValue.current = {}; @@ -156,10 +167,8 @@ export default function useInterruptedActions({ setLoading(true); setStreaming(true); - const response = resumeRun(threadId, [input], { - stream: true, - }); - if (!response) { + const resumedSuccessfully = resumeRun([input]); + if (!resumedSuccessfully) { // This will only be undefined if the graph ID is not found // in this case, the method will trigger a toast for us. return; @@ -204,7 +213,7 @@ export default function useInterruptedActions({ } } else { setLoading(true); - await resumeRun(threadId, humanResponse); + resumeRun(humanResponse); toast("Success", { description: "Response submitted successfully.", @@ -219,15 +228,6 @@ export default function useInterruptedActions({ e: React.MouseEvent, ) => { e.preventDefault(); - if (!threadId) { - toast("Error", { - description: "Please select a thread.", - duration: 5000, - richColors: true, - closeButton: true, - }); - return; - } const ignoreResponse = humanResponse.find((r) => r.type === "ignore"); if (!ignoreResponse) { @@ -241,7 +241,7 @@ export default function useInterruptedActions({ setLoading(true); initialHumanInterruptEditValue.current = {}; - await resumeRun(threadId, [ignoreResponse]); + resumeRun([ignoreResponse]); setLoading(false); toast("Successfully ignored thread", { @@ -253,20 +253,36 @@ export default function useInterruptedActions({ e: React.MouseEvent, ) => { e.preventDefault(); - if (!threadId) { - toast("Error", { - description: "Please select a thread.", - duration: 5000, - richColors: true, - closeButton: true, - }); - return; - } setLoading(true); initialHumanInterruptEditValue.current = {}; - await ignoreRun(threadId); + try { + thread.submit({}, { + command: { + goto: END, + update: { + messages: [{ + type: "human", + content: "Marking thread as resolved." + }] + } + } + }) + + toast("Success", { + description: "Marked thread as resolved.", + duration: 3000, + }); + } catch (e) { + console.error("Error marking thread as resolved", e); + toast("Error", { + description: "Failed to mark thread as resolved.", + richColors: true, + closeButton: true, + duration: 3000, + }); + } setLoading(false); }; diff --git a/src/components/thread/messages/ai.tsx b/src/components/thread/messages/ai.tsx index 61c47c3..76e2f7f 100644 --- a/src/components/thread/messages/ai.tsx +++ b/src/components/thread/messages/ai.tsx @@ -80,6 +80,7 @@ export function AssistantMessage({ const contentString = getContentString(message.content); const thread = useStreamContext(); + const isLastMessage = thread.messages[thread.messages.length - 1].id === message.id; const meta = thread.getMessagesMetadata(message); const interrupt = thread.interrupt; const parentCheckpoint = meta?.firstSeenState?.parent_checkpoint; @@ -118,7 +119,7 @@ export function AssistantMessage({ )) || (hasToolCalls && )} - {isAgentInboxInterruptSchema(interrupt?.value) && ( + {isAgentInboxInterruptSchema(interrupt?.value) && isLastMessage && ( )}
{children} diff --git a/src/providers/Thread.tsx b/src/providers/Thread.tsx index 1852cf6..8af653f 100644 --- a/src/providers/Thread.tsx +++ b/src/providers/Thread.tsx @@ -1,6 +1,6 @@ import { validate } from "uuid"; import { getApiKey } from "@/lib/api-key"; -import { Client, Run, Thread } from "@langchain/langgraph-sdk"; +import { Thread } from "@langchain/langgraph-sdk"; import { useQueryParam, StringParam } from "use-query-params"; import { createContext, @@ -11,9 +11,7 @@ import { Dispatch, SetStateAction, } from "react"; -import { END } from "@langchain/langgraph/web"; -import { HumanResponse } from "@langchain/langgraph/prebuilt"; -import { toast } from "sonner"; +import { createClient } from "./client"; interface ThreadContextType { getThreads: () => Promise; @@ -21,32 +19,10 @@ interface ThreadContextType { setThreads: Dispatch>; threadsLoading: boolean; setThreadsLoading: Dispatch>; - resumeRun: ( - threadId: string, - response: HumanResponse[], - options?: { - stream?: TStream; - }, - ) => TStream extends true - ? - | AsyncGenerator<{ - event: Record; - data: any; - }> - | undefined - : Promise | undefined; - ignoreRun: (threadId: string) => Promise; } const ThreadContext = createContext(undefined); -function createClient(apiUrl: string, apiKey: string | undefined) { - return new Client({ - apiKey, - apiUrl, - }); -} - function getThreadSearchMetadata( assistantId: string, ): { graph_id: string } | { assistant_id: string } { @@ -77,76 +53,12 @@ export function ThreadProvider({ children }: { children: ReactNode }) { return threads; }, [apiUrl, assistantId]); - const resumeRun = ( - threadId: string, - response: HumanResponse[], - options?: { - stream?: TStream; - }, - ): TStream extends true - ? - | AsyncGenerator<{ - event: Record; - data: any; - }> - | undefined - : Promise | undefined => { - if (!apiUrl || !assistantId) return undefined; - const client = createClient(apiUrl, getApiKey() ?? undefined); - - try { - if (options?.stream) { - return client.runs.stream(threadId, assistantId, { - command: { - resume: response, - }, - streamMode: "events", - }) as any; // Type assertion needed due to conditional return type - } - return client.runs.create(threadId, assistantId, { - command: { - resume: response, - }, - }) as any; // Type assertion needed due to conditional return type - } catch (e: any) { - console.error("Error sending human response", e); - throw e; - } - }; - - const ignoreRun = async (threadId: string) => { - if (!apiUrl || !assistantId) return; - const client = createClient(apiUrl, getApiKey() ?? undefined); - - try { - await client.threads.updateState(threadId, { - values: null, - asNode: END, - }); - - toast("Success", { - description: "Ignored thread", - duration: 3000, - }); - } catch (e) { - console.error("Error ignoring thread", e); - toast("Error", { - description: "Failed to ignore thread", - richColors: true, - closeButton: true, - duration: 3000, - }); - } - }; - const value = { getThreads, threads, setThreads, threadsLoading, setThreadsLoading, - resumeRun, - ignoreRun, }; return ( diff --git a/src/providers/client.ts b/src/providers/client.ts new file mode 100644 index 0000000..2b9d073 --- /dev/null +++ b/src/providers/client.ts @@ -0,0 +1,8 @@ +import { Client } from "@langchain/langgraph-sdk"; + +export function createClient(apiUrl: string, apiKey: string | undefined) { + return new Client({ + apiKey, + apiUrl, + }); +}