diff options
| author | Paul Buetow <paul@buetow.org> | 2025-08-28 23:28:31 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-08-28 23:28:31 +0300 |
| commit | 86aafe22eaf04687288e5a730acf0a473719c514 (patch) | |
| tree | b1a618442f903cf5de37d65f0b9d86392bb8194d | |
| parent | e3a7a18558fa5631c44fb70af673877d855206fc (diff) | |
copilot: add session token + codex code completion; lsp: prefer native CodeCompleter with chat fallback; remove obsolete throttle path; add tests; bump version to 0.3.0v0.3.0
| -rw-r--r-- | go.sum | 2 | ||||
| -rw-r--r-- | internal/llm/copilot.go | 194 | ||||
| -rw-r--r-- | internal/llm/provider.go | 18 | ||||
| -rw-r--r-- | internal/lsp/completion_cache_test.go | 4 | ||||
| -rw-r--r-- | internal/lsp/completion_codex_path_test.go | 58 | ||||
| -rw-r--r-- | internal/lsp/completion_prefix_strip_test.go | 14 | ||||
| -rw-r--r-- | internal/lsp/completion_throttle_test.go | 102 | ||||
| -rw-r--r-- | internal/lsp/handlers.go | 155 | ||||
| -rw-r--r-- | internal/lsp/testfakes_test.go | 18 | ||||
| -rw-r--r-- | internal/version.go | 2 |
10 files changed, 382 insertions, 185 deletions
@@ -0,0 +1,2 @@ +github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= +github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= diff --git a/internal/llm/copilot.go b/internal/llm/copilot.go index 6ab3a0d..7b3574c 100644 --- a/internal/llm/copilot.go +++ b/internal/llm/copilot.go @@ -1,4 +1,4 @@ -// Summary: GitHub Copilot client implementation for chat completions using the Copilot API. +// Summary: GitHub Copilot client for chat and Codex-style code completion. package llm import ( @@ -7,10 +7,13 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" + "regexp" "strings" "time" + "encoding/base64" appver "hexai/internal" "hexai/internal/logging" ) @@ -23,6 +26,10 @@ type copilotClient struct { defaultModel string chatLogger logging.ChatLogger defaultTemperature *float64 + + // cached Copilot session token retrieved from GitHub API using apiKey + sessionToken string + tokenExpiry time.Time } type copilotChatRequest struct { @@ -79,6 +86,10 @@ func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...Req if strings.TrimSpace(c.apiKey) == "" { return nilStringErr("missing Copilot API key") } + // Ensure we have a fresh session token + if err := c.ensureSession(ctx); err != nil { + return "", err + } o := Options{Model: c.defaultModel} for _, opt := range opts { opt(&o) @@ -102,9 +113,7 @@ func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...Req endpoint := c.baseURL + "/chat/completions" logging.Logf("llm/copilot ", "POST %s", endpoint) - resp, err := c.doJSON(ctx, endpoint, body, map[string]string{ - "Authorization": "Bearer " + c.apiKey, - }) + resp, err := c.postJSON(ctx, endpoint, body, c.headersChat()) if err != nil { logging.Logf("llm/copilot ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return "", err @@ -152,20 +161,11 @@ func buildCopilotChatRequest(o Options, messages []Message, defaultTemp *float64 return req } -func (c copilotClient) doJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - // GitHub Copilot (GitHub Models) requires an API version header and a UA. - req.Header.Set("Accept", "application/json") - req.Header.Set("X-GitHub-Api-Version", "2023-07-07") - req.Header.Set("User-Agent", "hexai/"+appver.Version) - for k, v := range headers { - req.Header.Set(k, v) - } - return c.httpClient.Do(req) +func (c copilotClient) postJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { return nil, err } + for k, v := range headers { req.Header.Set(k, v) } + return c.httpClient.Do(req) } func handleCopilotNon2xx(resp *http.Response, start time.Time) error { @@ -190,3 +190,161 @@ func decodeCopilotChat(resp *http.Response, start time.Time) (copilotChatRespons } return out, nil } + +// --- Copilot session token management --- + +type ghCopilotTokenResp struct { + Token string `json:"token"` +} + +func (c *copilotClient) ensureSession(ctx context.Context) error { + // If token valid for >60s, reuse + if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) { + return nil + } + if strings.TrimSpace(c.apiKey) == "" { + return errors.New("missing Copilot API key") + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil) + if err != nil { return err } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "hexai/"+appver.Version) + resp, err := c.httpClient.Do(req) + if err != nil { return err } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("copilot token http error: %d", resp.StatusCode) + } + var out ghCopilotTokenResp + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { return err } + if strings.TrimSpace(out.Token) == "" { return errors.New("empty copilot session token") } + // Parse JWT exp + exp := parseJWTExp(out.Token) + if exp.IsZero() { exp = time.Now().Add(10 * time.Minute) } + c.sessionToken = out.Token + c.tokenExpiry = exp + return nil +} + +var jwtExpRe = regexp.MustCompile(`"exp"\s*:\s*([0-9]+)`) // fallback if we can't base64 decode + +func parseJWTExp(token string) time.Time { + parts := strings.Split(token, ".") + if len(parts) < 2 { return time.Time{} } + b, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + if m := jwtExpRe.FindStringSubmatch(token); len(m) == 2 { + if n, err2 := parseInt64(m[1]); err2 == nil { return time.Unix(n, 0) } + } + return time.Time{} + } + var payload struct{ Exp int64 `json:"exp"` } + _ = json.Unmarshal(b, &payload) + if payload.Exp == 0 { return time.Time{} } + return time.Unix(payload.Exp, 0) +} + +func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, &n); return n, err } + +// --- Copilot headers --- + +func (c *copilotClient) headersChat() map[string]string { + _ = c.ensureSession(context.Background()) + h := map[string]string{ + "Content-Type": "application/json; charset=utf-8", + "Accept": "application/json", + "Authorization": "Bearer " + c.sessionToken, + "User-Agent": "GitHubCopilotChat/0.8.0", + "Editor-Plugin-Version": "copilot-chat/0.8.0", + "Editor-Version": "vscode/1.85.1", + "Openai-Intent": "conversation-panel", + "Openai-Organization": "github-copilot", + "VScode-MachineId": randHex(64), + "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + } + return h +} + +func (c *copilotClient) headersGhost() map[string]string { + _ = c.ensureSession(context.Background()) + h := map[string]string{ + "Content-Type": "application/json; charset=utf-8", + "Accept": "*/*", + "Authorization": "Bearer " + c.sessionToken, + "User-Agent": "GithubCopilot/1.155.0", + "Editor-Plugin-Version": "copilot/1.155.0", + "Editor-Version": "vscode/1.85.1", + "Openai-Intent": "copilot-ghost", + "Openai-Organization": "github-copilot", + "VScode-MachineId": randHex(64), + "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + } + return h +} + +func randHex(n int) string { + const hex = "0123456789abcdef" + b := make([]byte, n) + for i := range b { + b[i] = hex[int(time.Now().UnixNano()+int64(i))%len(hex)] + } + return string(b) +} + +// --- Codex-style code completion --- + +// CodeCompletion implements CodeCompleter; returns up to n suggestions. +func (c copilotClient) CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error) { + if strings.TrimSpace(c.apiKey) == "" { return nil, errors.New("missing Copilot API key") } + if err := c.ensureSession(ctx); err != nil { return nil, err } + if n <= 0 { n = 1 } + maxTokens := 500 + body := map[string]any{ + "extra": map[string]any{ + "language": language, + "next_indent": 0, + "prompt_tokens": 500, + "suffix_tokens": 400, + "trim_by_indentation": true, + }, + "max_tokens": maxTokens, + "n": n, + "nwo": "hexai", + "prompt": prompt, + "stop": []string{"\n\n"}, + "stream": true, + "suffix": suffix, + "temperature": temperature, + "top_p": 1, + } + buf, _ := json.Marshal(body) + url := "https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions" + resp, err := c.postJSON(ctx, url, buf, c.headersGhost()) + if err != nil { return nil, err } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("copilot codex http error: %d", resp.StatusCode) + } + // Read all and parse lines that start with "data: " accumulating by index + raw, _ := io.ReadAll(resp.Body) + byIndex := make(map[int]string) + lines := strings.Split(string(raw), "\n") + for _, ln := range lines { + if !strings.HasPrefix(ln, "data: ") { continue } + var evt struct{ Choices []struct{ Index int `json:"index"`; Text string `json:"text"` } `json:"choices"` } + if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil { continue } + for _, ch := range evt.Choices { byIndex[ch.Index] += ch.Text } + } + out := make([]string, 0, len(byIndex)) + for i := 0; i < n; i++ { + if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" { out = append(out, s) } + } + return out, nil +} + +// newLineDataReader wraps a streaming body and exposes a JSON decoder that +// decodes successive objects from lines prefixed by "data: ". +// (no streaming decoder needed; we parse whole body lines) diff --git a/internal/llm/provider.go b/internal/llm/provider.go index ed9ca59..7ab58c6 100644 --- a/internal/llm/provider.go +++ b/internal/llm/provider.go @@ -28,10 +28,20 @@ type Client interface { // token-by-token streaming responses. Callers can type-assert to Streamer and // fall back to Client.Chat when not implemented. type Streamer interface { - // ChatStream sends chat messages and invokes onDelta with incremental text - // chunks as they are produced by the model. Implementations should call - // onDelta with empty strings sparingly (prefer only non-empty chunks). - ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error + // ChatStream sends chat messages and invokes onDelta with incremental text + // chunks as they are produced by the model. Implementations should call + // onDelta with empty strings sparingly (prefer only non-empty chunks). + ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error +} + +// CodeCompleter is an optional interface for providers that support a +// prompt/suffix code-completion API (e.g., Copilot Codex endpoint). Clients +// can type-assert to this and prefer it over chat when available. +type CodeCompleter interface { + // CodeCompletion requests up to n suggestions given a left-hand prompt and + // right-hand suffix around the cursor. Language is advisory and may be + // ignored. Temperature applies when provider supports it. + CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error) } // Options for a request. Providers may ignore unsupported fields. diff --git a/internal/lsp/completion_cache_test.go b/internal/lsp/completion_cache_test.go index 0207a9f..a350281 100644 --- a/internal/lsp/completion_cache_test.go +++ b/internal/lsp/completion_cache_test.go @@ -21,7 +21,7 @@ func TestCompletionCache_IgnoresWhitespaceBeforeCursor(t *testing.T) { // First request with trailing spaces before cursor line := "foo " p := CompletionParams{ Position: Position{ Line: 0, Character: len(line) }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} } - items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok || len(items) == 0 || fake.calls != 1 { t.Fatalf("expected first call to invoke LLM; ok=%v len=%d calls=%d", ok, len(items), fake.calls) } @@ -29,7 +29,7 @@ func TestCompletionCache_IgnoresWhitespaceBeforeCursor(t *testing.T) { // Same logical context but with a different amount of trailing whitespace line2 := "foo " p2 := CompletionParams{ Position: Position{ Line: 0, Character: len(line2) }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} } - items2, ok2, _ := s.tryLLMCompletion(p2, "", line2, "", "", "", false, "") + items2, ok2 := s.tryLLMCompletion(p2, "", line2, "", "", "", false, "") if !ok2 || len(items2) == 0 { t.Fatalf("expected cache hit to still return items") } diff --git a/internal/lsp/completion_codex_path_test.go b/internal/lsp/completion_codex_path_test.go new file mode 100644 index 0000000..65ab75a --- /dev/null +++ b/internal/lsp/completion_codex_path_test.go @@ -0,0 +1,58 @@ +package lsp + +import ( + "context" + "errors" + "testing" + + "hexai/internal/llm" +) + +// fakeCodeLLM implements both llm.Client and llm.CodeCompleter. +type fakeCodeLLM struct{ + codeCalls int + chatCalls int + result string + codeErr error +} + +func (f *fakeCodeLLM) CodeCompletion(_ context.Context, _ string, _ string, n int, _ string, _ float64) ([]string, error) { + f.codeCalls++ + if f.codeErr != nil { return nil, f.codeErr } + if n <= 0 { n = 1 } + out := make([]string, n) + for i := 0; i < n; i++ { out[i] = f.result } + return out, nil +} + +func (f *fakeCodeLLM) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { + f.chatCalls++ + return "chat", nil +} +func (f *fakeCodeLLM) Name() string { return "fake" } +func (f *fakeCodeLLM) DefaultModel() string { return "m" } + +func TestTryLLMCompletion_PrefersCodeCompleterOverChat(t *testing.T) { + s := &Server{ maxTokens: 32, triggerChars: []string{"."}, compCache: make(map[string]string) } + fake := &fakeCodeLLM{ result: "DoThing()" } + s.llmClient = fake + line := "obj." + p := CompletionParams{ Position: Position{ Line: 0, Character: len(line) }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} } + items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + if !ok || len(items) == 0 { t.Fatalf("expected completion items via CodeCompleter path") } + if fake.codeCalls == 0 { t.Fatalf("expected CodeCompletion to be called") } + if fake.chatCalls != 0 { t.Fatalf("did not expect Chat fallback when CodeCompletion succeeds") } +} + +func TestTryLLMCompletion_FallsBackToChatOnCodeCompleterError(t *testing.T) { + s := &Server{ maxTokens: 32, triggerChars: []string{"."}, compCache: make(map[string]string) } + fake := &fakeCodeLLM{ result: "DoThing()", codeErr: errors.New("boom") } + s.llmClient = fake + line := "obj." + p := CompletionParams{ Position: Position{ Line: 0, Character: len(line) }, TextDocument: TextDocumentIdentifier{URI: "file://y.go"} } + items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + if !ok { t.Fatalf("expected ok=true even on fallback path") } + if len(items) == 0 { t.Fatalf("expected some items from Chat fallback") } + if fake.codeCalls == 0 { t.Fatalf("expected CodeCompletion to be attempted first") } + if fake.chatCalls == 0 { t.Fatalf("expected Chat fallback to be called when CodeCompletion errors") } +} diff --git a/internal/lsp/completion_prefix_strip_test.go b/internal/lsp/completion_prefix_strip_test.go index 5527cb2..9953714 100644 --- a/internal/lsp/completion_prefix_strip_test.go +++ b/internal/lsp/completion_prefix_strip_test.go @@ -45,8 +45,7 @@ func TestTryLLMCompletion_ManualInvokeAfterWhitespace_Allows(t *testing.T) { p := CompletionParams{ Position: Position{ Line: 0, Character: len(line) }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} } // Simulate manual user invocation (TriggerKind=1) p.Context = json.RawMessage([]byte(`{"triggerKind":1}`)) - items, ok, busy := s.tryLLMCompletion(p, "", line, "", "", "", false, "") - if busy { t.Fatalf("unexpected busy=true") } + items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok { t.Fatalf("expected ok=true for manual invoke after whitespace") } if len(items) == 0 { t.Fatalf("expected at least one completion item") } } @@ -57,8 +56,7 @@ func TestTryLLMCompletion_InlineSemicolonPromptAlwaysTriggers(t *testing.T) { line := "prefix ;do something; suffix" // No trigger char immediately before cursor; place cursor at end p := CompletionParams{ Position: Position{ Line: 0, Character: len(line) }, TextDocument: TextDocumentIdentifier{URI: "file://inline.go"} } - items, ok, busy := s.tryLLMCompletion(p, "", line, "", "", "", false, "") - if busy { t.Fatalf("unexpected busy=true") } + items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok || len(items) == 0 { t.Fatalf("expected completion to trigger on inline ;text; prompt") } } @@ -68,7 +66,7 @@ func TestTryLLMCompletion_DoubleSemicolonEmpty_DoesNotAutoTrigger(t *testing.T) s.llmClient = fake line := ";; " // empty content after ';;' should not force-trigger p := CompletionParams{ Position: Position{ Line: 0, Character: len(line) }, TextDocument: TextDocumentIdentifier{URI: "file://empty-inline.go"} } - items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok { t.Fatalf("expected ok=true for non-trigger path") } if len(items) != 0 { t.Fatalf("expected no items when inline ';;' is empty") } if fake.calls != 0 { t.Fatalf("LLM should not be called; calls=%d", fake.calls) } @@ -88,7 +86,7 @@ func TestBareDoubleSemicolonPreventsAutoTriggerEvenWithOtherTriggers(t *testing. // Place a '.' earlier but also include bare ';;' at end; should not auto-trigger line := "obj. call ;;" p := CompletionParams{ Position: Position{ Line: 0, Character: len(line) }, TextDocument: TextDocumentIdentifier{URI: "file://bare-ds.go"} } - items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok { t.Fatalf("expected ok=true (handled), but not auto-triggering") } if len(items) != 0 { t.Fatalf("expected no items due to bare ';;'") } if fake.calls != 0 { t.Fatalf("LLM should not be called; calls=%d", fake.calls) } @@ -101,7 +99,7 @@ func TestBareDoubleSemicolonOnNextLine_PreventsAutoTrigger(t *testing.T) { current := "expression := flag.String(\"expression\", \"\", \"Expression to evaluate\")" below := ";;" p := CompletionParams{ Position: Position{ Line: 0, Character: len(current) }, TextDocument: TextDocumentIdentifier{URI: "file://nextline.go"} } - items, ok, _ := s.tryLLMCompletion(p, "", current, below, "", "", false, "") + items, ok := s.tryLLMCompletion(p, "", current, below, "", "", false, "") if !ok { t.Fatalf("expected ok=true handled") } if len(items) != 0 { t.Fatalf("expected no items due to bare ';;' on next line") } if fake.calls != 0 { t.Fatalf("LLM should not be called; calls=%d", fake.calls) } @@ -115,7 +113,7 @@ func TestBareDoubleSemicolonPreventsManualInvoke(t *testing.T) { p := CompletionParams{ Position: Position{ Line: 0, Character: len(line) }, TextDocument: TextDocumentIdentifier{URI: "file://bare-ds-manual.go"} } // Simulate manual invoke p.Context = json.RawMessage([]byte(`{"triggerKind":1}`)) - items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") + items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "") if !ok { t.Fatalf("expected ok=true (handled)") } if len(items) != 0 { t.Fatalf("expected no items for bare ';;' even with manual invoke") } if fake.calls != 0 { t.Fatalf("LLM should not be called; calls=%d", fake.calls) } diff --git a/internal/lsp/completion_throttle_test.go b/internal/lsp/completion_throttle_test.go deleted file mode 100644 index 11b0e7a..0000000 --- a/internal/lsp/completion_throttle_test.go +++ /dev/null @@ -1,102 +0,0 @@ -package lsp - -import ( - "bytes" - "context" - "log" - "testing" - - "hexai/internal/llm" -) - -// countingLLM counts Chat calls; minimal implementation for tests. -type countingLLM struct{ calls int } - -func (f *countingLLM) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { - f.calls++ - return "x := 1", nil -} -func (f *countingLLM) Name() string { return "fake" } -func (f *countingLLM) DefaultModel() string { return "m" } - -func TestDefaultTriggerChars_DoesNotIncludeSemicolonOrQuestion(t *testing.T) { - var buf bytes.Buffer - logger := log.New(&buf, "", 0) - s := NewServer(bytes.NewBuffer(nil), &buf, logger, ServerOptions{}) - has := func(ch string) bool { - for _, c := range s.triggerChars { - if c == ch { return true } - } - return false - } - if has(";") || has("?") { - t.Fatalf("default trigger chars should not include ';' or '?' got=%v", s.triggerChars) - } -} - -// Note: The server no longer exposes a busy guard; completion requests are -// handled sequentially and the LSP can request again if needed. This test used -// to assert a busy path; it now asserts that a normal trigger proceeds and -// calls the LLM without reporting busy. -func TestTryLLMCompletion_NoBusyPath_CurrentBehavior(t *testing.T) { - s := &Server{ maxTokens: 32, triggerChars: []string{".", ":", "/", "_"} } - fake := &countingLLM{} - s.llmClient = fake - p := CompletionParams{ Position: Position{ Line: 0, Character: 4 }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} } - items, ok, busy := s.tryLLMCompletion(p, "", "foo.", "", "", "", false, "") - if !ok { - t.Fatalf("expected ok=true for a normal triggered completion") - } - if busy { - t.Fatalf("did not expect busy=true in current behavior") - } - if len(items) == 0 { - t.Fatalf("expected some completion items when triggered") - } - if fake.calls == 0 { - t.Fatalf("expected LLM Chat to be called") - } -} - -func TestTryLLMCompletion_MinPrefixSkipsEarly(t *testing.T) { - s := &Server{ maxTokens: 32, triggerChars: []string{".", ":", "/", "_"} } - fake := &countingLLM{} - s.llmClient = fake - // No trigger character -> skip regardless of prefix - p := CompletionParams{ Position: Position{ Line: 0, Character: 1 }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} } - items, ok, _ := s.tryLLMCompletion(p, "", "a", "", "", "", false, "") - if !ok { - t.Fatalf("expected ok=true when skipped by min-prefix heuristic") - } - if len(items) != 0 { - t.Fatalf("expected zero items when not triggered") - } - if fake.calls != 0 { - t.Fatalf("LLM Chat should not be called when not triggered; calls=%d", fake.calls) - } -} - -func TestTryLLMCompletion_RequiresTriggerChar(t *testing.T) { - s := &Server{ maxTokens: 32, triggerChars: []string{".", ":", "/", "_", " "} } - fake := &countingLLM{} - s.llmClient = fake - // With trigger character '.' directly before cursor -> allowed - items, ok, _ := s.tryLLMCompletion(CompletionParams{ Position: Position{ Line: 0, Character: 1 }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} }, "", ".", "", "", "", false, "") - if !ok || len(items) == 0 || fake.calls == 0 { t.Fatalf("expected allowed with '.' trigger") } - // Without trigger -> skipped - fake.calls = 0 - items, ok, _ = s.tryLLMCompletion(CompletionParams{ Position: Position{ Line: 0, Character: 1 }, TextDocument: TextDocumentIdentifier{URI: "file://y.go"} }, "", "a", "", "", "", false, "") - if !ok || len(items) != 0 || fake.calls != 0 { t.Fatalf("expected skip without trigger; ok=%v len=%d calls=%d", ok, len(items), fake.calls) } -} - -func TestTryLLMCompletion_AllowsSpaceTrigger(t *testing.T) { - s := &Server{ maxTokens: 32, triggerChars: []string{".", ":", "/", "_", " "} } - fake := &countingLLM{} - s.llmClient = fake - line := "type Matrix " - p := CompletionParams{ Position: Position{ Line: 0, Character: len(line) }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} } - items, ok, _ := s.tryLLMCompletion(p, "", line, "", "", "", false, "") - if !ok || len(items) == 0 || fake.calls == 0 { - t.Fatalf("expected allowed with space trigger; ok=%v len=%d calls=%d", ok, len(items), fake.calls) - } -} diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index 1b7436e..8efc48e 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -483,26 +483,45 @@ func (s *Server) handleCompletion(req Request) { if s.llmClient != nil { newFunc := s.isDefiningNewFunction(p.TextDocument.URI, p.Position) extra, has := s.buildAdditionalContext(newFunc, p.TextDocument.URI, p.Position) - items, ok, busy := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, has, extra) - if ok { - s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) - return - } - if busy { - // Inform client that results are incomplete so it may try again shortly. - logging.Logf("lsp ", "completion busy uri=%s line=%d char=%d returning isIncomplete", p.TextDocument.URI, p.Position.Line, p.Position.Character) - s.reply(req.ID, CompletionList{IsIncomplete: true, Items: []CompletionItem{}}, nil) - return - } - } - } - items := s.fallbackCompletionItems(docStr) - s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) + items, ok := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, has, extra) + if ok { + s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) + return + } + } + } + items := s.fallbackCompletionItems(docStr) + s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) } func (s *Server) reply(id json.RawMessage, result any, err *RespError) { - resp := Response{JSONRPC: "2.0", ID: id, Result: result, Error: err} - s.writeMessage(resp) + resp := Response{JSONRPC: "2.0", ID: id, Result: result, Error: err} + s.writeMessage(resp) +} + +// docBeforeAfter returns the full document text split at the given position. +// The returned strings are the text before the cursor (inclusive of anything +// left of the position) and the text after the cursor. +func (s *Server) docBeforeAfter(uri string, pos Position) (string, string) { + d := s.getDocument(uri) + if d == nil { return "", "" } + // Clamp indices + line := pos.Line + if line < 0 { line = 0 } + if line >= len(d.lines) { line = len(d.lines) - 1 } + col := pos.Character + if col < 0 { col = 0 } + if col > len(d.lines[line]) { col = len(d.lines[line]) } + // Build before + var b strings.Builder + for i := 0; i < line; i++ { b.WriteString(d.lines[i]); b.WriteByte('\n') } + b.WriteString(d.lines[line][:col]) + before := b.String() + // Build after + var a strings.Builder + a.WriteString(d.lines[line][col:]) + for i := line + 1; i < len(d.lines); i++ { a.WriteByte('\n'); a.WriteString(d.lines[i]) } + return before, a.String() } // extractTriggerInfo returns the LSP completion TriggerKind and TriggerCharacter @@ -742,17 +761,17 @@ func (s *Server) logCompletionContext(p CompletionParams, above, current, below, p.TextDocument.URI, p.Position.Line, p.Position.Character, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) } -func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) ([]CompletionItem, bool, bool) { +func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) ([]CompletionItem, bool) { ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second) defer cancel() // Inline prompt markers (strict ;text; or double-; patterns) explicitly allow triggering. inlinePrompt := lineHasInlinePrompt(current) // Only invoke LLM when triggered by our characters, manual invoke, or inline prompt markers. - if !inlinePrompt && !s.isTriggerEvent(p, current) { - logging.Logf("lsp ", "%scompletion skip=no-trigger line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) - return []CompletionItem{}, true, false - } + if !inlinePrompt && !s.isTriggerEvent(p, current) { + logging.Logf("lsp ", "%scompletion skip=no-trigger line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) + return []CompletionItem{}, true + } inParams := inParamList(current, p.Position.Character) @@ -776,18 +795,18 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun // Build a cache key for this completion context (ignore trailing whitespace // before the cursor when forming the key) and try cache before any LLM call. key := s.completionCacheKey(p, above, current, below, funcCtx, inParams, hasExtra, extraText) - if cleaned, ok := s.completionCacheGet(key); ok && strings.TrimSpace(cleaned) != "" { - logging.Logf("lsp ", "completion cache hit uri=%s line=%d char=%d preview=%s%s%s", - p.TextDocument.URI, p.Position.Line, p.Position.Character, - logging.AnsiGreen, logging.PreviewForLog(cleaned), logging.AnsiBase) - return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true, false - } + if cleaned, ok := s.completionCacheGet(key); ok && strings.TrimSpace(cleaned) != "" { + logging.Logf("lsp ", "completion cache hit uri=%s line=%d char=%d preview=%s%s%s", + p.TextDocument.URI, p.Position.Line, p.Position.Character, + logging.AnsiGreen, logging.PreviewForLog(cleaned), logging.AnsiBase) + return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true + } // If there is a bare ';;' on the current or next line (no valid ';;text;'), // do not auto-trigger unless it was a manual invoke. - if (isBareDoubleSemicolon(current) || isBareDoubleSemicolon(below)) && !manualInvoke { - logging.Logf("lsp ", "%scompletion skip=empty-double-semicolon line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) - return []CompletionItem{}, true, false - } + if (isBareDoubleSemicolon(current) || isBareDoubleSemicolon(below)) && !manualInvoke { + logging.Logf("lsp ", "%scompletion skip=empty-double-semicolon line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) + return []CompletionItem{}, true + } // Heuristic 1: Require a minimal typed identifier prefix to avoid early triggers, // but allow immediate completion after structural trigger chars like '.', ':', '/'. @@ -827,13 +846,49 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun if manualInvoke && s.manualInvokeMinPrefix >= 0 { min = s.manualInvokeMinPrefix } - if j-start < min { // require at least min identifier chars - logging.Logf("lsp ", "%scompletion skip=short-prefix line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) - return []CompletionItem{}, true, false - } - } - } - sysPrompt, userPrompt := buildPrompts(inParams, p, above, current, below, funcCtx) + if j-start < min { // require at least min identifier chars + logging.Logf("lsp ", "%scompletion skip=short-prefix line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) + return []CompletionItem{}, true + } + } + } + // Prefer provider-native code completion when available (e.g., Copilot Codex) + if cc, ok := s.llmClient.(llm.CodeCompleter); ok { + before, after := s.docBeforeAfter(p.TextDocument.URI, p.Position) + // Construct prompt/suffix similar to helix-gpt + path := strings.TrimPrefix(p.TextDocument.URI, "file://") + prompt := "// Path: " + path + "\n" + before + lang := "" + temp := 0.0 + if s.codingTemperature != nil { temp = *s.codingTemperature } + prov := "" + if s.llmClient != nil { prov = s.llmClient.Name() } + logging.Logf("lsp ", "completion path=codex provider=%s uri=%s", prov, path) + ctx2, cancel2 := context.WithTimeout(context.Background(), 8*time.Second) + defer cancel2() + suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, temp) + if err == nil && len(suggestions) > 0 { + cleaned := strings.TrimSpace(suggestions[0]) + if cleaned != "" { + cleaned = stripDuplicateAssignmentPrefix(current[:p.Position.Character], cleaned) + if cleaned != "" { cleaned = stripDuplicateGeneralPrefix(current[:p.Position.Character], cleaned) } + if cleaned != "" && hasDoubleSemicolonTrigger(current) { + indent := leadingIndent(current) + if indent != "" { cleaned = applyIndent(indent, cleaned) } + } + if strings.TrimSpace(cleaned) != "" { + key := s.completionCacheKey(p, above, current, below, funcCtx, inParams, hasExtra, extraText) + s.completionCachePut(key, cleaned) + return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true + } + } + } else if err != nil { + logging.Logf("lsp ", "completion path=codex error=%v (falling back to chat)", err) + } + // If provider-native path failed, fall back to chat below. + } + + sysPrompt, userPrompt := buildPrompts(inParams, p, above, current, below, funcCtx) messages := []llm.Message{ {Role: "system", Content: sysPrompt}, {Role: "user", Content: userPrompt}, @@ -859,13 +914,13 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun opts = append(opts, llm.WithTemperature(*s.codingTemperature)) } logging.Logf("lsp ", "completion llm=requesting model=%s", s.llmClient.DefaultModel()) - text, err := s.llmClient.Chat(ctx, messages, opts...) - if err != nil { - logging.Logf("lsp ", "llm completion error: %v", err) - // Log updated averages after this request (even if failed) - s.logLLMStats() - return nil, false, false - } + text, err := s.llmClient.Chat(ctx, messages, opts...) + if err != nil { + logging.Logf("lsp ", "llm completion error: %v", err) + // Log updated averages after this request (even if failed) + s.logLLMStats() + return nil, false + } // Update response counters (received) s.incRecvCounters(len(text)) s.logLLMStats() @@ -895,14 +950,14 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun cleaned = applyIndent(indent, cleaned) } } - if cleaned == "" { - return nil, false, false - } + if cleaned == "" { + return nil, false + } // Store successful completion in cache s.completionCachePut(key, cleaned) - return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true, false + return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true } // --- small completion cache (last ~10 entries) --- diff --git a/internal/lsp/testfakes_test.go b/internal/lsp/testfakes_test.go new file mode 100644 index 0000000..bfe536e --- /dev/null +++ b/internal/lsp/testfakes_test.go @@ -0,0 +1,18 @@ +package lsp + +import ( + "context" + "hexai/internal/llm" +) + +// countingLLM counts Chat calls; minimal implementation for tests that need +// to assert whether the chat-based completion path ran. +type countingLLM struct{ calls int } + +func (f *countingLLM) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { + f.calls++ + return "x := 1", nil +} +func (f *countingLLM) Name() string { return "fake" } +func (f *countingLLM) DefaultModel() string { return "m" } + diff --git a/internal/version.go b/internal/version.go index 40e4f17..8f22b67 100644 --- a/internal/version.go +++ b/internal/version.go @@ -1,4 +1,4 @@ // Summary: Hexai semantic version identifier used by CLI and LSP binaries. package internal -const Version = "0.2.1" +const Version = "0.3.0" |
