package pool import ( "reflect" "testing" ) func TestBuildArgs_PreservesNonOverlapping(t *testing.T) { base := []string{"-ngl", "999", "-c", "32768", "--flash-attn", "on", "--no-mmap"} user := []string{"--top-k", "20"} got := buildArgs(base, "/model.gguf", 8500, user) // -c 32768 must survive (user didn't supply -c) if !containsSeq(got, "-c", "32768") { t.Errorf("-c 32768 missing from args: %v", got) } // --top-k 20 must be present (user flag) if !containsSeq(got, "--top-k", "20") { t.Errorf("--top-k 20 missing from args: %v", got) } // --model and --port injected if !containsSeq(got, "--model", "/model.gguf") { t.Errorf("--model missing: %v", got) } if !containsSeq(got, "--port", "8500") { t.Errorf("--port missing: %v", got) } } func TestBuildArgs_UserOverridesBase(t *testing.T) { base := []string{"-ngl", "999", "-c", "32768"} user := []string{"-c", "131072"} got := buildArgs(base, "/model.gguf", 8500, user) // base -c should be dropped, user -c should be present count := 0 for i, tok := range got { if tok == "-c" && i+1 < len(got) { count++ if got[i+1] == "32768" { t.Errorf("base -c 32768 should have been deduped: %v", got) } } } if count != 1 { t.Errorf("expected exactly 1 -c flag, got %d in %v", count, got) } } func TestBuildArgs_NoUserFlags(t *testing.T) { base := []string{"-ngl", "999", "-c", "32768", "--no-mmap"} got := buildArgs(base, "/model.gguf", 8500, nil) if !containsSeq(got, "-c", "32768") { t.Errorf("-c 32768 missing when no user flags: %v", got) } if !containsSeq(got, "--no-mmap") { t.Errorf("--no-mmap missing: %v", got) } } func TestDedupFlags_Mixed(t *testing.T) { auto := []string{"--top-k", "40", "-c", "32768", "--no-mmap"} user := []string{"--top-k", "20"} got := dedupFlags(auto, user) want := []string{"-c", "32768", "--no-mmap"} if !reflect.DeepEqual(got, want) { t.Errorf("dedupFlags = %v, want %v", got, want) } } func TestDedupFlags_EqualsForm(t *testing.T) { auto := []string{"--ctx-size=4096", "--no-mmap"} user := []string{"--ctx-size", "8192"} got := dedupFlags(auto, user) want := []string{"--no-mmap"} if !reflect.DeepEqual(got, want) { t.Errorf("dedupFlags = %v, want %v", got, want) } } func containsSeq(args []string, seq ...string) bool { for i := 0; i <= len(args)-len(seq); i++ { match := true for j, s := range seq { if args[i+j] != s { match = false break } } if match { return true } } return false }