llama-sidecar v0.1.0: daemon + benchmarks + eval suite
Go daemon (cmd/llama-sidecar): per-agent llama-server process pool with LRU eviction, OpenAI-compatible proxy, flag validation (Unsloth port), deterministic hash-keyed sidecar reuse. Windows service support via schtasks/NSSM with DETACHED_PROCESS, stdout pipe drain, and request-ctx decoupled child lifetime. Bug fixes (3b.1–3b5): -c flag drop from StripShadowingFlags, UTF-8 BOM in JSON config, -fa → --flash-attn on default, child process exit after one request (stdin devnull, stdout pipe, CREATE_NO_WINDOW → DETACHED, context.Background for child lifetime, background reaper goroutine). bench/: MTP on/off throughput sweep across 8 GGUFs via SSH+schtasks automation to sam-desktop. Per-GGUF production flags from llama-swap config with --ctx-size 32768 override. eval/: accuracy benchmarks (MMLU 100q, GSM8K 50q, HumanEval 164) + A/B model comparison (14 agent-typed prompts × 8 models). All scripts resumable at individual question level. 94 Go tests, race detector clean. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
139
internal/config/config.go
Normal file
139
internal/config/config.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var utf8BOM = []byte{0xEF, 0xBB, 0xBF}
|
||||
|
||||
type Config struct {
|
||||
Bind string
|
||||
LlamaServerBin string
|
||||
ModelDirMap map[string]string
|
||||
PortRangeLo int
|
||||
PortRangeHi int
|
||||
MaxSidecars int
|
||||
LogLevel string
|
||||
BaseArgs []string
|
||||
HealthTimeoutSeconds int
|
||||
HealthIntervalSeconds int
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
bin := os.Getenv("LLAMA_SERVER_BIN")
|
||||
if bin == "" {
|
||||
return nil, fmt.Errorf("LLAMA_SERVER_BIN is required")
|
||||
}
|
||||
if _, err := os.Stat(bin); err != nil {
|
||||
return nil, fmt.Errorf("LLAMA_SERVER_BIN %q: %w", bin, err)
|
||||
}
|
||||
|
||||
mapFile := os.Getenv("MODEL_DIR_MAP_FILE")
|
||||
if mapFile == "" {
|
||||
return nil, fmt.Errorf("MODEL_DIR_MAP_FILE is required")
|
||||
}
|
||||
modelMap, err := loadModelMap(mapFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("MODEL_DIR_MAP_FILE: %w", err)
|
||||
}
|
||||
|
||||
bind := envOr("LLAMA_SIDECAR_BIND", "127.0.0.1:8402")
|
||||
logLevel := envOr("LOG_LEVEL", "info")
|
||||
maxSidecars := envIntOr("MAX_SIDECARS", 2)
|
||||
healthTimeout := envIntOr("HEALTH_TIMEOUT_SECONDS", 60)
|
||||
healthInterval := envIntOr("HEALTH_INTERVAL_SECONDS", 30)
|
||||
|
||||
lo, hi, err := parsePortRange(envOr("PORT_RANGE", "8500-8599"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("PORT_RANGE: %w", err)
|
||||
}
|
||||
if hi-lo+1 < maxSidecars {
|
||||
return nil, fmt.Errorf("PORT_RANGE %d-%d has %d ports but MAX_SIDECARS is %d", lo, hi, hi-lo+1, maxSidecars)
|
||||
}
|
||||
|
||||
baseArgs := defaultBaseArgs()
|
||||
if env := os.Getenv("BASE_ARGS"); env != "" {
|
||||
var parsed []string
|
||||
envBytes := bytes.TrimPrefix([]byte(env), utf8BOM)
|
||||
if err := json.Unmarshal(envBytes, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("BASE_ARGS: invalid JSON array: %w", err)
|
||||
}
|
||||
baseArgs = parsed
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Bind: bind,
|
||||
LlamaServerBin: bin,
|
||||
ModelDirMap: modelMap,
|
||||
PortRangeLo: lo,
|
||||
PortRangeHi: hi,
|
||||
MaxSidecars: maxSidecars,
|
||||
LogLevel: logLevel,
|
||||
BaseArgs: baseArgs,
|
||||
HealthTimeoutSeconds: healthTimeout,
|
||||
HealthIntervalSeconds: healthInterval,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func defaultBaseArgs() []string {
|
||||
return []string{"-ngl", "999", "-c", "32768", "--flash-attn", "on", "--no-mmap"}
|
||||
}
|
||||
|
||||
func loadModelMap(path string) (map[string]string, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data = bytes.TrimPrefix(data, utf8BOM)
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %w", err)
|
||||
}
|
||||
if len(m) == 0 {
|
||||
return nil, fmt.Errorf("model map is empty")
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func parsePortRange(s string) (int, int, error) {
|
||||
parts := strings.SplitN(s, "-", 2)
|
||||
if len(parts) != 2 {
|
||||
return 0, 0, fmt.Errorf("expected lo-hi format, got %q", s)
|
||||
}
|
||||
lo, err := strconv.Atoi(strings.TrimSpace(parts[0]))
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid lo port: %w", err)
|
||||
}
|
||||
hi, err := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid hi port: %w", err)
|
||||
}
|
||||
if hi <= lo {
|
||||
return 0, 0, fmt.Errorf("hi (%d) must be > lo (%d)", hi, lo)
|
||||
}
|
||||
return lo, hi, nil
|
||||
}
|
||||
|
||||
func envOr(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func envIntOr(key string, fallback int) int {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return n
|
||||
}
|
||||
79
internal/config/config_test.go
Normal file
79
internal/config/config_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoad_MissingRequired(t *testing.T) {
|
||||
os.Unsetenv("LLAMA_SERVER_BIN")
|
||||
os.Unsetenv("MODEL_DIR_MAP_FILE")
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing LLAMA_SERVER_BIN")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortRange(t *testing.T) {
|
||||
lo, hi, err := parsePortRange("8500-8599")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if lo != 8500 || hi != 8599 {
|
||||
t.Fatalf("got %d-%d", lo, hi)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortRange_Bad(t *testing.T) {
|
||||
_, _, err := parsePortRange("abc")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
_, _, err = parsePortRange("100-50")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for hi <= lo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadModelMap_BOM(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "model_map.json")
|
||||
content := append([]byte{0xEF, 0xBB, 0xBF}, []byte(`{"test-model": "/fake/path.gguf"}`)...)
|
||||
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m, err := loadModelMap(path)
|
||||
if err != nil {
|
||||
t.Fatalf("BOM-prefixed JSON should parse: %v", err)
|
||||
}
|
||||
if m["test-model"] != "/fake/path.gguf" {
|
||||
t.Fatalf("unexpected map: %v", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultBaseArgs_FlashAttn(t *testing.T) {
|
||||
args := defaultBaseArgs()
|
||||
for i, a := range args {
|
||||
if a == "--flash-attn" && i+1 < len(args) && args[i+1] == "on" {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatal("expected --flash-attn on in default args")
|
||||
}
|
||||
|
||||
func TestDefaultBaseArgs(t *testing.T) {
|
||||
args := defaultBaseArgs()
|
||||
if len(args) == 0 {
|
||||
t.Fatal("expected non-empty default args")
|
||||
}
|
||||
found := false
|
||||
for _, a := range args {
|
||||
if a == "--no-mmap" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected --no-mmap in default args")
|
||||
}
|
||||
}
|
||||
53
internal/pool/hash.go
Normal file
53
internal/pool/hash.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/indifferentketchup/llama-sidecar/internal/validator"
|
||||
)
|
||||
|
||||
// Hash computes a deterministic hash for a (modelID, flags) pair.
|
||||
// Flag order does not affect the result.
|
||||
func Hash(modelID string, flags []string) string {
|
||||
type pair struct {
|
||||
key, val string
|
||||
}
|
||||
|
||||
var pairs []pair
|
||||
i := 0
|
||||
for i < len(flags) {
|
||||
tok := flags[i]
|
||||
key := validator.FlagName(tok)
|
||||
if key == "" {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if idx := strings.IndexByte(tok, '='); idx >= 0 {
|
||||
pairs = append(pairs, pair{key: tok[:idx], val: tok[idx+1:]})
|
||||
i++
|
||||
} else if i+1 < len(flags) && validator.FlagName(flags[i+1]) == "" {
|
||||
pairs = append(pairs, pair{key: key, val: flags[i+1]})
|
||||
i += 2
|
||||
} else {
|
||||
pairs = append(pairs, pair{key: key, val: ""})
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(pairs, func(a, b int) bool {
|
||||
return pairs[a].key < pairs[b].key
|
||||
})
|
||||
|
||||
var parts []string
|
||||
for _, p := range pairs {
|
||||
parts = append(parts, p.key+"\x1f"+p.val)
|
||||
}
|
||||
serialized := strings.Join(parts, "\x1e")
|
||||
input := modelID + "\x1d" + serialized
|
||||
|
||||
sum := sha256.Sum256([]byte(input))
|
||||
return fmt.Sprintf("%x", sum[:8])
|
||||
}
|
||||
53
internal/pool/hash_test.go
Normal file
53
internal/pool/hash_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHash_OrderIndependence(t *testing.T) {
|
||||
flags1 := []string{"--a", "1", "--b", "2", "--c", "3"}
|
||||
h1 := Hash("foo", flags1)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
shuffled := make([]string, len(flags1))
|
||||
copy(shuffled, flags1)
|
||||
// Shuffle pairs (each pair is 2 tokens)
|
||||
pairs := make([][2]string, 0)
|
||||
for j := 0; j < len(shuffled); j += 2 {
|
||||
pairs = append(pairs, [2]string{shuffled[j], shuffled[j+1]})
|
||||
}
|
||||
rand.Shuffle(len(pairs), func(a, b int) { pairs[a], pairs[b] = pairs[b], pairs[a] })
|
||||
var flat []string
|
||||
for _, p := range pairs {
|
||||
flat = append(flat, p[0], p[1])
|
||||
}
|
||||
h := Hash("foo", flat)
|
||||
if h != h1 {
|
||||
t.Errorf("iteration %d: hash %s != %s for order %v", i, h, h1, flat)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHash_SeparatorCollision(t *testing.T) {
|
||||
h1 := Hash("foo", []string{"--a\x1eb", "1"})
|
||||
h2 := Hash("foo", []string{"--ab", "1"})
|
||||
if h1 == h2 {
|
||||
t.Error("separator collision: hashes should differ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHash_Length(t *testing.T) {
|
||||
h := Hash("model", []string{"--top-k", "20"})
|
||||
if len(h) != 16 {
|
||||
t.Errorf("expected 16 hex chars, got %d: %s", len(h), h)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHash_DifferentModels(t *testing.T) {
|
||||
h1 := Hash("model-a", []string{"--top-k", "20"})
|
||||
h2 := Hash("model-b", []string{"--top-k", "20"})
|
||||
if h1 == h2 {
|
||||
t.Error("different models should produce different hashes")
|
||||
}
|
||||
}
|
||||
188
internal/pool/pool.go
Normal file
188
internal/pool/pool.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/indifferentketchup/llama-sidecar/internal/config"
|
||||
"github.com/indifferentketchup/llama-sidecar/internal/validator"
|
||||
)
|
||||
|
||||
type SidecarInfo struct {
|
||||
Hash string `json:"hash"`
|
||||
ModelID string `json:"model_id"`
|
||||
Flags []string `json:"flags"`
|
||||
Port int `json:"port"`
|
||||
Pid int `json:"pid"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
Healthy bool `json:"healthy"`
|
||||
}
|
||||
|
||||
type Pool struct {
|
||||
mu sync.Mutex
|
||||
cfg *config.Config
|
||||
sidecars map[string]*Sidecar
|
||||
lru *list.List
|
||||
lruIdx map[string]*list.Element
|
||||
ports *PortAllocator
|
||||
spawner Spawner
|
||||
}
|
||||
|
||||
func New(cfg *config.Config, spawner Spawner) *Pool {
|
||||
return &Pool{
|
||||
cfg: cfg,
|
||||
sidecars: make(map[string]*Sidecar),
|
||||
lru: list.New(),
|
||||
lruIdx: make(map[string]*list.Element),
|
||||
ports: NewPortAllocator(cfg.PortRangeLo, cfg.PortRangeHi),
|
||||
spawner: spawner,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) Acquire(ctx context.Context, modelID string, flags []string) (*Sidecar, error) {
|
||||
if _, err := validator.ValidateExtraArgs(flags); err != nil {
|
||||
return nil, fmt.Errorf("validation: %w", err)
|
||||
}
|
||||
|
||||
modelPath, ok := p.cfg.ModelDirMap[modelID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown model: %s", modelID)
|
||||
}
|
||||
|
||||
hash := Hash(modelID, flags)
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if s, ok := p.sidecars[hash]; ok {
|
||||
if s.Healthy() {
|
||||
if el, ok := p.lruIdx[hash]; ok {
|
||||
p.lru.MoveToFront(el)
|
||||
}
|
||||
s.TouchLastUsed()
|
||||
return s, nil
|
||||
}
|
||||
p.removeLocked(hash)
|
||||
}
|
||||
|
||||
if len(p.sidecars) >= p.cfg.MaxSidecars {
|
||||
if err := p.evictLRULocked(); err != nil {
|
||||
return nil, fmt.Errorf("eviction failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
port, err := p.ports.Allocate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("port allocation: %w", err)
|
||||
}
|
||||
|
||||
p.mu.Unlock()
|
||||
s, err := p.spawner.Spawn(ctx, p.cfg, modelID, modelPath, flags, port, hash)
|
||||
p.mu.Lock()
|
||||
|
||||
if err != nil {
|
||||
p.ports.Release(port)
|
||||
return nil, fmt.Errorf("spawn: %w", err)
|
||||
}
|
||||
|
||||
p.sidecars[hash] = s
|
||||
el := p.lru.PushFront(hash)
|
||||
p.lruIdx[hash] = el
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (p *Pool) List() []SidecarInfo {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
out := make([]SidecarInfo, 0, len(p.sidecars))
|
||||
for _, s := range p.sidecars {
|
||||
out = append(out, SidecarInfo{
|
||||
Hash: s.Hash,
|
||||
ModelID: s.ModelID,
|
||||
Flags: s.Flags,
|
||||
Port: s.Port,
|
||||
Pid: s.Pid,
|
||||
StartedAt: s.StartedAt,
|
||||
LastUsed: time.Unix(0, s.LastUsed.Load()),
|
||||
Healthy: s.Healthy(),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *Pool) Remove(hash string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if _, ok := p.sidecars[hash]; !ok {
|
||||
return fmt.Errorf("sidecar %s not found", hash)
|
||||
}
|
||||
return p.removeLocked(hash)
|
||||
}
|
||||
|
||||
func (p *Pool) Shutdown(ctx context.Context) error {
|
||||
p.mu.Lock()
|
||||
hashes := make([]string, 0, len(p.sidecars))
|
||||
for h := range p.sidecars {
|
||||
hashes = append(hashes, h)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, h := range hashes {
|
||||
wg.Add(1)
|
||||
go func(hash string) {
|
||||
defer wg.Done()
|
||||
p.mu.Lock()
|
||||
s, ok := p.sidecars[hash]
|
||||
p.mu.Unlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := p.spawner.Kill(s); err != nil {
|
||||
slog.Error("shutdown kill failed", "hash", hash, "err", err)
|
||||
}
|
||||
}(h)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { wg.Wait(); close(done) }()
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
slog.Info("pool shutdown complete", "count", len(hashes))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pool) removeLocked(hash string) error {
|
||||
s, ok := p.sidecars[hash]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
delete(p.sidecars, hash)
|
||||
if el, ok := p.lruIdx[hash]; ok {
|
||||
p.lru.Remove(el)
|
||||
delete(p.lruIdx, hash)
|
||||
}
|
||||
if err := p.spawner.Kill(s); err != nil {
|
||||
slog.Error("kill failed during remove", "hash", hash, "err", err)
|
||||
}
|
||||
p.ports.Release(s.Port)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pool) evictLRULocked() error {
|
||||
back := p.lru.Back()
|
||||
if back == nil {
|
||||
return fmt.Errorf("pool full but LRU empty")
|
||||
}
|
||||
hash := back.Value.(string)
|
||||
slog.Info("evicting LRU sidecar", "hash", hash)
|
||||
return p.removeLocked(hash)
|
||||
}
|
||||
151
internal/pool/pool_test.go
Normal file
151
internal/pool/pool_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/indifferentketchup/llama-sidecar/internal/config"
|
||||
)
|
||||
|
||||
type fakeSpawner struct {
|
||||
spawnCount atomic.Int32
|
||||
killCount atomic.Int32
|
||||
}
|
||||
|
||||
func (f *fakeSpawner) Spawn(ctx context.Context, cfg *config.Config, modelID, modelPath string, flags []string, port int, hash string) (*Sidecar, error) {
|
||||
f.spawnCount.Add(1)
|
||||
s := &Sidecar{
|
||||
Hash: hash,
|
||||
ModelID: modelID,
|
||||
ModelPath: modelPath,
|
||||
Flags: flags,
|
||||
Port: port,
|
||||
Pid: 99999,
|
||||
StartedAt: time.Now(),
|
||||
stderr: newRingBuffer(8),
|
||||
cancel: func() {},
|
||||
}
|
||||
s.healthy.Store(true)
|
||||
s.LastUsed.Store(time.Now().UnixNano())
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (f *fakeSpawner) Kill(s *Sidecar) error {
|
||||
f.killCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func testConfig() *config.Config {
|
||||
return &config.Config{
|
||||
Bind: "127.0.0.1:0",
|
||||
LlamaServerBin: "/fake/llama-server",
|
||||
ModelDirMap: map[string]string{
|
||||
"model-a": "/fake/model-a.gguf",
|
||||
"model-b": "/fake/model-b.gguf",
|
||||
},
|
||||
PortRangeLo: 8500,
|
||||
PortRangeHi: 8509,
|
||||
MaxSidecars: 2,
|
||||
BaseArgs: []string{"-ngl", "999"},
|
||||
HealthTimeoutSeconds: 60,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_AcquireSameKey(t *testing.T) {
|
||||
fs := &fakeSpawner{}
|
||||
p := New(testConfig(), fs)
|
||||
ctx := context.Background()
|
||||
|
||||
s1, err := p.Acquire(ctx, "model-a", []string{"--top-k", "20"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s2, err := p.Acquire(ctx, "model-a", []string{"--top-k", "20"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if s1.Hash != s2.Hash {
|
||||
t.Fatalf("expected same sidecar, got different hashes: %s vs %s", s1.Hash, s2.Hash)
|
||||
}
|
||||
if fs.spawnCount.Load() != 1 {
|
||||
t.Fatalf("expected 1 spawn, got %d", fs.spawnCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_EvictLRU(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
cfg.MaxSidecars = 1
|
||||
fs := &fakeSpawner{}
|
||||
p := New(cfg, fs)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := p.Acquire(ctx, "model-a", []string{"--top-k", "20"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = p.Acquire(ctx, "model-b", []string{"--top-k", "40"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if fs.spawnCount.Load() != 2 {
|
||||
t.Fatalf("expected 2 spawns, got %d", fs.spawnCount.Load())
|
||||
}
|
||||
if fs.killCount.Load() != 1 {
|
||||
t.Fatalf("expected 1 kill (eviction), got %d", fs.killCount.Load())
|
||||
}
|
||||
list := p.List()
|
||||
if len(list) != 1 {
|
||||
t.Fatalf("expected 1 sidecar, got %d", len(list))
|
||||
}
|
||||
if list[0].ModelID != "model-b" {
|
||||
t.Fatalf("expected model-b, got %s", list[0].ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_ValidatorReject(t *testing.T) {
|
||||
fs := &fakeSpawner{}
|
||||
p := New(testConfig(), fs)
|
||||
_, err := p.Acquire(context.Background(), "model-a", []string{"--model", "evil.gguf"})
|
||||
if err == nil {
|
||||
t.Fatal("expected validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_UnknownModel(t *testing.T) {
|
||||
fs := &fakeSpawner{}
|
||||
p := New(testConfig(), fs)
|
||||
_, err := p.Acquire(context.Background(), "nonexistent", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected unknown model error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPool_ConcurrentAcquire(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
cfg.MaxSidecars = 10
|
||||
cfg.PortRangeHi = 8599
|
||||
fs := &fakeSpawner{}
|
||||
p := New(cfg, fs)
|
||||
ctx := context.Background()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 50; j++ {
|
||||
_, _ = p.Acquire(ctx, "model-a", []string{"--top-k", "20"})
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
list := p.List()
|
||||
if len(list) != 1 {
|
||||
t.Fatalf("expected 1 sidecar (same key), got %d", len(list))
|
||||
}
|
||||
}
|
||||
28
internal/pool/ports.go
Normal file
28
internal/pool/ports.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package pool
|
||||
|
||||
import "fmt"
|
||||
|
||||
type PortAllocator struct {
|
||||
ports chan int
|
||||
}
|
||||
|
||||
func NewPortAllocator(lo, hi int) *PortAllocator {
|
||||
ch := make(chan int, hi-lo+1)
|
||||
for p := lo; p <= hi; p++ {
|
||||
ch <- p
|
||||
}
|
||||
return &PortAllocator{ports: ch}
|
||||
}
|
||||
|
||||
func (pa *PortAllocator) Allocate() (int, error) {
|
||||
select {
|
||||
case p := <-pa.ports:
|
||||
return p, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("port allocator exhausted")
|
||||
}
|
||||
}
|
||||
|
||||
func (pa *PortAllocator) Release(port int) {
|
||||
pa.ports <- port
|
||||
}
|
||||
74
internal/pool/ports_test.go
Normal file
74
internal/pool/ports_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPortAllocator_AllocateRelease(t *testing.T) {
|
||||
pa := NewPortAllocator(8500, 8502)
|
||||
p1, err := pa.Allocate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p2, err := pa.Allocate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p3, err := pa.Allocate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// All three ports should be distinct
|
||||
if p1 == p2 || p2 == p3 || p1 == p3 {
|
||||
t.Fatalf("expected distinct ports: %d, %d, %d", p1, p2, p3)
|
||||
}
|
||||
|
||||
// Exhausted
|
||||
_, err = pa.Allocate()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when exhausted")
|
||||
}
|
||||
|
||||
// Release and re-allocate
|
||||
pa.Release(p2)
|
||||
p4, err := pa.Allocate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if p4 != p2 {
|
||||
t.Fatalf("expected released port %d, got %d", p2, p4)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortAllocator_Concurrent(t *testing.T) {
|
||||
pa := NewPortAllocator(8500, 8599)
|
||||
var wg sync.WaitGroup
|
||||
allocated := make(chan int, 100)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
p, err := pa.Allocate()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
allocated <- p
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
close(allocated)
|
||||
|
||||
seen := make(map[int]bool)
|
||||
for p := range allocated {
|
||||
if seen[p] {
|
||||
t.Fatalf("duplicate port %d", p)
|
||||
}
|
||||
seen[p] = true
|
||||
}
|
||||
if len(seen) != 100 {
|
||||
t.Fatalf("expected 100 ports, got %d", len(seen))
|
||||
}
|
||||
}
|
||||
313
internal/pool/sidecar.go
Normal file
313
internal/pool/sidecar.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/indifferentketchup/llama-sidecar/internal/config"
|
||||
"github.com/indifferentketchup/llama-sidecar/internal/validator"
|
||||
)
|
||||
|
||||
type Sidecar struct {
|
||||
Hash string
|
||||
ModelID string
|
||||
ModelPath string
|
||||
Flags []string
|
||||
Port int
|
||||
Pid int
|
||||
StartedAt time.Time
|
||||
LastUsed atomic.Int64
|
||||
healthy atomic.Bool
|
||||
cmd *exec.Cmd
|
||||
cancel context.CancelFunc
|
||||
done chan error
|
||||
stderr *ringBuffer
|
||||
stopMon context.CancelFunc
|
||||
stdinFile *os.File
|
||||
stdoutR *os.File
|
||||
stdoutFile *os.File
|
||||
}
|
||||
|
||||
func (s *Sidecar) Healthy() bool {
|
||||
return s.healthy.Load()
|
||||
}
|
||||
|
||||
func (s *Sidecar) TouchLastUsed() {
|
||||
s.LastUsed.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (s *Sidecar) LastStderr() string {
|
||||
return s.stderr.String()
|
||||
}
|
||||
|
||||
// Spawner abstracts sidecar creation for testing.
|
||||
type Spawner interface {
|
||||
Spawn(ctx context.Context, cfg *config.Config, modelID, modelPath string, flags []string, port int, hash string) (*Sidecar, error)
|
||||
Kill(s *Sidecar) error
|
||||
}
|
||||
|
||||
type RealSpawner struct{}
|
||||
|
||||
func (rs *RealSpawner) Spawn(ctx context.Context, cfg *config.Config, modelID, modelPath string, flags []string, port int, hash string) (*Sidecar, error) {
|
||||
args := buildArgs(cfg.BaseArgs, modelPath, port, flags)
|
||||
_ = ctx
|
||||
childCtx, cancel := context.WithCancel(context.Background())
|
||||
cmd := exec.CommandContext(childCtx, cfg.LlamaServerBin, args...)
|
||||
setPlatformAttrs(cmd)
|
||||
|
||||
devNull, err := os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("open devnull: %w", err)
|
||||
}
|
||||
cmd.Stdin = devNull
|
||||
|
||||
stderr := newRingBuffer(64)
|
||||
prefix := fmt.Sprintf("[sidecar:%s:%d] ", hash[:8], port)
|
||||
cmd.Stderr = io.MultiWriter(stderr, &prefixWriter{prefix: prefix})
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
devNull.Close()
|
||||
return nil, fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
go io.Copy(io.Discard, stdoutR)
|
||||
cmd.Stdout = stdoutW
|
||||
|
||||
slog.Info("spawning sidecar", "hash", hash, "model", modelID, "port", port, "args", strings.Join(args, " "))
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("spawn failed: %w", err)
|
||||
}
|
||||
|
||||
s := &Sidecar{
|
||||
Hash: hash,
|
||||
ModelID: modelID,
|
||||
ModelPath: modelPath,
|
||||
Flags: flags,
|
||||
Port: port,
|
||||
Pid: cmd.Process.Pid,
|
||||
StartedAt: time.Now(),
|
||||
cmd: cmd,
|
||||
cancel: cancel,
|
||||
done: make(chan error, 1),
|
||||
stderr: stderr,
|
||||
stdinFile: devNull,
|
||||
stdoutR: stdoutR,
|
||||
stdoutFile: stdoutW,
|
||||
}
|
||||
s.LastUsed.Store(time.Now().UnixNano())
|
||||
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
s.healthy.Store(false)
|
||||
exitCode := -1
|
||||
if cmd.ProcessState != nil {
|
||||
exitCode = cmd.ProcessState.ExitCode()
|
||||
}
|
||||
slog.Error("sidecar child exited",
|
||||
"hash", hash,
|
||||
"port", port,
|
||||
"pid", s.Pid,
|
||||
"exit_code", exitCode,
|
||||
"wait_err", fmt.Sprintf("%v", err),
|
||||
"uptime", time.Since(s.StartedAt).Round(time.Millisecond),
|
||||
"stderr_tail", stderr.String(),
|
||||
)
|
||||
s.done <- err
|
||||
close(s.done)
|
||||
}()
|
||||
|
||||
// Wait for health
|
||||
healthURL := fmt.Sprintf("http://127.0.0.1:%d/health", port)
|
||||
deadline := time.Now().Add(time.Duration(cfg.HealthTimeoutSeconds) * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
resp, err := http.Get(healthURL)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode == 200 {
|
||||
s.healthy.Store(true)
|
||||
slog.Info("sidecar healthy", "hash", hash, "port", port, "elapsed", time.Since(s.StartedAt).Round(time.Millisecond))
|
||||
monCtx, monCancel := context.WithCancel(ctx)
|
||||
s.stopMon = monCancel
|
||||
go s.healthMonitor(monCtx, cfg.HealthIntervalSeconds)
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-childCtx.Done():
|
||||
return nil, fmt.Errorf("sidecar process exited during health check")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
_ = rs.Kill(s)
|
||||
return nil, fmt.Errorf("health check timed out after %ds, last stderr: %s", cfg.HealthTimeoutSeconds, s.stderr.LastLine())
|
||||
}
|
||||
|
||||
func (rs *RealSpawner) Kill(s *Sidecar) error {
|
||||
if s.stopMon != nil {
|
||||
s.stopMon()
|
||||
}
|
||||
s.cancel()
|
||||
select {
|
||||
case <-s.done:
|
||||
case <-time.After(5 * time.Second):
|
||||
if s.cmd.Process != nil {
|
||||
_ = s.cmd.Process.Kill()
|
||||
}
|
||||
<-s.done
|
||||
}
|
||||
if s.stdinFile != nil {
|
||||
s.stdinFile.Close()
|
||||
}
|
||||
if s.stdoutFile != nil {
|
||||
s.stdoutFile.Close()
|
||||
}
|
||||
if s.stdoutR != nil {
|
||||
s.stdoutR.Close()
|
||||
}
|
||||
slog.Info("sidecar killed", "hash", s.Hash, "port", s.Port)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Sidecar) healthMonitor(ctx context.Context, intervalSec int) {
|
||||
ticker := time.NewTicker(time.Duration(intervalSec) * time.Second)
|
||||
defer ticker.Stop()
|
||||
failures := 0
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/health", s.Port)
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
resp, err := client.Get(url)
|
||||
if err != nil || resp.StatusCode != 200 {
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
failures++
|
||||
if failures >= 3 {
|
||||
slog.Warn("sidecar unhealthy, marking for eviction", "hash", s.Hash, "port", s.Port)
|
||||
s.healthy.Store(false)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
resp.Body.Close()
|
||||
failures = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildArgs(baseArgs []string, modelPath string, port int, userFlags []string) []string {
|
||||
deduped := dedupFlags(baseArgs, userFlags)
|
||||
args := make([]string, 0, len(deduped)+len(userFlags)+4)
|
||||
args = append(args, deduped...)
|
||||
args = append(args, "--model", modelPath)
|
||||
args = append(args, "--port", strconv.Itoa(port))
|
||||
args = append(args, userFlags...)
|
||||
return args
|
||||
}
|
||||
|
||||
// dedupFlags removes from autoArgs any flag that the user also supplied,
|
||||
// so the user's value wins via llama.cpp's last-wins CLI parsing.
|
||||
func dedupFlags(autoArgs, userArgs []string) []string {
|
||||
userNames := make(map[string]bool)
|
||||
for _, tok := range userArgs {
|
||||
if name := validator.FlagName(tok); name != "" {
|
||||
userNames[name] = true
|
||||
}
|
||||
}
|
||||
out := make([]string, 0, len(autoArgs))
|
||||
i := 0
|
||||
for i < len(autoArgs) {
|
||||
tok := autoArgs[i]
|
||||
name := validator.FlagName(tok)
|
||||
if name == "" || !userNames[name] {
|
||||
out = append(out, tok)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if strings.Contains(tok, "=") {
|
||||
i++
|
||||
} else if i+1 < len(autoArgs) && validator.FlagName(autoArgs[i+1]) == "" {
|
||||
i += 2
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Ring buffer for last N lines of stderr
|
||||
type ringBuffer struct {
|
||||
mu sync.Mutex
|
||||
lines []string
|
||||
max int
|
||||
}
|
||||
|
||||
func newRingBuffer(max int) *ringBuffer {
|
||||
return &ringBuffer{lines: make([]string, 0, max), max: max}
|
||||
}
|
||||
|
||||
func (rb *ringBuffer) Write(p []byte) (int, error) {
|
||||
rb.mu.Lock()
|
||||
defer rb.mu.Unlock()
|
||||
for _, line := range strings.Split(string(p), "\n") {
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
if len(rb.lines) >= rb.max {
|
||||
rb.lines = rb.lines[1:]
|
||||
}
|
||||
rb.lines = append(rb.lines, line)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (rb *ringBuffer) String() string {
|
||||
rb.mu.Lock()
|
||||
defer rb.mu.Unlock()
|
||||
return strings.Join(rb.lines, "\n")
|
||||
}
|
||||
|
||||
func (rb *ringBuffer) LastLine() string {
|
||||
rb.mu.Lock()
|
||||
defer rb.mu.Unlock()
|
||||
if len(rb.lines) == 0 {
|
||||
return ""
|
||||
}
|
||||
return rb.lines[len(rb.lines)-1]
|
||||
}
|
||||
|
||||
type prefixWriter struct {
|
||||
prefix string
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func (pw *prefixWriter) Write(p []byte) (int, error) {
|
||||
pw.buf.Write(p)
|
||||
for {
|
||||
line, err := pw.buf.ReadString('\n')
|
||||
if err != nil {
|
||||
pw.buf.WriteString(line)
|
||||
break
|
||||
}
|
||||
fmt.Fprint(os.Stderr, pw.prefix+line)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
96
internal/pool/sidecar_test.go
Normal file
96
internal/pool/sidecar_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildArgs_PreservesNonOverlapping(t *testing.T) {
|
||||
base := []string{"-ngl", "999", "-c", "32768", "--flash-attn", "on", "--no-mmap"}
|
||||
user := []string{"--top-k", "20"}
|
||||
got := buildArgs(base, "/model.gguf", 8500, user)
|
||||
|
||||
// -c 32768 must survive (user didn't supply -c)
|
||||
if !containsSeq(got, "-c", "32768") {
|
||||
t.Errorf("-c 32768 missing from args: %v", got)
|
||||
}
|
||||
// --top-k 20 must be present (user flag)
|
||||
if !containsSeq(got, "--top-k", "20") {
|
||||
t.Errorf("--top-k 20 missing from args: %v", got)
|
||||
}
|
||||
// --model and --port injected
|
||||
if !containsSeq(got, "--model", "/model.gguf") {
|
||||
t.Errorf("--model missing: %v", got)
|
||||
}
|
||||
if !containsSeq(got, "--port", "8500") {
|
||||
t.Errorf("--port missing: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildArgs_UserOverridesBase(t *testing.T) {
|
||||
base := []string{"-ngl", "999", "-c", "32768"}
|
||||
user := []string{"-c", "131072"}
|
||||
got := buildArgs(base, "/model.gguf", 8500, user)
|
||||
|
||||
// base -c should be dropped, user -c should be present
|
||||
count := 0
|
||||
for i, tok := range got {
|
||||
if tok == "-c" && i+1 < len(got) {
|
||||
count++
|
||||
if got[i+1] == "32768" {
|
||||
t.Errorf("base -c 32768 should have been deduped: %v", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected exactly 1 -c flag, got %d in %v", count, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildArgs_NoUserFlags(t *testing.T) {
|
||||
base := []string{"-ngl", "999", "-c", "32768", "--no-mmap"}
|
||||
got := buildArgs(base, "/model.gguf", 8500, nil)
|
||||
|
||||
if !containsSeq(got, "-c", "32768") {
|
||||
t.Errorf("-c 32768 missing when no user flags: %v", got)
|
||||
}
|
||||
if !containsSeq(got, "--no-mmap") {
|
||||
t.Errorf("--no-mmap missing: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDedupFlags_Mixed(t *testing.T) {
|
||||
auto := []string{"--top-k", "40", "-c", "32768", "--no-mmap"}
|
||||
user := []string{"--top-k", "20"}
|
||||
got := dedupFlags(auto, user)
|
||||
want := []string{"-c", "32768", "--no-mmap"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("dedupFlags = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDedupFlags_EqualsForm(t *testing.T) {
|
||||
auto := []string{"--ctx-size=4096", "--no-mmap"}
|
||||
user := []string{"--ctx-size", "8192"}
|
||||
got := dedupFlags(auto, user)
|
||||
want := []string{"--no-mmap"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("dedupFlags = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func containsSeq(args []string, seq ...string) bool {
|
||||
for i := 0; i <= len(args)-len(seq); i++ {
|
||||
match := true
|
||||
for j, s := range seq {
|
||||
if args[i+j] != s {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
7
internal/pool/sidecar_unix.go
Normal file
7
internal/pool/sidecar_unix.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package pool
|
||||
|
||||
import "os/exec"
|
||||
|
||||
func setPlatformAttrs(_ *exec.Cmd) {}
|
||||
15
internal/pool/sidecar_windows.go
Normal file
15
internal/pool/sidecar_windows.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build windows
|
||||
|
||||
package pool
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setPlatformAttrs(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x00000008 | 0x00000200, // DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP
|
||||
}
|
||||
}
|
||||
42
internal/server/admin.go
Normal file
42
internal/server/admin.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/indifferentketchup/llama-sidecar/internal/config"
|
||||
"github.com/indifferentketchup/llama-sidecar/internal/pool"
|
||||
)
|
||||
|
||||
func healthHandler(p *pool.Pool, cfg *config.Config, startedAt time.Time) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
sidecars := p.List()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"status": "ok",
|
||||
"sidecars": len(sidecars),
|
||||
"max": cfg.MaxSidecars,
|
||||
"uptime_seconds": int(time.Since(startedAt).Seconds()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func listSidecarsHandler(p *pool.Pool) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, p.List())
|
||||
}
|
||||
}
|
||||
|
||||
func deleteSidecarHandler(p *pool.Pool) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
hash := r.PathValue("hash")
|
||||
if hash == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "hash required"})
|
||||
return
|
||||
}
|
||||
if err := p.Remove(hash); err != nil {
|
||||
writeJSON(w, http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "removed"})
|
||||
}
|
||||
}
|
||||
111
internal/server/proxy.go
Normal file
111
internal/server/proxy.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/indifferentketchup/llama-sidecar/internal/pool"
|
||||
)
|
||||
|
||||
var shellUnsafe = strings.NewReplacer(
|
||||
"`", "", "$", "", "|", "", ";", "", "&", "", "\n", "",
|
||||
)
|
||||
|
||||
func parseFlags(raw string) ([]string, error) {
|
||||
cleaned := shellUnsafe.Replace(raw)
|
||||
if cleaned != raw {
|
||||
return nil, fmt.Errorf("flags contain unsafe characters")
|
||||
}
|
||||
return splitArgs(strings.TrimSpace(raw)), nil
|
||||
}
|
||||
|
||||
func splitArgs(s string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Fields(s)
|
||||
}
|
||||
|
||||
func proxyHandler(p *pool.Pool) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
flagsRaw := r.Header.Get("X-Agent-Flags")
|
||||
var flags []string
|
||||
if flagsRaw != "" {
|
||||
var err error
|
||||
flags, err = parseFlags(flagsRaw)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
modelID := r.Header.Get("X-Model-Id")
|
||||
if modelID == "" {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
|
||||
modelID = req.Model
|
||||
}
|
||||
r.Body = io.NopCloser(strings.NewReader(string(body)))
|
||||
r.ContentLength = int64(len(body))
|
||||
}
|
||||
if modelID == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "model not specified (X-Model-Id header or body.model)"})
|
||||
return
|
||||
}
|
||||
|
||||
sidecar, err := p.Acquire(r.Context(), modelID, flags)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(errMsg, "validation:") {
|
||||
status = http.StatusBadRequest
|
||||
} else if strings.Contains(errMsg, "unknown model:") {
|
||||
status = http.StatusNotFound
|
||||
} else if strings.Contains(errMsg, "port allocation:") {
|
||||
status = http.StatusServiceUnavailable
|
||||
}
|
||||
writeJSON(w, status, map[string]string{"error": errMsg})
|
||||
return
|
||||
}
|
||||
|
||||
target := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: fmt.Sprintf("127.0.0.1:%d", sidecar.Port),
|
||||
}
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
slog.Error("upstream error", "hash", sidecar.Hash, "port", sidecar.Port, "err", err)
|
||||
writeJSON(rw, http.StatusBadGateway, map[string]any{
|
||||
"error": "upstream unavailable",
|
||||
"error_detail": err.Error(),
|
||||
"sidecar_hash": sidecar.Hash,
|
||||
"sidecar_port": sidecar.Port,
|
||||
"last_stderr": sidecar.LastStderr(),
|
||||
})
|
||||
}
|
||||
|
||||
sidecar.TouchLastUsed()
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
56
internal/server/server.go
Normal file
56
internal/server/server.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/indifferentketchup/llama-sidecar/internal/config"
|
||||
"github.com/indifferentketchup/llama-sidecar/internal/pool"
|
||||
)
|
||||
|
||||
func New(cfg *config.Config, p *pool.Pool, startedAt time.Time) *http.Server {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /health", healthHandler(p, cfg, startedAt))
|
||||
mux.HandleFunc("GET /sidecars", listSidecarsHandler(p))
|
||||
mux.HandleFunc("DELETE /sidecars/{hash}", deleteSidecarHandler(p))
|
||||
mux.HandleFunc("POST /v1/chat/completions", proxyHandler(p))
|
||||
mux.HandleFunc("POST /v1/completions", proxyHandler(p))
|
||||
|
||||
handler := requestLogger(mux)
|
||||
|
||||
return &http.Server{
|
||||
Addr: cfg.Bind,
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
func requestLogger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
rw := &statusRecorder{ResponseWriter: w, status: 200}
|
||||
next.ServeHTTP(rw, r)
|
||||
slog.Info("request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", rw.status,
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (sr *statusRecorder) WriteHeader(code int) {
|
||||
sr.status = code
|
||||
sr.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (sr *statusRecorder) Flush() {
|
||||
if f, ok := sr.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
156
internal/validator/validator.go
Normal file
156
internal/validator/validator.go
Normal file
@@ -0,0 +1,156 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
|
||||
// Ported from studio/backend/core/inference/llama_server_args.py.
|
||||
// Original: https://github.com/unslothai/unsloth/blob/main/studio/backend/core/inference/llama_server_args.py
|
||||
|
||||
package validator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var denylistGroups = [][]string{
|
||||
// Model identity
|
||||
{"-m", "--model"},
|
||||
{"-mu", "--model-url"},
|
||||
{"-dr", "--docker-repo"},
|
||||
{"-hf", "-hfr", "--hf-repo"},
|
||||
{"-hff", "--hf-file"},
|
||||
{"-hfv", "-hfrv", "--hf-repo-v"},
|
||||
{"-hffv", "--hf-file-v"},
|
||||
{"-hft", "--hf-token"},
|
||||
{"-mm", "--mmproj"},
|
||||
{"-mmu", "--mmproj-url"},
|
||||
// Networking
|
||||
{"--host"},
|
||||
{"--port"},
|
||||
{"--path"},
|
||||
{"--api-prefix"},
|
||||
{"--reuse-port"},
|
||||
// Auth / TLS
|
||||
{"--api-key"},
|
||||
{"--api-key-file"},
|
||||
{"--ssl-key-file"},
|
||||
{"--ssl-cert-file"},
|
||||
// Server UI / multi-model
|
||||
{"--webui", "--no-webui"},
|
||||
{"--ui", "--no-ui"},
|
||||
{"--ui-config"},
|
||||
{"--ui-config-file"},
|
||||
{"--ui-mcp-proxy", "--no-ui-mcp-proxy"},
|
||||
{"--models-dir"},
|
||||
{"--models-preset"},
|
||||
{"--models-max"},
|
||||
{"--models-autoload", "--no-models-autoload"},
|
||||
}
|
||||
|
||||
var denylist map[string]bool
|
||||
|
||||
func init() {
|
||||
denylist = make(map[string]bool)
|
||||
for _, group := range denylistGroups {
|
||||
for _, flag := range group {
|
||||
denylist[flag] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FlagName returns the flag name for a CLI token, or "" if it isn't a flag.
|
||||
// Peels --key=value to the bare --key. Numeric values like -1 or -0.5
|
||||
// (e.g. --seed -1) are treated as values, not flags.
|
||||
func FlagName(token string) string {
|
||||
if !strings.HasPrefix(token, "-") || token == "-" || token == "--" {
|
||||
return ""
|
||||
}
|
||||
if len(token) >= 2 && (token[1] >= '0' && token[1] <= '9' || token[1] == '.') {
|
||||
return ""
|
||||
}
|
||||
if idx := strings.IndexByte(token, '='); idx >= 0 {
|
||||
return token[:idx]
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// ValidateExtraArgs validates user-supplied llama-server args. Returns the
|
||||
// args as a flat slice. Returns an error with the offending flag if any
|
||||
// token resolves to a managed flag.
|
||||
func ValidateExtraArgs(args []string) ([]string, error) {
|
||||
if len(args) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
out := make([]string, 0, len(args))
|
||||
for _, raw := range args {
|
||||
flag := FlagName(raw)
|
||||
if flag != "" && denylist[flag] {
|
||||
return nil, fmt.Errorf("llama-server flag '%s' is managed and cannot be passed as an extra arg", flag)
|
||||
}
|
||||
out = append(out, raw)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// IsManagedFlag returns true if flag is a managed llama-server flag.
|
||||
func IsManagedFlag(flag string) bool {
|
||||
return denylist[flag]
|
||||
}
|
||||
|
||||
var contextFlags = setOf("-c", "--ctx-size")
|
||||
var cacheFlags = setOf("-ctk", "--cache-type-k", "-ctv", "--cache-type-v")
|
||||
var specFlags = setOf(
|
||||
"--spec-default", "--spec-type", "--spec-ngram-size-n", "--spec-ngram-size",
|
||||
"--draft-min", "--draft-max",
|
||||
"--spec-draft-n-max", "--spec-draft-n-min", "--spec-draft-p-min", "--spec-draft-p-split",
|
||||
"--spec-ngram-mod-n-match", "--spec-ngram-mod-n-min", "--spec-ngram-mod-n-max",
|
||||
)
|
||||
var templateFlags = setOf(
|
||||
"--chat-template", "--chat-template-file", "--chat-template-kwargs",
|
||||
"--jinja", "--no-jinja",
|
||||
)
|
||||
var booleanShadowingFlags = setOf("--spec-default", "--jinja", "--no-jinja")
|
||||
|
||||
func setOf(vals ...string) map[string]bool {
|
||||
m := make(map[string]bool, len(vals))
|
||||
for _, v := range vals {
|
||||
m[v] = true
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// StripShadowingFlags removes flags that shadow first-class settings from
|
||||
// the arg list. By default all shadowing groups are stripped.
|
||||
func StripShadowingFlags(args []string) []string {
|
||||
shadowing := make(map[string]bool)
|
||||
for k, v := range contextFlags {
|
||||
shadowing[k] = v
|
||||
}
|
||||
for k, v := range cacheFlags {
|
||||
shadowing[k] = v
|
||||
}
|
||||
for k, v := range specFlags {
|
||||
shadowing[k] = v
|
||||
}
|
||||
for k, v := range templateFlags {
|
||||
shadowing[k] = v
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(args))
|
||||
i, n := 0, len(args)
|
||||
for i < n {
|
||||
tok := args[i]
|
||||
flag := FlagName(tok)
|
||||
if flag == "" || !shadowing[flag] {
|
||||
out = append(out, tok)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if booleanShadowingFlags[flag] || strings.Contains(tok, "=") {
|
||||
i++
|
||||
} else if i+1 < n && FlagName(args[i+1]) == "" {
|
||||
i += 2
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
150
internal/validator/validator_test.go
Normal file
150
internal/validator/validator_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateExtraArgs_DenyList(t *testing.T) {
|
||||
denied := []string{
|
||||
"-m", "--model",
|
||||
"-mu", "--model-url",
|
||||
"-dr", "--docker-repo",
|
||||
"-hf", "-hfr", "--hf-repo",
|
||||
"-hff", "--hf-file",
|
||||
"-hfv", "-hfrv", "--hf-repo-v",
|
||||
"-hffv", "--hf-file-v",
|
||||
"-hft", "--hf-token",
|
||||
"-mm", "--mmproj",
|
||||
"-mmu", "--mmproj-url",
|
||||
"--host", "--port", "--path", "--api-prefix", "--reuse-port",
|
||||
"--api-key", "--api-key-file",
|
||||
"--ssl-key-file", "--ssl-cert-file",
|
||||
"--webui", "--no-webui", "--ui", "--no-ui",
|
||||
"--ui-config", "--ui-config-file",
|
||||
"--ui-mcp-proxy", "--no-ui-mcp-proxy",
|
||||
"--models-dir", "--models-preset", "--models-max",
|
||||
"--models-autoload", "--no-models-autoload",
|
||||
}
|
||||
for _, flag := range denied {
|
||||
t.Run(flag, func(t *testing.T) {
|
||||
_, err := ValidateExtraArgs([]string{flag})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for %s", flag)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateExtraArgs_SafeFlags(t *testing.T) {
|
||||
safe := []string{
|
||||
"-c", "--ctx-size", "-ngl", "--gpu-layers",
|
||||
"--top-k", "--cache-type-k", "--jinja", "--no-jinja",
|
||||
"--spec-draft-n-max", "-fa", "--flash-attn",
|
||||
"-t", "--threads", "-np", "--parallel", "--no-mmap",
|
||||
}
|
||||
for _, flag := range safe {
|
||||
t.Run(flag, func(t *testing.T) {
|
||||
out, err := ValidateExtraArgs([]string{flag})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for %s: %v", flag, err)
|
||||
}
|
||||
if len(out) != 1 || out[0] != flag {
|
||||
t.Fatalf("expected [%s], got %v", flag, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateExtraArgs_FlagEqualsValue(t *testing.T) {
|
||||
_, err := ValidateExtraArgs([]string{"--model=evil.gguf"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for --model=evil.gguf")
|
||||
}
|
||||
out, err := ValidateExtraArgs([]string{"--ctx-size=4096"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out) != 1 || out[0] != "--ctx-size=4096" {
|
||||
t.Fatalf("expected [--ctx-size=4096], got %v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateExtraArgs_NegativeNumber(t *testing.T) {
|
||||
out, err := ValidateExtraArgs([]string{"--seed", "-1"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("expected 2 tokens, got %d", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateExtraArgs_Empty(t *testing.T) {
|
||||
out, err := ValidateExtraArgs(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatalf("expected nil, got %v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsManagedFlag(t *testing.T) {
|
||||
if !IsManagedFlag("--model") {
|
||||
t.Fatal("--model should be managed")
|
||||
}
|
||||
if !IsManagedFlag("-m") {
|
||||
t.Fatal("-m should be managed")
|
||||
}
|
||||
if IsManagedFlag("-c") {
|
||||
t.Fatal("-c should not be managed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlagName(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, want string
|
||||
}{
|
||||
{"--model=foo", "--model"},
|
||||
{"-c", "-c"},
|
||||
{"--top-k", "--top-k"},
|
||||
{"-1", ""},
|
||||
{"-0.5", ""},
|
||||
{"-", ""},
|
||||
{"--", ""},
|
||||
{"hello", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := FlagName(tt.in)
|
||||
if got != tt.want {
|
||||
t.Errorf("FlagName(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripShadowingFlags(t *testing.T) {
|
||||
t.Run("strips context flag with value", func(t *testing.T) {
|
||||
out := StripShadowingFlags([]string{"-c", "4096", "--top-k", "40"})
|
||||
if len(out) != 2 || out[0] != "--top-k" || out[1] != "40" {
|
||||
t.Fatalf("got %v", out)
|
||||
}
|
||||
})
|
||||
t.Run("retains non-shadowing flags", func(t *testing.T) {
|
||||
out := StripShadowingFlags([]string{"--top-k", "40", "--top-p", "0.95"})
|
||||
if len(out) != 4 {
|
||||
t.Fatalf("got %v", out)
|
||||
}
|
||||
})
|
||||
t.Run("strips boolean jinja flag", func(t *testing.T) {
|
||||
out := StripShadowingFlags([]string{"--jinja", "--top-k", "40"})
|
||||
if len(out) != 2 || out[0] != "--top-k" {
|
||||
t.Fatalf("got %v", out)
|
||||
}
|
||||
})
|
||||
t.Run("strips equals form", func(t *testing.T) {
|
||||
out := StripShadowingFlags([]string{"--ctx-size=4096"})
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("got %v", out)
|
||||
}
|
||||
})
|
||||
}
|
||||
26
internal/winsvc/winsvc_unix.go
Normal file
26
internal/winsvc/winsvc_unix.go
Normal file
@@ -0,0 +1,26 @@
|
||||
//go:build !windows
|
||||
|
||||
package winsvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
func RegisterShutdownHandler(ctx context.Context, shutdownFunc func(context.Context) error) {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
||||
<-sigCh
|
||||
slog.Info("shutdown signal received")
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
if err := shutdownFunc(shutdownCtx); err != nil {
|
||||
slog.Error("shutdown error", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
25
internal/winsvc/winsvc_windows.go
Normal file
25
internal/winsvc/winsvc_windows.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build windows
|
||||
|
||||
package winsvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"time"
|
||||
)
|
||||
|
||||
func RegisterShutdownHandler(ctx context.Context, shutdownFunc func(context.Context) error) {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, os.Interrupt)
|
||||
<-sigCh
|
||||
slog.Info("shutdown signal received")
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
if err := shutdownFunc(shutdownCtx); err != nil {
|
||||
slog.Error("shutdown error", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
Reference in New Issue
Block a user