diff --git a/apps/coder/src/routes/messages.ts b/apps/coder/src/routes/messages.ts index 8405ee1..abc988f 100644 --- a/apps/coder/src/routes/messages.ts +++ b/apps/coder/src/routes/messages.ts @@ -5,6 +5,33 @@ import type { Broker } from '@boocode/server/broker'; import type { WsFrame } from '@boocode/server/ws-frames'; import { resolveChatId } from './chat-resolve.js'; +const AnswerUserInputBody = z.object({ + tool_call_id: z.string().min(1), + answers: z + .array( + z.object({ + question: z.string(), + selected_options: z.array(z.string()), + free_text: z.string().nullable(), + }), + ) + .min(1) + .max(3), +}); + +const AskUserInputArgs = z.object({ + questions: z + .array( + z.object({ + question: z.string(), + type: z.enum(['single_select', 'multi_select']), + options: z.array(z.string()).min(1), + }), + ) + .min(1) + .max(3), +}); + const SendBody = z.object({ content: z.string().min(1).max(64_000), pane_id: z.string().min(1).max(200), @@ -219,6 +246,138 @@ export function registerMessageRoutes( }, ); + // POST /api/chats/:id/answer_user_input — answer a pending ask_user_input + app.post<{ Params: { id: string } }>( + '/api/chats/:id/answer_user_input', + async (req, reply) => { + const parsed = AnswerUserInputBody.safeParse(req.body); + if (!parsed.success) { + reply.code(400); + return { error: 'invalid_body', details: parsed.error.flatten() }; + } + const { tool_call_id, answers } = parsed.data; + + const chatRows = await sql<{ id: string; session_id: string }[]>` + SELECT id, session_id FROM chats WHERE id = ${req.params.id} AND status = 'open' + `; + if (chatRows.length === 0) { + reply.code(404); + return { error: 'chat_not_found' }; + } + const chat = chatRows[0]!; + const sessionId = chat.session_id; + + const callerRows = await sql<{ + message_id: string; + payload: { id: string; name: string; args: Record }; + }[]>` + SELECT p.message_id, p.payload + FROM message_parts p + JOIN messages m ON m.id = p.message_id + WHERE m.chat_id = ${chat.id} + AND m.role = 'assistant' + AND p.kind = 'tool_call' + AND p.payload->>'id' = ${tool_call_id} + ORDER BY m.created_at DESC + LIMIT 1 + `; + if (!callerRows[0]) { + reply.code(404); + return { error: 'unknown_tool_call_id' }; + } + const foundCall = callerRows[0].payload; + if (foundCall.name !== 'ask_user_input') { + reply.code(400); + return { error: 'tool_call_not_ask_user_input' }; + } + + const argsParsed = AskUserInputArgs.safeParse(foundCall.args); + if (!argsParsed.success) { + reply.code(400); + return { error: 'mismatched_answer_shape', detail: 'tool_call args invalid' }; + } + const questions = argsParsed.data.questions; + if (answers.length !== questions.length) { + reply.code(400); + return { error: 'mismatched_answer_shape', detail: `expected ${questions.length} answer(s), got ${answers.length}` }; + } + for (let i = 0; i < questions.length; i++) { + const q = questions[i]!; + const a = answers[i]!; + for (const sel of a.selected_options) { + if (!q.options.includes(sel)) { + reply.code(400); + return { error: 'mismatched_answer_shape', detail: `answer ${i + 1} option not in question: ${sel}` }; + } + } + if (q.type === 'single_select' && a.selected_options.length > 1) { + reply.code(400); + return { error: 'mismatched_answer_shape', detail: `answer ${i + 1} multi on single_select` }; + } + if (a.selected_options.length === 0 && (!a.free_text || !a.free_text.trim())) { + reply.code(400); + return { error: 'mismatched_answer_shape', detail: `answer ${i + 1} is empty` }; + } + } + + const toolRows = await sql<{ + message_id: string; + payload: { tool_call_id: string; output: unknown }; + }[]>` + SELECT p.message_id, p.payload + FROM message_parts p + JOIN messages m ON m.id = p.message_id + WHERE m.chat_id = ${chat.id} + AND m.role = 'tool' + AND p.kind = 'tool_result' + AND p.payload->>'tool_call_id' = ${tool_call_id} + ORDER BY m.created_at DESC + LIMIT 1 + `; + if (!toolRows[0]) { + reply.code(404); + return { error: 'unknown_tool_call_id', detail: 'tool message not found' }; + } + if (toolRows[0].payload?.output !== null) { + reply.code(409); + return { error: 'tool_call_already_answered' }; + } + + const answerSet = { answers }; + const newToolResults = { tool_call_id, output: answerSet, truncated: false }; + const toolMessageId = toolRows[0].message_id; + + const result = await sql.begin(async (tx) => { + await tx`DELETE FROM message_parts WHERE message_id = ${toolMessageId} AND kind = 'tool_result'`; + await tx` + INSERT INTO message_parts (message_id, sequence, kind, payload) + VALUES (${toolMessageId}, 0, 'tool_result', ${tx.json(newToolResults as never)}) + `; + const [assistantMsg] = await tx<{ id: string }[]>` + INSERT INTO messages (session_id, chat_id, role, content, status, created_at) + VALUES (${sessionId}, ${chat.id}, 'assistant', '', 'streaming', clock_timestamp()) + RETURNING id + `; + await tx`UPDATE sessions SET updated_at = clock_timestamp() WHERE id = ${sessionId}`; + await tx`UPDATE chats SET updated_at = clock_timestamp() WHERE id = ${chat.id}`; + return { tool_message_id: toolMessageId, assistant_message_id: assistantMsg!.id }; + }); + + broker.publishFrame(sessionId, { + type: 'tool_result', + tool_message_id: result.tool_message_id, + tool_call_id, + chat_id: chat.id, + output: answerSet, + truncated: false, + } as unknown as WsFrame); + inference.enqueue(sessionId, chat.id, result.assistant_message_id, 'default'); + + reply.code(202); + return result; + }, + ); + // POST /api/sessions/:sessionId/stop — cancel active inference app.post<{ Params: { sessionId: string } }>( '/api/sessions/:sessionId/stop', diff --git a/apps/web/src/components/AskUserInputCard.tsx b/apps/web/src/components/AskUserInputCard.tsx index e0e7e8b..0f89cea 100644 --- a/apps/web/src/components/AskUserInputCard.tsx +++ b/apps/web/src/components/AskUserInputCard.tsx @@ -1,7 +1,6 @@ import { useMemo, useState } from 'react'; import { Check } from 'lucide-react'; import { toast } from 'sonner'; -import { api } from '@/api/client'; import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group'; import { Button } from '@/components/ui/button'; import type { @@ -22,6 +21,7 @@ interface Props { toolCall: ToolCall; toolResult: ToolResult | null; chatId: string; + apiPrefix?: string; } function parseQuestions(raw: unknown): AskUserQuestion[] { @@ -63,7 +63,7 @@ function parseAnswerSet(raw: unknown): AskUserAnswerSet | null { return { answers }; } -export function AskUserInputCard({ toolCall, toolResult, chatId }: Props) { +export function AskUserInputCard({ toolCall, toolResult, chatId, apiPrefix = '' }: Props) { const questions = useMemo(() => parseQuestions(toolCall.args), [toolCall.args]); if (questions.length === 0) { @@ -74,9 +74,6 @@ export function AskUserInputCard({ toolCall, toolResult, chatId }: Props) { ); } - // Tool result with a non-null output means the answer is already submitted. - // The pending sentinel uses output=null, so this branch only triggers after - // the real WS tool_result frame lands. const answered = toolResult && toolResult.output !== null; if (answered) { const answerSet = parseAnswerSet(toolResult!.output); @@ -84,7 +81,7 @@ export function AskUserInputCard({ toolCall, toolResult, chatId }: Props) { } return ( - + ); } @@ -92,10 +89,12 @@ function PendingView({ questions, toolCallId, chatId, + apiPrefix = '', }: { questions: AskUserQuestion[]; toolCallId: string; chatId: string; + apiPrefix?: string; }) { // Per-question selections + free text. Selections are option arrays so the // multi_select case is uniform; single_select just constrains to length 1. @@ -133,9 +132,16 @@ function PendingView({ if (submitting) return; setSubmitting(true); try { - await api.chats.answerUserInput(chatId, toolCallId, answers); - // Card stays mounted; the incoming WS tool_result frame will flip it - // into AnsweredView via the parent prop change. + const url = `${apiPrefix}/api/chats/${chatId}/answer_user_input`; + const res = await fetch(url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ tool_call_id: toolCallId, answers }), + }); + if (!res.ok) { + const body = await res.json().catch(() => ({})) as { error?: string; detail?: string }; + throw new Error(body.detail ?? body.error ?? `HTTP ${res.status}`); + } } catch (err) { toast.error(err instanceof Error ? err.message : 'submit failed'); setSubmitting(false); diff --git a/apps/web/src/components/panes/CoderMessageList.tsx b/apps/web/src/components/panes/CoderMessageList.tsx index 427a685..8cf3e03 100644 --- a/apps/web/src/components/panes/CoderMessageList.tsx +++ b/apps/web/src/components/panes/CoderMessageList.tsx @@ -230,6 +230,7 @@ export function CoderMessageList({ messages, chatId, footer }: Props) { toolCall={item.run.call} toolResult={item.run.result} chatId={chatId} + apiPrefix="/api/coder" /> ); }