initial
This commit is contained in:
471
apps/server/src/services/inference.ts
Normal file
471
apps/server/src/services/inference.ts
Normal file
@@ -0,0 +1,471 @@
|
||||
import type { FastifyBaseLogger } from 'fastify';
|
||||
import type { Sql } from '../db.js';
|
||||
import type { Config } from '../config.js';
|
||||
import type { Message, Project, Session, ToolCall } from '../types/api.js';
|
||||
import { ALL_TOOLS, TOOLS_BY_NAME, toolJsonSchemas } from './tools.js';
|
||||
import { PathScopeError, resolveProjectRoot } from './path_guard.js';
|
||||
|
||||
const BASE_SYSTEM_PROMPT = (projectPath: string) =>
|
||||
`You are BooCode Chat, a code investigation assistant. The user is working on a project located at ${projectPath}. Use the file-read tools (view_file, list_dir, grep, find_files) to investigate code when needed. Be concise. Cite file paths and line numbers when discussing code. Do not hallucinate file contents — read the file first. Tool results may be truncated; if so, narrow your query rather than guessing.`;
|
||||
|
||||
const DB_FLUSH_INTERVAL_MS = 500;
|
||||
const MAX_TOOL_LOOP_DEPTH = 5;
|
||||
|
||||
export interface InferenceFrame {
|
||||
type: 'message_started' | 'delta' | 'tool_call' | 'tool_result' | 'message_complete' | 'error';
|
||||
message_id?: string;
|
||||
tool_message_id?: string;
|
||||
tool_call_id?: string;
|
||||
role?: 'assistant' | 'tool' | 'user';
|
||||
content?: string;
|
||||
tool_call?: ToolCall;
|
||||
output?: unknown;
|
||||
truncated?: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export type FramePublisher = (sessionId: string, frame: InferenceFrame) => void;
|
||||
|
||||
interface OpenAiMessage {
|
||||
role: 'system' | 'user' | 'assistant' | 'tool';
|
||||
content: string | null;
|
||||
tool_calls?: Array<{
|
||||
id: string;
|
||||
type: 'function';
|
||||
function: { name: string; arguments: string };
|
||||
}>;
|
||||
tool_call_id?: string;
|
||||
}
|
||||
|
||||
interface ChatCompletionDelta {
|
||||
role?: string;
|
||||
content?: string | null;
|
||||
tool_calls?: Array<{
|
||||
index: number;
|
||||
id?: string;
|
||||
type?: 'function';
|
||||
function?: { name?: string; arguments?: string };
|
||||
}>;
|
||||
}
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices: Array<{
|
||||
delta: ChatCompletionDelta;
|
||||
finish_reason: string | null;
|
||||
}>;
|
||||
}
|
||||
|
||||
interface InferenceContext {
|
||||
sql: Sql;
|
||||
config: Config;
|
||||
log: FastifyBaseLogger;
|
||||
publish: FramePublisher;
|
||||
}
|
||||
|
||||
export function buildMessagesPayload(
|
||||
session: Session,
|
||||
project: Project,
|
||||
history: Message[]
|
||||
): OpenAiMessage[] {
|
||||
const out: OpenAiMessage[] = [];
|
||||
let systemPrompt = BASE_SYSTEM_PROMPT(project.path);
|
||||
if (session.system_prompt && session.system_prompt.trim().length > 0) {
|
||||
systemPrompt += '\n\n' + session.system_prompt.trim();
|
||||
}
|
||||
out.push({ role: 'system', content: systemPrompt });
|
||||
|
||||
for (const m of history) {
|
||||
if (m.role === 'assistant' && m.status === 'streaming') continue;
|
||||
if (m.role === 'tool') {
|
||||
const tr = m.tool_results;
|
||||
if (!tr) continue;
|
||||
const outputText = tr.error
|
||||
? `error: ${tr.error}`
|
||||
: typeof tr.output === 'string'
|
||||
? tr.output
|
||||
: JSON.stringify(tr.output);
|
||||
out.push({
|
||||
role: 'tool',
|
||||
content: outputText,
|
||||
tool_call_id: tr.tool_call_id,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (m.role === 'assistant') {
|
||||
const msg: OpenAiMessage = {
|
||||
role: 'assistant',
|
||||
content: m.content && m.content.length > 0 ? m.content : null,
|
||||
};
|
||||
if (m.tool_calls && m.tool_calls.length > 0) {
|
||||
msg.tool_calls = m.tool_calls.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function' as const,
|
||||
function: { name: tc.name, arguments: JSON.stringify(tc.args) },
|
||||
}));
|
||||
}
|
||||
out.push(msg);
|
||||
continue;
|
||||
}
|
||||
out.push({ role: 'user', content: m.content });
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
async function loadContext(
|
||||
sql: Sql,
|
||||
sessionId: string
|
||||
): Promise<{ session: Session; project: Project; history: Message[] } | null> {
|
||||
const sessionRows = await sql<Session[]>`
|
||||
SELECT id, project_id, name, model, system_prompt, created_at, updated_at
|
||||
FROM sessions WHERE id = ${sessionId}
|
||||
`;
|
||||
if (sessionRows.length === 0) return null;
|
||||
const session = sessionRows[0]!;
|
||||
|
||||
const projectRows = await sql<Project[]>`
|
||||
SELECT id, name, path, added_at, last_session_id
|
||||
FROM projects WHERE id = ${session.project_id}
|
||||
`;
|
||||
if (projectRows.length === 0) return null;
|
||||
const project = projectRows[0]!;
|
||||
|
||||
const history = await sql<Message[]>`
|
||||
SELECT id, session_id, role, content, tool_calls, tool_results, status, last_seq, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ${sessionId}
|
||||
ORDER BY created_at ASC, id ASC
|
||||
`;
|
||||
|
||||
return { session, project, history };
|
||||
}
|
||||
|
||||
async function* sseLines(stream: ReadableStream<Uint8Array>): AsyncGenerator<string> {
|
||||
const reader = stream.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
let buffer = '';
|
||||
try {
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) break;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
let idx;
|
||||
while ((idx = buffer.indexOf('\n')) >= 0) {
|
||||
const line = buffer.slice(0, idx).replace(/\r$/, '');
|
||||
buffer = buffer.slice(idx + 1);
|
||||
if (line.length === 0) continue;
|
||||
yield line;
|
||||
}
|
||||
}
|
||||
if (buffer.length > 0) yield buffer;
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
|
||||
async function streamCompletion(
|
||||
ctx: InferenceContext,
|
||||
model: string,
|
||||
messages: OpenAiMessage[],
|
||||
includeTools: boolean,
|
||||
onDelta: (content: string) => void
|
||||
): Promise<{ finishReason: string | null; content: string; toolCalls: ToolCall[] }> {
|
||||
const body: Record<string, unknown> = { model, messages, stream: true };
|
||||
if (includeTools) {
|
||||
body['tools'] = toolJsonSchemas();
|
||||
body['tool_choice'] = 'auto';
|
||||
}
|
||||
|
||||
const res = await fetch(`${ctx.config.LLAMA_SWAP_URL}/v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
if (!res.ok || !res.body) {
|
||||
const text = await res.text().catch(() => '');
|
||||
throw new Error(`llama-swap returned ${res.status}: ${text.slice(0, 200)}`);
|
||||
}
|
||||
|
||||
let content = '';
|
||||
let finishReason: string | null = null;
|
||||
const toolCallsBuffer = new Map<number, { id: string; name: string; argsText: string }>();
|
||||
|
||||
for await (const line of sseLines(res.body)) {
|
||||
if (!line.startsWith('data:')) continue;
|
||||
const payload = line.slice(5).trim();
|
||||
if (payload === '[DONE]') break;
|
||||
let parsed: ChatCompletionChunk;
|
||||
try {
|
||||
parsed = JSON.parse(payload);
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
const choice = parsed.choices?.[0];
|
||||
if (!choice) continue;
|
||||
const delta = choice.delta ?? {};
|
||||
if (typeof delta.content === 'string' && delta.content.length > 0) {
|
||||
content += delta.content;
|
||||
onDelta(delta.content);
|
||||
}
|
||||
if (Array.isArray(delta.tool_calls)) {
|
||||
for (const tc of delta.tool_calls) {
|
||||
const idx = tc.index;
|
||||
const existing = toolCallsBuffer.get(idx) ?? { id: '', name: '', argsText: '' };
|
||||
if (tc.id) existing.id = tc.id;
|
||||
if (tc.function?.name) existing.name = tc.function.name;
|
||||
if (typeof tc.function?.arguments === 'string') existing.argsText += tc.function.arguments;
|
||||
toolCallsBuffer.set(idx, existing);
|
||||
}
|
||||
}
|
||||
if (choice.finish_reason) finishReason = choice.finish_reason;
|
||||
}
|
||||
|
||||
const toolCalls: ToolCall[] = [];
|
||||
for (const [, t] of [...toolCallsBuffer.entries()].sort(([a], [b]) => a - b)) {
|
||||
let args: Record<string, unknown> = {};
|
||||
if (t.argsText.length > 0) {
|
||||
try {
|
||||
args = JSON.parse(t.argsText);
|
||||
} catch {
|
||||
args = { _raw: t.argsText };
|
||||
}
|
||||
}
|
||||
toolCalls.push({ id: t.id || `call_${toolCalls.length}`, name: t.name, args });
|
||||
}
|
||||
|
||||
return { finishReason, content, toolCalls };
|
||||
}
|
||||
|
||||
async function executeToolCall(
|
||||
projectRoot: string,
|
||||
toolCall: ToolCall
|
||||
): Promise<{ output: unknown; truncated: boolean; error?: string }> {
|
||||
const tool = TOOLS_BY_NAME[toolCall.name];
|
||||
if (!tool) {
|
||||
return { output: null, truncated: false, error: `unknown tool: ${toolCall.name}` };
|
||||
}
|
||||
const parsed = tool.inputSchema.safeParse(toolCall.args);
|
||||
if (!parsed.success) {
|
||||
return {
|
||||
output: null,
|
||||
truncated: false,
|
||||
error: `invalid input: ${JSON.stringify(parsed.error.flatten())}`,
|
||||
};
|
||||
}
|
||||
try {
|
||||
const output = await tool.execute(parsed.data, projectRoot);
|
||||
const truncated =
|
||||
typeof output === 'object' && output !== null && 'truncated' in output
|
||||
? Boolean((output as { truncated: unknown }).truncated)
|
||||
: false;
|
||||
return { output, truncated };
|
||||
} catch (err) {
|
||||
if (err instanceof PathScopeError) {
|
||||
return { output: null, truncated: false, error: err.message };
|
||||
}
|
||||
return {
|
||||
output: null,
|
||||
truncated: false,
|
||||
error: err instanceof Error ? err.message : String(err),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async function runAssistantTurn(
|
||||
ctx: InferenceContext,
|
||||
sessionId: string,
|
||||
assistantMessageId: string,
|
||||
depth: number
|
||||
): Promise<void> {
|
||||
if (depth > MAX_TOOL_LOOP_DEPTH) {
|
||||
await ctx.sql`
|
||||
UPDATE messages
|
||||
SET status = 'failed', content = ${'tool loop depth exceeded'}
|
||||
WHERE id = ${assistantMessageId}
|
||||
`;
|
||||
ctx.publish(sessionId, {
|
||||
type: 'error',
|
||||
message_id: assistantMessageId,
|
||||
error: 'tool loop depth exceeded',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const loaded = await loadContext(ctx.sql, sessionId);
|
||||
if (!loaded) {
|
||||
ctx.log.warn({ sessionId }, 'inference: session or project missing');
|
||||
return;
|
||||
}
|
||||
const { session, project, history } = loaded;
|
||||
const projectRoot = await resolveProjectRoot(project.path);
|
||||
const messages = buildMessagesPayload(session, project, history);
|
||||
|
||||
ctx.publish(sessionId, {
|
||||
type: 'message_started',
|
||||
message_id: assistantMessageId,
|
||||
role: 'assistant',
|
||||
});
|
||||
|
||||
let accumulated = '';
|
||||
let pendingFlushTimer: NodeJS.Timeout | null = null;
|
||||
let flushPromise: Promise<unknown> = Promise.resolve();
|
||||
|
||||
const flushNow = () => {
|
||||
if (pendingFlushTimer) {
|
||||
clearTimeout(pendingFlushTimer);
|
||||
pendingFlushTimer = null;
|
||||
}
|
||||
const snapshot = accumulated;
|
||||
flushPromise = flushPromise.then(() =>
|
||||
ctx.sql`UPDATE messages SET content = ${snapshot} WHERE id = ${assistantMessageId}`
|
||||
);
|
||||
};
|
||||
|
||||
const scheduleFlush = () => {
|
||||
if (pendingFlushTimer) return;
|
||||
pendingFlushTimer = setTimeout(() => {
|
||||
pendingFlushTimer = null;
|
||||
flushNow();
|
||||
}, DB_FLUSH_INTERVAL_MS);
|
||||
};
|
||||
|
||||
let content = '';
|
||||
let finishReason: string | null = null;
|
||||
let toolCalls: ToolCall[] = [];
|
||||
|
||||
try {
|
||||
const result = await streamCompletion(
|
||||
ctx,
|
||||
session.model,
|
||||
messages,
|
||||
true,
|
||||
(delta) => {
|
||||
accumulated += delta;
|
||||
ctx.publish(sessionId, {
|
||||
type: 'delta',
|
||||
message_id: assistantMessageId,
|
||||
content: delta,
|
||||
});
|
||||
ctx.log.debug({ sessionId, delta }, 'inference delta');
|
||||
scheduleFlush();
|
||||
}
|
||||
);
|
||||
content = result.content;
|
||||
finishReason = result.finishReason;
|
||||
toolCalls = result.toolCalls;
|
||||
} catch (err) {
|
||||
if (pendingFlushTimer) {
|
||||
clearTimeout(pendingFlushTimer);
|
||||
pendingFlushTimer = null;
|
||||
}
|
||||
const errMsg = err instanceof Error ? err.message : String(err);
|
||||
await ctx.sql`
|
||||
UPDATE messages
|
||||
SET status = 'failed', content = ${accumulated}
|
||||
WHERE id = ${assistantMessageId}
|
||||
`;
|
||||
ctx.publish(sessionId, {
|
||||
type: 'error',
|
||||
message_id: assistantMessageId,
|
||||
error: errMsg,
|
||||
});
|
||||
ctx.log.error({ err, sessionId, assistantMessageId }, 'inference failed');
|
||||
return;
|
||||
}
|
||||
|
||||
if (pendingFlushTimer) {
|
||||
clearTimeout(pendingFlushTimer);
|
||||
pendingFlushTimer = null;
|
||||
}
|
||||
await flushPromise;
|
||||
|
||||
if (toolCalls.length > 0) {
|
||||
await ctx.sql`
|
||||
UPDATE messages
|
||||
SET content = ${content}, status = 'complete',
|
||||
tool_calls = ${ctx.sql.json(toolCalls as never)}
|
||||
WHERE id = ${assistantMessageId}
|
||||
`;
|
||||
for (const tc of toolCalls) {
|
||||
ctx.publish(sessionId, {
|
||||
type: 'tool_call',
|
||||
message_id: assistantMessageId,
|
||||
tool_call: tc,
|
||||
});
|
||||
}
|
||||
ctx.publish(sessionId, {
|
||||
type: 'message_complete',
|
||||
message_id: assistantMessageId,
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
toolCalls.map(async (tc) => {
|
||||
const [toolRow] = await ctx.sql<{ id: string }[]>`
|
||||
INSERT INTO messages (session_id, role, content, status, created_at)
|
||||
VALUES (${sessionId}, 'tool', '', 'complete', clock_timestamp())
|
||||
RETURNING id
|
||||
`;
|
||||
const toolMessageId = toolRow!.id;
|
||||
const result = await executeToolCall(projectRoot, tc);
|
||||
const stored = {
|
||||
tool_call_id: tc.id,
|
||||
output: result.output,
|
||||
truncated: result.truncated,
|
||||
...(result.error ? { error: result.error } : {}),
|
||||
};
|
||||
await ctx.sql`
|
||||
UPDATE messages
|
||||
SET tool_results = ${ctx.sql.json(stored as never)}
|
||||
WHERE id = ${toolMessageId}
|
||||
`;
|
||||
ctx.publish(sessionId, {
|
||||
type: 'tool_result',
|
||||
tool_message_id: toolMessageId,
|
||||
tool_call_id: tc.id,
|
||||
output: result.output,
|
||||
truncated: result.truncated,
|
||||
...(result.error ? { error: result.error } : {}),
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
const [nextAssistant] = await ctx.sql<{ id: string }[]>`
|
||||
INSERT INTO messages (session_id, role, content, status, created_at)
|
||||
VALUES (${sessionId}, 'assistant', '', 'streaming', clock_timestamp())
|
||||
RETURNING id
|
||||
`;
|
||||
await runAssistantTurn(ctx, sessionId, nextAssistant!.id, depth + 1);
|
||||
return;
|
||||
}
|
||||
|
||||
await ctx.sql`
|
||||
UPDATE messages
|
||||
SET content = ${content}, status = 'complete'
|
||||
WHERE id = ${assistantMessageId}
|
||||
`;
|
||||
ctx.publish(sessionId, {
|
||||
type: 'message_complete',
|
||||
message_id: assistantMessageId,
|
||||
});
|
||||
ctx.log.info({ sessionId, assistantMessageId, finishReason, chars: content.length }, 'inference complete');
|
||||
}
|
||||
|
||||
export async function runInference(
|
||||
ctx: InferenceContext,
|
||||
sessionId: string,
|
||||
assistantMessageId: string
|
||||
): Promise<void> {
|
||||
return runAssistantTurn(ctx, sessionId, assistantMessageId, 0);
|
||||
}
|
||||
|
||||
export function createInferenceRunner(ctx: InferenceContext) {
|
||||
return {
|
||||
enqueue(sessionId: string, assistantMessageId: string) {
|
||||
void runInference(ctx, sessionId, assistantMessageId).catch((err) => {
|
||||
ctx.log.error({ err }, 'unhandled inference error');
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Reference to keep ALL_TOOLS imported for type checks if needed
|
||||
export const _toolNames = ALL_TOOLS.map((t) => t.name);
|
||||
Reference in New Issue
Block a user