diff options
| author | Paul Buetow <paul@buetow.org> | 2025-09-26 20:19:41 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-09-26 20:19:41 +0300 |
| commit | 1731126b52e406a300270c8fc8ac1061a4422b27 (patch) | |
| tree | c74768df49994aa9676cbc69ebfb461ed0422e01 | |
| parent | 0583b360ceb606b8e58f12a17f588bd27feeb117 (diff) | |
Refactor surface config to support multi-provider arrays
| -rw-r--r-- | PLAN3.md | 24 | ||||
| -rw-r--r-- | config.toml.example | 21 | ||||
| -rw-r--r-- | internal/appconfig/config.go | 232 | ||||
| -rw-r--r-- | internal/appconfig/config_env_model_test.go | 32 | ||||
| -rw-r--r-- | internal/appconfig/config_test.go | 96 | ||||
| -rw-r--r-- | internal/hexaiaction/prompts.go | 31 | ||||
| -rw-r--r-- | internal/hexaiaction/prompts_more_test.go | 7 | ||||
| -rw-r--r-- | internal/hexaiaction/run.go | 6 | ||||
| -rw-r--r-- | internal/hexaicli/run.go | 56 | ||||
| -rw-r--r-- | internal/hexaicli/run_test.go | 6 | ||||
| -rw-r--r-- | internal/lsp/handlers_utils.go | 38 | ||||
| -rw-r--r-- | internal/runtimeconfig/store.go | 24 | ||||
| -rw-r--r-- | internal/runtimeconfig/store_test.go | 22 |
13 files changed, 330 insertions, 265 deletions
diff --git a/PLAN3.md b/PLAN3.md new file mode 100644 index 0000000..c03405c --- /dev/null +++ b/PLAN3.md @@ -0,0 +1,24 @@ +# Parallel Provider/Model Comparison Plan + +Goal: allow configuring multiple provider:model pairs per surface so users can run a local and cloud LLM side-by-side for manual comparison. Completions should fan out and show one suggestion per configured entry; Hexai CLI should emit one response per entry; code actions remain single-provider. + +## Phase 1 – Configuration & Schema +- [x] Audit existing per-surface config to support arrays of provider:model entries (preserving current single-entry behavior by default). +- [x] Design updated TOML and env schema (e.g., `[[models.completion]] provider="openai" model="gpt-4o"`). +- [x] Define merge, validation, and backward-compatibility rules (single entry auto-wraps into list). + +## Phase 2 – Runtime Plumbing +- Extend appconfig/runtime store to emit ordered slices for multi-entry surfaces, including diff output. +- Update request-spec helpers to iterate across configured entries, building dedicated request specs (and caching clients per provider/model combo). +- Ensure logging/stats capture provider/model context per entry. + +## Phase 3 – Surface Implementations +- Completion: fan out requests sequentially, gather one suggestion per entry, and surface them distinctly to the editor (label with provider/model). +- CLI: stream or print separate responses per entry, with clear headers and stats per run. +- Code actions: keep single-provider flow but ensure config ignores extra entries with validation warnings. +- Add reasonable concurrency limits / timeouts so multi-provider usage stays responsive. + +## Phase 4 – UX & Validation +- Tests covering multi-entry parsing, diffing, and surface-specific behavior (mock providers to simulate dual responses). +- Update docs and example TOML with new array syntax, including env override strategy. +- Capture lessons/issues in scratchpad for follow-up polishing. diff --git a/config.toml.example b/config.toml.example index e5a75f4..1bed1a9 100644 --- a/config.toml.example +++ b/config.toml.example @@ -28,18 +28,29 @@ chat_suffix = ">" # single-character chat_prefixes = ["?", "!", ":", ";"] # single-character items [models] -# Shorthand string form per surface +# Shorthand string form per surface (single entry) # completion = "gpt-4o-mini" # chat = "gpt-4.1" -[models.code_action] -# model = "gpt-4o" +# Full array form for multiple entries +# [[models.completion]] +# provider = "openai" +# model = "gpt-4o-mini" +# temperature = 0.2 +# +# [[models.completion]] +# provider = "ollama" +# model = "mistral" +# temperature = 0.2 + +# [[models.code_action]] # provider = "copilot" +# model = "gpt-4o" # temperature = 0.4 -[models.cli] -# model = "gpt-4.1" +# [[models.cli]] # provider = "openai" +# model = "gpt-4.1" # temperature = 0.6 [provider] diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go index 47abaaf..63d0437 100644 --- a/internal/appconfig/config.go +++ b/internal/appconfig/config.go @@ -13,6 +13,13 @@ import ( "github.com/pelletier/go-toml/v2" ) +// SurfaceConfig describes a provider/model pairing (with optional temperature). +type SurfaceConfig struct { + Provider string + Model string + Temperature *float64 +} + // App holds user-configurable settings read from ~/.config/hexai/config.toml. type App struct { MaxTokens int `json:"max_tokens" toml:"max_tokens"` @@ -58,19 +65,11 @@ type App struct { // Default temperature for Copilot requests (nil means use provider default) CopilotTemperature *float64 `json:"copilot_temperature" toml:"copilot_temperature"` - // Per-surface model overrides (fall back to provider defaults when unset) - CompletionModel string `json:"completion_model" toml:"completion_model"` - CompletionTemperature *float64 `json:"completion_temperature" toml:"completion_temperature"` - CompletionProvider string `json:"completion_provider" toml:"completion_provider"` - CodeActionModel string `json:"code_action_model" toml:"code_action_model"` - CodeActionTemperature *float64 `json:"code_action_temperature" toml:"code_action_temperature"` - CodeActionProvider string `json:"code_action_provider" toml:"code_action_provider"` - ChatModel string `json:"chat_model" toml:"chat_model"` - ChatTemperature *float64 `json:"chat_temperature" toml:"chat_temperature"` - ChatProvider string `json:"chat_provider" toml:"chat_provider"` - CLIModel string `json:"cli_model" toml:"cli_model"` - CLITemperature *float64 `json:"cli_temperature" toml:"cli_temperature"` - CLIProvider string `json:"cli_provider" toml:"cli_provider"` + // Per-surface provider/model configurations (ordered; first entry is primary) + CompletionConfigs []SurfaceConfig `json:"-" toml:"-"` + CodeActionConfigs []SurfaceConfig `json:"-" toml:"-"` + ChatConfigs []SurfaceConfig `json:"-" toml:"-"` + CLIConfigs []SurfaceConfig `json:"-" toml:"-"` // Prompt templates (configured only via file; no env overrides) // Completion/chat/code action/CLI prompt strings. See config.toml.example for placeholders. @@ -662,72 +661,66 @@ func parseSurfaceModels(raw map[string]any, logger *log.Logger) *App { return nil } var out App - var any bool - if model, provider, temp, ok := decodeModelEntry(table["completion"], "models.completion", logger); ok { - if model != "" { - out.CompletionModel = model - } - if provider != "" { - out.CompletionProvider = provider - } - if temp != nil { - out.CompletionTemperature = temp - } - any = true - } - if model, provider, temp, ok := decodeModelEntry(table["code_action"], "models.code_action", logger); ok { - if model != "" { - out.CodeActionModel = model - } - if provider != "" { - out.CodeActionProvider = provider - } - if temp != nil { - out.CodeActionTemperature = temp - } - any = true - } - if model, provider, temp, ok := decodeModelEntry(table["chat"], "models.chat", logger); ok { - if model != "" { - out.ChatModel = model - } - if provider != "" { - out.ChatProvider = provider - } - if temp != nil { - out.ChatTemperature = temp - } - any = true + appendEntries := func(dest *[]SurfaceConfig, key string, val any) bool { + entries, ok := parseSurfaceEntries(val, key, logger) + if !ok || len(entries) == 0 { + return false + } + *dest = append(*dest, entries...) + return true + } + any := appendEntries(&out.CompletionConfigs, "models.completion", table["completion"]) + any = appendEntries(&out.CodeActionConfigs, "models.code_action", table["code_action"]) || any + any = appendEntries(&out.ChatConfigs, "models.chat", table["chat"]) || any + any = appendEntries(&out.CLIConfigs, "models.cli", table["cli"]) || any + if !any { + return nil } - if model, provider, temp, ok := decodeModelEntry(table["cli"], "models.cli", logger); ok { - if model != "" { - out.CLIModel = model - } - if provider != "" { - out.CLIProvider = provider + return &out +} + +func parseSurfaceEntries(raw any, path string, logger *log.Logger) ([]SurfaceConfig, bool) { + switch v := raw.(type) { + case nil: + return nil, false + case []any: + var out []SurfaceConfig + for i, entry := range v { + cfg, ok := decodeModelEntry(entry, fmt.Sprintf("%s[%d]", path, i), logger) + if !ok || cfg == nil { + continue + } + out = append(out, *cfg) } - if temp != nil { - out.CLITemperature = temp + return out, len(out) > 0 + default: + if cfg, ok := decodeModelEntry(v, path, logger); ok && cfg != nil { + return []SurfaceConfig{*cfg}, true } - any = true + return nil, false } - if !any { +} + +func cloneSurfaceConfigs(src []SurfaceConfig) []SurfaceConfig { + if len(src) == 0 { return nil } - return &out + out := make([]SurfaceConfig, len(src)) + copy(out, src) + return out } -func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, *float64, bool) { +func decodeModelEntry(raw any, path string, logger *log.Logger) (*SurfaceConfig, bool) { if raw == nil { - return "", "", nil, false + return nil, false } switch v := raw.(type) { case string: model := strings.TrimSpace(v) if model == "" { - return "", "", nil, false + return nil, false } - return model, "", nil, true + return &SurfaceConfig{Model: model}, true case map[string]any: model := "" provider := "" @@ -737,7 +730,7 @@ func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, if logger != nil { logger.Printf("config: %s.model must be a string", path) } - return "", "", nil, false + return nil, false } model = strings.TrimSpace(s) } @@ -747,7 +740,7 @@ func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, if logger != nil { logger.Printf("config: %s.provider must be a string", path) } - return "", "", nil, false + return nil, false } provider = strings.TrimSpace(ps) } @@ -755,19 +748,19 @@ func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, if tRaw, ok := v["temperature"]; ok { parsed, ok := parseTemperatureValue(tRaw, path, logger) if !ok { - return "", "", nil, false + return nil, false } tempPtr = parsed } if model == "" && tempPtr == nil && provider == "" { - return "", "", nil, false + return nil, false } - return model, provider, tempPtr, true + return &SurfaceConfig{Provider: provider, Model: model, Temperature: tempPtr}, true default: if logger != nil { logger.Printf("config: %s must be a string or table, got %T", path, raw) } - return "", "", nil, false + return nil, false } } @@ -861,41 +854,17 @@ func (a *App) mergeBasics(other *App) { // mergeSurfaceModels copies per-surface model and temperature overrides. func (a *App) mergeSurfaceModels(other *App) { - if s := strings.TrimSpace(other.CompletionModel); s != "" { - a.CompletionModel = s - } - if other.CompletionTemperature != nil { - a.CompletionTemperature = other.CompletionTemperature - } - if s := strings.TrimSpace(other.CompletionProvider); s != "" { - a.CompletionProvider = s + if len(other.CompletionConfigs) > 0 { + a.CompletionConfigs = cloneSurfaceConfigs(other.CompletionConfigs) } - if s := strings.TrimSpace(other.CodeActionModel); s != "" { - a.CodeActionModel = s + if len(other.CodeActionConfigs) > 0 { + a.CodeActionConfigs = cloneSurfaceConfigs(other.CodeActionConfigs) } - if other.CodeActionTemperature != nil { - a.CodeActionTemperature = other.CodeActionTemperature + if len(other.ChatConfigs) > 0 { + a.ChatConfigs = cloneSurfaceConfigs(other.ChatConfigs) } - if s := strings.TrimSpace(other.CodeActionProvider); s != "" { - a.CodeActionProvider = s - } - if s := strings.TrimSpace(other.ChatModel); s != "" { - a.ChatModel = s - } - if other.ChatTemperature != nil { - a.ChatTemperature = other.ChatTemperature - } - if s := strings.TrimSpace(other.ChatProvider); s != "" { - a.ChatProvider = s - } - if s := strings.TrimSpace(other.CLIModel); s != "" { - a.CLIModel = s - } - if other.CLITemperature != nil { - a.CLITemperature = other.CLITemperature - } - if s := strings.TrimSpace(other.CLIProvider); s != "" { - a.CLIProvider = s + if len(other.CLIConfigs) > 0 { + a.CLIConfigs = cloneSurfaceConfigs(other.CLIConfigs) } } @@ -1263,52 +1232,33 @@ func loadFromEnv(logger *log.Logger) *App { } // Per-surface overrides - if s := getenv("HEXAI_MODEL_COMPLETION"); s != "" { - out.CompletionModel = s - any = true - } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_COMPLETION"); ok { - out.CompletionTemperature = f - any = true - } - if s := getenv("HEXAI_PROVIDER_COMPLETION"); s != "" { - out.CompletionProvider = s - any = true - } - if s := getenv("HEXAI_MODEL_CODE_ACTION"); s != "" { - out.CodeActionModel = s - any = true - } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CODE_ACTION"); ok { - out.CodeActionTemperature = f - any = true - } - if s := getenv("HEXAI_PROVIDER_CODE_ACTION"); s != "" { - out.CodeActionProvider = s - any = true - } - if s := getenv("HEXAI_MODEL_CHAT"); s != "" { - out.ChatModel = s - any = true - } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CHAT"); ok { - out.ChatTemperature = f - any = true + buildEntry := func(modelKey, tempKey, providerKey string) ([]SurfaceConfig, bool) { + model := getenv(modelKey) + tempPtr, tempSet := parseFloatPtr(tempKey) + provider := getenv(providerKey) + if model == "" && provider == "" && !tempSet { + return nil, false + } + entry := SurfaceConfig{Provider: provider, Model: model} + if tempSet { + entry.Temperature = tempPtr + } + return []SurfaceConfig{entry}, true } - if s := getenv("HEXAI_PROVIDER_CHAT"); s != "" { - out.ChatProvider = s + if entries, ok := buildEntry("HEXAI_MODEL_COMPLETION", "HEXAI_TEMPERATURE_COMPLETION", "HEXAI_PROVIDER_COMPLETION"); ok { + out.CompletionConfigs = entries any = true } - if s := getenv("HEXAI_MODEL_CLI"); s != "" { - out.CLIModel = s + if entries, ok := buildEntry("HEXAI_MODEL_CODE_ACTION", "HEXAI_TEMPERATURE_CODE_ACTION", "HEXAI_PROVIDER_CODE_ACTION"); ok { + out.CodeActionConfigs = entries any = true } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CLI"); ok { - out.CLITemperature = f + if entries, ok := buildEntry("HEXAI_MODEL_CHAT", "HEXAI_TEMPERATURE_CHAT", "HEXAI_PROVIDER_CHAT"); ok { + out.ChatConfigs = entries any = true } - if s := getenv("HEXAI_PROVIDER_CLI"); s != "" { - out.CLIProvider = s + if entries, ok := buildEntry("HEXAI_MODEL_CLI", "HEXAI_TEMPERATURE_CLI", "HEXAI_PROVIDER_CLI"); ok { + out.CLIConfigs = entries any = true } diff --git a/internal/appconfig/config_env_model_test.go b/internal/appconfig/config_env_model_test.go index f34416d..7038819 100644 --- a/internal/appconfig/config_env_model_test.go +++ b/internal/appconfig/config_env_model_test.go @@ -44,22 +44,30 @@ func TestEnv_SurfaceModelOverrides(t *testing.T) { t.Setenv("HEXAI_TEMPERATURE_CLI", "0.22") t.Setenv("HEXAI_PROVIDER_CLI", "ollama") cfg := Load(log.New(os.Stderr, "test ", 0)) - if cfg.CompletionModel != "gpt-c" { - t.Fatalf("expected completion model override, got %q", cfg.CompletionModel) + if len(cfg.CompletionConfigs) != 1 { + t.Fatalf("expected single completion entry, got %+v", cfg.CompletionConfigs) } - if cfg.CompletionTemperature == nil || *cfg.CompletionTemperature != 0.44 { - t.Fatalf("expected completion temperature override, got %v", cfg.CompletionTemperature) + comp := cfg.CompletionConfigs[0] + if comp.Model != "gpt-c" { + t.Fatalf("expected completion model override, got %+v", comp) } - if cfg.CompletionProvider != "copilot" { - t.Fatalf("expected completion provider override, got %q", cfg.CompletionProvider) + if comp.Temperature == nil || *comp.Temperature != 0.44 { + t.Fatalf("expected completion temperature override, got %+v", comp) } - if cfg.CLIModel != "gpt-cli" { - t.Fatalf("expected cli model override, got %q", cfg.CLIModel) + if comp.Provider != "copilot" { + t.Fatalf("expected completion provider override, got %+v", comp) } - if cfg.CLITemperature == nil || *cfg.CLITemperature != 0.22 { - t.Fatalf("expected cli temperature override, got %v", cfg.CLITemperature) + if len(cfg.CLIConfigs) != 1 { + t.Fatalf("expected single CLI entry, got %+v", cfg.CLIConfigs) } - if cfg.CLIProvider != "ollama" { - t.Fatalf("expected cli provider override, got %q", cfg.CLIProvider) + cli := cfg.CLIConfigs[0] + if cli.Model != "gpt-cli" { + t.Fatalf("expected cli model override, got %+v", cli) + } + if cli.Temperature == nil || *cli.Temperature != 0.22 { + t.Fatalf("expected cli temperature override, got %+v", cli) + } + if cli.Provider != "ollama" { + t.Fatalf("expected cli provider override, got %+v", cli) } } diff --git a/internal/appconfig/config_test.go b/internal/appconfig/config_test.go index ea68305..e7f6059 100644 --- a/internal/appconfig/config_test.go +++ b/internal/appconfig/config_test.go @@ -88,20 +88,20 @@ completion_throttle_ms = 300 [triggers] trigger_characters = [".", ":"] -[models.completion] +[[models.completion]] model = "gpt-file-complete" provider = "openai" -[models.code_action] +[[models.code_action]] model = "gpt-file-action" temperature = 0.45 provider = "copilot" -[models.chat] +[[models.chat]] model = "gpt-file-chat" provider = "openai" -[models.cli] +[[models.cli]] model = "gpt-file-cli" temperature = 0.15 provider = "ollama" @@ -192,29 +192,41 @@ temperature = 0.0 if cfg.CopilotBaseURL != "http://copilot-override" || cfg.CopilotModel != "ghost-override" || cfg.CopilotTemperature == nil || *cfg.CopilotTemperature != 0.3 { t.Fatalf("copilot overrides not applied: %+v", cfg) } - if cfg.CompletionModel != "env-completion" || cfg.CompletionTemperature == nil || *cfg.CompletionTemperature != 0.33 { - t.Fatalf("completion overrides not applied: model=%q temp=%v", cfg.CompletionModel, cfg.CompletionTemperature) + if len(cfg.CompletionConfigs) != 1 || cfg.CompletionConfigs[0].Model != "env-completion" { + t.Fatalf("completion overrides not applied: %+v", cfg.CompletionConfigs) } - if cfg.CompletionProvider != "copilot" { - t.Fatalf("completion provider override not applied: %q", cfg.CompletionProvider) + if cfg.CompletionConfigs[0].Temperature == nil || *cfg.CompletionConfigs[0].Temperature != 0.33 { + t.Fatalf("completion temperature override missing: %+v", cfg.CompletionConfigs[0]) } - if cfg.CodeActionModel != "env-action" || cfg.CodeActionTemperature == nil || *cfg.CodeActionTemperature != 0.55 { - t.Fatalf("code action overrides not applied: model=%q temp=%v", cfg.CodeActionModel, cfg.CodeActionTemperature) + if cfg.CompletionConfigs[0].Provider != "copilot" { + t.Fatalf("completion provider override not applied: %+v", cfg.CompletionConfigs[0]) } - if cfg.CodeActionProvider != "openai" { - t.Fatalf("code action provider override not applied: %q", cfg.CodeActionProvider) + if len(cfg.CodeActionConfigs) != 1 || cfg.CodeActionConfigs[0].Model != "env-action" { + t.Fatalf("code action overrides not applied: %+v", cfg.CodeActionConfigs) } - if cfg.ChatModel != "env-chat" || cfg.ChatTemperature == nil || *cfg.ChatTemperature != 0.66 { - t.Fatalf("chat overrides not applied: model=%q temp=%v", cfg.ChatModel, cfg.ChatTemperature) + if cfg.CodeActionConfigs[0].Temperature == nil || *cfg.CodeActionConfigs[0].Temperature != 0.55 { + t.Fatalf("code action temp override missing: %+v", cfg.CodeActionConfigs[0]) } - if cfg.ChatProvider != "copilot" { - t.Fatalf("chat provider override not applied: %q", cfg.ChatProvider) + if cfg.CodeActionConfigs[0].Provider != "openai" { + t.Fatalf("code action provider override not applied: %+v", cfg.CodeActionConfigs[0]) } - if cfg.CLIModel != "env-cli" || cfg.CLITemperature == nil || *cfg.CLITemperature != 0.77 { - t.Fatalf("cli overrides not applied: model=%q temp=%v", cfg.CLIModel, cfg.CLITemperature) + if len(cfg.ChatConfigs) != 1 || cfg.ChatConfigs[0].Model != "env-chat" { + t.Fatalf("chat overrides not applied: %+v", cfg.ChatConfigs) } - if cfg.CLIProvider != "ollama" { - t.Fatalf("cli provider override not applied: %q", cfg.CLIProvider) + if cfg.ChatConfigs[0].Temperature == nil || *cfg.ChatConfigs[0].Temperature != 0.66 { + t.Fatalf("chat temp override missing: %+v", cfg.ChatConfigs[0]) + } + if cfg.ChatConfigs[0].Provider != "copilot" { + t.Fatalf("chat provider override not applied: %+v", cfg.ChatConfigs[0]) + } + if len(cfg.CLIConfigs) != 1 || cfg.CLIConfigs[0].Model != "env-cli" { + t.Fatalf("cli overrides not applied: %+v", cfg.CLIConfigs) + } + if cfg.CLIConfigs[0].Temperature == nil || *cfg.CLIConfigs[0].Temperature != 0.77 { + t.Fatalf("cli temp override missing: %+v", cfg.CLIConfigs[0]) + } + if cfg.CLIConfigs[0].Provider != "ollama" { + t.Fatalf("cli provider override not applied: %+v", cfg.CLIConfigs[0]) } // Ensure file values would have applied absent env @@ -234,29 +246,41 @@ temperature = 0.0 if cfg2.OpenAIBaseURL != "https://api.example" || cfg2.OpenAIModel != "gpt-x" || cfg2.OpenAITemperature == nil || *cfg2.OpenAITemperature != 0.0 { t.Fatalf("file merge (openai) not applied: %+v", cfg2) } - if cfg2.CompletionModel != "gpt-file-complete" || cfg2.CompletionTemperature != nil { - t.Fatalf("file merge (completion) not applied: %+v", cfg2) + if len(cfg2.CompletionConfigs) != 1 || cfg2.CompletionConfigs[0].Model != "gpt-file-complete" { + t.Fatalf("file merge (completion) not applied: %+v", cfg2.CompletionConfigs) + } + if cfg2.CompletionConfigs[0].Temperature != nil { + t.Fatalf("expected nil completion temperature, got %+v", cfg2.CompletionConfigs[0]) + } + if cfg2.CompletionConfigs[0].Provider != "openai" { + t.Fatalf("file merge (completion provider) not applied: %+v", cfg2.CompletionConfigs[0]) + } + if len(cfg2.CodeActionConfigs) != 1 || cfg2.CodeActionConfigs[0].Model != "gpt-file-action" { + t.Fatalf("file merge (code action) not applied: %+v", cfg2.CodeActionConfigs) + } + if cfg2.CodeActionConfigs[0].Temperature == nil || *cfg2.CodeActionConfigs[0].Temperature != 0.45 { + t.Fatalf("expected code action temp 0.45, got %+v", cfg2.CodeActionConfigs[0]) } - if cfg2.CompletionProvider != "openai" { - t.Fatalf("file merge (completion provider) not applied: %q", cfg2.CompletionProvider) + if cfg2.CodeActionConfigs[0].Provider != "copilot" { + t.Fatalf("file merge (code action provider) not applied: %+v", cfg2.CodeActionConfigs[0]) } - if cfg2.CodeActionModel != "gpt-file-action" || cfg2.CodeActionTemperature == nil || *cfg2.CodeActionTemperature != 0.45 { - t.Fatalf("file merge (code action) not applied: %+v", cfg2) + if len(cfg2.ChatConfigs) != 1 || cfg2.ChatConfigs[0].Model != "gpt-file-chat" { + t.Fatalf("file merge (chat) not applied: %+v", cfg2.ChatConfigs) } - if cfg2.CodeActionProvider != "copilot" { - t.Fatalf("file merge (code action provider) not applied: %q", cfg2.CodeActionProvider) + if cfg2.ChatConfigs[0].Temperature != nil { + t.Fatalf("expected nil chat temp, got %+v", cfg2.ChatConfigs[0]) } - if cfg2.ChatModel != "gpt-file-chat" || cfg2.ChatTemperature != nil { - t.Fatalf("file merge (chat) not applied: %+v", cfg2) + if cfg2.ChatConfigs[0].Provider != "openai" { + t.Fatalf("file merge (chat provider) not applied: %+v", cfg2.ChatConfigs[0]) } - if cfg2.ChatProvider != "openai" { - t.Fatalf("file merge (chat provider) not applied: %q", cfg2.ChatProvider) + if len(cfg2.CLIConfigs) != 1 || cfg2.CLIConfigs[0].Model != "gpt-file-cli" { + t.Fatalf("file merge (cli) not applied: %+v", cfg2.CLIConfigs) } - if cfg2.CLIModel != "gpt-file-cli" || cfg2.CLITemperature == nil || *cfg2.CLITemperature != 0.15 { - t.Fatalf("file merge (cli) not applied: %+v", cfg2) + if cfg2.CLIConfigs[0].Temperature == nil || *cfg2.CLIConfigs[0].Temperature != 0.15 { + t.Fatalf("expected CLI temp 0.15, got %+v", cfg2.CLIConfigs[0]) } - if cfg2.CLIProvider != "ollama" { - t.Fatalf("file merge (cli provider) not applied: %q", cfg2.CLIProvider) + if cfg2.CLIConfigs[0].Provider != "ollama" { + t.Fatalf("file merge (cli provider) not applied: %+v", cfg2.CLIConfigs[0]) } } diff --git a/internal/hexaiaction/prompts.go b/internal/hexaiaction/prompts.go index 47dadbf..a113391 100644 --- a/internal/hexaiaction/prompts.go +++ b/internal/hexaiaction/prompts.go @@ -56,9 +56,9 @@ func defaultModelForProvider(cfg appconfig.App, provider string) string { } } -func selectActionTemperature(cfg appconfig.App, provider, model string) (float64, bool) { - if cfg.CodeActionTemperature != nil { - return *cfg.CodeActionTemperature, true +func selectActionTemperature(cfg appconfig.App, provider string, entry appconfig.SurfaceConfig, model string) (float64, bool) { + if entry.Temperature != nil { + return *entry.Temperature, true } if cfg.CodingTemperature != nil { temp := *cfg.CodingTemperature @@ -201,22 +201,25 @@ func reqOptsFrom(cfg appconfig.App) requestArgs { opts = append(opts, llm.WithMaxTokens(cfg.MaxTokens)) } provider := canonicalProvider(cfg.Provider) - if strings.TrimSpace(cfg.CodeActionProvider) != "" { - provider = canonicalProvider(cfg.CodeActionProvider) + entries := cfg.CodeActionConfigs + if len(entries) == 0 { + entries = []appconfig.SurfaceConfig{{Provider: cfg.Provider, Model: strings.TrimSpace(defaultModelForProvider(cfg, provider))}} } - override := strings.TrimSpace(cfg.CodeActionModel) - fallback := strings.TrimSpace(defaultModelForProvider(cfg, provider)) - effective := override - if effective == "" { - effective = fallback + primary := entries[0] + if strings.TrimSpace(primary.Provider) != "" { + provider = canonicalProvider(primary.Provider) } - if override != "" { - opts = append(opts, llm.WithModel(override)) + model := strings.TrimSpace(primary.Model) + if model == "" { + model = strings.TrimSpace(defaultModelForProvider(cfg, provider)) + } + if strings.TrimSpace(primary.Model) != "" { + opts = append(opts, llm.WithModel(strings.TrimSpace(primary.Model))) } - if temp, ok := selectActionTemperature(cfg, provider, effective); ok { + if temp, ok := selectActionTemperature(cfg, provider, primary, model); ok { opts = append(opts, llm.WithTemperature(temp)) } - return requestArgs{model: effective, options: opts} + return requestArgs{model: model, options: opts} } // Timeout helpers to mirror LSP behavior. diff --git a/internal/hexaiaction/prompts_more_test.go b/internal/hexaiaction/prompts_more_test.go index 97d3979..cfccd0c 100644 --- a/internal/hexaiaction/prompts_more_test.go +++ b/internal/hexaiaction/prompts_more_test.go @@ -32,7 +32,12 @@ func TestRunOnce_StripsFences(t *testing.T) { } func TestReqOptsFrom_Override(t *testing.T) { - cfg := appconfig.App{MaxTokens: 123, CodeActionModel: "override", CodeActionTemperature: ptrFloat(0.6), Provider: "openai", CodeActionProvider: "copilot", CopilotModel: "gpt-4o"} + cfg := appconfig.App{ + MaxTokens: 123, + Provider: "openai", + CopilotModel: "gpt-4o", + CodeActionConfigs: []appconfig.SurfaceConfig{{Provider: "copilot", Model: "override", Temperature: ptrFloat(0.6)}}, + } req := reqOptsFrom(cfg) if req.model != "override" { t.Fatalf("expected override model, got %q", req.model) diff --git a/internal/hexaiaction/run.go b/internal/hexaiaction/run.go index 953da80..a5f47cf 100644 --- a/internal/hexaiaction/run.go +++ b/internal/hexaiaction/run.go @@ -41,8 +41,10 @@ func Run(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer) error { if len(cfg.CustomActions) > 0 { chooseActionFn = func() (ActionKind, error) { return RunTUIWithCustom(cfg.CustomActions, cfg.TmuxCustomMenuHotkey) } } - if providerOverride := strings.TrimSpace(cfg.CodeActionProvider); providerOverride != "" { - cfg.Provider = providerOverride + if len(cfg.CodeActionConfigs) > 0 { + if provider := strings.TrimSpace(cfg.CodeActionConfigs[0].Provider); provider != "" { + cfg.Provider = provider + } } cli, err := newClientFromApp(cfg) if err != nil { diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index b965261..5f6284c 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -27,44 +27,49 @@ type requestArgs struct { func buildCLIRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { provider := canonicalProvider(cfg.Provider) - if strings.TrimSpace(cfg.CLIProvider) != "" { - provider = canonicalProvider(cfg.CLIProvider) + entries := cfg.CLIConfigs + if len(entries) == 0 { + entries = []appconfig.SurfaceConfig{{Provider: cfg.Provider, Model: strings.TrimSpace(defaultModelForProvider(cfg, provider))}} } - if client != nil { - provider = strings.ToLower(strings.TrimSpace(client.Name())) + primary := entries[0] + if strings.TrimSpace(primary.Provider) != "" { + provider = canonicalProvider(primary.Provider) } - override := strings.TrimSpace(cfg.CLIModel) - fallback := strings.TrimSpace(defaultModelForProvider(cfg, provider)) + model := strings.TrimSpace(primary.Model) if client != nil { - if dm := strings.TrimSpace(client.DefaultModel()); dm != "" { - fallback = dm + provider = strings.ToLower(strings.TrimSpace(client.Name())) + if model == "" { + model = strings.TrimSpace(client.DefaultModel()) } } - effective := override - if effective == "" { - effective = fallback + if model == "" { + model = strings.TrimSpace(defaultModelForProvider(cfg, provider)) } opts := make([]llm.RequestOption, 0, 2) - if override != "" { - opts = append(opts, llm.WithModel(override)) + if strings.TrimSpace(primary.Model) != "" { + opts = append(opts, llm.WithModel(strings.TrimSpace(primary.Model))) } - if temp, ok := cliTemperature(cfg, provider, effective); ok { + if temp, ok := cliTemperatureFromEntry(cfg, provider, primary, model); ok { opts = append(opts, llm.WithTemperature(temp)) } - return requestArgs{model: effective, options: opts} + return requestArgs{model: model, options: opts} } func defaultRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { - model := strings.TrimSpace(cfg.CLIModel) - if model == "" && client != nil { - model = strings.TrimSpace(client.DefaultModel()) + if len(cfg.CLIConfigs) > 0 { + if m := strings.TrimSpace(cfg.CLIConfigs[0].Model); m != "" { + return requestArgs{model: m} + } + } + if client != nil { + return requestArgs{model: strings.TrimSpace(client.DefaultModel())} } - return requestArgs{model: model} + return requestArgs{} } -func cliTemperature(cfg appconfig.App, provider, model string) (float64, bool) { - if cfg.CLITemperature != nil { - return *cfg.CLITemperature, true +func cliTemperatureFromEntry(cfg appconfig.App, provider string, entry appconfig.SurfaceConfig, model string) (float64, bool) { + if entry.Temperature != nil { + return *entry.Temperature, true } if cfg.CodingTemperature != nil { temp := *cfg.CodingTemperature @@ -107,9 +112,10 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. if cfg.StatsWindowMinutes > 0 { stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) } - providerOverride := strings.TrimSpace(cfg.CLIProvider) - if providerOverride != "" { - cfg.Provider = providerOverride + if len(cfg.CLIConfigs) > 0 { + if provider := strings.TrimSpace(cfg.CLIConfigs[0].Provider); provider != "" { + cfg.Provider = provider + } } client, err := newClientFromApp(cfg) if err != nil { diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go index 4dcbbc5..0250ac9 100644 --- a/internal/hexaicli/run_test.go +++ b/internal/hexaicli/run_test.go @@ -152,7 +152,11 @@ func TestPrintProviderInfo(t *testing.T) { } func TestBuildCLIRequestArgs_Override(t *testing.T) { - cfg := appconfig.App{CLIModel: "override", CLITemperature: floatPtr(0.7), Provider: "openai", CLIProvider: "copilot", CopilotModel: "gpt-4o"} + cfg := appconfig.App{ + Provider: "openai", + CopilotModel: "gpt-4o", + CLIConfigs: []appconfig.SurfaceConfig{{Provider: "copilot", Model: "override", Temperature: floatPtr(0.7)}}, + } req := buildCLIRequestArgs(cfg, &fakeClient{name: "copilot", model: "default"}) if req.model != "override" { t.Fatalf("expected model override, got %q", req.model) diff --git a/internal/lsp/handlers_utils.go b/internal/lsp/handlers_utils.go index 3bd13ee..c8d2d24 100644 --- a/internal/lsp/handlers_utils.go +++ b/internal/lsp/handlers_utils.go @@ -81,43 +81,41 @@ func resolveDefaultModel(cfg appconfig.App, provider string) string { } } -func surfaceModelFromConfig(cfg appconfig.App, surface surfaceKind) string { +func surfaceConfigsFor(cfg appconfig.App, surface surfaceKind) []appconfig.SurfaceConfig { switch surface { case surfaceCompletion: - return cfg.CompletionModel + return cfg.CompletionConfigs case surfaceCodeAction: - return cfg.CodeActionModel + return cfg.CodeActionConfigs case surfaceChat: - return cfg.ChatModel + return cfg.ChatConfigs default: + return nil + } +} + +func surfaceModelFromConfig(cfg appconfig.App, surface surfaceKind) string { + configs := surfaceConfigsFor(cfg, surface) + if len(configs) == 0 { return "" } + return configs[0].Model } func surfaceProviderFromConfig(cfg appconfig.App, surface surfaceKind) string { - switch surface { - case surfaceCompletion: - return cfg.CompletionProvider - case surfaceCodeAction: - return cfg.CodeActionProvider - case surfaceChat: - return cfg.ChatProvider - default: + configs := surfaceConfigsFor(cfg, surface) + if len(configs) == 0 { return "" } + return configs[0].Provider } func surfaceTemperatureFromConfig(cfg appconfig.App, surface surfaceKind) *float64 { - switch surface { - case surfaceCompletion: - return cfg.CompletionTemperature - case surfaceCodeAction: - return cfg.CodeActionTemperature - case surfaceChat: - return cfg.ChatTemperature - default: + configs := surfaceConfigsFor(cfg, surface) + if len(configs) == 0 { return nil } + return configs[0].Temperature } func chooseSurfaceTemperature(surface surfaceKind, cfg appconfig.App, provider string, overrideModel, fallbackModel string) (float64, bool) { diff --git a/internal/runtimeconfig/store.go b/internal/runtimeconfig/store.go index 3112951..4ee7ada 100644 --- a/internal/runtimeconfig/store.go +++ b/internal/runtimeconfig/store.go @@ -129,6 +129,14 @@ func flattenAppConfig(cfg appconfig.App) map[string]string { switch field.Name { case "StatsWindowMinutes": key = "stats_window_minutes" + case "CompletionConfigs": + key = "completion_configs" + case "CodeActionConfigs": + key = "code_action_configs" + case "ChatConfigs": + key = "chat_configs" + case "CLIConfigs": + key = "cli_configs" default: continue } @@ -170,6 +178,22 @@ func stringifyValue(v reflect.Value) string { } return strings.Join(parts, ",") } + if v.Type().Elem() == reflect.TypeOf(appconfig.SurfaceConfig{}) { + parts := make([]string, 0, v.Len()) + for i := 0; i < v.Len(); i++ { + entry := v.Index(i).Interface().(appconfig.SurfaceConfig) + segment := strings.TrimSpace(entry.Provider) + if segment != "" { + segment += ":" + } + segment += strings.TrimSpace(entry.Model) + if entry.Temperature != nil { + segment += fmt.Sprintf("@%.3f", *entry.Temperature) + } + parts = append(parts, segment) + } + return strings.Join(parts, "|") + } return fmt.Sprint(v.Interface()) case reflect.Ptr: if v.IsNil() { diff --git a/internal/runtimeconfig/store_test.go b/internal/runtimeconfig/store_test.go index 6d23b33..1c05cc9 100644 --- a/internal/runtimeconfig/store_test.go +++ b/internal/runtimeconfig/store_test.go @@ -96,16 +96,22 @@ func TestStoreReloadLogsSummary(t *testing.T) { } func TestDiff_SurfaceModel(t *testing.T) { - oldCfg := appconfig.App{CompletionModel: "gpt-4o", CompletionProvider: "openai"} - newCfg := appconfig.App{CompletionModel: "gpt-4.1", CompletionProvider: "copilot"} + oldCfg := appconfig.App{CompletionConfigs: []appconfig.SurfaceConfig{{Provider: "openai", Model: "gpt-4o"}}} + newCfg := appconfig.App{CompletionConfigs: []appconfig.SurfaceConfig{{Provider: "copilot", Model: "gpt-4.1"}}} changes := Diff(oldCfg, newCfg) - if len(changes) != 2 { - t.Fatalf("expected single change, got %+v", changes) + if len(changes) == 0 { + t.Fatalf("expected diff entries, got none") } - if changes[0].Key != "completion_model" || changes[0].Old != "gpt-4o" || changes[0].New != "gpt-4.1" { - t.Fatalf("unexpected diff entry: %+v", changes[0]) + found := false + for _, ch := range changes { + if ch.Key == "completion_configs" { + if !strings.Contains(ch.Old, "gpt-4o") || !strings.Contains(ch.New, "gpt-4.1") { + t.Fatalf("unexpected diff contents: %+v", ch) + } + found = true + } } - if changes[1].Key != "completion_provider" || changes[1].Old != "openai" || changes[1].New != "copilot" { - t.Fatalf("unexpected provider diff: %+v", changes[1]) + if !found { + t.Fatalf("expected completion configs diff, got %+v", changes) } } |
