package validator import ( "testing" ) func TestValidateExtraArgs_DenyList(t *testing.T) { denied := []string{ "-m", "--model", "-mu", "--model-url", "-dr", "--docker-repo", "-hf", "-hfr", "--hf-repo", "-hff", "--hf-file", "-hfv", "-hfrv", "--hf-repo-v", "-hffv", "--hf-file-v", "-hft", "--hf-token", "-mm", "--mmproj", "-mmu", "--mmproj-url", "--host", "--port", "--path", "--api-prefix", "--reuse-port", "--api-key", "--api-key-file", "--ssl-key-file", "--ssl-cert-file", "--webui", "--no-webui", "--ui", "--no-ui", "--ui-config", "--ui-config-file", "--ui-mcp-proxy", "--no-ui-mcp-proxy", "--models-dir", "--models-preset", "--models-max", "--models-autoload", "--no-models-autoload", } for _, flag := range denied { t.Run(flag, func(t *testing.T) { _, err := ValidateExtraArgs([]string{flag}) if err == nil { t.Fatalf("expected error for %s", flag) } }) } } func TestValidateExtraArgs_SafeFlags(t *testing.T) { safe := []string{ "-c", "--ctx-size", "-ngl", "--gpu-layers", "--top-k", "--cache-type-k", "--jinja", "--no-jinja", "--spec-draft-n-max", "-fa", "--flash-attn", "-t", "--threads", "-np", "--parallel", "--no-mmap", } for _, flag := range safe { t.Run(flag, func(t *testing.T) { out, err := ValidateExtraArgs([]string{flag}) if err != nil { t.Fatalf("unexpected error for %s: %v", flag, err) } if len(out) != 1 || out[0] != flag { t.Fatalf("expected [%s], got %v", flag, out) } }) } } func TestValidateExtraArgs_FlagEqualsValue(t *testing.T) { _, err := ValidateExtraArgs([]string{"--model=evil.gguf"}) if err == nil { t.Fatal("expected error for --model=evil.gguf") } out, err := ValidateExtraArgs([]string{"--ctx-size=4096"}) if err != nil { t.Fatal(err) } if len(out) != 1 || out[0] != "--ctx-size=4096" { t.Fatalf("expected [--ctx-size=4096], got %v", out) } } func TestValidateExtraArgs_NegativeNumber(t *testing.T) { out, err := ValidateExtraArgs([]string{"--seed", "-1"}) if err != nil { t.Fatal(err) } if len(out) != 2 { t.Fatalf("expected 2 tokens, got %d", len(out)) } } func TestValidateExtraArgs_Empty(t *testing.T) { out, err := ValidateExtraArgs(nil) if err != nil { t.Fatal(err) } if out != nil { t.Fatalf("expected nil, got %v", out) } } func TestIsManagedFlag(t *testing.T) { if !IsManagedFlag("--model") { t.Fatal("--model should be managed") } if !IsManagedFlag("-m") { t.Fatal("-m should be managed") } if IsManagedFlag("-c") { t.Fatal("-c should not be managed") } } func TestFlagName(t *testing.T) { tests := []struct { in, want string }{ {"--model=foo", "--model"}, {"-c", "-c"}, {"--top-k", "--top-k"}, {"-1", ""}, {"-0.5", ""}, {"-", ""}, {"--", ""}, {"hello", ""}, } for _, tt := range tests { got := FlagName(tt.in) if got != tt.want { t.Errorf("FlagName(%q) = %q, want %q", tt.in, got, tt.want) } } } func TestStripShadowingFlags(t *testing.T) { t.Run("strips context flag with value", func(t *testing.T) { out := StripShadowingFlags([]string{"-c", "4096", "--top-k", "40"}) if len(out) != 2 || out[0] != "--top-k" || out[1] != "40" { t.Fatalf("got %v", out) } }) t.Run("retains non-shadowing flags", func(t *testing.T) { out := StripShadowingFlags([]string{"--top-k", "40", "--top-p", "0.95"}) if len(out) != 4 { t.Fatalf("got %v", out) } }) t.Run("strips boolean jinja flag", func(t *testing.T) { out := StripShadowingFlags([]string{"--jinja", "--top-k", "40"}) if len(out) != 2 || out[0] != "--top-k" { t.Fatalf("got %v", out) } }) t.Run("strips equals form", func(t *testing.T) { out := StripShadowingFlags([]string{"--ctx-size=4096"}) if len(out) != 0 { t.Fatalf("got %v", out) } }) }