diff options
| author | Paul Buetow <paul@buetow.org> | 2025-09-28 00:20:05 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-09-28 00:20:05 +0300 |
| commit | 0ac2d186e84f77d73d924e2c0ce975a17c3a8078 (patch) | |
| tree | 49f3e2def38449544e1d67f047cbcb4aab802658 /internal | |
| parent | 51b2621d58633aa5c0f5cc7b64616d70d41acc91 (diff) | |
Improve multi-provider completion streaming and CLI selector flags
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/hexaicli/run.go | 316 | ||||
| -rw-r--r-- | internal/hexaicli/run_model_override_test.go | 54 | ||||
| -rw-r--r-- | internal/hexaicli/run_test.go | 17 | ||||
| -rw-r--r-- | internal/lsp/chat_trigger_suppression_test.go | 2 | ||||
| -rw-r--r-- | internal/lsp/completion_cache_test.go | 4 | ||||
| -rw-r--r-- | internal/lsp/completion_codex_path_test.go | 4 | ||||
| -rw-r--r-- | internal/lsp/completion_prefix_strip_test.go | 12 | ||||
| -rw-r--r-- | internal/lsp/debounce_throttle_test.go | 6 | ||||
| -rw-r--r-- | internal/lsp/handlers_completion.go | 90 | ||||
| -rw-r--r-- | internal/lsp/handlers_document.go | 2 | ||||
| -rw-r--r-- | internal/lsp/server.go | 36 |
11 files changed, 456 insertions, 87 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 06fcb83..b7745c8 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -20,6 +20,8 @@ import ( "codeberg.org/snonux/hexai/internal/logging" "codeberg.org/snonux/hexai/internal/stats" "codeberg.org/snonux/hexai/internal/tmux" + "github.com/mattn/go-runewidth" + "golang.org/x/term" ) type requestArgs struct { @@ -35,6 +37,23 @@ type cliJob struct { req requestArgs } +type columnPrinter struct { + mu sync.Mutex + stdout io.Writer + columns int + colWidth int + partial []string + providers []string + models []string +} + +type columnWriter struct { + printer *columnPrinter + index int +} + +type selectionContextKey struct{} + func buildCLIJobs(cfg appconfig.App) ([]cliJob, error) { entries := cfg.CLIConfigs if len(entries) == 0 { @@ -150,6 +169,13 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) return err } + if selected := selectionFromContext(ctx); len(selected) > 0 { + jobs, err = filterJobsBySelection(jobs, selected) + if err != nil { + fmt.Fprintf(stderr, logging.AnsiBase+"hexai: %v"+logging.AnsiReset+"\n", err) + return err + } + } if len(jobs) == 0 { return fmt.Errorf("hexai: no CLI providers configured") } @@ -203,16 +229,29 @@ type cliJobResult struct { func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error { results := make([]*cliJobResult, len(jobs)) var wg sync.WaitGroup + var printer *columnPrinter + if len(jobs) > 0 { + printer = newColumnPrinter(stdout, jobs) + printer.PrintHeader() + } for _, job := range jobs { job := job wg.Add(1) printProviderInfo(stderr, job.client, job.req.model) go func() { defer wg.Done() - var outBuf, errBuf bytes.Buffer + var errBuf bytes.Buffer + var outBuf bytes.Buffer jobMsgs := make([]llm.Message, len(msgs)) copy(jobMsgs, msgs) - err := runChat(ctx, job.client, job.req, jobMsgs, input, &outBuf, &errBuf) + writer := io.Writer(&outBuf) + if printer != nil { + writer = printer.Writer(job.index) + } + err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf) + if printer != nil { + printer.Flush(job.index) + } results[job.index] = &cliJobResult{ provider: job.client.Name(), model: job.req.model, @@ -224,48 +263,275 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st } wg.Wait() var firstErr error - printed := false - for _, res := range results { - if res == nil { - continue - } - if printed { - if _, err := io.WriteString(stdout, "\n"); err != nil { - return err + if printer == nil { + printed := false + for _, res := range results { + if res == nil { + continue } - } - heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) - if _, err := io.WriteString(stdout, heading); err != nil { - return err - } - if res.output != "" { - if _, err := io.WriteString(stdout, res.output); err != nil { + if printed { + if _, err := io.WriteString(stdout, "\n"); err != nil { + return err + } + } + heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) + if _, err := io.WriteString(stdout, heading); err != nil { return err } - if !strings.HasSuffix(res.output, "\n") { - if _, err := io.WriteString(stdout, "\n"); err != nil { + if res.output != "" { + if _, err := io.WriteString(stdout, res.output); err != nil { return err } + if !strings.HasSuffix(res.output, "\n") { + if _, err := io.WriteString(stdout, "\n"); err != nil { + return err + } + } } + printed = true + } + } + for _, res := range results { + if res == nil { + continue } - printed = true if res.summary != "" { - if _, err := io.WriteString(stderr, res.summary); err != nil { - return err + summary := strings.TrimLeft(res.summary, "\n") + if summary != "" { + if _, err := io.WriteString(stderr, summary); err != nil { + return err + } } } if res.err != nil { if _, err := fmt.Fprintf(stderr, logging.AnsiBase+"hexai: provider=%s model=%s error: %v"+logging.AnsiReset+"\n", res.provider, res.model, res.err); err != nil { return err } - if firstErr == nil { - firstErr = res.err - } + } + if firstErr == nil && res.err != nil { + firstErr = res.err } } return firstErr } +func newColumnPrinter(stdout io.Writer, jobs []cliJob) *columnPrinter { + cols := len(jobs) + width := detectTerminalWidth(stdout) + if width <= 0 { + width = 100 + } + sepWidth := (cols - 1) * 3 + colWidth := (width - sepWidth) / cols + if colWidth < 20 { + colWidth = 20 + } + providers := make([]string, cols) + models := make([]string, cols) + for _, job := range jobs { + providers[job.index] = job.client.Name() + models[job.index] = job.req.model + } + return &columnPrinter{ + stdout: stdout, + columns: cols, + colWidth: colWidth, + partial: make([]string, cols), + providers: providers, + models: models, + } +} + +func detectTerminalWidth(w io.Writer) int { + type fder interface{ Fd() uintptr } + if f, ok := w.(*os.File); ok { + if width, _, err := term.GetSize(int(f.Fd())); err == nil { + return width + } + } + if f, ok := w.(fder); ok { + if width, _, err := term.GetSize(int(f.Fd())); err == nil { + return width + } + } + return 0 +} + +func (cp *columnPrinter) Writer(idx int) io.Writer { + return columnWriter{printer: cp, index: idx} +} + +func (cp *columnPrinter) PrintHeader() { + cp.mu.Lock() + defer cp.mu.Unlock() + combo := make([]string, cp.columns) + for i := 0; i < cp.columns; i++ { + provider := strings.TrimSpace(cp.providers[i]) + model := strings.TrimSpace(cp.models[i]) + switch { + case provider != "" && model != "": + combo[i] = provider + ":" + model + case provider != "": + combo[i] = provider + case model != "": + combo[i] = model + default: + combo[i] = "" + } + } + cp.writeLine(combo) + divider := make([]string, cp.columns) + line := strings.Repeat("─", cp.colWidth) + for i := range divider { + divider[i] = line + } + cp.writeLine(divider) +} + +func (cp *columnPrinter) Flush(idx int) { + cp.mu.Lock() + defer cp.mu.Unlock() + if idx < 0 || idx >= len(cp.partial) { + return + } + if cp.partial[idx] == "" { + return + } + cp.emitJobLine(idx, cp.partial[idx]) + cp.partial[idx] = "" +} + +func (w columnWriter) Write(p []byte) (int, error) { + return w.printer.write(w.index, string(p)) +} + +func (cp *columnPrinter) write(idx int, data string) (int, error) { + cp.mu.Lock() + defer cp.mu.Unlock() + if idx < 0 || idx >= len(cp.partial) { + return len(data), nil + } + data = strings.ReplaceAll(data, "\r", "") + cp.partial[idx] += data + for strings.Contains(cp.partial[idx], "\n") { + line, rest, _ := strings.Cut(cp.partial[idx], "\n") + cp.partial[idx] = rest + cp.emitJobLine(idx, line) + } + return len(data), nil +} + +func (cp *columnPrinter) emitJobLine(idx int, line string) { + segments := cp.wrap(line) + for _, seg := range segments { + cells := make([]string, cp.columns) + if idx >= 0 && idx < len(cells) { + cells[idx] = seg + } + cp.writeLine(cells) + } +} + +func (cp *columnPrinter) wrap(text string) []string { + text = strings.ReplaceAll(text, "\t", " ") + if runewidth.StringWidth(text) <= cp.colWidth { + return []string{text} + } + var lines []string + var current strings.Builder + width := 0 + for _, r := range text { + rw := runewidth.RuneWidth(r) + if width+rw > cp.colWidth && current.Len() > 0 { + lines = append(lines, current.String()) + current.Reset() + width = 0 + } + current.WriteRune(r) + width += rw + } + if current.Len() > 0 { + lines = append(lines, current.String()) + } + if len(lines) == 0 { + lines = append(lines, "") + } + return lines +} + +func (cp *columnPrinter) writeLine(cells []string) { + if len(cells) < cp.columns { + extra := make([]string, cp.columns-len(cells)) + cells = append(cells, extra...) + } + var builder strings.Builder + for i := 0; i < cp.columns; i++ { + cell := cells[i] + width := runewidth.StringWidth(cell) + if width > cp.colWidth { + cell = runewidth.Truncate(cell, cp.colWidth, "…") + width = runewidth.StringWidth(cell) + } + builder.WriteString(cell) + if pad := cp.colWidth - width; pad > 0 { + builder.WriteString(strings.Repeat(" ", pad)) + } + if i != cp.columns-1 { + builder.WriteString(" │ ") + } + } + builder.WriteByte('\n') + _, _ = cp.stdout.Write([]byte(builder.String())) +} + +// WithCLISelection injects provider indices into the context so Run only executes those jobs. +func WithCLISelection(ctx context.Context, indices []int) context.Context { + if ctx == nil { + ctx = context.Background() + } + cpy := make([]int, len(indices)) + copy(cpy, indices) + return context.WithValue(ctx, selectionContextKey{}, cpy) +} + +func selectionFromContext(ctx context.Context) []int { + if ctx == nil { + return nil + } + if v, ok := ctx.Value(selectionContextKey{}).([]int); ok { + cpy := make([]int, len(v)) + copy(cpy, v) + return cpy + } + return nil +} + +func filterJobsBySelection(jobs []cliJob, indices []int) ([]cliJob, error) { + if len(indices) == 0 { + return jobs, nil + } + filtered := make([]cliJob, 0, len(indices)) + seen := make(map[int]struct{}, len(indices)) + for _, idx := range indices { + if idx < 0 || idx >= len(jobs) { + return nil, fmt.Errorf("provider index %d out of range (0-%d)", idx, len(jobs)-1) + } + if _, ok := seen[idx]; ok { + continue + } + clone := jobs[idx] + filtered = append(filtered, clone) + seen[idx] = struct{}{} + } + for i := range filtered { + filtered[i].index = i + } + if len(filtered) == 0 { + return nil, fmt.Errorf("no CLI providers matched selection") + } + return filtered, nil +} + // readInput reads from stdin and args, then combines them per CLI rules. func readInput(stdin io.Reader, args []string) (string, error) { var stdinData string diff --git a/internal/hexaicli/run_model_override_test.go b/internal/hexaicli/run_model_override_test.go index 6394bd1..b32b172 100644 --- a/internal/hexaicli/run_model_override_test.go +++ b/internal/hexaicli/run_model_override_test.go @@ -1,39 +1,45 @@ package hexaicli import ( - "bytes" - "context" - "strings" - "testing" + "bytes" + "context" + "strings" + "testing" - "codeberg.org/snonux/hexai/internal/appconfig" - "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/llm" ) type fakeClientModelEnv struct{ name, model string } -func (f fakeClientModelEnv) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { return "ok", nil } + +func (f fakeClientModelEnv) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { + return "ok", nil +} func (f fakeClientModelEnv) Name() string { return f.name } func (f fakeClientModelEnv) DefaultModel() string { return f.model } // Ensure that HEXAI_MODEL overrides config for CLI runs. func TestRun_ModelEnvOverride_FlowsIntoClient(t *testing.T) { - t.Setenv("HEXAI_MODEL", "gpt-5-codex") - t.Setenv("HEXAI_PROVIDER", "openai") - // Replace client constructor to assert model was overridden - oldNew := newClientFromApp - defer func() { newClientFromApp = oldNew }() + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("HEXAI_MODEL", "gpt-5-codex") + t.Setenv("HEXAI_PROVIDER", "openai") + // Replace client constructor to assert model was overridden + oldNew := newClientFromApp + defer func() { newClientFromApp = oldNew }() + var seenModel string newClientFromApp = func(cfg appconfig.App) (llm.Client, error) { - if strings.TrimSpace(cfg.OpenAIModel) != "gpt-5-codex" { - t.Fatalf("expected cfg.OpenAIModel=gpt-5-codex, got %q", cfg.OpenAIModel) - } - return fakeClientModelEnv{name: "openai", model: cfg.OpenAIModel}, nil - } + seenModel = strings.TrimSpace(cfg.OpenAIModel) + return fakeClientModelEnv{name: "openai", model: cfg.OpenAIModel}, nil + } - var out, errb bytes.Buffer - if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb); err != nil { - t.Fatalf("run error: %v", err) - } - if !strings.Contains(errb.String(), "model=gpt-5-codex") { - t.Fatalf("stderr should print effective model, got: %s", errb.String()) - } + var out, errb bytes.Buffer + if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb); err != nil { + t.Fatalf("run error: %v", err) + } + if seenModel != "gpt-5-codex" { + t.Fatalf("expected cfg.OpenAIModel=gpt-5-codex, got %q", seenModel) + } + if !strings.Contains(errb.String(), "model=gpt-5-codex") { + t.Fatalf("stderr should print effective model, got: %s", errb.String()) + } } diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go index f11545e..dfde068 100644 --- a/internal/hexaicli/run_test.go +++ b/internal/hexaicli/run_test.go @@ -225,6 +225,23 @@ func TestBuildCLIJobs_MultiEntries(t *testing.T) { } } +func TestFilterJobsBySelection(t *testing.T) { + jobs := []cliJob{{index: 0, provider: "openai"}, {index: 1, provider: "ollama"}, {index: 2, provider: "copilot"}} + filtered, err := filterJobsBySelection(jobs, []int{2, 0}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(filtered) != 2 || filtered[0].provider != "copilot" || filtered[1].provider != "openai" { + t.Fatalf("unexpected filtered order: %+v", filtered) + } + if filtered[0].index != 0 || filtered[1].index != 1 { + t.Fatalf("expected reindexed jobs, got %+v", filtered) + } + if _, err := filterJobsBySelection(jobs, []int{5}); err == nil { + t.Fatalf("expected out-of-range error") + } +} + func TestNewClientFromConfig_Ollama(t *testing.T) { cfg := appconfig.App{Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m"} c, err := newClientFromConfig(cfg) diff --git a/internal/lsp/chat_trigger_suppression_test.go b/internal/lsp/chat_trigger_suppression_test.go index 9f9f5bc..852f955 100644 --- a/internal/lsp/chat_trigger_suppression_test.go +++ b/internal/lsp/chat_trigger_suppression_test.go @@ -13,7 +13,7 @@ func TestCompletionSuppressedOnChatTriggerEOL(t *testing.T) { tests := []string{"What now?>", "Explain!>", "Refactor:>", "note ;>"} for i, line := range tests { p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://chat-suppr.go"}} - items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok { t.Fatalf("case %d: expected ok=true", i) } diff --git a/internal/lsp/completion_cache_test.go b/internal/lsp/completion_cache_test.go index 057b5c5..ff85906 100644 --- a/internal/lsp/completion_cache_test.go +++ b/internal/lsp/completion_cache_test.go @@ -25,7 +25,7 @@ func TestCompletionCache_IgnoresWhitespaceBeforeCursor(t *testing.T) { // First request with trailing spaces before cursor line := "foo " p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://x.go"}} - items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok || len(items) == 0 || fake.calls != 1 { t.Fatalf("expected first call to invoke LLM; ok=%v len=%d calls=%d", ok, len(items), fake.calls) } @@ -33,7 +33,7 @@ func TestCompletionCache_IgnoresWhitespaceBeforeCursor(t *testing.T) { // Same logical context but with a different amount of trailing whitespace line2 := "foo " p2 := CompletionParams{Position: Position{Line: 0, Character: len(line2)}, TextDocument: TextDocumentIdentifier{URI: "file://x.go"}} - items2, ok2 := s.tryLLMCompletion(p2, "", line2, "", "", "", false, "") + items2, ok2, _ := s.tryLLMCompletion(p2, "", line2, "", "", "", false, "") if !ok2 || len(items2) == 0 { t.Fatalf("expected cache hit to still return items") } diff --git a/internal/lsp/completion_codex_path_test.go b/internal/lsp/completion_codex_path_test.go index ea27c6e..6ee8c97 100644 --- a/internal/lsp/completion_codex_path_test.go +++ b/internal/lsp/completion_codex_path_test.go @@ -48,7 +48,7 @@ func TestTryLLMCompletion_PrefersCodeCompleterOverChat(t *testing.T) { s.llmClient = fake line := "obj." p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://x.go"}} - items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok || len(items) == 0 { t.Fatalf("expected completion items via CodeCompleter path") } @@ -70,7 +70,7 @@ func TestTryLLMCompletion_FallsBackToChatOnCodeCompleterError(t *testing.T) { s.llmClient = fake line := "obj." p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://y.go"}} - items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok { t.Fatalf("expected ok=true even on fallback path") } diff --git a/internal/lsp/completion_prefix_strip_test.go b/internal/lsp/completion_prefix_strip_test.go index 6173d6f..e0c655c 100644 --- a/internal/lsp/completion_prefix_strip_test.go +++ b/internal/lsp/completion_prefix_strip_test.go @@ -52,7 +52,7 @@ func TestTryLLMCompletion_ManualInvokeAfterWhitespace_Allows(t *testing.T) { p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://x.go"}} // Simulate manual user invocation (TriggerKind=1) p.Context = json.RawMessage([]byte(`{"triggerKind":1}`)) - items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok { t.Fatalf("expected ok=true for manual invoke after whitespace") } @@ -72,7 +72,7 @@ func TestTryLLMCompletion_InlinePromptAlwaysTriggers(t *testing.T) { line := "prefix >do something> suffix" // No trigger char immediately before cursor; place cursor at end p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://inline.go"}} - items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok || len(items) == 0 { t.Fatalf("expected completion to trigger on inline >text> prompt") } @@ -89,7 +89,7 @@ func TestTryLLMCompletion_DoubleOpenEmpty_DoesNotAutoTrigger(t *testing.T) { s.llmClient = fake line := ">> " // empty content after double-open should not force-trigger p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://empty-inline.go"}} - items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok { t.Fatalf("expected ok=true for non-trigger path") } @@ -128,7 +128,7 @@ func TestBareDoubleOpenPreventsAutoTriggerEvenWithOtherTriggers(t *testing.T) { // Place a '.' earlier but also include bare double-open at end; should not auto-trigger line := "obj. call >>" p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://bare-ds.go"}} - items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok { t.Fatalf("expected ok=true (handled), but not auto-triggering") } @@ -152,7 +152,7 @@ func TestBareDoubleOpenOnNextLine_PreventsAutoTrigger(t *testing.T) { current := "expression := flag.String(\"expression\", \"\", \"Expression to evaluate\")" below := ">>" p := CompletionParams{Position: Position{Line: 0, Character: len(current)}, TextDocument: TextDocumentIdentifier{URI: "file://nextline.go"}} - items, ok := s.tryLLMCompletion(p, "", current, below, "", "", false, "") + items, ok, _ := s.tryLLMCompletion(p, "", current, below, "", "", false, "") if !ok { t.Fatalf("expected ok=true handled") } @@ -177,7 +177,7 @@ func TestBareDoubleOpenPreventsManualInvoke(t *testing.T) { p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://bare-ds-manual.go"}} // Simulate manual invoke p.Context = json.RawMessage([]byte(`{"triggerKind":1}`)) - items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok { t.Fatalf("expected ok=true (handled)") } diff --git a/internal/lsp/debounce_throttle_test.go b/internal/lsp/debounce_throttle_test.go index 81a2c1a..7efd439 100644 --- a/internal/lsp/debounce_throttle_test.go +++ b/internal/lsp/debounce_throttle_test.go @@ -37,7 +37,7 @@ func TestCompletionDebounce_WaitsUntilQuiet(t *testing.T) { p.Context = json.RawMessage([]byte(`{"triggerKind":1}`)) start := time.Now() - _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + _, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok { t.Fatalf("expected ok=true") } @@ -65,7 +65,7 @@ func TestCompletionThrottle_SerializesCalls(t *testing.T) { p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://throttle.go"}} p.Context = json.RawMessage([]byte(`{"triggerKind":1}`)) start := time.Now() - if _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok { + if _, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok { t.Fatalf("first call expected ok=true") } if f1.t.IsZero() { @@ -77,7 +77,7 @@ func TestCompletionThrottle_SerializesCalls(t *testing.T) { s.compCache = make(map[string]string) f2 := &timeLLM{} s.llmClient = f2 - if _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok { + if _, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok { t.Fatalf("second call expected ok=true") } if f2.t.IsZero() { diff --git a/internal/lsp/handlers_completion.go b/internal/lsp/handlers_completion.go index 237d34d..78e685a 100644 --- a/internal/lsp/handlers_completion.go +++ b/internal/lsp/handlers_completion.go @@ -45,9 +45,9 @@ func (s *Server) handleCompletion(req Request) { if s.llmClient != nil { newFunc := s.isDefiningNewFunction(p.TextDocument.URI, p.Position) extra, has := s.buildAdditionalContext(newFunc, p.TextDocument.URI, p.Position) - items, ok := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, has, extra) + items, ok, incomplete := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, has, extra) if ok { - s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) + s.reply(req.ID, CompletionList{IsIncomplete: incomplete, Items: items}, nil) return } } @@ -87,28 +87,33 @@ func (s *Server) logCompletionContext(p CompletionParams, above, current, below, p.TextDocument.URI, p.Position.Line, p.Position.Character, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) } -func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) ([]CompletionItem, bool) { +func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) ([]CompletionItem, bool, bool) { ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) - defer cancel() + var cancelOnce sync.Once + end := func() { cancelOnce.Do(cancel) } plan, items, handled := s.prepareCompletionPlan(p, above, current, below, funcCtx, docStr, hasExtra, extraText) if handled { - return items, true + end() + return items, true, false } specs := s.buildRequestSpecs(surfaceCompletion) if len(specs) == 0 { - return nil, false + end() + return nil, false, false } type jobResult struct { items []CompletionItem ok bool } - results := make([]jobResult, len(specs)) + results := make(chan jobResult, len(specs)) var wg sync.WaitGroup - var mu sync.Mutex + started := 0 s.waitForDebounce(ctx) if !s.waitForThrottle(ctx) { - return nil, false + end() + close(results) + return nil, false, false } for _, spec := range specs { spec := spec @@ -116,27 +121,67 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun if client == nil { continue } + started++ wg.Add(1) go func(idx int, spec requestSpec, client llm.Client) { defer wg.Done() items, ok := s.runCompletionForSpec(ctx, plan, spec, client) - mu.Lock() - results[idx] = jobResult{items: items, ok: ok} - mu.Unlock() + results <- jobResult{items: items, ok: ok} }(spec.index, spec, client) } - wg.Wait() - accumulated := make([]CompletionItem, 0) - for _, res := range results { - if !res.ok { - continue + + if started == 0 { + end() + close(results) + return nil, false, false + } + + go func() { + wg.Wait() + close(results) + }() + + if started == 1 { + res := <-results + if !res.ok || len(res.items) == 0 { + end() + return nil, false, false } - accumulated = append(accumulated, res.items...) + end() + return res.items, true, false } - if len(accumulated) == 0 { - return nil, false + + firstCh := make(chan []CompletionItem, 1) + go func(planKey string) { + defer end() + combined := make([]CompletionItem, 0) + firstSent := false + for res := range results { + if !res.ok || len(res.items) == 0 { + continue + } + combined = append(combined, res.items...) + if !firstSent { + first := make([]CompletionItem, len(res.items)) + copy(first, res.items) + firstCh <- first + firstSent = true + } + } + if !firstSent { + close(firstCh) + return + } + s.storePendingCompletion(planKey, combined) + close(firstCh) + }(plan.cacheKey) + + firstItems, ok := <-firstCh + if !ok || len(firstItems) == 0 { + end() + return nil, false, false } - return accumulated, true + return firstItems, true, true } func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) (completionPlan, []CompletionItem, bool) { @@ -162,6 +207,9 @@ func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below plan.inParams = inParamList(current, p.Position.Character) plan.manualInvoke = parseManualInvoke(p.Context) plan.cacheKey = s.completionCacheKey(p, above, current, below, funcCtx, plan.inParams, hasExtra, extraText) + if pending := s.takePendingCompletion(plan.cacheKey); len(pending) > 0 { + return plan, pending, true + } if isBareDoubleOpen(current, openChar, closeChar) || isBareDoubleOpen(below, openChar, closeChar) { logging.Logf("lsp ", "%scompletion skip=empty-double-semicolon line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) return plan, []CompletionItem{}, true diff --git a/internal/lsp/handlers_document.go b/internal/lsp/handlers_document.go index 9325877..da7db51 100644 --- a/internal/lsp/handlers_document.go +++ b/internal/lsp/handlers_document.go @@ -231,7 +231,7 @@ func (s *Server) runInlinePrompt(uri string, pos Position) { docStr := s.buildDocString(p, above, current, below, funcCtx) newFunc := s.isDefiningNewFunction(uri, p.Position) extra, hasExtra := s.buildAdditionalContext(newFunc, uri, p.Position) - items, ok := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, hasExtra, extra) + items, ok, _ := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, hasExtra, extra) if !ok || len(items) == 0 { return } diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 1fbb0cc..f8b328b 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -40,8 +40,9 @@ type Server struct { llmRespBytesTotal int64 startTime time.Time // Small LRU cache for recent code completion outputs (keyed by context) - compCache map[string]string - compCacheOrder []string // most-recent at end; cap ~10 + compCache map[string]string + compCacheOrder []string // most-recent at end; cap ~10 + pendingCompletions map[string][]CompletionItem // Outgoing JSON-RPC id counter for server-initiated requests nextID int64 lastLLMCall time.Time @@ -112,6 +113,7 @@ func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions) s := &Server{in: bufio.NewReader(r), out: w, logger: logger, docs: make(map[string]*document), logContext: opts.LogContext, configStore: opts.ConfigStore} s.startTime = time.Now() s.compCache = make(map[string]string) + s.pendingCompletions = make(map[string][]CompletionItem) s.applyOptions(opts) // Initialize dispatch table s.handlers = map[string]func(Request){ @@ -315,6 +317,36 @@ func (s *Server) currentConfig() appconfig.App { return s.cfg } +func (s *Server) storePendingCompletion(key string, items []CompletionItem) { + if len(items) == 0 { + return + } + cpy := make([]CompletionItem, len(items)) + copy(cpy, items) + s.mu.Lock() + if s.pendingCompletions == nil { + s.pendingCompletions = make(map[string][]CompletionItem) + } + s.pendingCompletions[key] = cpy + s.mu.Unlock() +} + +func (s *Server) takePendingCompletion(key string) []CompletionItem { + s.mu.Lock() + defer s.mu.Unlock() + if len(s.pendingCompletions) == 0 { + return nil + } + items, ok := s.pendingCompletions[key] + if !ok { + return nil + } + delete(s.pendingCompletions, key) + cpy := make([]CompletionItem, len(items)) + copy(cpy, items) + return cpy +} + func (s *Server) maxTokens() int { cfg := s.currentConfig() if cfg.MaxTokens <= 0 { |
