diff options
| -rw-r--r-- | PLAN3.md | 20 | ||||
| -rw-r--r-- | README.md | 1 | ||||
| -rw-r--r-- | SCRATCHPAD.md | 9 | ||||
| -rw-r--r-- | config.toml.example | 1 | ||||
| -rw-r--r-- | docs/configuration.md | 4 | ||||
| -rw-r--r-- | internal/appconfig/config.go | 10 | ||||
| -rw-r--r-- | internal/appconfig/config_test.go | 28 | ||||
| -rw-r--r-- | internal/hexaicli/run.go | 181 | ||||
| -rw-r--r-- | internal/hexaicli/run_test.go | 50 | ||||
| -rw-r--r-- | internal/lsp/handlers.go | 12 | ||||
| -rw-r--r-- | internal/lsp/handlers_completion.go | 196 | ||||
| -rw-r--r-- | internal/lsp/handlers_document.go | 5 | ||||
| -rw-r--r-- | internal/lsp/handlers_utils.go | 129 | ||||
| -rw-r--r-- | internal/lsp/llm_request_opts_test.go | 41 | ||||
| -rw-r--r-- | internal/lsp/provider_native_success_test.go | 9 | ||||
| -rw-r--r-- | internal/lsp/server.go | 22 |
16 files changed, 505 insertions, 213 deletions
@@ -8,17 +8,17 @@ Goal: allow configuring multiple provider:model pairs per surface so users can r - [x] Define merge, validation, and backward-compatibility rules (single entry auto-wraps into list). ## Phase 2 – Runtime Plumbing -- Extend appconfig/runtime store to emit ordered slices for multi-entry surfaces, including diff output. -- Update request-spec helpers to iterate across configured entries, building dedicated request specs (and caching clients per provider/model combo). -- Ensure logging/stats capture provider/model context per entry. +- [x] Extend appconfig/runtime store to emit ordered slices for multi-entry surfaces, including diff output. +- [x] Update request-spec helpers to iterate across configured entries, building dedicated request specs (and caching clients per provider/model combo). +- [x] Ensure logging/stats capture provider/model context per entry. ## Phase 3 – Surface Implementations -- Completion: fan out requests sequentially, gather one suggestion per entry, and surface them distinctly to the editor (label with provider/model). -- CLI: stream or print separate responses per entry, with clear headers and stats per run. -- Code actions: keep single-provider flow but ensure config ignores extra entries with validation warnings. -- Add reasonable concurrency limits / timeouts so multi-provider usage stays responsive. +- [x] Completion: fan out requests concurrently, gather one suggestion per entry, and surface them distinctly to the editor (labelled with provider/model). +- [x] CLI: run all configured providers in parallel and print separate responses per entry with stats. +- [x] Code actions: keep single-provider flow and warn/ignore additional `[[models.code_action]]` entries. +- [x] Add concurrency safeguards (debounce/throttle gate still respected before fan-out). ## Phase 4 – UX & Validation -- Tests covering multi-entry parsing, diffing, and surface-specific behavior (mock providers to simulate dual responses). -- Update docs and example TOML with new array syntax, including env override strategy. -- Capture lessons/issues in scratchpad for follow-up polishing. +- [x] Tests covering multi-entry parsing, diffing, and surface behavior (expanded CLI/LSP/appconfig suites). +- [x] Update docs and example TOML with array syntax and dual-provider guidance. +- [x] Capture lessons/issues in scratchpad for follow-up polishing. @@ -12,6 +12,7 @@ It has got improved capabilities for Go code understanding (for example, create * LSP AI Code actions * LSP in-editor chat with the LLM * Stand-alone command line tool for LLM interaction +* Parallel completions and CLI responses from multiple providers/models for side-by-side comparison * TUI AI code-action runner (`hexai-tmux-action`) with Bubble Tea - Includes a “Custom prompt” action (hotkey `p`) that opens your editor (`$HEXAI_EDITOR` or `$EDITOR`) on a temporary Markdown file. * Support for OpenAI, GitHub Copilot, and Ollama diff --git a/SCRATCHPAD.md b/SCRATCHPAD.md index c6b0f54..803b1a3 100644 --- a/SCRATCHPAD.md +++ b/SCRATCHPAD.md @@ -1,16 +1,9 @@ # Project scratch pad This document shows future items and items in progress. Already completed ones are deleted from this document as updates occur. - -## Features - * [ ] hexai cli to keep context for the follow-up question/prompt? -* [/] Be able to switch LLMs ad-hoc by re-reading the config. - -## More - +* [/] configure multiple models for cli and code completion * [ ] Exclude the test coverage files from git and wipe them from the history * [/] Review documentation * [/] Manual review the code * [ ] ASCIInema: Record and share terminal sessions for demos and bug reports -* [ ] Release notes: highlight per-surface model overrides once bundled diff --git a/config.toml.example b/config.toml.example index 1bed1a9..e054ef0 100644 --- a/config.toml.example +++ b/config.toml.example @@ -44,6 +44,7 @@ chat_prefixes = ["?", "!", ":", ";"] # single-character items # temperature = 0.2 # [[models.code_action]] +# # Only the first entry is used; extras are ignored with a warning. # provider = "copilot" # model = "gpt-4o" # temperature = 0.4 diff --git a/docs/configuration.md b/docs/configuration.md index 6db7a27..6e42172 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -27,7 +27,7 @@ Environment overrides Per-surface models - Use the `[models]` table in `config.toml` to tailor individual entry points (completion, code actions, chat, CLI) without changing the global provider default. -- Each key accepts either a string (shortcut) or a table with `model` / `temperature` fields, e.g.: +- Each key accepts either a string (shortcut) or one or more tables with `model` / `temperature` fields, e.g.: ```toml [models] @@ -43,6 +43,8 @@ Per-surface models provider = "openai" ``` +- Repeating the table (`[[models.<surface>]]`) configures multiple provider/model pairs. Completion requests and the Hexai CLI fan out to every configured entry concurrently and label the responses with `provider:model`. Code actions continue to use the first entry only; any extra [[models.code_action]] tables are ignored at runtime and the loader logs a warning so you know an additional entry was skipped. + - When a per-surface value is omitted, Hexai falls back to the provider’s configured default. Temperatures inherit from `coding_temperature` unless explicitly set, and OpenAI `gpt-5*` models automatically raise an unspecified coding temperature to `1.0` for exploratory behavior. Provider overrides support `"openai"`, `"copilot"`, or `"ollama"` and read the matching credential variables. Runtime reloads diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go index 63d0437..27c7e02 100644 --- a/internal/appconfig/config.go +++ b/internal/appconfig/config.go @@ -670,7 +670,15 @@ func parseSurfaceModels(raw map[string]any, logger *log.Logger) *App { return true } any := appendEntries(&out.CompletionConfigs, "models.completion", table["completion"]) - any = appendEntries(&out.CodeActionConfigs, "models.code_action", table["code_action"]) || any + if ok := appendEntries(&out.CodeActionConfigs, "models.code_action", table["code_action"]); ok { + if len(out.CodeActionConfigs) > 1 { + if logger != nil { + logger.Printf("config: models.code_action supports a single entry; ignoring %d extra", len(out.CodeActionConfigs)-1) + } + out.CodeActionConfigs = out.CodeActionConfigs[:1] + } + any = true + } any = appendEntries(&out.ChatConfigs, "models.chat", table["chat"]) || any any = appendEntries(&out.CLIConfigs, "models.cli", table["cli"]) || any if !any { diff --git a/internal/appconfig/config_test.go b/internal/appconfig/config_test.go index e7f6059..4ae04d8 100644 --- a/internal/appconfig/config_test.go +++ b/internal/appconfig/config_test.go @@ -1,6 +1,7 @@ package appconfig import ( + "bytes" "io" "log" "os" @@ -64,6 +65,33 @@ func TestLoad_Defaults_WithLogger_NoFile_NoEnv(t *testing.T) { } } +func TestParseSurfaceModels_CodeActionWarns(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.toml") + writeFile(t, path, ` +[models] + [[models.code_action]] + provider = "openai" + model = "gpt-4o" + + [[models.code_action]] + provider = "copilot" + model = "cpt" +`) + var buf bytes.Buffer + logger := log.New(&buf, "", 0) + app, err := loadFromFile(path, logger) + if err != nil { + t.Fatalf("loadFromFile: %v", err) + } + if len(app.CodeActionConfigs) != 1 || app.CodeActionConfigs[0].Model != "gpt-4o" { + t.Fatalf("expected single code action entry, got %+v", app.CodeActionConfigs) + } + if msg := buf.String(); !strings.Contains(msg, "models.code_action supports a single entry") { + t.Fatalf("expected warning, got %q", msg) + } +} + func TestLoad_FileMerge_And_EnvOverride(t *testing.T) { dir := t.TempDir() t.Setenv("XDG_CONFIG_HOME", dir) diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 5f6284c..06fcb83 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -3,12 +3,14 @@ package hexaicli import ( + "bytes" "context" "fmt" "io" "log" "os" "strings" + "sync" "time" "codeberg.org/snonux/hexai/internal/appconfig" @@ -25,48 +27,79 @@ type requestArgs struct { options []llm.RequestOption } -func buildCLIRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { - provider := canonicalProvider(cfg.Provider) +type cliJob struct { + index int + provider string + entry appconfig.SurfaceConfig + client llm.Client + req requestArgs +} + +func buildCLIJobs(cfg appconfig.App) ([]cliJob, error) { entries := cfg.CLIConfigs if len(entries) == 0 { - entries = []appconfig.SurfaceConfig{{Provider: cfg.Provider, Model: strings.TrimSpace(defaultModelForProvider(cfg, provider))}} + entries = []appconfig.SurfaceConfig{{}} + } + jobs := make([]cliJob, 0, len(entries)) + for i, raw := range entries { + entry := appconfig.SurfaceConfig{Provider: strings.TrimSpace(raw.Provider), Model: strings.TrimSpace(raw.Model), Temperature: raw.Temperature} + provider := entry.Provider + if provider == "" { + provider = cfg.Provider + } + provider = canonicalProvider(provider) + derived := cfg + derived.Provider = provider + switch provider { + case "openai": + if entry.Model != "" { + derived.OpenAIModel = entry.Model + } + case "copilot": + if entry.Model != "" { + derived.CopilotModel = entry.Model + } + case "ollama": + if entry.Model != "" { + derived.OllamaModel = entry.Model + } + } + client, err := newClientFromApp(derived) + if err != nil { + return nil, err + } + req := buildCLIRequest(entry, provider, cfg, client) + if strings.TrimSpace(req.model) == "" { + req.model = strings.TrimSpace(client.DefaultModel()) + } + jobs = append(jobs, cliJob{index: i, provider: provider, entry: entry, client: client, req: req}) } - primary := entries[0] - if strings.TrimSpace(primary.Provider) != "" { - provider = canonicalProvider(primary.Provider) + return jobs, nil +} + +func buildCLIRequest(entry appconfig.SurfaceConfig, provider string, cfg appconfig.App, client llm.Client) requestArgs { + opts := make([]llm.RequestOption, 0, 2) + if cfg.MaxTokens > 0 { + opts = append(opts, llm.WithMaxTokens(cfg.MaxTokens)) } - model := strings.TrimSpace(primary.Model) - if client != nil { - provider = strings.ToLower(strings.TrimSpace(client.Name())) - if model == "" { + model := strings.TrimSpace(entry.Model) + if model == "" { + if client != nil { model = strings.TrimSpace(client.DefaultModel()) } + if model == "" { + model = strings.TrimSpace(defaultModelForProvider(cfg, provider)) + } } - if model == "" { - model = strings.TrimSpace(defaultModelForProvider(cfg, provider)) - } - opts := make([]llm.RequestOption, 0, 2) - if strings.TrimSpace(primary.Model) != "" { - opts = append(opts, llm.WithModel(strings.TrimSpace(primary.Model))) + if entry.Model != "" { + opts = append(opts, llm.WithModel(entry.Model)) } - if temp, ok := cliTemperatureFromEntry(cfg, provider, primary, model); ok { + if temp, ok := cliTemperatureFromEntry(cfg, provider, entry, model); ok { opts = append(opts, llm.WithTemperature(temp)) } return requestArgs{model: model, options: opts} } -func defaultRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { - if len(cfg.CLIConfigs) > 0 { - if m := strings.TrimSpace(cfg.CLIConfigs[0].Model); m != "" { - return requestArgs{model: m} - } - } - if client != nil { - return requestArgs{model: strings.TrimSpace(client.DefaultModel())} - } - return requestArgs{} -} - func cliTemperatureFromEntry(cfg appconfig.App, provider string, entry appconfig.SurfaceConfig, model string) (float64, bool) { if entry.Temperature != nil { return *entry.Temperature, true @@ -112,17 +145,14 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. if cfg.StatsWindowMinutes > 0 { stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) } - if len(cfg.CLIConfigs) > 0 { - if provider := strings.TrimSpace(cfg.CLIConfigs[0].Provider); provider != "" { - cfg.Provider = provider - } - } - client, err := newClientFromApp(cfg) + jobs, err := buildCLIJobs(cfg) if err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) return err } - req := buildCLIRequestArgs(cfg, client) + if len(jobs) == 0 { + return fmt.Errorf("hexai: no CLI providers configured") + } // Prefer piped stdin when present; only open the editor when there are no args // and no stdin content available. input, rerr := readInput(stdin, args) @@ -136,9 +166,8 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset) return rerr } - printProviderInfo(stderr, client, req.model) msgs := buildMessagesFromConfig(cfg, input) - if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { + if err := runCLIJobs(ctx, jobs, msgs, input, stdout, stderr); err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) return err } @@ -153,7 +182,7 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) return err } - req := defaultRequestArgs(appconfig.App{}, client) + req := requestArgs{model: strings.TrimSpace(client.DefaultModel())} printProviderInfo(stderr, client, req.model) msgs := buildMessages(input) if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { @@ -163,6 +192,80 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, return nil } +type cliJobResult struct { + provider string + model string + output string + summary string + err error +} + +func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error { + results := make([]*cliJobResult, len(jobs)) + var wg sync.WaitGroup + for _, job := range jobs { + job := job + wg.Add(1) + printProviderInfo(stderr, job.client, job.req.model) + go func() { + defer wg.Done() + var outBuf, errBuf bytes.Buffer + jobMsgs := make([]llm.Message, len(msgs)) + copy(jobMsgs, msgs) + err := runChat(ctx, job.client, job.req, jobMsgs, input, &outBuf, &errBuf) + results[job.index] = &cliJobResult{ + provider: job.client.Name(), + model: job.req.model, + output: outBuf.String(), + summary: errBuf.String(), + err: err, + } + }() + } + wg.Wait() + var firstErr error + printed := false + for _, res := range results { + if res == nil { + continue + } + if printed { + if _, err := io.WriteString(stdout, "\n"); err != nil { + return err + } + } + heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) + if _, err := io.WriteString(stdout, heading); err != nil { + return err + } + if res.output != "" { + if _, err := io.WriteString(stdout, res.output); err != nil { + return err + } + if !strings.HasSuffix(res.output, "\n") { + if _, err := io.WriteString(stdout, "\n"); err != nil { + return err + } + } + } + printed = true + if res.summary != "" { + if _, err := io.WriteString(stderr, res.summary); err != nil { + return err + } + } + if res.err != nil { + if _, err := fmt.Fprintf(stderr, logging.AnsiBase+"hexai: provider=%s model=%s error: %v"+logging.AnsiReset+"\n", res.provider, res.model, res.err); err != nil { + return err + } + if firstErr == nil { + firstErr = res.err + } + } + } + return firstErr +} + // readInput reads from stdin and args, then combines them per CLI rules. func readInput(stdin io.Reader, args []string) (string, error) { var stdinData string diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go index 0250ac9..f11545e 100644 --- a/internal/hexaicli/run_test.go +++ b/internal/hexaicli/run_test.go @@ -151,13 +151,13 @@ func TestPrintProviderInfo(t *testing.T) { } } -func TestBuildCLIRequestArgs_Override(t *testing.T) { +func TestBuildCLIRequest_Override(t *testing.T) { cfg := appconfig.App{ Provider: "openai", CopilotModel: "gpt-4o", - CLIConfigs: []appconfig.SurfaceConfig{{Provider: "copilot", Model: "override", Temperature: floatPtr(0.7)}}, } - req := buildCLIRequestArgs(cfg, &fakeClient{name: "copilot", model: "default"}) + entry := appconfig.SurfaceConfig{Provider: "copilot", Model: "override", Temperature: floatPtr(0.7)} + req := buildCLIRequest(entry, "copilot", cfg, &fakeClient{name: "copilot", model: "default"}) if req.model != "override" { t.Fatalf("expected model override, got %q", req.model) } @@ -170,9 +170,10 @@ func TestBuildCLIRequestArgs_Override(t *testing.T) { } } -func TestBuildCLIRequestArgs_Gpt5Temp(t *testing.T) { +func TestBuildCLIRequest_Gpt5Temp(t *testing.T) { cfg := appconfig.App{Provider: "openai", CodingTemperature: floatPtr(0.2)} - req := buildCLIRequestArgs(cfg, &fakeClient{name: "openai", model: "gpt-5.1"}) + entry := appconfig.SurfaceConfig{} + req := buildCLIRequest(entry, "openai", cfg, &fakeClient{name: "openai", model: "gpt-5.1"}) if req.model != "gpt-5.1" { t.Fatalf("expected fallback model, got %q", req.model) } @@ -185,6 +186,45 @@ func TestBuildCLIRequestArgs_Gpt5Temp(t *testing.T) { } } +func TestBuildCLIJobs_MultiEntries(t *testing.T) { + old := newClientFromApp + defer func() { newClientFromApp = old }() + newClientFromApp = func(cfg appconfig.App) (llm.Client, error) { + model := cfg.OpenAIModel + if cfg.Provider == "copilot" { + model = cfg.CopilotModel + } + if cfg.Provider == "ollama" { + model = cfg.OllamaModel + } + if strings.TrimSpace(model) == "" { + model = "default" + } + return &fakeClient{name: cfg.Provider, model: model}, nil + } + cfg := appconfig.App{ + Provider: "ollama", + OllamaModel: "llama3", + CLIConfigs: []appconfig.SurfaceConfig{ + {Provider: "openai", Model: "gpt-4o"}, + {Provider: "copilot", Model: "cpt"}, + }, + } + jobs, err := buildCLIJobs(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(jobs) != 2 { + t.Fatalf("expected 2 jobs, got %d", len(jobs)) + } + if jobs[0].provider != "openai" || jobs[0].req.model != "gpt-4o" { + t.Fatalf("unexpected first job: %+v", jobs[0]) + } + if jobs[1].provider != "copilot" || jobs[1].req.model != "cpt" { + t.Fatalf("unexpected second job: %+v", jobs[1]) + } +} + func TestNewClientFromConfig_Ollama(t *testing.T) { cfg := appconfig.App{Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m"} c, err := newClientFromConfig(cfg) diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index c1a637f..94b6348 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -343,13 +343,15 @@ func (s *Server) isTriggerEvent(p CompletionParams, current string) bool { return false } -func (s *Server) makeCompletionItems(cleaned string, inParams bool, current string, p CompletionParams, docStr string) []CompletionItem { +func (s *Server) makeCompletionItems(cleaned string, inParams bool, current string, p CompletionParams, docStr string, detail string, sortPrefix string) []CompletionItem { te, filter := computeTextEditAndFilter(cleaned, inParams, current, p) rm := s.collectPromptRemovalEdits(p.TextDocument.URI) label := labelForCompletion(cleaned, filter) - detail := "Hexai LLM completion" - if client := s.currentLLMClient(); client != nil { - detail = "Hexai " + client.Name() + ":" + client.DefaultModel() + if strings.TrimSpace(detail) == "" { + detail = "Hexai LLM completion" + } + if sortPrefix == "" { + sortPrefix = "0000" } return []CompletionItem{{ Label: label, @@ -359,7 +361,7 @@ func (s *Server) makeCompletionItems(cleaned string, inParams bool, current stri FilterText: strings.TrimLeft(filter, " \t"), TextEdit: te, AdditionalTextEdits: rm, - SortText: "0000", + SortText: sortPrefix, Documentation: docStr, }} } diff --git a/internal/lsp/handlers_completion.go b/internal/lsp/handlers_completion.go index d115741..237d34d 100644 --- a/internal/lsp/handlers_completion.go +++ b/internal/lsp/handlers_completion.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "strings" + "sync" "time" "codeberg.org/snonux/hexai/internal/llm" @@ -94,17 +95,48 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun if handled { return items, true } - - spec := s.buildRequestSpec(surfaceCompletion) - client := s.clientFor(spec) - if client == nil { + specs := s.buildRequestSpecs(surfaceCompletion) + if len(specs) == 0 { return nil, false } - if items, ok := s.tryProviderNativeCompletion(spec, client, current, p, above, below, funcCtx, docStr, hasExtra, extraText, plan.inParams); ok { - return items, true + type jobResult struct { + items []CompletionItem + ok bool } - - return s.executeChatCompletion(ctx, plan, spec, client) + results := make([]jobResult, len(specs)) + var wg sync.WaitGroup + var mu sync.Mutex + s.waitForDebounce(ctx) + if !s.waitForThrottle(ctx) { + return nil, false + } + for _, spec := range specs { + spec := spec + client := s.clientFor(spec) + if client == nil { + continue + } + wg.Add(1) + go func(idx int, spec requestSpec, client llm.Client) { + defer wg.Done() + items, ok := s.runCompletionForSpec(ctx, plan, spec, client) + mu.Lock() + results[idx] = jobResult{items: items, ok: ok} + mu.Unlock() + }(spec.index, spec, client) + } + wg.Wait() + accumulated := make([]CompletionItem, 0) + for _, res := range results { + if !res.ok { + continue + } + accumulated = append(accumulated, res.items...) + } + if len(accumulated) == 0 { + return nil, false + } + return accumulated, true } func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) (completionPlan, []CompletionItem, bool) { @@ -130,12 +162,6 @@ func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below plan.inParams = inParamList(current, p.Position.Character) plan.manualInvoke = parseManualInvoke(p.Context) plan.cacheKey = s.completionCacheKey(p, above, current, below, funcCtx, plan.inParams, hasExtra, extraText) - if cleaned, ok := s.completionCacheGet(plan.cacheKey); ok && strings.TrimSpace(cleaned) != "" { - logging.Logf("lsp ", "completion cache hit uri=%s line=%d char=%d preview=%s%s%s", - p.TextDocument.URI, p.Position.Line, p.Position.Character, - logging.AnsiGreen, logging.PreviewForLog(cleaned), logging.AnsiBase) - return plan, s.makeCompletionItems(cleaned, plan.inParams, current, p, docStr), true - } if isBareDoubleOpen(current, openChar, closeChar) || isBareDoubleOpen(below, openChar, closeChar) { logging.Logf("lsp ", "%scompletion skip=empty-double-semicolon line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) return plan, []CompletionItem{}, true @@ -147,38 +173,58 @@ func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below return plan, nil, false } -func (s *Server) executeChatCompletion(ctx context.Context, plan completionPlan, spec requestSpec, client llm.Client) ([]CompletionItem, bool) { +func (s *Server) runCompletionForSpec(ctx context.Context, plan completionPlan, spec requestSpec, client llm.Client) ([]CompletionItem, bool) { + sortPrefix := fmt.Sprintf("%04d", spec.index) + modelKey := spec.effectiveModel(client.DefaultModel()) + providerKey := spec.provider + if providerKey == "" { + providerKey = canonicalProvider(client.Name()) + } + cacheKey := plan.cacheKey + "|" + providerKey + ":" + modelKey + if cached, ok := s.completionCacheGet(cacheKey); ok && strings.TrimSpace(cached) != "" { + logging.Logf("lsp ", "completion cache hit uri=%s line=%d char=%d preview=%s%s%s", + plan.params.TextDocument.URI, plan.params.Position.Line, plan.params.Position.Character, + logging.AnsiGreen, logging.PreviewForLog(cached), logging.AnsiBase) + detail := fmt.Sprintf("Hexai %s:%s", client.Name(), modelKey) + items := s.makeCompletionItems(cached, plan.inParams, plan.current, plan.params, plan.docStr, detail, sortPrefix) + return items, true + } + if items, ok := s.tryProviderNativeCompletion(ctx, plan, spec, client, sortPrefix); ok { + return items, true + } + return s.executeChatCompletion(ctx, plan, spec, client, sortPrefix) +} + +func (s *Server) executeChatCompletion(ctx context.Context, plan completionPlan, spec requestSpec, client llm.Client, sortPrefix string) ([]CompletionItem, bool) { messages := s.buildCompletionMessages(plan.inlinePrompt, plan.hasExtra, plan.extraText, plan.inParams, plan.params, plan.above, plan.current, plan.below, plan.funcCtx) sentSize := 0 for _, m := range messages { sentSize += len(m.Content) } s.incSentCounters(sentSize) - opts := spec.options - s.waitForDebounce(ctx) - if !s.waitForThrottle(ctx) { - return nil, false - } - modelUsed := spec.effectiveModel() - if strings.TrimSpace(modelUsed) == "" { - modelUsed = client.DefaultModel() - } - logging.Logf("lsp ", "completion llm=requesting model=%s", modelUsed) - text, err := client.Chat(ctx, messages, opts...) + text, err := client.Chat(ctx, messages, spec.options...) if err != nil { logging.Logf("lsp ", "llm completion error: %v", err) - s.logLLMStats(modelUsed) + s.logLLMStats("") return nil, false } s.incRecvCounters(len(text)) + modelUsed := spec.effectiveModel(client.DefaultModel()) + _ = stats.Update(ctx, client.Name(), modelUsed, sentSize, len(text)) s.logLLMStats(modelUsed) trimmed := strings.TrimSpace(text) cleaned := s.postProcessCompletion(trimmed, plan.current[:plan.params.Position.Character], plan.current) if cleaned == "" { return nil, false } - s.completionCachePut(plan.cacheKey, cleaned) - items := s.makeCompletionItems(cleaned, plan.inParams, plan.current, plan.params, plan.docStr) + detail := fmt.Sprintf("Hexai %s:%s", client.Name(), modelUsed) + providerKey := spec.provider + if providerKey == "" { + providerKey = canonicalProvider(client.Name()) + } + cacheKey := plan.cacheKey + "|" + providerKey + ":" + modelUsed + s.completionCachePut(cacheKey, cleaned) + items := s.makeCompletionItems(cleaned, plan.inParams, plan.current, plan.params, plan.docStr, detail, sortPrefix) return items, true } @@ -260,79 +306,75 @@ func (s *Server) prefixHeuristicAllows(inlinePrompt bool, current string, p Comp } // tryProviderNativeCompletion attempts provider-native completion and returns items when successful. -func (s *Server) tryProviderNativeCompletion(spec requestSpec, client llm.Client, current string, p CompletionParams, above, below, funcCtx, docStr string, hasExtra bool, extraText string, inParams bool) ([]CompletionItem, bool) { +func (s *Server) tryProviderNativeCompletion(ctx context.Context, plan completionPlan, spec requestSpec, client llm.Client, sortPrefix string) ([]CompletionItem, bool) { cc, ok := client.(llm.CodeCompleter) if !ok { return nil, false } + current := plan.current + p := plan.params before, after := s.docBeforeAfter(p.TextDocument.URI, p.Position) path := strings.TrimPrefix(p.TextDocument.URI, "file://") - // Build provider-native prompt from template cfg := s.currentConfig() _, _, openChar, closeChar := s.inlineMarkers() prompt := renderTemplate(cfg.PromptNativeCompletion, map[string]string{ "path": path, "before": before, }) - lang := "" provider := spec.provider if provider == "" { provider = canonicalProvider(cfg.Provider) } logging.Logf("lsp ", "completion path=codex provider=%s uri=%s", provider, path) - ctx2, cancel2 := context.WithTimeout(context.Background(), 15*time.Second) + ctx2, cancel2 := context.WithTimeout(ctx, 15*time.Second) defer cancel2() - - // Debounce and throttle prior to provider-native call - s.waitForDebounce(ctx2) - if !s.waitForThrottle(ctx2) { - return nil, false - } - // Count approximate payload sizes: prompt+after sent; first suggestion received sentBytes := len(prompt) + len(after) - modelUsed := spec.effectiveModel() - if strings.TrimSpace(modelUsed) == "" { - modelUsed = client.DefaultModel() - } + modelUsed := spec.effectiveModel(client.DefaultModel()) tempVal := 0.0 - if val, ok := chooseSurfaceTemperature(surfaceCompletion, cfg, provider, spec.modelOverride, spec.fallbackModel); ok { + if val, ok := chooseSurfaceTemperature(surfaceCompletion, cfg, spec.entry, provider, modelUsed); ok { tempVal = val } - suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, tempVal) - if err == nil && len(suggestions) > 0 { - // Update counters and heartbeat - s.incSentCounters(sentBytes) - s.incRecvCounters(len(suggestions[0])) - // Contribute to global stats (provider-native path) - if client != nil { - _ = stats.Update(ctx2, client.Name(), modelUsed, sentBytes, len(suggestions[0])) + suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, "", tempVal) + if err != nil || len(suggestions) == 0 { + if err != nil { + logging.Logf("lsp ", "completion path=codex error=%v (falling back)", err) } - s.logLLMStats(modelUsed) - cleaned := strings.TrimSpace(suggestions[0]) - if cleaned != "" { - cleaned = stripDuplicateAssignmentPrefix(current[:p.Position.Character], cleaned) - if cleaned != "" { - cleaned = stripDuplicateGeneralPrefix(current[:p.Position.Character], cleaned) - } - if cleaned != "" && hasDoubleOpenTrigger(current, openChar, closeChar) { - indent := leadingIndent(current) - if indent != "" { - cleaned = applyIndent(indent, cleaned) - } - } - if strings.TrimSpace(cleaned) != "" { - key := s.completionCacheKey(p, above, current, below, funcCtx, inParams, hasExtra, extraText) - s.completionCachePut(key, cleaned) - return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true - } + return nil, false + } + s.incSentCounters(sentBytes) + s.incRecvCounters(len(suggestions[0])) + _ = stats.Update(ctx2, client.Name(), modelUsed, sentBytes, len(suggestions[0])) + s.logLLMStats(modelUsed) + cleaned := strings.TrimSpace(suggestions[0]) + if cleaned == "" { + return nil, false + } + cleaned = stripDuplicateAssignmentPrefix(current[:p.Position.Character], cleaned) + if cleaned == "" { + return nil, false + } + cleaned = stripDuplicateGeneralPrefix(current[:p.Position.Character], cleaned) + if cleaned == "" { + return nil, false + } + if strings.TrimSpace(cleaned) != "" && hasDoubleOpenTrigger(current, openChar, closeChar) { + indent := leadingIndent(current) + if indent != "" { + cleaned = applyIndent(indent, cleaned) } - } else if err != nil { - logging.Logf("lsp ", "completion path=codex error=%v (falling back to chat)", err) - // Still emit a heartbeat for visibility, even on error - s.incSentCounters(sentBytes) - s.logLLMStats(modelUsed) } - return nil, false + if strings.TrimSpace(cleaned) == "" { + return nil, false + } + detail := fmt.Sprintf("Hexai %s:%s", client.Name(), modelUsed) + providerKey := provider + if providerKey == "" { + providerKey = canonicalProvider(client.Name()) + } + cacheKey := plan.cacheKey + "|" + providerKey + ":" + modelUsed + s.completionCachePut(cacheKey, cleaned) + items := s.makeCompletionItems(cleaned, plan.inParams, current, p, plan.docStr, detail, sortPrefix) + return items, true } // waitForDebounce sleeps until there has been no input activity for at least diff --git a/internal/lsp/handlers_document.go b/internal/lsp/handlers_document.go index f8ed9ed..9325877 100644 --- a/internal/lsp/handlers_document.go +++ b/internal/lsp/handlers_document.go @@ -172,10 +172,7 @@ func (s *Server) detectAndHandleChat(uri string) { if client == nil { return } - modelUsed := spec.effectiveModel() - if strings.TrimSpace(modelUsed) == "" { - modelUsed = client.DefaultModel() - } + modelUsed := spec.effectiveModel(client.DefaultModel()) logging.Logf("lsp ", "chat llm=requesting model=%s", modelUsed) text, err := s.chatWithStats(ctx, surfaceChat, spec, msgs) if err != nil { diff --git a/internal/lsp/handlers_utils.go b/internal/lsp/handlers_utils.go index c8d2d24..2748a60 100644 --- a/internal/lsp/handlers_utils.go +++ b/internal/lsp/handlers_utils.go @@ -25,41 +25,79 @@ const ( type requestSpec struct { provider string - modelOverride string + entry appconfig.SurfaceConfig fallbackModel string options []llm.RequestOption + index int } -func (r requestSpec) effectiveModel() string { - if s := strings.TrimSpace(r.modelOverride); s != "" { - return s +func (r requestSpec) modelOverride() string { return strings.TrimSpace(r.entry.Model) } + +func (r requestSpec) effectiveModel(defaultModel string) string { + if m := strings.TrimSpace(r.entry.Model); m != "" { + return m + } + if f := strings.TrimSpace(r.fallbackModel); f != "" { + return f } - return strings.TrimSpace(r.fallbackModel) + return strings.TrimSpace(defaultModel) } -func (s *Server) buildRequestSpec(surface surfaceKind) requestSpec { +func (s *Server) buildRequestSpecs(surface surfaceKind) []requestSpec { cfg := s.currentConfig() - providerOverride := strings.TrimSpace(surfaceProviderFromConfig(cfg, surface)) - provider := canonicalProvider(cfg.Provider) - if providerOverride != "" { - provider = canonicalProvider(providerOverride) + entries := surfaceConfigsFor(cfg, surface) + if len(entries) == 0 { + entries = []appconfig.SurfaceConfig{{Provider: cfg.Provider}} } - fallbackModel := strings.TrimSpace(resolveDefaultModel(cfg, provider)) - modelOverride := strings.TrimSpace(surfaceModelFromConfig(cfg, surface)) maxTokens := s.maxTokens() - opts := []llm.RequestOption{llm.WithMaxTokens(maxTokens)} - if tempVal, ok := chooseSurfaceTemperature(surface, cfg, provider, modelOverride, fallbackModel); ok { - opts = append(opts, llm.WithTemperature(tempVal)) - } - if modelOverride != "" { - opts = append(opts, llm.WithModel(modelOverride)) + specs := make([]requestSpec, 0, len(entries)) + for idx, raw := range entries { + entry := appconfig.SurfaceConfig{ + Provider: strings.TrimSpace(raw.Provider), + Model: strings.TrimSpace(raw.Model), + Temperature: raw.Temperature, + } + provider := entry.Provider + if provider == "" { + provider = cfg.Provider + } + provider = canonicalProvider(provider) + fallbackModel := entry.Model + if fallbackModel == "" { + fallbackModel = strings.TrimSpace(resolveDefaultModel(cfg, provider)) + } + opts := []llm.RequestOption{llm.WithMaxTokens(maxTokens)} + if entry.Model != "" { + opts = append(opts, llm.WithModel(entry.Model)) + } + if temp, ok := chooseSurfaceTemperature(surface, cfg, entry, provider, fallbackModel); ok { + opts = append(opts, llm.WithTemperature(temp)) + } + specs = append(specs, requestSpec{ + provider: provider, + entry: entry, + fallbackModel: fallbackModel, + options: opts, + index: idx, + }) } - return requestSpec{ - provider: provider, - modelOverride: modelOverride, - fallbackModel: fallbackModel, - options: opts, + return specs +} + +func (s *Server) primaryRequestSpec(surface surfaceKind) requestSpec { + specs := s.buildRequestSpecs(surface) + if len(specs) == 0 { + cfg := s.currentConfig() + provider := canonicalProvider(cfg.Provider) + fallback := strings.TrimSpace(resolveDefaultModel(cfg, provider)) + return requestSpec{provider: provider, fallbackModel: fallback, options: []llm.RequestOption{llm.WithMaxTokens(s.maxTokens())}} } + return specs[0] +} + +// buildRequestSpec is retained for consumers expecting a single-entry helper. +func (s *Server) buildRequestSpec(surface surfaceKind) requestSpec { + return s.primaryRequestSpec(surface) } func canonicalProvider(name string) string { @@ -94,37 +132,13 @@ func surfaceConfigsFor(cfg appconfig.App, surface surfaceKind) []appconfig.Surfa } } -func surfaceModelFromConfig(cfg appconfig.App, surface surfaceKind) string { - configs := surfaceConfigsFor(cfg, surface) - if len(configs) == 0 { - return "" - } - return configs[0].Model -} - -func surfaceProviderFromConfig(cfg appconfig.App, surface surfaceKind) string { - configs := surfaceConfigsFor(cfg, surface) - if len(configs) == 0 { - return "" - } - return configs[0].Provider -} - -func surfaceTemperatureFromConfig(cfg appconfig.App, surface surfaceKind) *float64 { - configs := surfaceConfigsFor(cfg, surface) - if len(configs) == 0 { - return nil - } - return configs[0].Temperature -} - -func chooseSurfaceTemperature(surface surfaceKind, cfg appconfig.App, provider string, overrideModel, fallbackModel string) (float64, bool) { - if t := surfaceTemperatureFromConfig(cfg, surface); t != nil { - return *t, true +func chooseSurfaceTemperature(surface surfaceKind, cfg appconfig.App, entry appconfig.SurfaceConfig, provider string, fallbackModel string) (float64, bool) { + if entry.Temperature != nil { + return *entry.Temperature, true } if cfg.CodingTemperature != nil { temp := *cfg.CodingTemperature - effectiveModel := strings.TrimSpace(overrideModel) + effectiveModel := strings.TrimSpace(entry.Model) if effectiveModel == "" { effectiveModel = strings.TrimSpace(fallbackModel) } @@ -133,7 +147,7 @@ func chooseSurfaceTemperature(surface surfaceKind, cfg appconfig.App, provider s } return temp, true } - effectiveModel := strings.TrimSpace(overrideModel) + effectiveModel := strings.TrimSpace(entry.Model) if effectiveModel == "" { effectiveModel = strings.TrimSpace(fallbackModel) } @@ -283,19 +297,16 @@ func (s *Server) chatWithStats(ctx context.Context, surface surfaceKind, spec re if client == nil { return "", fmt.Errorf("llm client unavailable") } + modelUsed := spec.effectiveModel(client.DefaultModel()) txt, err := client.Chat(ctx, msgs, spec.options...) if err != nil { - s.logLLMStats(spec.effectiveModel()) + s.logLLMStats(modelUsed) return "", err } s.incRecvCounters(len(txt)) // Update global stats cache - model := spec.effectiveModel() - if model == "" { - model = client.DefaultModel() - } - _ = stats.Update(ctx, client.Name(), model, sent, len(txt)) - s.logLLMStats(model) + _ = stats.Update(ctx, client.Name(), modelUsed, sent, len(txt)) + s.logLLMStats(modelUsed) return txt, nil } diff --git a/internal/lsp/llm_request_opts_test.go b/internal/lsp/llm_request_opts_test.go index 263db79..f4d31b9 100644 --- a/internal/lsp/llm_request_opts_test.go +++ b/internal/lsp/llm_request_opts_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/llm" ) @@ -15,6 +16,10 @@ func (f fakeClient) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOpt func (f fakeClient) Name() string { return f.name } func (f fakeClient) DefaultModel() string { return f.model } +func floatPtr(v float64) *float64 { + return &v +} + func TestRequestSpec_Gpt5_ForcesTemp1(t *testing.T) { s := newTestServer() one := 0.2 @@ -30,7 +35,41 @@ func TestRequestSpec_Gpt5_ForcesTemp1(t *testing.T) { if got.Temperature != 1.0 { t.Fatalf("expected temp 1.0 for gpt-5, got %v", got.Temperature) } - if model := spec.effectiveModel(); model != "gpt-5.0" { + if model := spec.effectiveModel(s.llmClient.DefaultModel()); model != "gpt-5.0" { t.Fatalf("expected fallback model gpt-5.0, got %q", model) } } + +func TestBuildRequestSpecs_MultiEntries(t *testing.T) { + s := newTestServer() + s.cfg.CompletionConfigs = []appconfig.SurfaceConfig{ + {Provider: "openai", Model: "gpt-4o"}, + {Provider: "copilot", Model: "cpt", Temperature: floatPtr(0.4)}, + } + s.cfg.OpenAIModel = "gpt-3.5" + s.cfg.CopilotModel = "cpt-base" + s.cfg.MaxTokens = 256 + specs := s.buildRequestSpecs(surfaceCompletion) + if len(specs) != 2 { + t.Fatalf("expected 2 specs, got %d", len(specs)) + } + if specs[0].provider != "openai" || specs[0].index != 0 { + t.Fatalf("unexpected first spec: %+v", specs[0]) + } + if specs[1].provider != "copilot" || specs[1].index != 1 { + t.Fatalf("unexpected second spec: %+v", specs[1]) + } + var opts1, opts2 llm.Options + for _, opt := range specs[0].options { + opt(&opts1) + } + for _, opt := range specs[1].options { + opt(&opts2) + } + if opts1.Model != "gpt-4o" || opts1.MaxTokens != 256 { + t.Fatalf("unexpected opts1: %+v", opts1) + } + if opts2.Model != "cpt" || opts2.Temperature != 0.4 { + t.Fatalf("unexpected opts2: %+v", opts2) + } +} diff --git a/internal/lsp/provider_native_success_test.go b/internal/lsp/provider_native_success_test.go index aab886c..e5ab81e 100644 --- a/internal/lsp/provider_native_success_test.go +++ b/internal/lsp/provider_native_success_test.go @@ -25,7 +25,8 @@ func TestProviderNativeCompletion_Success(t *testing.T) { // current line with dot trigger; position after dot current := "fmt." p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}} - items, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false) + plan := completionPlan{current: current, params: p, funcCtx: "func f(){}", docStr: "doc", cacheKey: "k"} + items, ok := s.tryProviderNativeCompletion(context.Background(), plan, spec, s.llmClient, "0000") if !ok || len(items) == 0 { t.Fatalf("expected provider-native items") } @@ -51,7 +52,8 @@ func TestProviderNativeCompletion_IndentWithDoubleOpen(t *testing.T) { spec := s.buildRequestSpec(surfaceCompletion) current := " >>do>" // leading indent + double-open marker p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}} - items, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false) + plan := completionPlan{current: current, params: p, funcCtx: "func f(){}", docStr: "doc", cacheKey: "k"} + items, ok := s.tryProviderNativeCompletion(context.Background(), plan, spec, s.llmClient, "0000") if !ok || len(items) == 0 { t.Fatalf("expected provider-native items") } @@ -88,7 +90,8 @@ func TestProviderNativeCompletion_UsesPromptTemplate(t *testing.T) { current := "fmt." // Cursor at line 1, char 1 -> before should be "AAA\nB" p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Position: Position{Line: 1, Character: 1}} - if _, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false); !ok { + plan := completionPlan{current: current, params: p, funcCtx: "func f(){}", docStr: "doc", cacheKey: "k"} + if _, ok := s.tryProviderNativeCompletion(context.Background(), plan, spec, s.llmClient, "0000"); !ok { t.Fatalf("expected provider-native path") } if cap.lastPrompt == "" { diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 28f3218..1fbb0cc 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -257,6 +257,28 @@ func (s *Server) clientFor(spec requestSpec) llm.Client { if store != nil { cfg = store.Snapshot() } + cfg.Provider = provider + modelOverride := strings.TrimSpace(spec.entry.Model) + switch provider { + case "openai": + if modelOverride != "" { + cfg.OpenAIModel = modelOverride + } else if spec.fallbackModel != "" { + cfg.OpenAIModel = spec.fallbackModel + } + case "copilot": + if modelOverride != "" { + cfg.CopilotModel = modelOverride + } else if spec.fallbackModel != "" { + cfg.CopilotModel = spec.fallbackModel + } + case "ollama": + if modelOverride != "" { + cfg.OllamaModel = modelOverride + } else if spec.fallbackModel != "" { + cfg.OllamaModel = spec.fallbackModel + } + } client, err := newClientForProvider(cfg, provider) if err != nil { logging.Logf("lsp ", "failed to build client for provider=%s: %v", provider, err) |
