summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2025-09-28 00:20:05 +0300
committerPaul Buetow <paul@buetow.org>2025-09-28 00:20:05 +0300
commit0ac2d186e84f77d73d924e2c0ce975a17c3a8078 (patch)
tree49f3e2def38449544e1d67f047cbcb4aab802658
parent51b2621d58633aa5c0f5cc7b64616d70d41acc91 (diff)
Improve multi-provider completion streaming and CLI selector flags
-rw-r--r--cmd/hexai/main.go54
-rw-r--r--internal/hexaicli/run.go316
-rw-r--r--internal/hexaicli/run_model_override_test.go54
-rw-r--r--internal/hexaicli/run_test.go17
-rw-r--r--internal/lsp/chat_trigger_suppression_test.go2
-rw-r--r--internal/lsp/completion_cache_test.go4
-rw-r--r--internal/lsp/completion_codex_path_test.go4
-rw-r--r--internal/lsp/completion_prefix_strip_test.go12
-rw-r--r--internal/lsp/debounce_throttle_test.go6
-rw-r--r--internal/lsp/handlers_completion.go90
-rw-r--r--internal/lsp/handlers_document.go2
-rw-r--r--internal/lsp/server.go36
12 files changed, 506 insertions, 91 deletions
diff --git a/cmd/hexai/main.go b/cmd/hexai/main.go
index 00fa4a3..4c4fbd2 100644
--- a/cmd/hexai/main.go
+++ b/cmd/hexai/main.go
@@ -5,21 +5,67 @@ import (
"context"
"flag"
"fmt"
+ "io"
+ "log"
"os"
+ "strconv"
+ "strings"
"codeberg.org/snonux/hexai/internal"
+ "codeberg.org/snonux/hexai/internal/appconfig"
"codeberg.org/snonux/hexai/internal/hexaicli"
)
func main() {
- showVersion := flag.Bool("version", false, "print version and exit")
- flag.Parse()
+ logger := log.New(io.Discard, "", 0)
+ cfg := appconfig.Load(logger)
+ cliEntries := cfg.CLIConfigs
+ if len(cliEntries) == 0 {
+ cliEntries = []appconfig.SurfaceConfig{{Provider: cfg.Provider}}
+ }
+ fs := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
+ showVersion := fs.Bool("version", false, "print version and exit")
+ selectedFlags := make([]bool, len(cliEntries))
+ for i, entry := range cliEntries {
+ name := strconv.Itoa(i)
+ provider := strings.TrimSpace(entry.Provider)
+ if provider == "" {
+ provider = cfg.Provider
+ }
+ model := strings.TrimSpace(entry.Model)
+ if model == "" {
+ model = pickDefaultModel(cfg, provider)
+ }
+ desc := fmt.Sprintf("use only provider #%d (%s:%s)", i, provider, model)
+ fs.BoolVar(&selectedFlags[i], name, false, desc)
+ }
+ _ = fs.Parse(os.Args[1:])
if *showVersion {
fmt.Fprintln(os.Stdout, internal.Version)
return
}
-
- if err := hexaicli.Run(context.Background(), flag.Args(), os.Stdin, os.Stdout, os.Stderr); err != nil {
+ var selection []int
+ for i, sel := range selectedFlags {
+ if sel {
+ selection = append(selection, i)
+ }
+ }
+ ctx := context.Background()
+ if len(selection) > 0 {
+ ctx = hexaicli.WithCLISelection(ctx, selection)
+ }
+ if err := hexaicli.Run(ctx, fs.Args(), os.Stdin, os.Stdout, os.Stderr); err != nil {
os.Exit(1)
}
}
+
+func pickDefaultModel(cfg appconfig.App, provider string) string {
+ switch strings.ToLower(strings.TrimSpace(provider)) {
+ case "ollama":
+ return strings.TrimSpace(cfg.OllamaModel)
+ case "copilot":
+ return strings.TrimSpace(cfg.CopilotModel)
+ default:
+ return strings.TrimSpace(cfg.OpenAIModel)
+ }
+}
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index 06fcb83..b7745c8 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -20,6 +20,8 @@ import (
"codeberg.org/snonux/hexai/internal/logging"
"codeberg.org/snonux/hexai/internal/stats"
"codeberg.org/snonux/hexai/internal/tmux"
+ "github.com/mattn/go-runewidth"
+ "golang.org/x/term"
)
type requestArgs struct {
@@ -35,6 +37,23 @@ type cliJob struct {
req requestArgs
}
+type columnPrinter struct {
+ mu sync.Mutex
+ stdout io.Writer
+ columns int
+ colWidth int
+ partial []string
+ providers []string
+ models []string
+}
+
+type columnWriter struct {
+ printer *columnPrinter
+ index int
+}
+
+type selectionContextKey struct{}
+
func buildCLIJobs(cfg appconfig.App) ([]cliJob, error) {
entries := cfg.CLIConfigs
if len(entries) == 0 {
@@ -150,6 +169,13 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.
fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err)
return err
}
+ if selected := selectionFromContext(ctx); len(selected) > 0 {
+ jobs, err = filterJobsBySelection(jobs, selected)
+ if err != nil {
+ fmt.Fprintf(stderr, logging.AnsiBase+"hexai: %v"+logging.AnsiReset+"\n", err)
+ return err
+ }
+ }
if len(jobs) == 0 {
return fmt.Errorf("hexai: no CLI providers configured")
}
@@ -203,16 +229,29 @@ type cliJobResult struct {
func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error {
results := make([]*cliJobResult, len(jobs))
var wg sync.WaitGroup
+ var printer *columnPrinter
+ if len(jobs) > 0 {
+ printer = newColumnPrinter(stdout, jobs)
+ printer.PrintHeader()
+ }
for _, job := range jobs {
job := job
wg.Add(1)
printProviderInfo(stderr, job.client, job.req.model)
go func() {
defer wg.Done()
- var outBuf, errBuf bytes.Buffer
+ var errBuf bytes.Buffer
+ var outBuf bytes.Buffer
jobMsgs := make([]llm.Message, len(msgs))
copy(jobMsgs, msgs)
- err := runChat(ctx, job.client, job.req, jobMsgs, input, &outBuf, &errBuf)
+ writer := io.Writer(&outBuf)
+ if printer != nil {
+ writer = printer.Writer(job.index)
+ }
+ err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf)
+ if printer != nil {
+ printer.Flush(job.index)
+ }
results[job.index] = &cliJobResult{
provider: job.client.Name(),
model: job.req.model,
@@ -224,48 +263,275 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st
}
wg.Wait()
var firstErr error
- printed := false
- for _, res := range results {
- if res == nil {
- continue
- }
- if printed {
- if _, err := io.WriteString(stdout, "\n"); err != nil {
- return err
+ if printer == nil {
+ printed := false
+ for _, res := range results {
+ if res == nil {
+ continue
}
- }
- heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model)
- if _, err := io.WriteString(stdout, heading); err != nil {
- return err
- }
- if res.output != "" {
- if _, err := io.WriteString(stdout, res.output); err != nil {
+ if printed {
+ if _, err := io.WriteString(stdout, "\n"); err != nil {
+ return err
+ }
+ }
+ heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model)
+ if _, err := io.WriteString(stdout, heading); err != nil {
return err
}
- if !strings.HasSuffix(res.output, "\n") {
- if _, err := io.WriteString(stdout, "\n"); err != nil {
+ if res.output != "" {
+ if _, err := io.WriteString(stdout, res.output); err != nil {
return err
}
+ if !strings.HasSuffix(res.output, "\n") {
+ if _, err := io.WriteString(stdout, "\n"); err != nil {
+ return err
+ }
+ }
}
+ printed = true
+ }
+ }
+ for _, res := range results {
+ if res == nil {
+ continue
}
- printed = true
if res.summary != "" {
- if _, err := io.WriteString(stderr, res.summary); err != nil {
- return err
+ summary := strings.TrimLeft(res.summary, "\n")
+ if summary != "" {
+ if _, err := io.WriteString(stderr, summary); err != nil {
+ return err
+ }
}
}
if res.err != nil {
if _, err := fmt.Fprintf(stderr, logging.AnsiBase+"hexai: provider=%s model=%s error: %v"+logging.AnsiReset+"\n", res.provider, res.model, res.err); err != nil {
return err
}
- if firstErr == nil {
- firstErr = res.err
- }
+ }
+ if firstErr == nil && res.err != nil {
+ firstErr = res.err
}
}
return firstErr
}
+func newColumnPrinter(stdout io.Writer, jobs []cliJob) *columnPrinter {
+ cols := len(jobs)
+ width := detectTerminalWidth(stdout)
+ if width <= 0 {
+ width = 100
+ }
+ sepWidth := (cols - 1) * 3
+ colWidth := (width - sepWidth) / cols
+ if colWidth < 20 {
+ colWidth = 20
+ }
+ providers := make([]string, cols)
+ models := make([]string, cols)
+ for _, job := range jobs {
+ providers[job.index] = job.client.Name()
+ models[job.index] = job.req.model
+ }
+ return &columnPrinter{
+ stdout: stdout,
+ columns: cols,
+ colWidth: colWidth,
+ partial: make([]string, cols),
+ providers: providers,
+ models: models,
+ }
+}
+
+func detectTerminalWidth(w io.Writer) int {
+ type fder interface{ Fd() uintptr }
+ if f, ok := w.(*os.File); ok {
+ if width, _, err := term.GetSize(int(f.Fd())); err == nil {
+ return width
+ }
+ }
+ if f, ok := w.(fder); ok {
+ if width, _, err := term.GetSize(int(f.Fd())); err == nil {
+ return width
+ }
+ }
+ return 0
+}
+
+func (cp *columnPrinter) Writer(idx int) io.Writer {
+ return columnWriter{printer: cp, index: idx}
+}
+
+func (cp *columnPrinter) PrintHeader() {
+ cp.mu.Lock()
+ defer cp.mu.Unlock()
+ combo := make([]string, cp.columns)
+ for i := 0; i < cp.columns; i++ {
+ provider := strings.TrimSpace(cp.providers[i])
+ model := strings.TrimSpace(cp.models[i])
+ switch {
+ case provider != "" && model != "":
+ combo[i] = provider + ":" + model
+ case provider != "":
+ combo[i] = provider
+ case model != "":
+ combo[i] = model
+ default:
+ combo[i] = ""
+ }
+ }
+ cp.writeLine(combo)
+ divider := make([]string, cp.columns)
+ line := strings.Repeat("─", cp.colWidth)
+ for i := range divider {
+ divider[i] = line
+ }
+ cp.writeLine(divider)
+}
+
+func (cp *columnPrinter) Flush(idx int) {
+ cp.mu.Lock()
+ defer cp.mu.Unlock()
+ if idx < 0 || idx >= len(cp.partial) {
+ return
+ }
+ if cp.partial[idx] == "" {
+ return
+ }
+ cp.emitJobLine(idx, cp.partial[idx])
+ cp.partial[idx] = ""
+}
+
+func (w columnWriter) Write(p []byte) (int, error) {
+ return w.printer.write(w.index, string(p))
+}
+
+func (cp *columnPrinter) write(idx int, data string) (int, error) {
+ cp.mu.Lock()
+ defer cp.mu.Unlock()
+ if idx < 0 || idx >= len(cp.partial) {
+ return len(data), nil
+ }
+ data = strings.ReplaceAll(data, "\r", "")
+ cp.partial[idx] += data
+ for strings.Contains(cp.partial[idx], "\n") {
+ line, rest, _ := strings.Cut(cp.partial[idx], "\n")
+ cp.partial[idx] = rest
+ cp.emitJobLine(idx, line)
+ }
+ return len(data), nil
+}
+
+func (cp *columnPrinter) emitJobLine(idx int, line string) {
+ segments := cp.wrap(line)
+ for _, seg := range segments {
+ cells := make([]string, cp.columns)
+ if idx >= 0 && idx < len(cells) {
+ cells[idx] = seg
+ }
+ cp.writeLine(cells)
+ }
+}
+
+func (cp *columnPrinter) wrap(text string) []string {
+ text = strings.ReplaceAll(text, "\t", " ")
+ if runewidth.StringWidth(text) <= cp.colWidth {
+ return []string{text}
+ }
+ var lines []string
+ var current strings.Builder
+ width := 0
+ for _, r := range text {
+ rw := runewidth.RuneWidth(r)
+ if width+rw > cp.colWidth && current.Len() > 0 {
+ lines = append(lines, current.String())
+ current.Reset()
+ width = 0
+ }
+ current.WriteRune(r)
+ width += rw
+ }
+ if current.Len() > 0 {
+ lines = append(lines, current.String())
+ }
+ if len(lines) == 0 {
+ lines = append(lines, "")
+ }
+ return lines
+}
+
+func (cp *columnPrinter) writeLine(cells []string) {
+ if len(cells) < cp.columns {
+ extra := make([]string, cp.columns-len(cells))
+ cells = append(cells, extra...)
+ }
+ var builder strings.Builder
+ for i := 0; i < cp.columns; i++ {
+ cell := cells[i]
+ width := runewidth.StringWidth(cell)
+ if width > cp.colWidth {
+ cell = runewidth.Truncate(cell, cp.colWidth, "…")
+ width = runewidth.StringWidth(cell)
+ }
+ builder.WriteString(cell)
+ if pad := cp.colWidth - width; pad > 0 {
+ builder.WriteString(strings.Repeat(" ", pad))
+ }
+ if i != cp.columns-1 {
+ builder.WriteString(" │ ")
+ }
+ }
+ builder.WriteByte('\n')
+ _, _ = cp.stdout.Write([]byte(builder.String()))
+}
+
+// WithCLISelection injects provider indices into the context so Run only executes those jobs.
+func WithCLISelection(ctx context.Context, indices []int) context.Context {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ cpy := make([]int, len(indices))
+ copy(cpy, indices)
+ return context.WithValue(ctx, selectionContextKey{}, cpy)
+}
+
+func selectionFromContext(ctx context.Context) []int {
+ if ctx == nil {
+ return nil
+ }
+ if v, ok := ctx.Value(selectionContextKey{}).([]int); ok {
+ cpy := make([]int, len(v))
+ copy(cpy, v)
+ return cpy
+ }
+ return nil
+}
+
+func filterJobsBySelection(jobs []cliJob, indices []int) ([]cliJob, error) {
+ if len(indices) == 0 {
+ return jobs, nil
+ }
+ filtered := make([]cliJob, 0, len(indices))
+ seen := make(map[int]struct{}, len(indices))
+ for _, idx := range indices {
+ if idx < 0 || idx >= len(jobs) {
+ return nil, fmt.Errorf("provider index %d out of range (0-%d)", idx, len(jobs)-1)
+ }
+ if _, ok := seen[idx]; ok {
+ continue
+ }
+ clone := jobs[idx]
+ filtered = append(filtered, clone)
+ seen[idx] = struct{}{}
+ }
+ for i := range filtered {
+ filtered[i].index = i
+ }
+ if len(filtered) == 0 {
+ return nil, fmt.Errorf("no CLI providers matched selection")
+ }
+ return filtered, nil
+}
+
// readInput reads from stdin and args, then combines them per CLI rules.
func readInput(stdin io.Reader, args []string) (string, error) {
var stdinData string
diff --git a/internal/hexaicli/run_model_override_test.go b/internal/hexaicli/run_model_override_test.go
index 6394bd1..b32b172 100644
--- a/internal/hexaicli/run_model_override_test.go
+++ b/internal/hexaicli/run_model_override_test.go
@@ -1,39 +1,45 @@
package hexaicli
import (
- "bytes"
- "context"
- "strings"
- "testing"
+ "bytes"
+ "context"
+ "strings"
+ "testing"
- "codeberg.org/snonux/hexai/internal/appconfig"
- "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/appconfig"
+ "codeberg.org/snonux/hexai/internal/llm"
)
type fakeClientModelEnv struct{ name, model string }
-func (f fakeClientModelEnv) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { return "ok", nil }
+
+func (f fakeClientModelEnv) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) {
+ return "ok", nil
+}
func (f fakeClientModelEnv) Name() string { return f.name }
func (f fakeClientModelEnv) DefaultModel() string { return f.model }
// Ensure that HEXAI_MODEL overrides config for CLI runs.
func TestRun_ModelEnvOverride_FlowsIntoClient(t *testing.T) {
- t.Setenv("HEXAI_MODEL", "gpt-5-codex")
- t.Setenv("HEXAI_PROVIDER", "openai")
- // Replace client constructor to assert model was overridden
- oldNew := newClientFromApp
- defer func() { newClientFromApp = oldNew }()
+ t.Setenv("XDG_CONFIG_HOME", t.TempDir())
+ t.Setenv("HEXAI_MODEL", "gpt-5-codex")
+ t.Setenv("HEXAI_PROVIDER", "openai")
+ // Replace client constructor to assert model was overridden
+ oldNew := newClientFromApp
+ defer func() { newClientFromApp = oldNew }()
+ var seenModel string
newClientFromApp = func(cfg appconfig.App) (llm.Client, error) {
- if strings.TrimSpace(cfg.OpenAIModel) != "gpt-5-codex" {
- t.Fatalf("expected cfg.OpenAIModel=gpt-5-codex, got %q", cfg.OpenAIModel)
- }
- return fakeClientModelEnv{name: "openai", model: cfg.OpenAIModel}, nil
- }
+ seenModel = strings.TrimSpace(cfg.OpenAIModel)
+ return fakeClientModelEnv{name: "openai", model: cfg.OpenAIModel}, nil
+ }
- var out, errb bytes.Buffer
- if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb); err != nil {
- t.Fatalf("run error: %v", err)
- }
- if !strings.Contains(errb.String(), "model=gpt-5-codex") {
- t.Fatalf("stderr should print effective model, got: %s", errb.String())
- }
+ var out, errb bytes.Buffer
+ if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb); err != nil {
+ t.Fatalf("run error: %v", err)
+ }
+ if seenModel != "gpt-5-codex" {
+ t.Fatalf("expected cfg.OpenAIModel=gpt-5-codex, got %q", seenModel)
+ }
+ if !strings.Contains(errb.String(), "model=gpt-5-codex") {
+ t.Fatalf("stderr should print effective model, got: %s", errb.String())
+ }
}
diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go
index f11545e..dfde068 100644
--- a/internal/hexaicli/run_test.go
+++ b/internal/hexaicli/run_test.go
@@ -225,6 +225,23 @@ func TestBuildCLIJobs_MultiEntries(t *testing.T) {
}
}
+func TestFilterJobsBySelection(t *testing.T) {
+ jobs := []cliJob{{index: 0, provider: "openai"}, {index: 1, provider: "ollama"}, {index: 2, provider: "copilot"}}
+ filtered, err := filterJobsBySelection(jobs, []int{2, 0})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(filtered) != 2 || filtered[0].provider != "copilot" || filtered[1].provider != "openai" {
+ t.Fatalf("unexpected filtered order: %+v", filtered)
+ }
+ if filtered[0].index != 0 || filtered[1].index != 1 {
+ t.Fatalf("expected reindexed jobs, got %+v", filtered)
+ }
+ if _, err := filterJobsBySelection(jobs, []int{5}); err == nil {
+ t.Fatalf("expected out-of-range error")
+ }
+}
+
func TestNewClientFromConfig_Ollama(t *testing.T) {
cfg := appconfig.App{Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m"}
c, err := newClientFromConfig(cfg)
diff --git a/internal/lsp/chat_trigger_suppression_test.go b/internal/lsp/chat_trigger_suppression_test.go
index 9f9f5bc..852f955 100644
--- a/internal/lsp/chat_trigger_suppression_test.go
+++ b/internal/lsp/chat_trigger_suppression_test.go
@@ -13,7 +13,7 @@ func TestCompletionSuppressedOnChatTriggerEOL(t *testing.T) {
tests := []string{"What now?>", "Explain!>", "Refactor:>", "note ;>"}
for i, line := range tests {
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://chat-suppr.go"}}
- items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
if !ok {
t.Fatalf("case %d: expected ok=true", i)
}
diff --git a/internal/lsp/completion_cache_test.go b/internal/lsp/completion_cache_test.go
index 057b5c5..ff85906 100644
--- a/internal/lsp/completion_cache_test.go
+++ b/internal/lsp/completion_cache_test.go
@@ -25,7 +25,7 @@ func TestCompletionCache_IgnoresWhitespaceBeforeCursor(t *testing.T) {
// First request with trailing spaces before cursor
line := "foo "
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://x.go"}}
- items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
if !ok || len(items) == 0 || fake.calls != 1 {
t.Fatalf("expected first call to invoke LLM; ok=%v len=%d calls=%d", ok, len(items), fake.calls)
}
@@ -33,7 +33,7 @@ func TestCompletionCache_IgnoresWhitespaceBeforeCursor(t *testing.T) {
// Same logical context but with a different amount of trailing whitespace
line2 := "foo "
p2 := CompletionParams{Position: Position{Line: 0, Character: len(line2)}, TextDocument: TextDocumentIdentifier{URI: "file://x.go"}}
- items2, ok2 := s.tryLLMCompletion(p2, "", line2, "", "", "", false, "")
+ items2, ok2, _ := s.tryLLMCompletion(p2, "", line2, "", "", "", false, "")
if !ok2 || len(items2) == 0 {
t.Fatalf("expected cache hit to still return items")
}
diff --git a/internal/lsp/completion_codex_path_test.go b/internal/lsp/completion_codex_path_test.go
index ea27c6e..6ee8c97 100644
--- a/internal/lsp/completion_codex_path_test.go
+++ b/internal/lsp/completion_codex_path_test.go
@@ -48,7 +48,7 @@ func TestTryLLMCompletion_PrefersCodeCompleterOverChat(t *testing.T) {
s.llmClient = fake
line := "obj."
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://x.go"}}
- items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
if !ok || len(items) == 0 {
t.Fatalf("expected completion items via CodeCompleter path")
}
@@ -70,7 +70,7 @@ func TestTryLLMCompletion_FallsBackToChatOnCodeCompleterError(t *testing.T) {
s.llmClient = fake
line := "obj."
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://y.go"}}
- items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
if !ok {
t.Fatalf("expected ok=true even on fallback path")
}
diff --git a/internal/lsp/completion_prefix_strip_test.go b/internal/lsp/completion_prefix_strip_test.go
index 6173d6f..e0c655c 100644
--- a/internal/lsp/completion_prefix_strip_test.go
+++ b/internal/lsp/completion_prefix_strip_test.go
@@ -52,7 +52,7 @@ func TestTryLLMCompletion_ManualInvokeAfterWhitespace_Allows(t *testing.T) {
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://x.go"}}
// Simulate manual user invocation (TriggerKind=1)
p.Context = json.RawMessage([]byte(`{"triggerKind":1}`))
- items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
if !ok {
t.Fatalf("expected ok=true for manual invoke after whitespace")
}
@@ -72,7 +72,7 @@ func TestTryLLMCompletion_InlinePromptAlwaysTriggers(t *testing.T) {
line := "prefix >do something> suffix"
// No trigger char immediately before cursor; place cursor at end
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://inline.go"}}
- items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
if !ok || len(items) == 0 {
t.Fatalf("expected completion to trigger on inline >text> prompt")
}
@@ -89,7 +89,7 @@ func TestTryLLMCompletion_DoubleOpenEmpty_DoesNotAutoTrigger(t *testing.T) {
s.llmClient = fake
line := ">> " // empty content after double-open should not force-trigger
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://empty-inline.go"}}
- items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
if !ok {
t.Fatalf("expected ok=true for non-trigger path")
}
@@ -128,7 +128,7 @@ func TestBareDoubleOpenPreventsAutoTriggerEvenWithOtherTriggers(t *testing.T) {
// Place a '.' earlier but also include bare double-open at end; should not auto-trigger
line := "obj. call >>"
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://bare-ds.go"}}
- items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
if !ok {
t.Fatalf("expected ok=true (handled), but not auto-triggering")
}
@@ -152,7 +152,7 @@ func TestBareDoubleOpenOnNextLine_PreventsAutoTrigger(t *testing.T) {
current := "expression := flag.String(\"expression\", \"\", \"Expression to evaluate\")"
below := ">>"
p := CompletionParams{Position: Position{Line: 0, Character: len(current)}, TextDocument: TextDocumentIdentifier{URI: "file://nextline.go"}}
- items, ok := s.tryLLMCompletion(p, "", current, below, "", "", false, "")
+ items, ok, _ := s.tryLLMCompletion(p, "", current, below, "", "", false, "")
if !ok {
t.Fatalf("expected ok=true handled")
}
@@ -177,7 +177,7 @@ func TestBareDoubleOpenPreventsManualInvoke(t *testing.T) {
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://bare-ds-manual.go"}}
// Simulate manual invoke
p.Context = json.RawMessage([]byte(`{"triggerKind":1}`))
- items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
if !ok {
t.Fatalf("expected ok=true (handled)")
}
diff --git a/internal/lsp/debounce_throttle_test.go b/internal/lsp/debounce_throttle_test.go
index 81a2c1a..7efd439 100644
--- a/internal/lsp/debounce_throttle_test.go
+++ b/internal/lsp/debounce_throttle_test.go
@@ -37,7 +37,7 @@ func TestCompletionDebounce_WaitsUntilQuiet(t *testing.T) {
p.Context = json.RawMessage([]byte(`{"triggerKind":1}`))
start := time.Now()
- _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ _, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
if !ok {
t.Fatalf("expected ok=true")
}
@@ -65,7 +65,7 @@ func TestCompletionThrottle_SerializesCalls(t *testing.T) {
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://throttle.go"}}
p.Context = json.RawMessage([]byte(`{"triggerKind":1}`))
start := time.Now()
- if _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok {
+ if _, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok {
t.Fatalf("first call expected ok=true")
}
if f1.t.IsZero() {
@@ -77,7 +77,7 @@ func TestCompletionThrottle_SerializesCalls(t *testing.T) {
s.compCache = make(map[string]string)
f2 := &timeLLM{}
s.llmClient = f2
- if _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok {
+ if _, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok {
t.Fatalf("second call expected ok=true")
}
if f2.t.IsZero() {
diff --git a/internal/lsp/handlers_completion.go b/internal/lsp/handlers_completion.go
index 237d34d..78e685a 100644
--- a/internal/lsp/handlers_completion.go
+++ b/internal/lsp/handlers_completion.go
@@ -45,9 +45,9 @@ func (s *Server) handleCompletion(req Request) {
if s.llmClient != nil {
newFunc := s.isDefiningNewFunction(p.TextDocument.URI, p.Position)
extra, has := s.buildAdditionalContext(newFunc, p.TextDocument.URI, p.Position)
- items, ok := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, has, extra)
+ items, ok, incomplete := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, has, extra)
if ok {
- s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil)
+ s.reply(req.ID, CompletionList{IsIncomplete: incomplete, Items: items}, nil)
return
}
}
@@ -87,28 +87,33 @@ func (s *Server) logCompletionContext(p CompletionParams, above, current, below,
p.TextDocument.URI, p.Position.Line, p.Position.Character, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx))
}
-func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) ([]CompletionItem, bool) {
+func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) ([]CompletionItem, bool, bool) {
ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second)
- defer cancel()
+ var cancelOnce sync.Once
+ end := func() { cancelOnce.Do(cancel) }
plan, items, handled := s.prepareCompletionPlan(p, above, current, below, funcCtx, docStr, hasExtra, extraText)
if handled {
- return items, true
+ end()
+ return items, true, false
}
specs := s.buildRequestSpecs(surfaceCompletion)
if len(specs) == 0 {
- return nil, false
+ end()
+ return nil, false, false
}
type jobResult struct {
items []CompletionItem
ok bool
}
- results := make([]jobResult, len(specs))
+ results := make(chan jobResult, len(specs))
var wg sync.WaitGroup
- var mu sync.Mutex
+ started := 0
s.waitForDebounce(ctx)
if !s.waitForThrottle(ctx) {
- return nil, false
+ end()
+ close(results)
+ return nil, false, false
}
for _, spec := range specs {
spec := spec
@@ -116,27 +121,67 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun
if client == nil {
continue
}
+ started++
wg.Add(1)
go func(idx int, spec requestSpec, client llm.Client) {
defer wg.Done()
items, ok := s.runCompletionForSpec(ctx, plan, spec, client)
- mu.Lock()
- results[idx] = jobResult{items: items, ok: ok}
- mu.Unlock()
+ results <- jobResult{items: items, ok: ok}
}(spec.index, spec, client)
}
- wg.Wait()
- accumulated := make([]CompletionItem, 0)
- for _, res := range results {
- if !res.ok {
- continue
+
+ if started == 0 {
+ end()
+ close(results)
+ return nil, false, false
+ }
+
+ go func() {
+ wg.Wait()
+ close(results)
+ }()
+
+ if started == 1 {
+ res := <-results
+ if !res.ok || len(res.items) == 0 {
+ end()
+ return nil, false, false
}
- accumulated = append(accumulated, res.items...)
+ end()
+ return res.items, true, false
}
- if len(accumulated) == 0 {
- return nil, false
+
+ firstCh := make(chan []CompletionItem, 1)
+ go func(planKey string) {
+ defer end()
+ combined := make([]CompletionItem, 0)
+ firstSent := false
+ for res := range results {
+ if !res.ok || len(res.items) == 0 {
+ continue
+ }
+ combined = append(combined, res.items...)
+ if !firstSent {
+ first := make([]CompletionItem, len(res.items))
+ copy(first, res.items)
+ firstCh <- first
+ firstSent = true
+ }
+ }
+ if !firstSent {
+ close(firstCh)
+ return
+ }
+ s.storePendingCompletion(planKey, combined)
+ close(firstCh)
+ }(plan.cacheKey)
+
+ firstItems, ok := <-firstCh
+ if !ok || len(firstItems) == 0 {
+ end()
+ return nil, false, false
}
- return accumulated, true
+ return firstItems, true, true
}
func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) (completionPlan, []CompletionItem, bool) {
@@ -162,6 +207,9 @@ func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below
plan.inParams = inParamList(current, p.Position.Character)
plan.manualInvoke = parseManualInvoke(p.Context)
plan.cacheKey = s.completionCacheKey(p, above, current, below, funcCtx, plan.inParams, hasExtra, extraText)
+ if pending := s.takePendingCompletion(plan.cacheKey); len(pending) > 0 {
+ return plan, pending, true
+ }
if isBareDoubleOpen(current, openChar, closeChar) || isBareDoubleOpen(below, openChar, closeChar) {
logging.Logf("lsp ", "%scompletion skip=empty-double-semicolon line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase)
return plan, []CompletionItem{}, true
diff --git a/internal/lsp/handlers_document.go b/internal/lsp/handlers_document.go
index 9325877..da7db51 100644
--- a/internal/lsp/handlers_document.go
+++ b/internal/lsp/handlers_document.go
@@ -231,7 +231,7 @@ func (s *Server) runInlinePrompt(uri string, pos Position) {
docStr := s.buildDocString(p, above, current, below, funcCtx)
newFunc := s.isDefiningNewFunction(uri, p.Position)
extra, hasExtra := s.buildAdditionalContext(newFunc, uri, p.Position)
- items, ok := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, hasExtra, extra)
+ items, ok, _ := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, hasExtra, extra)
if !ok || len(items) == 0 {
return
}
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index 1fbb0cc..f8b328b 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -40,8 +40,9 @@ type Server struct {
llmRespBytesTotal int64
startTime time.Time
// Small LRU cache for recent code completion outputs (keyed by context)
- compCache map[string]string
- compCacheOrder []string // most-recent at end; cap ~10
+ compCache map[string]string
+ compCacheOrder []string // most-recent at end; cap ~10
+ pendingCompletions map[string][]CompletionItem
// Outgoing JSON-RPC id counter for server-initiated requests
nextID int64
lastLLMCall time.Time
@@ -112,6 +113,7 @@ func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions)
s := &Server{in: bufio.NewReader(r), out: w, logger: logger, docs: make(map[string]*document), logContext: opts.LogContext, configStore: opts.ConfigStore}
s.startTime = time.Now()
s.compCache = make(map[string]string)
+ s.pendingCompletions = make(map[string][]CompletionItem)
s.applyOptions(opts)
// Initialize dispatch table
s.handlers = map[string]func(Request){
@@ -315,6 +317,36 @@ func (s *Server) currentConfig() appconfig.App {
return s.cfg
}
+func (s *Server) storePendingCompletion(key string, items []CompletionItem) {
+ if len(items) == 0 {
+ return
+ }
+ cpy := make([]CompletionItem, len(items))
+ copy(cpy, items)
+ s.mu.Lock()
+ if s.pendingCompletions == nil {
+ s.pendingCompletions = make(map[string][]CompletionItem)
+ }
+ s.pendingCompletions[key] = cpy
+ s.mu.Unlock()
+}
+
+func (s *Server) takePendingCompletion(key string) []CompletionItem {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if len(s.pendingCompletions) == 0 {
+ return nil
+ }
+ items, ok := s.pendingCompletions[key]
+ if !ok {
+ return nil
+ }
+ delete(s.pendingCompletions, key)
+ cpy := make([]CompletionItem, len(items))
+ copy(cpy, items)
+ return cpy
+}
+
func (s *Server) maxTokens() int {
cfg := s.currentConfig()
if cfg.MaxTokens <= 0 {