diff options
| -rw-r--r-- | PLAN2.md | 28 | ||||
| -rw-r--r-- | SCRATCHPAD.md | 1 | ||||
| -rw-r--r-- | config.toml.example | 15 | ||||
| -rw-r--r-- | docs/configuration.md | 23 | ||||
| -rw-r--r-- | internal/appconfig/config.go | 264 | ||||
| -rw-r--r-- | internal/appconfig/config_env_model_test.go | 74 | ||||
| -rw-r--r-- | internal/appconfig/config_test.go | 84 | ||||
| -rw-r--r-- | internal/hexaiaction/prompts.go | 86 | ||||
| -rw-r--r-- | internal/hexaiaction/prompts_more_test.go | 33 | ||||
| -rw-r--r-- | internal/hexaiaction/run.go | 9 | ||||
| -rw-r--r-- | internal/hexaicli/run.go | 119 | ||||
| -rw-r--r-- | internal/hexaicli/run_more_test.go | 3 | ||||
| -rw-r--r-- | internal/hexaicli/run_test.go | 44 | ||||
| -rw-r--r-- | internal/lsp/document_test.go | 9 | ||||
| -rw-r--r-- | internal/lsp/handlers_codeaction.go | 32 | ||||
| -rw-r--r-- | internal/lsp/handlers_completion.go | 56 | ||||
| -rw-r--r-- | internal/lsp/handlers_document.go | 15 | ||||
| -rw-r--r-- | internal/lsp/handlers_utils.go | 166 | ||||
| -rw-r--r-- | internal/lsp/llm_request_opts_test.go | 11 | ||||
| -rw-r--r-- | internal/lsp/llm_stats_test.go | 2 | ||||
| -rw-r--r-- | internal/lsp/provider_native_success_test.go | 9 | ||||
| -rw-r--r-- | internal/lsp/server.go | 85 | ||||
| -rw-r--r-- | internal/runtimeconfig/store_test.go | 15 |
23 files changed, 1040 insertions, 143 deletions
diff --git a/PLAN2.md b/PLAN2.md new file mode 100644 index 0000000..ff518e9 --- /dev/null +++ b/PLAN2.md @@ -0,0 +1,28 @@ +# Per-Surface LLM Model Configuration Plan + +Goal: allow users to configure distinct LLM models for (1) code completion, (2) code actions, (3) in-editor chat, and (4) the `hexai` CLI while keeping defaults sensible and maintaining backward compatibility. The new options must remain hot-reloadable via the existing runtime config store. + +## Phase 1 – Configuration Design +- [x] Audit current config structures (`internal/appconfig`) and identify the model/temperature fields each surface consumes. +- [x] Propose TOML schema extensions (e.g., `[models] completion = "..."`) plus environment variable overrides. +- [x] Define precedence rules and fallback behavior when only a global model is provided. +- [x] Sketch migration approach (default legacy fields map to all surfaces). + +## Phase 2 – Loader & Runtime Store Updates +- [x] Extend `appconfig` to parse per-surface model settings (and optional temperature overrides) with validation. +- [x] Update `runtimeconfig.Store` diff/flatten logic to include the new fields and guarantee reload propagation works without restart. +- [x] Ensure reload summaries list per-surface changes cleanly. +- [x] Add unit tests covering config parsing, env overrides, and diff output, plus runtime reload coverage. + +## Phase 3 – Surface Wiring +- [x] Completion: adjust LSP completion code to pick the configured completion model, falling back to provider defaults. +- [x] Code actions: ensure code-action prompts and CLI action runner request the configured model. +- [x] In-editor chat: pass chat-specific model to chat requests and CLI chat command handling. +- [x] Hexai CLI: respect the CLI model when building `llm.Config` or request options. +- [x] Provide logging to confirm which model each surface uses for easier debugging. + +## Phase 4 – Validation & Docs +- [x] Add integration/unit tests covering each surface model selection path. +- [x] Verify runtime reload switches models without restart (including diff output). +- [x] Update docs (`docs/configuration.md`, examples) with new keys and environment variables. +- [x] Announce in scratchpad or release notes placeholder for future update. diff --git a/SCRATCHPAD.md b/SCRATCHPAD.md index dba529c..c6b0f54 100644 --- a/SCRATCHPAD.md +++ b/SCRATCHPAD.md @@ -13,3 +13,4 @@ This document shows future items and items in progress. Already completed ones a * [/] Review documentation * [/] Manual review the code * [ ] ASCIInema: Record and share terminal sessions for demos and bug reports +* [ ] Release notes: highlight per-surface model overrides once bundled diff --git a/config.toml.example b/config.toml.example index 9ac6f51..e5a75f4 100644 --- a/config.toml.example +++ b/config.toml.example @@ -27,6 +27,21 @@ inline_close = ">" # single-character chat_suffix = ">" # single-character chat_prefixes = ["?", "!", ":", ";"] # single-character items +[models] +# Shorthand string form per surface +# completion = "gpt-4o-mini" +# chat = "gpt-4.1" + +[models.code_action] +# model = "gpt-4o" +# provider = "copilot" +# temperature = 0.4 + +[models.cli] +# model = "gpt-4.1" +# provider = "openai" +# temperature = 0.6 + [provider] name = "openai" # openai | copilot | ollama diff --git a/docs/configuration.md b/docs/configuration.md index 6239a4c..6db7a27 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -21,6 +21,29 @@ Environment overrides - `HEXAI_OPENAI_MODEL`, `HEXAI_OPENAI_BASE_URL`, `HEXAI_OPENAI_TEMPERATURE` - `HEXAI_COPILOT_MODEL`, `HEXAI_COPILOT_BASE_URL`, `HEXAI_COPILOT_TEMPERATURE` - `HEXAI_OLLAMA_MODEL`, `HEXAI_OLLAMA_BASE_URL`, `HEXAI_OLLAMA_TEMPERATURE` + - Per-surface overrides: `HEXAI_MODEL_COMPLETION`, `HEXAI_MODEL_CODE_ACTION`, `HEXAI_MODEL_CHAT`, `HEXAI_MODEL_CLI` + - Per-surface temperatures: `HEXAI_TEMPERATURE_COMPLETION`, `HEXAI_TEMPERATURE_CODE_ACTION`, `HEXAI_TEMPERATURE_CHAT`, `HEXAI_TEMPERATURE_CLI` + +Per-surface models + +- Use the `[models]` table in `config.toml` to tailor individual entry points (completion, code actions, chat, CLI) without changing the global provider default. +- Each key accepts either a string (shortcut) or a table with `model` / `temperature` fields, e.g.: + + ```toml + [models] + completion = "gpt-4.1-mini" + + [models.code_action] + model = "gpt-4o" + provider = "copilot" + temperature = 0.4 + + [models.cli] + model = "gpt-4.1" + provider = "openai" + ``` + +- When a per-surface value is omitted, Hexai falls back to the provider’s configured default. Temperatures inherit from `coding_temperature` unless explicitly set, and OpenAI `gpt-5*` models automatically raise an unspecified coding temperature to `1.0` for exploratory behavior. Provider overrides support `"openai"`, `"copilot"`, or `"ollama"` and read the matching credential variables. Runtime reloads diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go index adf9b75..47abaaf 100644 --- a/internal/appconfig/config.go +++ b/internal/appconfig/config.go @@ -58,6 +58,20 @@ type App struct { // Default temperature for Copilot requests (nil means use provider default) CopilotTemperature *float64 `json:"copilot_temperature" toml:"copilot_temperature"` + // Per-surface model overrides (fall back to provider defaults when unset) + CompletionModel string `json:"completion_model" toml:"completion_model"` + CompletionTemperature *float64 `json:"completion_temperature" toml:"completion_temperature"` + CompletionProvider string `json:"completion_provider" toml:"completion_provider"` + CodeActionModel string `json:"code_action_model" toml:"code_action_model"` + CodeActionTemperature *float64 `json:"code_action_temperature" toml:"code_action_temperature"` + CodeActionProvider string `json:"code_action_provider" toml:"code_action_provider"` + ChatModel string `json:"chat_model" toml:"chat_model"` + ChatTemperature *float64 `json:"chat_temperature" toml:"chat_temperature"` + ChatProvider string `json:"chat_provider" toml:"chat_provider"` + CLIModel string `json:"cli_model" toml:"cli_model"` + CLITemperature *float64 `json:"cli_temperature" toml:"cli_temperature"` + CLIProvider string `json:"cli_provider" toml:"cli_provider"` + // Prompt templates (configured only via file; no env overrides) // Completion/chat/code action/CLI prompt strings. See config.toml.example for placeholders. // Completion @@ -589,7 +603,7 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) { "copilot_model": {}, "copilot_base_url": {}, "copilot_temperature": {}, } for k := range raw { - if _, isTable := map[string]struct{}{"general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {}, "chat": {}, "provider": {}, "openai": {}, "copilot": {}, "ollama": {}, "prompts": {}}[k]; isTable { + if _, isTable := map[string]struct{}{"general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {}, "chat": {}, "provider": {}, "models": {}, "openai": {}, "copilot": {}, "ollama": {}, "prompts": {}}[k]; isTable { continue } if _, isLegacy := legacy[k]; isLegacy { @@ -629,12 +643,170 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) { } } } + if m := parseSurfaceModels(raw, logger); m != nil { + tab.mergeSurfaceModels(m) + } return &tab, nil } +func parseSurfaceModels(raw map[string]any, logger *log.Logger) *App { + modelsRaw, ok := raw["models"] + if !ok { + return nil + } + table, ok := modelsRaw.(map[string]any) + if !ok { + if logger != nil { + logger.Printf("config: ignoring models section (expected table, got %T)", modelsRaw) + } + return nil + } + var out App + var any bool + if model, provider, temp, ok := decodeModelEntry(table["completion"], "models.completion", logger); ok { + if model != "" { + out.CompletionModel = model + } + if provider != "" { + out.CompletionProvider = provider + } + if temp != nil { + out.CompletionTemperature = temp + } + any = true + } + if model, provider, temp, ok := decodeModelEntry(table["code_action"], "models.code_action", logger); ok { + if model != "" { + out.CodeActionModel = model + } + if provider != "" { + out.CodeActionProvider = provider + } + if temp != nil { + out.CodeActionTemperature = temp + } + any = true + } + if model, provider, temp, ok := decodeModelEntry(table["chat"], "models.chat", logger); ok { + if model != "" { + out.ChatModel = model + } + if provider != "" { + out.ChatProvider = provider + } + if temp != nil { + out.ChatTemperature = temp + } + any = true + } + if model, provider, temp, ok := decodeModelEntry(table["cli"], "models.cli", logger); ok { + if model != "" { + out.CLIModel = model + } + if provider != "" { + out.CLIProvider = provider + } + if temp != nil { + out.CLITemperature = temp + } + any = true + } + if !any { + return nil + } + return &out +} + +func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, *float64, bool) { + if raw == nil { + return "", "", nil, false + } + switch v := raw.(type) { + case string: + model := strings.TrimSpace(v) + if model == "" { + return "", "", nil, false + } + return model, "", nil, true + case map[string]any: + model := "" + provider := "" + if m, ok := v["model"]; ok { + s, ok := m.(string) + if !ok { + if logger != nil { + logger.Printf("config: %s.model must be a string", path) + } + return "", "", nil, false + } + model = strings.TrimSpace(s) + } + if pRaw, ok := v["provider"]; ok { + ps, ok := pRaw.(string) + if !ok { + if logger != nil { + logger.Printf("config: %s.provider must be a string", path) + } + return "", "", nil, false + } + provider = strings.TrimSpace(ps) + } + var tempPtr *float64 + if tRaw, ok := v["temperature"]; ok { + parsed, ok := parseTemperatureValue(tRaw, path, logger) + if !ok { + return "", "", nil, false + } + tempPtr = parsed + } + if model == "" && tempPtr == nil && provider == "" { + return "", "", nil, false + } + return model, provider, tempPtr, true + default: + if logger != nil { + logger.Printf("config: %s must be a string or table, got %T", path, raw) + } + return "", "", nil, false + } +} + +func parseTemperatureValue(raw any, path string, logger *log.Logger) (*float64, bool) { + switch v := raw.(type) { + case float64: + return floatPtr(v), true + case int64: + return floatPtr(float64(v)), true + case string: + s := strings.TrimSpace(v) + if s == "" { + return nil, true + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + if logger != nil { + logger.Printf("config: %s.temperature invalid: %v", path, err) + } + return nil, false + } + return floatPtr(f), true + default: + if logger != nil { + logger.Printf("config: %s.temperature must be numeric or string, got %T", path, raw) + } + return nil, false + } +} + +func floatPtr(v float64) *float64 { + f := v + return &f +} + func (a *App) mergeWith(other *App) { a.mergeBasics(other) a.mergeProviderFields(other) + a.mergeSurfaceModels(other) a.mergePrompts(other) } @@ -687,6 +859,46 @@ func (a *App) mergeBasics(other *App) { } } +// mergeSurfaceModels copies per-surface model and temperature overrides. +func (a *App) mergeSurfaceModels(other *App) { + if s := strings.TrimSpace(other.CompletionModel); s != "" { + a.CompletionModel = s + } + if other.CompletionTemperature != nil { + a.CompletionTemperature = other.CompletionTemperature + } + if s := strings.TrimSpace(other.CompletionProvider); s != "" { + a.CompletionProvider = s + } + if s := strings.TrimSpace(other.CodeActionModel); s != "" { + a.CodeActionModel = s + } + if other.CodeActionTemperature != nil { + a.CodeActionTemperature = other.CodeActionTemperature + } + if s := strings.TrimSpace(other.CodeActionProvider); s != "" { + a.CodeActionProvider = s + } + if s := strings.TrimSpace(other.ChatModel); s != "" { + a.ChatModel = s + } + if other.ChatTemperature != nil { + a.ChatTemperature = other.ChatTemperature + } + if s := strings.TrimSpace(other.ChatProvider); s != "" { + a.ChatProvider = s + } + if s := strings.TrimSpace(other.CLIModel); s != "" { + a.CLIModel = s + } + if other.CLITemperature != nil { + a.CLITemperature = other.CLITemperature + } + if s := strings.TrimSpace(other.CLIProvider); s != "" { + a.CLIProvider = s + } +} + // mergePrompts copies non-empty prompt templates from other. func (a *App) mergePrompts(other *App) { // Completion @@ -1050,6 +1262,56 @@ func loadFromEnv(logger *log.Logger) *App { any = true } + // Per-surface overrides + if s := getenv("HEXAI_MODEL_COMPLETION"); s != "" { + out.CompletionModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_COMPLETION"); ok { + out.CompletionTemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_COMPLETION"); s != "" { + out.CompletionProvider = s + any = true + } + if s := getenv("HEXAI_MODEL_CODE_ACTION"); s != "" { + out.CodeActionModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CODE_ACTION"); ok { + out.CodeActionTemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_CODE_ACTION"); s != "" { + out.CodeActionProvider = s + any = true + } + if s := getenv("HEXAI_MODEL_CHAT"); s != "" { + out.ChatModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CHAT"); ok { + out.ChatTemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_CHAT"); s != "" { + out.ChatProvider = s + any = true + } + if s := getenv("HEXAI_MODEL_CLI"); s != "" { + out.CLIModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CLI"); ok { + out.CLITemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_CLI"); s != "" { + out.CLIProvider = s + any = true + } + if !any { return nil } diff --git a/internal/appconfig/config_env_model_test.go b/internal/appconfig/config_env_model_test.go index 2db2bb5..f34416d 100644 --- a/internal/appconfig/config_env_model_test.go +++ b/internal/appconfig/config_env_model_test.go @@ -1,37 +1,65 @@ package appconfig import ( - "log" - "os" - "testing" + "log" + "os" + "testing" ) // Test that HEXAI_MODEL applies to provider model fields and that // provider-specific envs take precedence when both are set. func TestEnv_GenericModelOverrideAndPrecedence(t *testing.T) { - t.Setenv("HEXAI_MODEL", "gpt-5-codex") - t.Setenv("HEXAI_PROVIDER", "openai") - // No provider-specific env set yet: HEXAI_MODEL should flow into OpenAIModel - cfg := Load(log.New(os.Stderr, "test ", 0)) - if cfg.OpenAIModel != "gpt-5-codex" { - t.Fatalf("expected OpenAIModel=gpt-5-codex via HEXAI_MODEL, got %q", cfg.OpenAIModel) - } + t.Setenv("HEXAI_MODEL", "gpt-5-codex") + t.Setenv("HEXAI_PROVIDER", "openai") + // No provider-specific env set yet: HEXAI_MODEL should flow into OpenAIModel + cfg := Load(log.New(os.Stderr, "test ", 0)) + if cfg.OpenAIModel != "gpt-5-codex" { + t.Fatalf("expected OpenAIModel=gpt-5-codex via HEXAI_MODEL, got %q", cfg.OpenAIModel) + } - // Now set a provider-specific model; it should win over HEXAI_MODEL - t.Setenv("HEXAI_OPENAI_MODEL", "gpt-5-thinking") - cfg2 := Load(log.New(os.Stderr, "test ", 0)) - if cfg2.OpenAIModel != "gpt-5-thinking" { - t.Fatalf("expected OpenAIModel from HEXAI_OPENAI_MODEL to win, got %q", cfg2.OpenAIModel) - } + // Now set a provider-specific model; it should win over HEXAI_MODEL + t.Setenv("HEXAI_OPENAI_MODEL", "gpt-5-thinking") + cfg2 := Load(log.New(os.Stderr, "test ", 0)) + if cfg2.OpenAIModel != "gpt-5-thinking" { + t.Fatalf("expected OpenAIModel from HEXAI_OPENAI_MODEL to win, got %q", cfg2.OpenAIModel) + } } // Test that HEXAI_MODEL_FORCE overrides provider-specific envs (used by CLI --model). func TestEnv_ModelForce_OverridesProviderSpecific(t *testing.T) { - t.Setenv("HEXAI_OPENAI_MODEL", "gpt-5-main") - t.Setenv("HEXAI_MODEL_FORCE", "gpt-5-codex") - t.Setenv("HEXAI_PROVIDER", "openai") - cfg := Load(log.New(os.Stderr, "test ", 0)) - if cfg.OpenAIModel != "gpt-5-codex" { - t.Fatalf("expected OpenAIModel forced to gpt-5-codex, got %q", cfg.OpenAIModel) - } + t.Setenv("HEXAI_OPENAI_MODEL", "gpt-5-main") + t.Setenv("HEXAI_MODEL_FORCE", "gpt-5-codex") + t.Setenv("HEXAI_PROVIDER", "openai") + cfg := Load(log.New(os.Stderr, "test ", 0)) + if cfg.OpenAIModel != "gpt-5-codex" { + t.Fatalf("expected OpenAIModel forced to gpt-5-codex, got %q", cfg.OpenAIModel) + } +} + +func TestEnv_SurfaceModelOverrides(t *testing.T) { + t.Setenv("HEXAI_MODEL_COMPLETION", "gpt-c") + t.Setenv("HEXAI_TEMPERATURE_COMPLETION", "0.44") + t.Setenv("HEXAI_PROVIDER_COMPLETION", "copilot") + t.Setenv("HEXAI_MODEL_CLI", "gpt-cli") + t.Setenv("HEXAI_TEMPERATURE_CLI", "0.22") + t.Setenv("HEXAI_PROVIDER_CLI", "ollama") + cfg := Load(log.New(os.Stderr, "test ", 0)) + if cfg.CompletionModel != "gpt-c" { + t.Fatalf("expected completion model override, got %q", cfg.CompletionModel) + } + if cfg.CompletionTemperature == nil || *cfg.CompletionTemperature != 0.44 { + t.Fatalf("expected completion temperature override, got %v", cfg.CompletionTemperature) + } + if cfg.CompletionProvider != "copilot" { + t.Fatalf("expected completion provider override, got %q", cfg.CompletionProvider) + } + if cfg.CLIModel != "gpt-cli" { + t.Fatalf("expected cli model override, got %q", cfg.CLIModel) + } + if cfg.CLITemperature == nil || *cfg.CLITemperature != 0.22 { + t.Fatalf("expected cli temperature override, got %v", cfg.CLITemperature) + } + if cfg.CLIProvider != "ollama" { + t.Fatalf("expected cli provider override, got %q", cfg.CLIProvider) + } } diff --git a/internal/appconfig/config_test.go b/internal/appconfig/config_test.go index b03137e..ea68305 100644 --- a/internal/appconfig/config_test.go +++ b/internal/appconfig/config_test.go @@ -88,6 +88,24 @@ completion_throttle_ms = 300 [triggers] trigger_characters = [".", ":"] +[models.completion] +model = "gpt-file-complete" +provider = "openai" + +[models.code_action] +model = "gpt-file-action" +temperature = 0.45 +provider = "copilot" + +[models.chat] +model = "gpt-file-chat" +provider = "openai" + +[models.cli] +model = "gpt-file-cli" +temperature = 0.15 +provider = "ollama" + [provider] name = "openai" @@ -107,6 +125,10 @@ model = "ghost" temperature = 0.0 `) + if _, err := loadFromFile(cfgPath, newLogger()); err != nil { + t.Fatalf("loadFromFile: %v", err) + } + // Env overrides take precedence withEnv(t, "HEXAI_MAX_TOKENS", "321") withEnv(t, "HEXAI_CONTEXT_MODE", "always-full") @@ -128,6 +150,18 @@ temperature = 0.0 withEnv(t, "HEXAI_COPILOT_BASE_URL", "http://copilot-override") withEnv(t, "HEXAI_COPILOT_MODEL", "ghost-override") withEnv(t, "HEXAI_COPILOT_TEMPERATURE", "0.3") + withEnv(t, "HEXAI_MODEL_COMPLETION", "env-completion") + withEnv(t, "HEXAI_TEMPERATURE_COMPLETION", "0.33") + withEnv(t, "HEXAI_PROVIDER_COMPLETION", "copilot") + withEnv(t, "HEXAI_MODEL_CODE_ACTION", "env-action") + withEnv(t, "HEXAI_TEMPERATURE_CODE_ACTION", "0.55") + withEnv(t, "HEXAI_PROVIDER_CODE_ACTION", "openai") + withEnv(t, "HEXAI_MODEL_CHAT", "env-chat") + withEnv(t, "HEXAI_TEMPERATURE_CHAT", "0.66") + withEnv(t, "HEXAI_PROVIDER_CHAT", "copilot") + withEnv(t, "HEXAI_MODEL_CLI", "env-cli") + withEnv(t, "HEXAI_TEMPERATURE_CLI", "0.77") + withEnv(t, "HEXAI_PROVIDER_CLI", "ollama") logger := newLogger() cfg := Load(logger) @@ -158,11 +192,35 @@ temperature = 0.0 if cfg.CopilotBaseURL != "http://copilot-override" || cfg.CopilotModel != "ghost-override" || cfg.CopilotTemperature == nil || *cfg.CopilotTemperature != 0.3 { t.Fatalf("copilot overrides not applied: %+v", cfg) } + if cfg.CompletionModel != "env-completion" || cfg.CompletionTemperature == nil || *cfg.CompletionTemperature != 0.33 { + t.Fatalf("completion overrides not applied: model=%q temp=%v", cfg.CompletionModel, cfg.CompletionTemperature) + } + if cfg.CompletionProvider != "copilot" { + t.Fatalf("completion provider override not applied: %q", cfg.CompletionProvider) + } + if cfg.CodeActionModel != "env-action" || cfg.CodeActionTemperature == nil || *cfg.CodeActionTemperature != 0.55 { + t.Fatalf("code action overrides not applied: model=%q temp=%v", cfg.CodeActionModel, cfg.CodeActionTemperature) + } + if cfg.CodeActionProvider != "openai" { + t.Fatalf("code action provider override not applied: %q", cfg.CodeActionProvider) + } + if cfg.ChatModel != "env-chat" || cfg.ChatTemperature == nil || *cfg.ChatTemperature != 0.66 { + t.Fatalf("chat overrides not applied: model=%q temp=%v", cfg.ChatModel, cfg.ChatTemperature) + } + if cfg.ChatProvider != "copilot" { + t.Fatalf("chat provider override not applied: %q", cfg.ChatProvider) + } + if cfg.CLIModel != "env-cli" || cfg.CLITemperature == nil || *cfg.CLITemperature != 0.77 { + t.Fatalf("cli overrides not applied: model=%q temp=%v", cfg.CLIModel, cfg.CLITemperature) + } + if cfg.CLIProvider != "ollama" { + t.Fatalf("cli provider override not applied: %q", cfg.CLIProvider) + } // Ensure file values would have applied absent env // Spot-check: reset env and reload for _, k := range []string{ - "HEXAI_MAX_TOKENS", "HEXAI_CONTEXT_MODE", "HEXAI_CONTEXT_WINDOW_LINES", "HEXAI_MAX_CONTEXT_TOKENS", "HEXAI_LOG_PREVIEW_LIMIT", "HEXAI_CODING_TEMPERATURE", "HEXAI_MANUAL_INVOKE_MIN_PREFIX", "HEXAI_COMPLETION_DEBOUNCE_MS", "HEXAI_COMPLETION_THROTTLE_MS", "HEXAI_TRIGGER_CHARACTERS", "HEXAI_PROVIDER", "HEXAI_OPENAI_BASE_URL", "HEXAI_OPENAI_MODEL", "HEXAI_OPENAI_TEMPERATURE", "HEXAI_OLLAMA_BASE_URL", "HEXAI_OLLAMA_MODEL", "HEXAI_OLLAMA_TEMPERATURE", "HEXAI_COPILOT_BASE_URL", "HEXAI_COPILOT_MODEL", "HEXAI_COPILOT_TEMPERATURE", + "HEXAI_MAX_TOKENS", "HEXAI_CONTEXT_MODE", "HEXAI_CONTEXT_WINDOW_LINES", "HEXAI_MAX_CONTEXT_TOKENS", "HEXAI_LOG_PREVIEW_LIMIT", "HEXAI_CODING_TEMPERATURE", "HEXAI_MANUAL_INVOKE_MIN_PREFIX", "HEXAI_COMPLETION_DEBOUNCE_MS", "HEXAI_COMPLETION_THROTTLE_MS", "HEXAI_TRIGGER_CHARACTERS", "HEXAI_PROVIDER", "HEXAI_OPENAI_BASE_URL", "HEXAI_OPENAI_MODEL", "HEXAI_OPENAI_TEMPERATURE", "HEXAI_OLLAMA_BASE_URL", "HEXAI_OLLAMA_MODEL", "HEXAI_OLLAMA_TEMPERATURE", "HEXAI_COPILOT_BASE_URL", "HEXAI_COPILOT_MODEL", "HEXAI_COPILOT_TEMPERATURE", "HEXAI_MODEL_COMPLETION", "HEXAI_TEMPERATURE_COMPLETION", "HEXAI_MODEL_CODE_ACTION", "HEXAI_TEMPERATURE_CODE_ACTION", "HEXAI_MODEL_CHAT", "HEXAI_TEMPERATURE_CHAT", "HEXAI_MODEL_CLI", "HEXAI_TEMPERATURE_CLI", "HEXAI_PROVIDER_COMPLETION", "HEXAI_PROVIDER_CODE_ACTION", "HEXAI_PROVIDER_CHAT", "HEXAI_PROVIDER_CLI", } { t.Setenv(k, "") } @@ -176,6 +234,30 @@ temperature = 0.0 if cfg2.OpenAIBaseURL != "https://api.example" || cfg2.OpenAIModel != "gpt-x" || cfg2.OpenAITemperature == nil || *cfg2.OpenAITemperature != 0.0 { t.Fatalf("file merge (openai) not applied: %+v", cfg2) } + if cfg2.CompletionModel != "gpt-file-complete" || cfg2.CompletionTemperature != nil { + t.Fatalf("file merge (completion) not applied: %+v", cfg2) + } + if cfg2.CompletionProvider != "openai" { + t.Fatalf("file merge (completion provider) not applied: %q", cfg2.CompletionProvider) + } + if cfg2.CodeActionModel != "gpt-file-action" || cfg2.CodeActionTemperature == nil || *cfg2.CodeActionTemperature != 0.45 { + t.Fatalf("file merge (code action) not applied: %+v", cfg2) + } + if cfg2.CodeActionProvider != "copilot" { + t.Fatalf("file merge (code action provider) not applied: %q", cfg2.CodeActionProvider) + } + if cfg2.ChatModel != "gpt-file-chat" || cfg2.ChatTemperature != nil { + t.Fatalf("file merge (chat) not applied: %+v", cfg2) + } + if cfg2.ChatProvider != "openai" { + t.Fatalf("file merge (chat provider) not applied: %q", cfg2.ChatProvider) + } + if cfg2.CLIModel != "gpt-file-cli" || cfg2.CLITemperature == nil || *cfg2.CLITemperature != 0.15 { + t.Fatalf("file merge (cli) not applied: %+v", cfg2) + } + if cfg2.CLIProvider != "ollama" { + t.Fatalf("file merge (cli provider) not applied: %q", cfg2.CLIProvider) + } } func TestGetConfigPath_XDG(t *testing.T) { diff --git a/internal/hexaiaction/prompts.go b/internal/hexaiaction/prompts.go index 207302e..47dadbf 100644 --- a/internal/hexaiaction/prompts.go +++ b/internal/hexaiaction/prompts.go @@ -25,6 +25,11 @@ type chatDoer interface { type providerNamer interface{ Name() string } +type requestArgs struct { + model string + options []llm.RequestOption +} + func providerOf(c any) string { if n, ok := c.(providerNamer); ok { return n.Name() @@ -32,6 +37,42 @@ func providerOf(c any) string { return "llm" } +func canonicalProvider(name string) string { + p := strings.ToLower(strings.TrimSpace(name)) + if p == "" { + return "openai" + } + return p +} + +func defaultModelForProvider(cfg appconfig.App, provider string) string { + switch provider { + case "ollama": + return cfg.OllamaModel + case "copilot": + return cfg.CopilotModel + default: + return cfg.OpenAIModel + } +} + +func selectActionTemperature(cfg appconfig.App, provider, model string) (float64, bool) { + if cfg.CodeActionTemperature != nil { + return *cfg.CodeActionTemperature, true + } + if cfg.CodingTemperature != nil { + temp := *cfg.CodingTemperature + if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") && temp == 0.2 { + temp = 1.0 + } + return temp, true + } + if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") { + return 1.0, true + } + return 0, false +} + func runRewrite(ctx context.Context, cfg appconfig.App, client chatDoer, instruction, selection string) (string, error) { sys := cfg.PromptCodeActionRewriteSystem user := Render(cfg.PromptCodeActionRewriteUser, map[string]string{"instruction": instruction, "selection": selection}) @@ -118,9 +159,9 @@ func runOnce(ctx context.Context, client chatDoer, sys, user string) (string, er return out, nil } -func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, opts []llm.RequestOption) (string, error) { +func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, req requestArgs) (string, error) { msgs := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} - txt, err := client.Chat(ctx, msgs, opts...) + txt, err := client.Chat(ctx, msgs, req.options...) if err != nil { return "", err } @@ -131,7 +172,11 @@ func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, opt sent += len(m.Content) } recv := len(out) - _ = stats.Update(ctx, providerOf(client), client.DefaultModel(), sent, recv) + model := strings.TrimSpace(req.model) + if model == "" { + model = client.DefaultModel() + } + _ = stats.Update(ctx, providerOf(client), model, sent, recv) if snap, err := stats.TakeSnapshot(); err == nil { minsWin := snap.Window.Minutes() if minsWin <= 0 { @@ -139,30 +184,39 @@ func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, opt } scopeReqs := int64(0) if pe, ok := snap.Providers[providerOf(client)]; ok { - if mc, ok2 := pe.Models[client.DefaultModel()]; ok2 { + if mc, ok2 := pe.Models[model]; ok2 { scopeReqs = mc.Reqs } } scopeRPM := float64(scopeReqs) / minsWin - _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, providerOf(client), client.DefaultModel(), scopeRPM, scopeReqs, snap.Window)) + _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, providerOf(client), model, scopeRPM, scopeReqs, snap.Window)) } return out, nil } // reqOptsFrom builds LLM request options similar to LSP behavior. -func reqOptsFrom(cfg appconfig.App) []llm.RequestOption { - opts := []llm.RequestOption{llm.WithMaxTokens(cfg.MaxTokens)} - // Apply temperature, with special-case for gpt-5 (default temp must be 1.0) - if cfg.CodingTemperature != nil { - temp := *cfg.CodingTemperature - prov := strings.ToLower(strings.TrimSpace(cfg.Provider)) - model := strings.ToLower(strings.TrimSpace(cfg.OpenAIModel)) - if prov == "openai" && strings.HasPrefix(model, "gpt-5") { - temp = 1.0 - } +func reqOptsFrom(cfg appconfig.App) requestArgs { + opts := make([]llm.RequestOption, 0, 3) + if cfg.MaxTokens > 0 { + opts = append(opts, llm.WithMaxTokens(cfg.MaxTokens)) + } + provider := canonicalProvider(cfg.Provider) + if strings.TrimSpace(cfg.CodeActionProvider) != "" { + provider = canonicalProvider(cfg.CodeActionProvider) + } + override := strings.TrimSpace(cfg.CodeActionModel) + fallback := strings.TrimSpace(defaultModelForProvider(cfg, provider)) + effective := override + if effective == "" { + effective = fallback + } + if override != "" { + opts = append(opts, llm.WithModel(override)) + } + if temp, ok := selectActionTemperature(cfg, provider, effective); ok { opts = append(opts, llm.WithTemperature(temp)) } - return opts + return requestArgs{model: effective, options: opts} } // Timeout helpers to mirror LSP behavior. diff --git a/internal/hexaiaction/prompts_more_test.go b/internal/hexaiaction/prompts_more_test.go index 9f5d6cb..97d3979 100644 --- a/internal/hexaiaction/prompts_more_test.go +++ b/internal/hexaiaction/prompts_more_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/llm" ) @@ -15,6 +16,11 @@ func (d simpleDoer) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOpt } func (d simpleDoer) DefaultModel() string { return "m" } +func ptrFloat(v float64) *float64 { + x := v + return &x +} + func TestRunOnce_StripsFences(t *testing.T) { got, err := runOnce(context.Background(), simpleDoer{"```\nok\n```"}, "SYS", "USER") if err != nil { @@ -24,3 +30,30 @@ func TestRunOnce_StripsFences(t *testing.T) { t.Fatalf("got %q", got) } } + +func TestReqOptsFrom_Override(t *testing.T) { + cfg := appconfig.App{MaxTokens: 123, CodeActionModel: "override", CodeActionTemperature: ptrFloat(0.6), Provider: "openai", CodeActionProvider: "copilot", CopilotModel: "gpt-4o"} + req := reqOptsFrom(cfg) + if req.model != "override" { + t.Fatalf("expected override model, got %q", req.model) + } + var opts llm.Options + for _, o := range req.options { + o(&opts) + } + if opts.MaxTokens != 123 || opts.Model != "override" || opts.Temperature != 0.6 { + t.Fatalf("unexpected options: %+v", opts) + } +} + +func TestReqOptsFrom_Gpt5Temp(t *testing.T) { + cfg := appconfig.App{Provider: "openai", CodingTemperature: ptrFloat(0.2), OpenAIModel: "gpt-5.0"} + req := reqOptsFrom(cfg) + var opts llm.Options + for _, o := range req.options { + o(&opts) + } + if opts.Temperature != 1.0 { + t.Fatalf("expected gpt-5 temp adjustment to 1.0, got %v", opts.Temperature) + } +} diff --git a/internal/hexaiaction/run.go b/internal/hexaiaction/run.go index a48bf94..953da80 100644 --- a/internal/hexaiaction/run.go +++ b/internal/hexaiaction/run.go @@ -41,12 +41,19 @@ func Run(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer) error { if len(cfg.CustomActions) > 0 { chooseActionFn = func() (ActionKind, error) { return RunTUIWithCustom(cfg.CustomActions, cfg.TmuxCustomMenuHotkey) } } + if providerOverride := strings.TrimSpace(cfg.CodeActionProvider); providerOverride != "" { + cfg.Provider = providerOverride + } cli, err := newClientFromApp(cfg) if err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai-tmux-action: LLM disabled: %v"+logging.AnsiReset+"\n", err) return err } - _ = tmux.SetStatus(tmux.FormatLLMStartStatus(cli.Name(), cli.DefaultModel())) + primaryModel := strings.TrimSpace(reqOptsFrom(cfg).model) + if primaryModel == "" { + primaryModel = cli.DefaultModel() + } + _ = tmux.SetStatus(tmux.FormatLLMStartStatus(cli.Name(), primaryModel)) var client chatDoer = cli parts, err := ParseInput(stdin) if err != nil { diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 11e8938..b965261 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -20,6 +20,84 @@ import ( "codeberg.org/snonux/hexai/internal/tmux" ) +type requestArgs struct { + model string + options []llm.RequestOption +} + +func buildCLIRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { + provider := canonicalProvider(cfg.Provider) + if strings.TrimSpace(cfg.CLIProvider) != "" { + provider = canonicalProvider(cfg.CLIProvider) + } + if client != nil { + provider = strings.ToLower(strings.TrimSpace(client.Name())) + } + override := strings.TrimSpace(cfg.CLIModel) + fallback := strings.TrimSpace(defaultModelForProvider(cfg, provider)) + if client != nil { + if dm := strings.TrimSpace(client.DefaultModel()); dm != "" { + fallback = dm + } + } + effective := override + if effective == "" { + effective = fallback + } + opts := make([]llm.RequestOption, 0, 2) + if override != "" { + opts = append(opts, llm.WithModel(override)) + } + if temp, ok := cliTemperature(cfg, provider, effective); ok { + opts = append(opts, llm.WithTemperature(temp)) + } + return requestArgs{model: effective, options: opts} +} + +func defaultRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { + model := strings.TrimSpace(cfg.CLIModel) + if model == "" && client != nil { + model = strings.TrimSpace(client.DefaultModel()) + } + return requestArgs{model: model} +} + +func cliTemperature(cfg appconfig.App, provider, model string) (float64, bool) { + if cfg.CLITemperature != nil { + return *cfg.CLITemperature, true + } + if cfg.CodingTemperature != nil { + temp := *cfg.CodingTemperature + if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") && temp == 0.2 { + temp = 1.0 + } + return temp, true + } + if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") { + return 1.0, true + } + return 0, false +} + +func canonicalProvider(name string) string { + p := strings.ToLower(strings.TrimSpace(name)) + if p == "" { + return "openai" + } + return p +} + +func defaultModelForProvider(cfg appconfig.App, provider string) string { + switch provider { + case "ollama": + return cfg.OllamaModel + case "copilot": + return cfg.CopilotModel + default: + return cfg.OpenAIModel + } +} + // Run executes the Hexai CLI behavior given arguments and I/O streams. // It assumes flags have already been parsed by the caller. func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { @@ -29,11 +107,16 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. if cfg.StatsWindowMinutes > 0 { stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) } + providerOverride := strings.TrimSpace(cfg.CLIProvider) + if providerOverride != "" { + cfg.Provider = providerOverride + } client, err := newClientFromApp(cfg) if err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) return err } + req := buildCLIRequestArgs(cfg, client) // Prefer piped stdin when present; only open the editor when there are no args // and no stdin content available. input, rerr := readInput(stdin, args) @@ -47,9 +130,9 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset) return rerr } - printProviderInfo(stderr, client) + printProviderInfo(stderr, client, req.model) msgs := buildMessagesFromConfig(cfg, input) - if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil { + if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) return err } @@ -64,9 +147,10 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) return err } - printProviderInfo(stderr, client) + req := defaultRequestArgs(appconfig.App{}, client) + printProviderInfo(stderr, client, req.model) msgs := buildMessages(input) - if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil { + if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) return err } @@ -128,22 +212,26 @@ func buildMessagesFromConfig(cfg appconfig.App, input string) []llm.Message { } // runChat executes the chat request, handling streaming and summary output. -func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { +func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { start := time.Now() // Best-effort tmux status update (colored start heartbeat) - _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), client.DefaultModel())) + model := strings.TrimSpace(req.model) + if model == "" { + model = client.DefaultModel() + } + _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model)) var output string if s, ok := client.(llm.Streamer); ok { var b strings.Builder if err := s.ChatStream(ctx, msgs, func(chunk string) { b.WriteString(chunk) fmt.Fprint(out, chunk) - }); err != nil { + }, req.options...); err != nil { return err } output = b.String() } else { - txt, err := client.Chat(ctx, msgs) + txt, err := client.Chat(ctx, msgs, req.options...) if err != nil { return err } @@ -157,7 +245,7 @@ func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input s sent += len(m.Content) } recv := len(output) - _ = stats.Update(ctx, client.Name(), client.DefaultModel(), sent, recv) + _ = stats.Update(ctx, client.Name(), model, sent, recv) snap, _ := stats.TakeSnapshot() minsWin := snap.Window.Minutes() if minsWin <= 0 { @@ -165,20 +253,23 @@ func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input s } scopeReqs := int64(0) if pe, ok := snap.Providers[client.Name()]; ok { - if mc, ok2 := pe.Models[client.DefaultModel()]; ok2 { + if mc, ok2 := pe.Models[model]; ok2 { scopeReqs = mc.Reqs } } scopeRPM := float64(scopeReqs) / minsWin fmt.Fprintf(errw, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d | global Σ reqs=%d rpm=%.2f"+logging.AnsiReset+"\n", - client.Name(), client.DefaultModel(), dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM) - _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), client.DefaultModel(), scopeRPM, scopeReqs, snap.Window)) + client.Name(), model, dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM) + _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), model, scopeRPM, scopeReqs, snap.Window)) return nil } // printProviderInfo writes the provider/model line to stderr. -func printProviderInfo(errw io.Writer, client llm.Client) { - fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), client.DefaultModel()) +func printProviderInfo(errw io.Writer, client llm.Client, model string) { + if strings.TrimSpace(model) == "" { + model = client.DefaultModel() + } + fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), model) } // newClientFromConfig is kept for tests; delegates to llmutils. diff --git a/internal/hexaicli/run_more_test.go b/internal/hexaicli/run_more_test.go index bd88d56..469f0c0 100644 --- a/internal/hexaicli/run_more_test.go +++ b/internal/hexaicli/run_more_test.go @@ -26,7 +26,8 @@ func TestRunChat_Streaming(t *testing.T) { var out, errw bytes.Buffer input := "hello" msgs := []llm.Message{{Role: "user", Content: input}} - if err := runChat(context.Background(), streamClient{}, msgs, input, &out, &errw); err != nil { + req := requestArgs{model: "m"} + if err := runChat(context.Background(), streamClient{}, req, msgs, input, &out, &errw); err != nil { t.Fatalf("runChat failed: %v", err) } if out.String() != "AB" { diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go index a4184f6..4dcbbc5 100644 --- a/internal/hexaicli/run_test.go +++ b/internal/hexaicli/run_test.go @@ -16,6 +16,11 @@ type failingReader struct{ err error } func (f failingReader) Read([]byte) (int, error) { return 0, f.err } +func floatPtr(v float64) *float64 { + x := v + return &x +} + func TestReadInput_Combinations(t *testing.T) { // stdin + arg restore, f := setStdin(t, "from-stdin") @@ -72,7 +77,8 @@ func TestRunChat_StreamAndNonStream(t *testing.T) { // stream path fc := &fakeStreamer{fakeClient: fakeClient{name: "p", model: "m"}, chunks: []string{"H", "i", "!"}} var out, errb bytes.Buffer - if err := runChat(context.Background(), fc, buildMessages("hello"), "hello", &out, &errb); err != nil { + req := requestArgs{model: fc.DefaultModel()} + if err := runChat(context.Background(), fc, req, buildMessages("hello"), "hello", &out, &errb); err != nil { t.Fatalf("stream: %v", err) } if out.String() != "Hi!" || !strings.Contains(errb.String(), "provider=p model=m") { @@ -82,7 +88,7 @@ func TestRunChat_StreamAndNonStream(t *testing.T) { fc2 := &fakeClient{name: "p2", model: "m2", resp: "Yo"} out.Reset() errb.Reset() - if err := runChat(context.Background(), fc2, buildMessages("hello"), "hello", &out, &errb); err != nil { + if err := runChat(context.Background(), fc2, requestArgs{model: fc2.DefaultModel()}, buildMessages("hello"), "hello", &out, &errb); err != nil { t.Fatalf("non-stream: %v", err) } if out.String() != "Yo" || !strings.Contains(errb.String(), "provider=p2 model=m2") { @@ -101,7 +107,7 @@ func (c clientErr) DefaultModel() string { return c.model } func TestRunChat_ErrorPaths(t *testing.T) { ctx := context.Background() out, errb := &bytes.Buffer{}, &bytes.Buffer{} - if err := runChat(ctx, clientErr{"p", "m"}, buildMessages("hi"), "hi", out, errb); err == nil { + if err := runChat(ctx, clientErr{"p", "m"}, requestArgs{model: "m"}, buildMessages("hi"), "hi", out, errb); err == nil { t.Fatalf("expected error from Chat") } } @@ -139,12 +145,42 @@ func TestRun_OpenAI_NoKey_ShowsError(t *testing.T) { func TestPrintProviderInfo(t *testing.T) { var b bytes.Buffer - printProviderInfo(&b, &fakeClient{name: "x", model: "y"}) + printProviderInfo(&b, &fakeClient{name: "x", model: "y"}, "y") if !strings.Contains(b.String(), "provider=x model=y") { t.Fatalf("missing provider line: %q", b.String()) } } +func TestBuildCLIRequestArgs_Override(t *testing.T) { + cfg := appconfig.App{CLIModel: "override", CLITemperature: floatPtr(0.7), Provider: "openai", CLIProvider: "copilot", CopilotModel: "gpt-4o"} + req := buildCLIRequestArgs(cfg, &fakeClient{name: "copilot", model: "default"}) + if req.model != "override" { + t.Fatalf("expected model override, got %q", req.model) + } + var opts llm.Options + for _, o := range req.options { + o(&opts) + } + if opts.Model != "override" || opts.Temperature != 0.7 { + t.Fatalf("unexpected options: %+v", opts) + } +} + +func TestBuildCLIRequestArgs_Gpt5Temp(t *testing.T) { + cfg := appconfig.App{Provider: "openai", CodingTemperature: floatPtr(0.2)} + req := buildCLIRequestArgs(cfg, &fakeClient{name: "openai", model: "gpt-5.1"}) + if req.model != "gpt-5.1" { + t.Fatalf("expected fallback model, got %q", req.model) + } + var opts llm.Options + for _, o := range req.options { + o(&opts) + } + if opts.Temperature != 1.0 { + t.Fatalf("expected temp 1.0, got %v", opts.Temperature) + } +} + func TestNewClientFromConfig_Ollama(t *testing.T) { cfg := appconfig.App{Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m"} c, err := newClientFromConfig(cfg) diff --git a/internal/lsp/document_test.go b/internal/lsp/document_test.go index ed2ccea..fd13e5d 100644 --- a/internal/lsp/document_test.go +++ b/internal/lsp/document_test.go @@ -8,6 +8,7 @@ import ( "testing" "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/llm" ) func newTestServer() *Server { @@ -35,9 +36,11 @@ func newTestServer() *Server { PromptCodeActionGoTestUser: "Function under test:\n{{function}}", } return &Server{ - logger: log.New(io.Discard, "", 0), - docs: make(map[string]*document), - cfg: cfg, + logger: log.New(io.Discard, "", 0), + docs: make(map[string]*document), + cfg: cfg, + altClients: make(map[string]llm.Client), + llmProvider: canonicalProvider(cfg.Provider), } } diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go index 7631935..24429a1 100644 --- a/internal/lsp/handlers_codeaction.go +++ b/internal/lsp/handlers_codeaction.go @@ -245,8 +245,8 @@ func (s *Server) completeCodeAction(ca CodeAction, uri string, rng Range, sys, u ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} - opts := s.llmRequestOpts() - if text, err := s.chatWithStats(ctx, messages, opts...); err == nil { + spec := s.buildRequestSpec(surfaceCodeAction) + if text, err := s.chatWithStats(ctx, surfaceCodeAction, spec, messages); err == nil { if out := stripCodeFences(strings.TrimSpace(text)); out != "" { edit := WorkspaceEdit{Changes: map[string][]TextEdit{uri: {{Range: rng, NewText: out}}}} ca.Edit = &edit @@ -555,22 +555,20 @@ func findGoFunctionAtLine(lines []string, idx int) (int, int) { // generateGoTestFunction uses LLM to produce a test function; falls back to a stub when unavailable. func (s *Server) generateGoTestFunction(funcCode string) string { - if client := s.currentLLMClient(); client != nil { - cfg := s.currentConfig() - sys := cfg.PromptCodeActionGoTestSystem - user := renderTemplate(cfg.PromptCodeActionGoTestUser, map[string]string{"function": funcCode}) - ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second) - defer cancel() - messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} - opts := s.llmRequestOpts() - if out, err := s.chatWithStats(ctx, messages, opts...); err == nil { - cleaned := strings.TrimSpace(stripCodeFences(out)) - if cleaned != "" { - return cleaned - } - } else { - logging.Logf("lsp ", "codeAction go_test llm error: %v", err) + spec := s.buildRequestSpec(surfaceCodeAction) + cfg := s.currentConfig() + sys := cfg.PromptCodeActionGoTestSystem + user := renderTemplate(cfg.PromptCodeActionGoTestUser, map[string]string{"function": funcCode}) + ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second) + defer cancel() + messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} + if out, err := s.chatWithStats(ctx, surfaceCodeAction, spec, messages); err == nil { + cleaned := strings.TrimSpace(stripCodeFences(out)) + if cleaned != "" { + return cleaned } + } else { + logging.Logf("lsp ", "codeAction go_test llm error: %v", err) } // Fallback stub name := deriveGoFuncName(funcCode) diff --git a/internal/lsp/handlers_completion.go b/internal/lsp/handlers_completion.go index f7f41ef..d115741 100644 --- a/internal/lsp/handlers_completion.go +++ b/internal/lsp/handlers_completion.go @@ -95,11 +95,16 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun return items, true } - if items, ok := s.tryProviderNativeCompletion(current, p, above, below, funcCtx, docStr, hasExtra, extraText, plan.inParams); ok { + spec := s.buildRequestSpec(surfaceCompletion) + client := s.clientFor(spec) + if client == nil { + return nil, false + } + if items, ok := s.tryProviderNativeCompletion(spec, client, current, p, above, below, funcCtx, docStr, hasExtra, extraText, plan.inParams); ok { return items, true } - return s.executeChatCompletion(ctx, plan) + return s.executeChatCompletion(ctx, plan, spec, client) } func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) (completionPlan, []CompletionItem, bool) { @@ -142,31 +147,31 @@ func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below return plan, nil, false } -func (s *Server) executeChatCompletion(ctx context.Context, plan completionPlan) ([]CompletionItem, bool) { +func (s *Server) executeChatCompletion(ctx context.Context, plan completionPlan, spec requestSpec, client llm.Client) ([]CompletionItem, bool) { messages := s.buildCompletionMessages(plan.inlinePrompt, plan.hasExtra, plan.extraText, plan.inParams, plan.params, plan.above, plan.current, plan.below, plan.funcCtx) sentSize := 0 for _, m := range messages { sentSize += len(m.Content) } s.incSentCounters(sentSize) - opts := s.llmRequestOpts() + opts := spec.options s.waitForDebounce(ctx) if !s.waitForThrottle(ctx) { return nil, false } - client := s.currentLLMClient() - if client == nil { - return nil, false + modelUsed := spec.effectiveModel() + if strings.TrimSpace(modelUsed) == "" { + modelUsed = client.DefaultModel() } - logging.Logf("lsp ", "completion llm=requesting model=%s", client.DefaultModel()) + logging.Logf("lsp ", "completion llm=requesting model=%s", modelUsed) text, err := client.Chat(ctx, messages, opts...) if err != nil { logging.Logf("lsp ", "llm completion error: %v", err) - s.logLLMStats() + s.logLLMStats(modelUsed) return nil, false } s.incRecvCounters(len(text)) - s.logLLMStats() + s.logLLMStats(modelUsed) trimmed := strings.TrimSpace(text) cleaned := s.postProcessCompletion(trimmed, plan.current[:plan.params.Position.Character], plan.current) if cleaned == "" { @@ -255,8 +260,7 @@ func (s *Server) prefixHeuristicAllows(inlinePrompt bool, current string, p Comp } // tryProviderNativeCompletion attempts provider-native completion and returns items when successful. -func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, above, below, funcCtx, docStr string, hasExtra bool, extraText string, inParams bool) ([]CompletionItem, bool) { - client := s.currentLLMClient() +func (s *Server) tryProviderNativeCompletion(spec requestSpec, client llm.Client, current string, p CompletionParams, above, below, funcCtx, docStr string, hasExtra bool, extraText string, inParams bool) ([]CompletionItem, bool) { cc, ok := client.(llm.CodeCompleter) if !ok { return nil, false @@ -271,15 +275,11 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, "before": before, }) lang := "" - temp := 0.0 - if cfg.CodingTemperature != nil { - temp = *cfg.CodingTemperature - } - prov := "" - if client != nil { - prov = client.Name() + provider := spec.provider + if provider == "" { + provider = canonicalProvider(cfg.Provider) } - logging.Logf("lsp ", "completion path=codex provider=%s uri=%s", prov, path) + logging.Logf("lsp ", "completion path=codex provider=%s uri=%s", provider, path) ctx2, cancel2 := context.WithTimeout(context.Background(), 15*time.Second) defer cancel2() @@ -290,16 +290,24 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, } // Count approximate payload sizes: prompt+after sent; first suggestion received sentBytes := len(prompt) + len(after) - suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, temp) + modelUsed := spec.effectiveModel() + if strings.TrimSpace(modelUsed) == "" { + modelUsed = client.DefaultModel() + } + tempVal := 0.0 + if val, ok := chooseSurfaceTemperature(surfaceCompletion, cfg, provider, spec.modelOverride, spec.fallbackModel); ok { + tempVal = val + } + suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, tempVal) if err == nil && len(suggestions) > 0 { // Update counters and heartbeat s.incSentCounters(sentBytes) s.incRecvCounters(len(suggestions[0])) // Contribute to global stats (provider-native path) if client != nil { - _ = stats.Update(ctx2, client.Name(), client.DefaultModel(), sentBytes, len(suggestions[0])) + _ = stats.Update(ctx2, client.Name(), modelUsed, sentBytes, len(suggestions[0])) } - s.logLLMStats() + s.logLLMStats(modelUsed) cleaned := strings.TrimSpace(suggestions[0]) if cleaned != "" { cleaned = stripDuplicateAssignmentPrefix(current[:p.Position.Character], cleaned) @@ -322,7 +330,7 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, logging.Logf("lsp ", "completion path=codex error=%v (falling back to chat)", err) // Still emit a heartbeat for visibility, even on error s.incSentCounters(sentBytes) - s.logLLMStats() + s.logLLMStats(modelUsed) } return nil, false } diff --git a/internal/lsp/handlers_document.go b/internal/lsp/handlers_document.go index 0340866..f8ed9ed 100644 --- a/internal/lsp/handlers_document.go +++ b/internal/lsp/handlers_document.go @@ -161,22 +161,23 @@ func (s *Server) detectAndHandleChat(uri string) { } return } - if s.currentLLMClient() == nil { - continue - } go func(prompt string, remove int) { ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) defer cancel() // Build messages with history and context_mode aware extras. pos := Position{Line: lineIdx, Character: lastIdx + 1} msgs := s.buildChatMessages(uri, pos, prompt) - opts := s.llmRequestOpts() - client := s.currentLLMClient() + spec := s.buildRequestSpec(surfaceChat) + client := s.clientFor(spec) if client == nil { return } - logging.Logf("lsp ", "chat llm=requesting model=%s", client.DefaultModel()) - text, err := s.chatWithStats(ctx, msgs, opts...) + modelUsed := spec.effectiveModel() + if strings.TrimSpace(modelUsed) == "" { + modelUsed = client.DefaultModel() + } + logging.Logf("lsp ", "chat llm=requesting model=%s", modelUsed) + text, err := s.chatWithStats(ctx, surfaceChat, spec, msgs) if err != nil { logging.Logf("lsp ", "chat llm error: %v", err) return diff --git a/internal/lsp/handlers_utils.go b/internal/lsp/handlers_utils.go index 5d5ca27..3bd13ee 100644 --- a/internal/lsp/handlers_utils.go +++ b/internal/lsp/handlers_utils.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/llm" "codeberg.org/snonux/hexai/internal/logging" "codeberg.org/snonux/hexai/internal/stats" @@ -14,24 +15,134 @@ import ( tmx "codeberg.org/snonux/hexai/internal/tmux" ) -// llmRequestOpts builds request options from server settings. -func (s *Server) llmRequestOpts() []llm.RequestOption { +type surfaceKind string + +const ( + surfaceCompletion surfaceKind = "completion" + surfaceCodeAction surfaceKind = "code_action" + surfaceChat surfaceKind = "chat" +) + +type requestSpec struct { + provider string + modelOverride string + fallbackModel string + options []llm.RequestOption +} + +func (r requestSpec) effectiveModel() string { + if s := strings.TrimSpace(r.modelOverride); s != "" { + return s + } + return strings.TrimSpace(r.fallbackModel) +} + +func (s *Server) buildRequestSpec(surface surfaceKind) requestSpec { + cfg := s.currentConfig() + providerOverride := strings.TrimSpace(surfaceProviderFromConfig(cfg, surface)) + provider := canonicalProvider(cfg.Provider) + if providerOverride != "" { + provider = canonicalProvider(providerOverride) + } + fallbackModel := strings.TrimSpace(resolveDefaultModel(cfg, provider)) + modelOverride := strings.TrimSpace(surfaceModelFromConfig(cfg, surface)) maxTokens := s.maxTokens() - client := s.currentLLMClient() - tempPtr := s.codingTemperature() opts := []llm.RequestOption{llm.WithMaxTokens(maxTokens)} - if tempPtr != nil { - temp := *tempPtr - if client != nil { - prov := strings.ToLower(strings.TrimSpace(client.Name())) - model := strings.ToLower(strings.TrimSpace(client.DefaultModel())) - if prov == "openai" && strings.HasPrefix(model, "gpt-5") { - temp = 1.0 - } + if tempVal, ok := chooseSurfaceTemperature(surface, cfg, provider, modelOverride, fallbackModel); ok { + opts = append(opts, llm.WithTemperature(tempVal)) + } + if modelOverride != "" { + opts = append(opts, llm.WithModel(modelOverride)) + } + return requestSpec{ + provider: provider, + modelOverride: modelOverride, + fallbackModel: fallbackModel, + options: opts, + } +} + +func canonicalProvider(name string) string { + p := strings.ToLower(strings.TrimSpace(name)) + if p == "" { + return "openai" + } + return p +} + +func resolveDefaultModel(cfg appconfig.App, provider string) string { + switch provider { + case "ollama": + return strings.TrimSpace(cfg.OllamaModel) + case "copilot": + return strings.TrimSpace(cfg.CopilotModel) + default: + return strings.TrimSpace(cfg.OpenAIModel) + } +} + +func surfaceModelFromConfig(cfg appconfig.App, surface surfaceKind) string { + switch surface { + case surfaceCompletion: + return cfg.CompletionModel + case surfaceCodeAction: + return cfg.CodeActionModel + case surfaceChat: + return cfg.ChatModel + default: + return "" + } +} + +func surfaceProviderFromConfig(cfg appconfig.App, surface surfaceKind) string { + switch surface { + case surfaceCompletion: + return cfg.CompletionProvider + case surfaceCodeAction: + return cfg.CodeActionProvider + case surfaceChat: + return cfg.ChatProvider + default: + return "" + } +} + +func surfaceTemperatureFromConfig(cfg appconfig.App, surface surfaceKind) *float64 { + switch surface { + case surfaceCompletion: + return cfg.CompletionTemperature + case surfaceCodeAction: + return cfg.CodeActionTemperature + case surfaceChat: + return cfg.ChatTemperature + default: + return nil + } +} + +func chooseSurfaceTemperature(surface surfaceKind, cfg appconfig.App, provider string, overrideModel, fallbackModel string) (float64, bool) { + if t := surfaceTemperatureFromConfig(cfg, surface); t != nil { + return *t, true + } + if cfg.CodingTemperature != nil { + temp := *cfg.CodingTemperature + effectiveModel := strings.TrimSpace(overrideModel) + if effectiveModel == "" { + effectiveModel = strings.TrimSpace(fallbackModel) + } + if provider == "openai" && strings.HasPrefix(strings.ToLower(effectiveModel), "gpt-5") && temp == 0.2 { + temp = 1.0 } - opts = append(opts, llm.WithTemperature(temp)) + return temp, true } - return opts + effectiveModel := strings.TrimSpace(overrideModel) + if effectiveModel == "" { + effectiveModel = strings.TrimSpace(fallbackModel) + } + if provider == "openai" && strings.HasPrefix(strings.ToLower(effectiveModel), "gpt-5") { + return 1.0, true + } + return 0, false } // small helpers for LLM traffic stats @@ -49,7 +160,7 @@ func (s *Server) incRecvCounters(n int) { s.mu.Unlock() } -func (s *Server) logLLMStats() { +func (s *Server) logLLMStats(model string) { s.mu.RLock() avgSent := int64(0) if s.llmReqTotal > 0 { @@ -75,11 +186,14 @@ func (s *Server) logLLMStats() { if err == nil { if client := s.currentLLMClient(); client != nil { provider := client.Name() - model := client.DefaultModel() + modelName := strings.TrimSpace(model) + if modelName == "" { + modelName = client.DefaultModel() + } // Per-scope rpm estimated from window scopeReqs := int64(0) if pe, ok := snap.Providers[provider]; ok { - if mc, ok2 := pe.Models[model]; ok2 { + if mc, ok2 := pe.Models[modelName]; ok2 { scopeReqs = mc.Reqs } } @@ -88,7 +202,7 @@ func (s *Server) logLLMStats() { minsWin = 0.001 } scopeRPM := float64(scopeReqs) / minsWin - status := tmx.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, provider, model, scopeRPM, scopeReqs, snap.Window) + status := tmx.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, provider, modelName, scopeRPM, scopeReqs, snap.Window) _ = tmx.SetStatus(status) } } @@ -154,7 +268,7 @@ func isIdentChar(ch byte) bool { } // chatWithStats wraps llmClient.Chat to increment counters and emit a tmux heartbeat. -func (s *Server) chatWithStats(ctx context.Context, msgs []llm.Message, opts ...llm.RequestOption) (string, error) { +func (s *Server) chatWithStats(ctx context.Context, surface surfaceKind, spec requestSpec, msgs []llm.Message) (string, error) { // Count bytes sent sent := 0 for _, m := range msgs { @@ -167,19 +281,23 @@ func (s *Server) chatWithStats(ctx context.Context, msgs []llm.Message, opts ... return "", context.Canceled } // Perform request - client := s.currentLLMClient() + client := s.clientFor(spec) if client == nil { return "", fmt.Errorf("llm client unavailable") } - txt, err := client.Chat(ctx, msgs, opts...) + txt, err := client.Chat(ctx, msgs, spec.options...) if err != nil { - s.logLLMStats() + s.logLLMStats(spec.effectiveModel()) return "", err } s.incRecvCounters(len(txt)) // Update global stats cache - _ = stats.Update(ctx, client.Name(), client.DefaultModel(), sent, len(txt)) - s.logLLMStats() + model := spec.effectiveModel() + if model == "" { + model = client.DefaultModel() + } + _ = stats.Update(ctx, client.Name(), model, sent, len(txt)) + s.logLLMStats(model) return txt, nil } diff --git a/internal/lsp/llm_request_opts_test.go b/internal/lsp/llm_request_opts_test.go index c6699b0..263db79 100644 --- a/internal/lsp/llm_request_opts_test.go +++ b/internal/lsp/llm_request_opts_test.go @@ -15,17 +15,22 @@ func (f fakeClient) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOpt func (f fakeClient) Name() string { return f.name } func (f fakeClient) DefaultModel() string { return f.model } -func TestLlmRequestOpts_Gpt5_ForcesTemp1(t *testing.T) { +func TestRequestSpec_Gpt5_ForcesTemp1(t *testing.T) { s := newTestServer() one := 0.2 s.cfg.CodingTemperature = &one s.llmClient = fakeClient{name: "openai", model: "gpt-5.0"} - opts := s.llmRequestOpts() + s.cfg.OpenAIModel = "gpt-5.0" + + spec := s.buildRequestSpec(surfaceCompletion) var got llm.Options - for _, o := range opts { + for _, o := range spec.options { o(&got) } if got.Temperature != 1.0 { t.Fatalf("expected temp 1.0 for gpt-5, got %v", got.Temperature) } + if model := spec.effectiveModel(); model != "gpt-5.0" { + t.Fatalf("expected fallback model gpt-5.0, got %q", model) + } } diff --git a/internal/lsp/llm_stats_test.go b/internal/lsp/llm_stats_test.go index 43582a2..7813c10 100644 --- a/internal/lsp/llm_stats_test.go +++ b/internal/lsp/llm_stats_test.go @@ -6,5 +6,5 @@ func TestLogLLMStats_CoversCounters(t *testing.T) { s := newTestServer() s.incSentCounters(10) s.incRecvCounters(20) - s.logLLMStats() // just ensure it does not panic and executes + s.logLLMStats("model") // just ensure it does not panic and executes } diff --git a/internal/lsp/provider_native_success_test.go b/internal/lsp/provider_native_success_test.go index 6df5698..aab886c 100644 --- a/internal/lsp/provider_native_success_test.go +++ b/internal/lsp/provider_native_success_test.go @@ -21,10 +21,11 @@ func (fakeCompleterOk) CodeCompletion(context.Context, string, string, int, stri func TestProviderNativeCompletion_Success(t *testing.T) { s := newTestServer() s.llmClient = fakeCompleterOk{} + spec := s.buildRequestSpec(surfaceCompletion) // current line with dot trigger; position after dot current := "fmt." p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}} - items, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false) + items, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false) if !ok || len(items) == 0 { t.Fatalf("expected provider-native items") } @@ -47,9 +48,10 @@ func (fakeCompleterIndent) CodeCompletion(context.Context, string, string, int, func TestProviderNativeCompletion_IndentWithDoubleOpen(t *testing.T) { s := newTestServer() s.llmClient = fakeCompleterIndent{} + spec := s.buildRequestSpec(surfaceCompletion) current := " >>do>" // leading indent + double-open marker p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}} - items, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false) + items, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false) if !ok || len(items) == 0 { t.Fatalf("expected provider-native items") } @@ -80,12 +82,13 @@ func TestProviderNativeCompletion_UsesPromptTemplate(t *testing.T) { cfg := s.cfg cfg.PromptNativeCompletion = "NATIVE {{path}} {{before}}" s.cfg = cfg + spec := s.buildRequestSpec(surfaceCompletion) uri := "file:///x.go" s.setDocument(uri, "AAA\nBBB\nCCC") current := "fmt." // Cursor at line 1, char 1 -> before should be "AAA\nB" p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Position: Position{Line: 1, Character: 1}} - if _, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false); !ok { + if _, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false); !ok { t.Fatalf("expected provider-native path") } if cap.lastPrompt == "" { diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 7b8bc88..28f3218 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -6,6 +6,7 @@ import ( "encoding/json" "io" "log" + "os" "strings" "sync" "time" @@ -29,6 +30,8 @@ type Server struct { configStore *runtimeconfig.Store cfg appconfig.App llmClient llm.Client + llmProvider string + altClients map[string]llm.Client lastInput time.Time // LLM request stats llmReqTotal int64 @@ -186,6 +189,12 @@ func (s *Server) applyOptions(opts ServerOptions) { } } s.llmClient = opts.Client + if opts.Client != nil { + s.llmProvider = canonicalProvider(opts.Client.Name()) + } else { + s.llmProvider = canonicalProvider(s.cfg.Provider) + } + s.altClients = make(map[string]llm.Client) } // ApplyOptions updates the server's configuration at runtime. @@ -199,6 +208,82 @@ func (s *Server) currentLLMClient() llm.Client { return s.llmClient } +func newClientForProvider(cfg appconfig.App, provider string) (llm.Client, error) { + llmCfg := llm.Config{ + Provider: provider, + OpenAIBaseURL: cfg.OpenAIBaseURL, + OpenAIModel: cfg.OpenAIModel, + OpenAITemperature: cfg.OpenAITemperature, + OllamaBaseURL: cfg.OllamaBaseURL, + OllamaModel: cfg.OllamaModel, + OllamaTemperature: cfg.OllamaTemperature, + CopilotBaseURL: cfg.CopilotBaseURL, + CopilotModel: cfg.CopilotModel, + CopilotTemperature: cfg.CopilotTemperature, + } + oaKey := strings.TrimSpace(os.Getenv("HEXAI_OPENAI_API_KEY")) + if oaKey == "" { + oaKey = strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + } + cpKey := strings.TrimSpace(os.Getenv("HEXAI_COPILOT_API_KEY")) + if cpKey == "" { + cpKey = strings.TrimSpace(os.Getenv("COPILOT_API_KEY")) + } + return llm.NewFromConfig(llmCfg, oaKey, cpKey) +} + +func (s *Server) clientFor(spec requestSpec) llm.Client { + provider := canonicalProvider(spec.provider) + s.mu.RLock() + baseProvider := s.llmProvider + baseClient := s.llmClient + if baseClient != nil && strings.TrimSpace(baseProvider) == "" { + baseProvider = canonicalProvider(baseClient.Name()) + } + if provider == "" { + provider = baseProvider + } + if provider == baseProvider && baseClient != nil { + s.mu.RUnlock() + return baseClient + } + if c, ok := s.altClients[provider]; ok { + s.mu.RUnlock() + return c + } + cfg := s.cfg + store := s.configStore + s.mu.RUnlock() + if store != nil { + cfg = store.Snapshot() + } + client, err := newClientForProvider(cfg, provider) + if err != nil { + logging.Logf("lsp ", "failed to build client for provider=%s: %v", provider, err) + if baseClient != nil { + return baseClient + } + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + if provider == s.llmProvider { + if s.llmClient == nil { + s.llmClient = client + s.llmProvider = provider + } + return s.llmClient + } + if existing, ok := s.altClients[provider]; ok { + return existing + } + if s.altClients == nil { + s.altClients = make(map[string]llm.Client) + } + s.altClients[provider] = client + return client +} + func (s *Server) currentConfig() appconfig.App { if s.configStore != nil { return s.configStore.Snapshot() diff --git a/internal/runtimeconfig/store_test.go b/internal/runtimeconfig/store_test.go index 6e40c76..6d23b33 100644 --- a/internal/runtimeconfig/store_test.go +++ b/internal/runtimeconfig/store_test.go @@ -94,3 +94,18 @@ func TestStoreReloadLogsSummary(t *testing.T) { t.Fatalf("expected change details in log, got %q", logOutput) } } + +func TestDiff_SurfaceModel(t *testing.T) { + oldCfg := appconfig.App{CompletionModel: "gpt-4o", CompletionProvider: "openai"} + newCfg := appconfig.App{CompletionModel: "gpt-4.1", CompletionProvider: "copilot"} + changes := Diff(oldCfg, newCfg) + if len(changes) != 2 { + t.Fatalf("expected single change, got %+v", changes) + } + if changes[0].Key != "completion_model" || changes[0].Old != "gpt-4o" || changes[0].New != "gpt-4.1" { + t.Fatalf("unexpected diff entry: %+v", changes[0]) + } + if changes[1].Key != "completion_provider" || changes[1].Old != "openai" || changes[1].New != "copilot" { + t.Fatalf("unexpected provider diff: %+v", changes[1]) + } +} |
