diff options
| author | Paul Buetow <paul@buetow.org> | 2025-08-16 15:35:02 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-08-16 15:35:02 +0300 |
| commit | 6c8eb6876fe87553770de114ebd34649a0c6ec10 (patch) | |
| tree | 064517edaf9d59522bec7191a61362a853c195bd /internal | |
| parent | 1e1df8c204f6771719f85d8402128d72138bb863 (diff) | |
lsp: split monolithic server.go into modules; add configurable max tokens and context strategies (minimal|window|file-on-new-func|always-full); provide flags/env fallbacks; add unit tests for helpers and context; update README; remove obsolete files
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/lsp/context.go | 80 | ||||
| -rw-r--r-- | internal/lsp/context_test.go | 69 | ||||
| -rw-r--r-- | internal/lsp/document.go | 145 | ||||
| -rw-r--r-- | internal/lsp/document_test.go | 76 | ||||
| -rw-r--r-- | internal/lsp/handlers.go | 251 | ||||
| -rw-r--r-- | internal/lsp/handlers_test.go | 126 | ||||
| -rw-r--r-- | internal/lsp/server.go | 479 | ||||
| -rw-r--r-- | internal/lsp/transport.go | 63 | ||||
| -rw-r--r-- | internal/lsp/types.go | 120 | ||||
| -rw-r--r-- | internal/test.go | 18 |
10 files changed, 962 insertions, 465 deletions
diff --git a/internal/lsp/context.go b/internal/lsp/context.go new file mode 100644 index 0000000..c08a865 --- /dev/null +++ b/internal/lsp/context.go @@ -0,0 +1,80 @@ +package lsp + +import ( + "strings" +) + +// buildAdditionalContext builds extra context messages based on the configured mode. +// Modes: +// - minimal: no extra context +// - window: include a window of lines around the cursor +// - file-on-new-func: include full file only when defining a new function +// - always-full: always include the full file +func (s *Server) buildAdditionalContext(newFunc bool, uri string, pos Position) (string, bool) { + mode := s.contextMode + switch mode { + case "minimal": + return "", false + case "window": + return s.windowContext(uri, pos), true + case "file-on-new-func": + if newFunc { + return s.fullFileContext(uri), true + } + return "", false + case "always-full": + return s.fullFileContext(uri), true + default: + // fallback to minimal if unknown + return "", false + } +} + +func (s *Server) windowContext(uri string, pos Position) string { + d := s.getDocument(uri) + if d == nil || len(d.lines) == 0 { + return "" + } + n := len(d.lines) + half := s.windowLines / 2 + start := pos.Line - half + if start < 0 { + start = 0 + } + end := pos.Line + half + 1 + if end > n { + end = n + } + text := strings.Join(d.lines[start:end], "\n") + return truncateToApproxTokens(text, s.maxContextTokens) +} + +func (s *Server) fullFileContext(uri string) string { + d := s.getDocument(uri) + if d == nil { + return "" + } + return truncateToApproxTokens(d.text, s.maxContextTokens) +} + +// truncateToApproxTokens naively truncates the input to fit approx N tokens. +// Uses 4 chars/token heuristic for speed and determinism. +func truncateToApproxTokens(text string, maxTokens int) string { + if maxTokens <= 0 { + return "" + } + maxChars := maxTokens * 4 + if len(text) <= maxChars { + return text + } + // try to cut on a line boundary near maxChars + cut := maxChars + if cut > len(text) { + cut = len(text) + } + if i := strings.LastIndex(text[:cut], "\n"); i > 0 { + cut = i + } + return text[:cut] +} + diff --git a/internal/lsp/context_test.go b/internal/lsp/context_test.go new file mode 100644 index 0000000..32834b8 --- /dev/null +++ b/internal/lsp/context_test.go @@ -0,0 +1,69 @@ +package lsp + +import ( + "strconv" + "strings" + "testing" +) + +func TestWindowContext_Bounds(t *testing.T) { + s := newTestServer() + s.windowLines = 4 // half=2 + s.maxContextTokens = 9999 + lines := make([]string, 10) + for i := 0; i < 10; i++ { + lines[i] = "L" + strconv.Itoa(i) + } + text := strings.Join(lines, "\n") + uri := "file:///w.go" + s.setDocument(uri, text) + got := s.windowContext(uri, Position{Line: 5, Character: 0}) + // expect lines 3..7 inclusive + want := strings.Join(lines[3:8], "\n") + if got != want { + t.Fatalf("window context got %q want %q", got, want) + } +} + +func TestBuildAdditionalContext_Minimal(t *testing.T) { + s := newTestServer() + s.contextMode = "minimal" + if ctx, ok := s.buildAdditionalContext(false, "file:///x.go", Position{}); ok || ctx != "" { + t.Fatalf("expected no context in minimal mode; got ok=%v ctx=%q", ok, ctx) + } +} + +func TestBuildAdditionalContext_FileOnNewFunc(t *testing.T) { + s := newTestServer() + s.contextMode = "file-on-new-func" + s.maxContextTokens = 9999 + uri := "file:///x.go" + body := "package x\n\nfunc a(){}\n" + s.setDocument(uri, body) + if ctx, ok := s.buildAdditionalContext(true, uri, Position{}); !ok || ctx == "" { + t.Fatalf("expected full context when new func; ok=%v ctx=%q", ok, ctx) + } + if ctx, ok := s.buildAdditionalContext(false, uri, Position{}); ok || ctx != "" { + t.Fatalf("expected no context when not new func; ok=%v ctx=%q", ok, ctx) + } +} + +func TestBuildAdditionalContext_AlwaysFull(t *testing.T) { + s := newTestServer() + s.contextMode = "always-full" + s.maxContextTokens = 9999 + uri := "file:///x.go" + body := "line1\nline2\n" + s.setDocument(uri, body) + if ctx, ok := s.buildAdditionalContext(false, uri, Position{}); !ok || ctx == "" { + t.Fatalf("expected context in always-full; ok=%v ctx=%q", ok, ctx) + } +} + +func TestTruncateToApproxTokens(t *testing.T) { + text := strings.Repeat("abcd", 10) // 40 chars + got := truncateToApproxTokens(text, 5) // ~20 chars + if len(got) > 5*4 { + t.Fatalf("truncate exceeded budget: got len=%d budget=%d", len(got), 5*4) + } +} diff --git a/internal/lsp/document.go b/internal/lsp/document.go new file mode 100644 index 0000000..e5eaf06 --- /dev/null +++ b/internal/lsp/document.go @@ -0,0 +1,145 @@ +package lsp + +import ( + "strings" + "time" +) + +// --- Document store and helpers --- + +type document struct { + uri string + text string + lines []string +} + +func (s *Server) setDocument(uri, text string) { + s.mu.Lock() + defer s.mu.Unlock() + s.docs[uri] = &document{uri: uri, text: text, lines: splitLines(text)} +} + +func (s *Server) deleteDocument(uri string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.docs, uri) +} + +func (s *Server) markActivity() { + s.mu.Lock() + s.lastInput = time.Now() + s.mu.Unlock() +} + +func (s *Server) getDocument(uri string) *document { + s.mu.RLock() + defer s.mu.RUnlock() + return s.docs[uri] +} + +func splitLines(sx string) []string { + sx = strings.ReplaceAll(sx, "\r\n", "\n") + return strings.Split(sx, "\n") +} + +func (s *Server) lineContext(uri string, pos Position) (above, current, below, funcCtx string) { + d := s.getDocument(uri) + if d == nil || len(d.lines) == 0 { + return "", "", "", "" + } + idx := pos.Line + if idx < 0 { + idx = 0 + } + if idx >= len(d.lines) { + idx = len(d.lines) - 1 + } + current = d.lines[idx] + if idx-1 >= 0 { + above = d.lines[idx-1] + } + if idx+1 < len(d.lines) { + below = d.lines[idx+1] + } + for i := idx; i >= 0; i-- { + line := strings.TrimSpace(d.lines[i]) + if hasAny(line, []string{"func ", "def ", "class ", "fn ", "procedure ", "sub "}) { + funcCtx = line + break + } + } + return +} + +// isDefiningNewFunction returns true when the cursor appears to be within +// a function declaration/signature and before the opening '{' of the body. +// Heuristic: find nearest preceding line containing "func "; ensure no '{' +// appears before the cursor across those lines. +func (s *Server) isDefiningNewFunction(uri string, pos Position) bool { + d := s.getDocument(uri) + if d == nil || len(d.lines) == 0 { + return false + } + idx := pos.Line + if idx < 0 { + idx = 0 + } + if idx >= len(d.lines) { + idx = len(d.lines) - 1 + } + // Find signature start + sigStart := -1 + for i := idx; i >= 0; i-- { + if strings.Contains(d.lines[i], "func ") { + sigStart = i + break + } + // stop if we hit a closing brace which likely ends a previous block + if strings.Contains(d.lines[i], "}") { + break + } + } + if sigStart == -1 { + return false + } + // Scan for '{' from sigStart up to cursor position; if found before or at cursor, we're in body + for i := sigStart; i <= idx; i++ { + line := d.lines[i] + brace := strings.Index(line, "{") + if brace >= 0 { + if i < idx { + return false // body started on a previous line + } + // same line as cursor: if brace position < cursor character, then already in body + if pos.Character > brace { + return false + } + } + } + return true +} + +func hasAny(s string, needles []string) bool { + for _, n := range needles { + if strings.Contains(s, n) { + return true + } + } + return false +} + +func trimLen(s string) string { + s = strings.TrimSpace(s) + if len(s) > 200 { + return s[:200] + "…" + } + return s +} + +func firstLine(s string) string { + s = strings.ReplaceAll(s, "\r\n", "\n") + if idx := strings.IndexByte(s, '\n'); idx >= 0 { + return s[:idx] + } + return s +} diff --git a/internal/lsp/document_test.go b/internal/lsp/document_test.go new file mode 100644 index 0000000..8d81a99 --- /dev/null +++ b/internal/lsp/document_test.go @@ -0,0 +1,76 @@ +package lsp + +import ( + "io" + "log" + "strings" + "testing" +) + +func newTestServer() *Server { + return &Server{ + logger: log.New(io.Discard, "", 0), + docs: make(map[string]*document), + } +} + +func TestSplitLines(t *testing.T) { + in := "a\r\nb\nc" + got := splitLines(in) + want := []string{"a", "b", "c"} + if len(got) != len(want) { + t.Fatalf("len mismatch: got %d want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("line %d: got %q want %q", i, got[i], want[i]) + } + } +} + +func TestLineContext(t *testing.T) { + s := newTestServer() + src := "package main\n\nfunc add(a, b int) int {\n\treturn a + b\n}\n" + uri := "file:///test.go" + s.setDocument(uri, src) + + // Position on the return line (line 3, zero-based) + above, current, below, funcCtx := s.lineContext(uri, Position{Line: 3, Character: 0}) + + if want := "func add(a, b int) int {"; funcCtx != want { + t.Fatalf("funcCtx got %q want %q", funcCtx, want) + } + if want := "func add(a, b int) int {"; above != want { + t.Fatalf("above got %q want %q", above, want) + } + if want := "\treturn a + b"; current != want { + t.Fatalf("current got %q want %q", current, want) + } + if want := "}"; below != want { + t.Fatalf("below got %q want %q", below, want) + } +} + +func TestLineContext_EmptyDoc(t *testing.T) { + s := newTestServer() + a, c, b, f := s.lineContext("file:///missing.go", Position{Line: 0, Character: 0}) + if a != "" || b != "" || c != "" || f != "" { + t.Fatalf("expected all empty for missing doc; got above=%q current=%q below=%q func=%q", a, c, b, f) + } +} + +func TestTrimLen(t *testing.T) { + long := strings.Repeat("a", 205) + got := trimLen(long) + want := strings.Repeat("a", 200) + "…" + if got != want { + t.Fatalf("trimLen got %q want %q", got, want) + } +} + +func TestFirstLine(t *testing.T) { + s := "first line\r\nsecond line" + if got := firstLine(s); got != "first line" { + t.Fatalf("firstLine got %q want %q", got, "first line") + } +} diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go new file mode 100644 index 0000000..8a782c4 --- /dev/null +++ b/internal/lsp/handlers.go @@ -0,0 +1,251 @@ +package lsp + +import ( + "context" + "encoding/json" + "fmt" + "hexai/internal" + "hexai/internal/llm" + "os" + "strings" + "time" +) + +func (s *Server) handle(req Request) { + switch req.Method { + case "initialize": + s.handleInitialize(req) + case "initialized": + s.handleInitialized() + case "shutdown": + s.handleShutdown(req) + case "exit": + s.handleExit() + case "textDocument/didOpen": + s.handleDidOpen(req) + case "textDocument/didChange": + s.handleDidChange(req) + case "textDocument/didClose": + s.handleDidClose(req) + case "textDocument/completion": + s.handleCompletion(req) + default: + if len(req.ID) != 0 { + s.reply(req.ID, nil, &RespError{Code: -32601, Message: fmt.Sprintf("method not found: %s", req.Method)}) + } + } +} + +func (s *Server) handleInitialize(req Request) { + res := InitializeResult{ + Capabilities: ServerCapabilities{ + TextDocumentSync: 1, // 1 = TextDocumentSyncKindFull + CompletionProvider: &CompletionOptions{ + ResolveProvider: false, + // TODO: Make the trigger characters configurable + TriggerCharacters: []string{".", ":", "/", "_"}, + }, + }, + ServerInfo: &ServerInfo{Name: "hexai", Version: internal.Version}, + } + s.reply(req.ID, res, nil) +} + +func (s *Server) handleInitialized() { + s.logger.Println("client initialized") +} + +func (s *Server) handleShutdown(req Request) { + s.reply(req.ID, nil, nil) +} + +func (s *Server) handleExit() { + s.exited = true + os.Exit(0) +} + +func (s *Server) handleDidOpen(req Request) { + var p DidOpenTextDocumentParams + if err := json.Unmarshal(req.Params, &p); err == nil { + s.setDocument(p.TextDocument.URI, p.TextDocument.Text) + s.markActivity() + } +} + +func (s *Server) handleDidChange(req Request) { + var p DidChangeTextDocumentParams + if err := json.Unmarshal(req.Params, &p); err == nil { + if len(p.ContentChanges) > 0 { + s.setDocument(p.TextDocument.URI, p.ContentChanges[len(p.ContentChanges)-1].Text) + } + s.markActivity() + } +} + +func (s *Server) handleDidClose(req Request) { + var p DidCloseTextDocumentParams + if err := json.Unmarshal(req.Params, &p); err == nil { + s.deleteDocument(p.TextDocument.URI) + s.markActivity() + } +} + +func (s *Server) handleCompletion(req Request) { + var p CompletionParams + var docStr string + if err := json.Unmarshal(req.Params, &p); err == nil { + above, current, below, funcCtx := s.lineContext(p.TextDocument.URI, p.Position) + docStr = s.buildDocString(p, above, current, below, funcCtx) + if s.logContext { + s.logCompletionContext(p, above, current, below, funcCtx) + } + if s.llmClient != nil { + newFunc := s.isDefiningNewFunction(p.TextDocument.URI, p.Position) + extra, has := s.buildAdditionalContext(newFunc, p.TextDocument.URI, p.Position) + items, ok := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, has, extra) + if ok { + s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) + return + } + } + } + items := s.fallbackCompletionItems(docStr) + s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) +} + +func (s *Server) reply(id json.RawMessage, result any, err *RespError) { + resp := Response{JSONRPC: "2.0", ID: id, Result: result, Error: err} + s.writeMessage(resp) +} + +// --- completion helpers --- + +func (s *Server) buildDocString(p CompletionParams, above, current, below, funcCtx string) string { + return fmt.Sprintf("file: %s\nline: %d\nabove: %s\ncurrent: %s\nbelow: %s\nfunction: %s", + p.TextDocument.URI, p.Position.Line, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) +} + +func (s *Server) logCompletionContext(p CompletionParams, above, current, below, funcCtx string) { + s.logger.Printf("completion ctx uri=%s line=%d char=%d above=%q current=%q below=%q function=%q", + p.TextDocument.URI, p.Position.Line, p.Position.Character, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) +} + +func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) ([]CompletionItem, bool) { + ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second) + defer cancel() + + inParams := inParamList(current, p.Position.Character) + sysPrompt, userPrompt := buildPrompts(inParams, p, above, current, below, funcCtx) + messages := []llm.Message{ + {Role: "system", Content: sysPrompt}, + {Role: "user", Content: userPrompt}, + } + if hasExtra && extraText != "" { + messages = append(messages, llm.Message{Role: "user", Content: "Additional context:\n" + extraText}) + } + + text, err := s.llmClient.Chat(ctx, messages, llm.WithMaxTokens(s.maxTokens), llm.WithTemperature(0.2)) + if err != nil { + s.logger.Printf("llm completion error: %v", err) + return nil, false + } + cleaned := strings.TrimSpace(text) + if cleaned == "" { + return nil, false + } + + te, filter := computeTextEditAndFilter(cleaned, inParams, current, p) + label := labelForCompletion(cleaned, filter) + items := []CompletionItem{{ + Label: label, + Kind: 1, + Detail: "OpenAI through Hexai completion", + InsertTextFormat: 1, + FilterText: strings.TrimLeft(filter, " \t"), + TextEdit: te, + SortText: "0000", + Documentation: docStr, + }} + return items, true +} + +func inParamList(current string, cursor int) bool { + if !strings.Contains(current, "func ") { + return false + } + open := strings.Index(current, "(") + close := strings.Index(current, ")") + return open >= 0 && cursor > open && (close == -1 || cursor <= close) +} + +func buildPrompts(inParams bool, p CompletionParams, above, current, below, funcCtx string) (string, string) { + if inParams { + sys := "You are a terse Go code completion engine for function signatures. Return only the parameter list contents (without parentheses), no braces, no prose. Prefer idiomatic names and types." + user := fmt.Sprintf("Cursor is inside the function parameter list. Suggest only the parameter list (no parentheses).\nFunction line: %s\nCurrent line (cursor at %d): %s", funcCtx, p.Position.Character, current) + return sys, user + } + sys := "You are a terse code completion engine. Return only the code to insert, no surrounding prose or backticks." + user := fmt.Sprintf("Provide the next likely code to insert at the cursor.\nFile: %s\nFunction/context: %s\nAbove line: %s\nCurrent line (cursor at character %d): %s\nBelow line: %s\nOnly return the completion snippet.", p.TextDocument.URI, funcCtx, above, p.Position.Character, current, below) + return sys, user +} + +func computeTextEditAndFilter(cleaned string, inParams bool, current string, p CompletionParams) (*TextEdit, string) { + if inParams { + open := strings.Index(current, "(") + close := strings.Index(current, ")") + if open >= 0 { + left := open + 1 + right := len(current) + if close >= 0 && close >= left { + right = close + } + if p.Position.Character < right { + right = p.Position.Character + } + te := &TextEdit{Range: Range{Start: Position{Line: p.Position.Line, Character: left}, End: Position{Line: p.Position.Line, Character: right}}, NewText: cleaned} + var filter string + if left >= 0 && right >= left && right <= len(current) { + filter = strings.TrimLeft(current[left:right], " \t") + } + return te, filter + } + } + startChar := computeWordStart(current, p.Position.Character) + te := &TextEdit{Range: Range{Start: Position{Line: p.Position.Line, Character: startChar}, End: Position{Line: p.Position.Line, Character: p.Position.Character}}, NewText: cleaned} + filter := strings.TrimLeft(current[startChar:p.Position.Character], " \t") + return te, filter +} + +func computeWordStart(current string, at int) int { + if at > len(current) { + at = len(current) + } + for at > 0 { + ch := current[at-1] + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' { + at-- + continue + } + break + } + return at +} + +func labelForCompletion(cleaned, filter string) string { + label := trimLen(firstLine(cleaned)) + if filter != "" && !strings.HasPrefix(strings.ToLower(label), strings.ToLower(filter)) { + return filter + } + return label +} + +func (s *Server) fallbackCompletionItems(docStr string) []CompletionItem { + return []CompletionItem{{ + Label: "hexai-complete", + Kind: 1, + Detail: "dummy completion", + InsertText: "hexai", + SortText: "9999", + Documentation: docStr, + }} +} diff --git a/internal/lsp/handlers_test.go b/internal/lsp/handlers_test.go new file mode 100644 index 0000000..bde9b82 --- /dev/null +++ b/internal/lsp/handlers_test.go @@ -0,0 +1,126 @@ +package lsp + +import ( + "strings" + "testing" +) + +func TestInParamList(t *testing.T) { + line := "func foo(a int, b string) int {" + if !inParamList(line, 15) { // inside params + t.Fatalf("expected inParamList true for cursor inside params") + } + if inParamList(line, 2) { // before 'func' + t.Fatalf("expected inParamList false for cursor before params") + } + if inParamList(line, len(line)) { // after ')' + t.Fatalf("expected inParamList false for cursor after params") + } +} + +func TestComputeWordStart(t *testing.T) { + current := "fmt.Prin" + // Cursor after the word (index 8) + got := computeWordStart(current, 8) + // should stop after the dot at index 4 + if want := 4; got != want { + t.Fatalf("computeWordStart got %d want %d", got, want) + } +} + +func TestComputeTextEditAndFilter_InParams(t *testing.T) { + current := "func foo(a int, b string) {" // ')' at index 26 + p := CompletionParams{Position: Position{Line: 10, Character: 20}} + te, filter := computeTextEditAndFilter("x int, y string", true, current, p) + + if te == nil { + t.Fatalf("expected TextEdit") + } + // left should be after '(' which is at index 8 + if te.Range.Start.Line != 10 || te.Range.Start.Character != 9 { + t.Fatalf("start got line=%d char=%d want line=10 char=9", te.Range.Start.Line, te.Range.Start.Character) + } + // right should clamp to cursor (20) + if te.Range.End.Line != 10 || te.Range.End.Character != 20 { + t.Fatalf("end got line=%d char=%d want line=10 char=20", te.Range.End.Line, te.Range.End.Character) + } + if filter == "" { + t.Fatalf("expected non-empty filter inside params") + } +} + +func TestComputeTextEditAndFilter_Word(t *testing.T) { + current := "fmt.Prin" + p := CompletionParams{Position: Position{Line: 2, Character: len(current)}} + te, filter := computeTextEditAndFilter("Println", false, current, p) + if te == nil { + t.Fatalf("expected TextEdit") + } + if te.Range.Start.Character != 4 || te.Range.End.Character != len(current) { + t.Fatalf("range chars got %d..%d want 4..%d", te.Range.Start.Character, te.Range.End.Character, len(current)) + } + if filter != "Prin" { + t.Fatalf("filter got %q want %q", filter, "Prin") + } +} + +func TestLabelForCompletion(t *testing.T) { + if got := labelForCompletion("Println", "Pri"); got != "Println" { + t.Fatalf("label mismatch got %q want %q", got, "Println") + } + if got := labelForCompletion("Println", "X"); got != "X" { + t.Fatalf("label mismatch with filter got %q want %q", got, "X") + } + if got := labelForCompletion("Println\nmore", ""); got != "Println" { + t.Fatalf("label firstLine got %q want %q", got, "Println") + } +} + +func TestBuildPrompts_InParams(t *testing.T) { + p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///t.go"}, Position: Position{Line: 1, Character: 12}} + sys, user := buildPrompts(true, p, "above", "func foo(", "below", "func foo(") + if sys == "" || user == "" { + t.Fatalf("expected non-empty prompts") + } + if want := "function signatures"; !contains(sys, want) { + t.Fatalf("system prompt missing %q: %q", want, sys) + } + if want := "parameter list"; !contains(user, want) { + t.Fatalf("user prompt missing %q: %q", want, user) + } +} + +func TestBuildPrompts_Outside(t *testing.T) { + p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///t.go"}, Position: Position{Line: 1, Character: 5}} + sys, user := buildPrompts(false, p, "ab", "cur", "be", "fnctx") + if sys == "" || user == "" { + t.Fatalf("expected non-empty prompts") + } + if want := "completion engine"; !contains(sys, want) { + t.Fatalf("system prompt missing %q: %q", want, sys) + } + if want := "Provide the next likely code"; !contains(user, want) { + t.Fatalf("user prompt missing %q: %q", want, user) + } +} + +func TestComputeTextEditAndFilter_NoParensFallback(t *testing.T) { + current := "func foo bar" // no parentheses + cursor := len(current) + p := CompletionParams{Position: Position{Line: 0, Character: cursor}} + te, filter := computeTextEditAndFilter("baz", true, current, p) + if te == nil { + t.Fatalf("expected TextEdit from fallback path") + } + // fallback should behave like word edit; start at last space + 1 + lastSpace := strings.LastIndex(current, " ") + if te.Range.Start.Character != lastSpace+1 || te.Range.End.Character != cursor { + t.Fatalf("range got %d..%d want %d..%d", te.Range.Start.Character, te.Range.End.Character, lastSpace+1, cursor) + } + if filter != "bar" { + t.Fatalf("filter got %q want %q", filter, "bar") + } +} + +// small helper to avoid importing strings +func contains(s, sub string) bool { return len(s) >= len(sub) && (func() bool { i := 0; for i+len(sub) <= len(s) { if s[i:i+len(sub)] == sub { return true }; i++ }; return false })() } diff --git a/internal/lsp/server.go b/internal/lsp/server.go index ec1a113..3154613 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -2,79 +2,14 @@ package lsp import ( "bufio" - "context" "encoding/json" - "fmt" - "hexai/internal" "hexai/internal/llm" "io" "log" - "net/textproto" - "os" - "strconv" - "strings" "sync" "time" ) -// JSON-RPC 2.0 structures (minimal) -type Request struct { - JSONRPC string `json:"jsonrpc"` - ID json.RawMessage `json:"id,omitempty"` - Method string `json:"method"` - Params json.RawMessage `json:"params,omitempty"` -} - -type Response struct { - JSONRPC string `json:"jsonrpc"` - ID json.RawMessage `json:"id,omitempty"` - Result any `json:"result,omitempty"` - Error *RespError `json:"error,omitempty"` -} - -type RespError struct { - Code int `json:"code"` - Message string `json:"message"` -} - -// LSP responses (subset) -type InitializeResult struct { - Capabilities ServerCapabilities `json:"capabilities"` - ServerInfo *ServerInfo `json:"serverInfo,omitempty"` -} - -type ServerInfo struct { - Name string `json:"name"` - Version string `json:"version,omitempty"` -} - -type ServerCapabilities struct { - TextDocumentSync any `json:"textDocumentSync,omitempty"` - CompletionProvider *CompletionOptions `json:"completionProvider,omitempty"` -} - -type CompletionOptions struct { - ResolveProvider bool `json:"resolveProvider,omitempty"` - TriggerCharacters []string `json:"triggerCharacters,omitempty"` -} - -type CompletionList struct { - IsIncomplete bool `json:"isIncomplete"` - Items []CompletionItem `json:"items"` -} - -type CompletionItem struct { - Label string `json:"label"` - Kind int `json:"kind,omitempty"` - Detail string `json:"detail,omitempty"` - InsertText string `json:"insertText,omitempty"` - InsertTextFormat int `json:"insertTextFormat,omitempty"` - FilterText string `json:"filterText,omitempty"` - TextEdit *TextEdit `json:"textEdit,omitempty"` - SortText string `json:"sortText,omitempty"` - Documentation string `json:"documentation,omitempty"` -} - // Server implements a minimal LSP over stdio. type Server struct { in *bufio.Reader @@ -84,19 +19,38 @@ type Server struct { mu sync.RWMutex docs map[string]*document logContext bool - llmClient llm.Client - lastInput time.Time -} - -func NewServer(r io.Reader, w io.Writer, logger *log.Logger, logContext bool) *Server { - s := &Server{in: bufio.NewReader(r), out: w, logger: logger, docs: make(map[string]*document), logContext: logContext} - if c, err := llm.NewDefault(logger); err != nil { - // Keep running without LLM; completions will be basic. - s.logger.Printf("llm disabled: %v", err) - } else { - s.llmClient = c - } - return s + llmClient llm.Client + lastInput time.Time + maxTokens int + contextMode string + windowLines int + maxContextTokens int +} + +func NewServer(r io.Reader, w io.Writer, logger *log.Logger, logContext bool, maxTokens int, contextMode string, windowLines int, maxContextTokens int) *Server { + s := &Server{in: bufio.NewReader(r), out: w, logger: logger, docs: make(map[string]*document), logContext: logContext} + if maxTokens <= 0 { + maxTokens = 500 + } + s.maxTokens = maxTokens + if contextMode == "" { + contextMode = "file-on-new-func" + } + if windowLines <= 0 { + windowLines = 120 + } + if maxContextTokens <= 0 { + maxContextTokens = 2000 + } + s.contextMode = contextMode + s.windowLines = windowLines + s.maxContextTokens = maxContextTokens + if c, err := llm.NewDefault(logger); err != nil { + s.logger.Printf("llm disabled: %v", err) + } else { + s.llmClient = c + } + return s } func (s *Server) Run() error { @@ -123,372 +77,3 @@ func (s *Server) Run() error { } } } - -func (s *Server) handle(req Request) { - switch req.Method { - case "initialize": - res := InitializeResult{ - Capabilities: ServerCapabilities{ - // 1 = TextDocumentSyncKindFull - TextDocumentSync: 1, - CompletionProvider: &CompletionOptions{ - ResolveProvider: false, - TriggerCharacters: []string{".", ":", "/", "_"}, - }, - }, - ServerInfo: &ServerInfo{Name: "hexai", Version: internal.Version}, - } - s.reply(req.ID, res, nil) - case "initialized": - // Notification; no response - s.logger.Println("client initialized") - case "shutdown": - s.reply(req.ID, nil, nil) - case "exit": - s.exited = true - // No response per spec. - os.Exit(0) - case "textDocument/didOpen": - var p DidOpenTextDocumentParams - if err := json.Unmarshal(req.Params, &p); err == nil { - s.setDocument(p.TextDocument.URI, p.TextDocument.Text) - s.markActivity() - } - case "textDocument/didChange": - var p DidChangeTextDocumentParams - if err := json.Unmarshal(req.Params, &p); err == nil { - if len(p.ContentChanges) > 0 { - s.setDocument(p.TextDocument.URI, p.ContentChanges[len(p.ContentChanges)-1].Text) - } - s.markActivity() - } - case "textDocument/didClose": - var p DidCloseTextDocumentParams - if err := json.Unmarshal(req.Params, &p); err == nil { - s.deleteDocument(p.TextDocument.URI) - s.markActivity() - } - case "textDocument/completion": - var p CompletionParams - var docStr string - if err := json.Unmarshal(req.Params, &p); err == nil { - above, current, below, funcCtx := s.lineContext(p.TextDocument.URI, p.Position) - docStr = fmt.Sprintf("file: %s\nline: %d\nabove: %s\ncurrent: %s\nbelow: %s\nfunction: %s", p.TextDocument.URI, p.Position.Line, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) - if s.logContext { - s.logger.Printf("completion ctx uri=%s line=%d char=%d above=%q current=%q below=%q function=%q", - p.TextDocument.URI, p.Position.Line, p.Position.Character, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) - } - // Previously: gated LLM calls until 1s idle. Removed to complete as you type. - // Try LLM-backed suggestion if available (always, no idle gating) - if s.llmClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second) - defer cancel() - // Tailor prompt if inside a Go function parameter list - inParams := false - if strings.Contains(current, "func ") { - open := strings.Index(current, "(") - close := strings.Index(current, ")") - if open >= 0 && p.Position.Character > open && (close == -1 || p.Position.Character <= close) { - inParams = true - } - } - sysPrompt := "You are a terse code completion engine. Return only the code to insert, no surrounding prose or backticks." - userPrompt := fmt.Sprintf("Provide the next likely code to insert at the cursor.\nFile: %s\nFunction/context: %s\nAbove line: %s\nCurrent line (cursor at character %d): %s\nBelow line: %s\nOnly return the completion snippet.", p.TextDocument.URI, funcCtx, above, p.Position.Character, current, below) - if inParams { - sysPrompt = "You are a terse Go code completion engine for function signatures. Return only the parameter list contents (without parentheses), no braces, no prose. Prefer idiomatic names and types." - userPrompt = fmt.Sprintf("Cursor is inside the function parameter list. Suggest only the parameter list (no parentheses).\nFunction line: %s\nCurrent line (cursor at %d): %s", funcCtx, p.Position.Character, current) - } - messages := []llm.Message{ - {Role: "system", Content: sysPrompt}, - {Role: "user", Content: userPrompt}, - } - // keep completions small by default - text, err := s.llmClient.Chat(ctx, messages, llm.WithMaxTokens(96), llm.WithTemperature(0.2)) - if err == nil && strings.TrimSpace(text) != "" { - cleaned := strings.TrimSpace(text) - var te *TextEdit - var filter string - if inParams { - // Replace inside the parentheses - open := strings.Index(current, "(") - close := strings.Index(current, ")") - if open >= 0 { - left := open + 1 - right := len(current) - if close >= 0 && close >= left { - right = close - } - if p.Position.Character < right { - right = p.Position.Character - } - te = &TextEdit{Range: Range{Start: Position{Line: p.Position.Line, Character: left}, End: Position{Line: p.Position.Line, Character: right}}, NewText: cleaned} - if left >= 0 && right >= left && right <= len(current) { - filter = strings.TrimLeft(current[left:right], " \t") - } - } - } - if te == nil { - // compute word start for replacement - startChar := p.Position.Character - if startChar > len(current) { - startChar = len(current) - } - for startChar > 0 { - ch := current[startChar-1] - if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' { - startChar-- - continue - } - break - } - te = &TextEdit{Range: Range{Start: Position{Line: p.Position.Line, Character: startChar}, End: Position{Line: p.Position.Line, Character: p.Position.Character}}, NewText: cleaned} - filter = strings.TrimLeft(current[startChar:p.Position.Character], " \t") - } - // Choose a label that starts with the current prefix when possible so the client doesn't filter it out. - label := trimLen(firstLine(cleaned)) - if filter != "" && !strings.HasPrefix(strings.ToLower(label), strings.ToLower(filter)) { - label = filter - } - items := []CompletionItem{{ - Label: label, - Kind: 1, - Detail: "OpenAI completion", - InsertTextFormat: 1, - FilterText: strings.TrimLeft(filter, " \t"), - TextEdit: te, - SortText: "0000", - Documentation: docStr, - }} - s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) - return - } - if err != nil { - s.logger.Printf("llm completion error: %v", err) - } - } - } - // Fallback basic/dummy completion - items := []CompletionItem{{ - Label: "hexai-complete", - Kind: 1, - Detail: "dummy completion", - InsertText: "hexai", - SortText: "9999", - Documentation: docStr, - }} - s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) - default: - // Unknown method; reply with Method Not Found for requests that have an ID. - if len(req.ID) != 0 { - s.reply(req.ID, nil, &RespError{Code: -32601, Message: fmt.Sprintf("method not found: %s", req.Method)}) - } - } -} - -func (s *Server) reply(id json.RawMessage, result any, err *RespError) { - resp := Response{JSONRPC: "2.0", ID: id, Result: result, Error: err} - s.writeMessage(resp) -} - -func (s *Server) readMessage() ([]byte, error) { - tp := textproto.NewReader(s.in) - var contentLength int - for { - line, err := tp.ReadLine() - if err != nil { - return nil, err - } - if line == "" { // end of headers - break - } - parts := strings.SplitN(line, ":", 2) - if len(parts) != 2 { - continue - } - key := strings.TrimSpace(strings.ToLower(parts[0])) - val := strings.TrimSpace(parts[1]) - switch key { - case "content-length": - n, err := strconv.Atoi(val) - if err != nil { - return nil, fmt.Errorf("invalid Content-Length: %v", err) - } - contentLength = n - } - } - if contentLength <= 0 { - return nil, fmt.Errorf("missing or invalid Content-Length") - } - buf := make([]byte, contentLength) - if _, err := io.ReadFull(s.in, buf); err != nil { - return nil, err - } - return buf, nil -} - -func (s *Server) writeMessage(v any) { - data, err := json.Marshal(v) - if err != nil { - s.logger.Printf("marshal error: %v", err) - return - } - header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data)) - if _, err := io.WriteString(s.out, header); err != nil { - s.logger.Printf("write header error: %v", err) - return - } - if _, err := s.out.Write(data); err != nil { - s.logger.Printf("write body error: %v", err) - return - } -} - -// --- Document store and helpers --- - -type document struct { - uri string - text string - lines []string -} - -func (s *Server) setDocument(uri, text string) { - s.mu.Lock() - defer s.mu.Unlock() - s.docs[uri] = &document{uri: uri, text: text, lines: splitLines(text)} -} - -func (s *Server) deleteDocument(uri string) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.docs, uri) -} - -func (s *Server) markActivity() { - s.mu.Lock() - s.lastInput = time.Now() - s.mu.Unlock() -} - -func (s *Server) getDocument(uri string) *document { - s.mu.RLock() - defer s.mu.RUnlock() - return s.docs[uri] -} - -func splitLines(sx string) []string { - sx = strings.ReplaceAll(sx, "\r\n", "\n") - return strings.Split(sx, "\n") -} - -// LSP param types (subset) -type TextDocumentItem struct { - URI string `json:"uri"` - LanguageID string `json:"languageId,omitempty"` - Version int `json:"version,omitempty"` - Text string `json:"text"` -} - -type VersionedTextDocumentIdentifier struct { - URI string `json:"uri"` - Version int `json:"version,omitempty"` -} - -type TextDocumentIdentifier struct { - URI string `json:"uri"` -} - -type DidOpenTextDocumentParams struct { - TextDocument TextDocumentItem `json:"textDocument"` -} - -type TextDocumentContentChangeEvent struct { - Range any `json:"range,omitempty"` - RangeLength int `json:"rangeLength,omitempty"` - Text string `json:"text"` -} - -type DidChangeTextDocumentParams struct { - TextDocument VersionedTextDocumentIdentifier `json:"textDocument"` - ContentChanges []TextDocumentContentChangeEvent `json:"contentChanges"` -} - -type DidCloseTextDocumentParams struct { - TextDocument TextDocumentIdentifier `json:"textDocument"` -} - -type Position struct { - Line int `json:"line"` - Character int `json:"character"` -} - -type CompletionParams struct { - TextDocument TextDocumentIdentifier `json:"textDocument"` - Position Position `json:"position"` - Context any `json:"context,omitempty"` -} - -// Range defines a text range in a document. -type Range struct { - Start Position `json:"start"` - End Position `json:"end"` -} - -// TextEdit represents a textual edit applicable to a document. -type TextEdit struct { - Range Range `json:"range"` - NewText string `json:"newText"` -} - -func (s *Server) lineContext(uri string, pos Position) (above, current, below, funcCtx string) { - d := s.getDocument(uri) - if d == nil || len(d.lines) == 0 { - return "", "", "", "" - } - idx := pos.Line - if idx < 0 { - idx = 0 - } - if idx >= len(d.lines) { - idx = len(d.lines) - 1 - } - current = d.lines[idx] - if idx-1 >= 0 { - above = d.lines[idx-1] - } - if idx+1 < len(d.lines) { - below = d.lines[idx+1] - } - for i := idx; i >= 0; i-- { - line := strings.TrimSpace(d.lines[i]) - if hasAny(line, []string{"func ", "def ", "class ", "fn ", "procedure ", "sub "}) { - funcCtx = line - break - } - } - return -} - -func hasAny(s string, needles []string) bool { - for _, n := range needles { - if strings.Contains(s, n) { - return true - } - } - return false -} - -func trimLen(s string) string { - s = strings.TrimSpace(s) - if len(s) > 200 { - return s[:200] + "…" - } - return s -} - -func firstLine(s string) string { - s = strings.ReplaceAll(s, "\r\n", "\n") - if idx := strings.IndexByte(s, '\n'); idx >= 0 { - return s[:idx] - } - return s -} diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go new file mode 100644 index 0000000..671d69b --- /dev/null +++ b/internal/lsp/transport.go @@ -0,0 +1,63 @@ +package lsp + +import ( + "encoding/json" + "fmt" + "io" + "net/textproto" + "strconv" + "strings" +) + +func (s *Server) readMessage() ([]byte, error) { + tp := textproto.NewReader(s.in) + var contentLength int + for { + line, err := tp.ReadLine() + if err != nil { + return nil, err + } + if line == "" { // end of headers + break + } + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + continue + } + key := strings.TrimSpace(strings.ToLower(parts[0])) + val := strings.TrimSpace(parts[1]) + switch key { + case "content-length": + n, err := strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("invalid Content-Length: %v", err) + } + contentLength = n + } + } + if contentLength <= 0 { + return nil, fmt.Errorf("missing or invalid Content-Length") + } + buf := make([]byte, contentLength) + if _, err := io.ReadFull(s.in, buf); err != nil { + return nil, err + } + return buf, nil +} + +func (s *Server) writeMessage(v any) { + data, err := json.Marshal(v) + if err != nil { + s.logger.Printf("marshal error: %v", err) + return + } + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data)) + if _, err := io.WriteString(s.out, header); err != nil { + s.logger.Printf("write header error: %v", err) + return + } + if _, err := s.out.Write(data); err != nil { + s.logger.Printf("write body error: %v", err) + return + } +} diff --git a/internal/lsp/types.go b/internal/lsp/types.go new file mode 100644 index 0000000..3a9a397 --- /dev/null +++ b/internal/lsp/types.go @@ -0,0 +1,120 @@ +package lsp + +import "encoding/json" + +// JSON-RPC 2.0 structures (minimal) +type Request struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type Response struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Result any `json:"result,omitempty"` + Error *RespError `json:"error,omitempty"` +} + +type RespError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// LSP responses (subset) +type InitializeResult struct { + Capabilities ServerCapabilities `json:"capabilities"` + ServerInfo *ServerInfo `json:"serverInfo,omitempty"` +} + +type ServerInfo struct { + Name string `json:"name"` + Version string `json:"version,omitempty"` +} + +type ServerCapabilities struct { + TextDocumentSync any `json:"textDocumentSync,omitempty"` + CompletionProvider *CompletionOptions `json:"completionProvider,omitempty"` +} + +type CompletionOptions struct { + ResolveProvider bool `json:"resolveProvider,omitempty"` + TriggerCharacters []string `json:"triggerCharacters,omitempty"` +} + +type CompletionList struct { + IsIncomplete bool `json:"isIncomplete"` + Items []CompletionItem `json:"items"` +} + +type CompletionItem struct { + Label string `json:"label"` + Kind int `json:"kind,omitempty"` + Detail string `json:"detail,omitempty"` + InsertText string `json:"insertText,omitempty"` + InsertTextFormat int `json:"insertTextFormat,omitempty"` + FilterText string `json:"filterText,omitempty"` + TextEdit *TextEdit `json:"textEdit,omitempty"` + SortText string `json:"sortText,omitempty"` + Documentation string `json:"documentation,omitempty"` +} + +// LSP param types (subset) +type TextDocumentItem struct { + URI string `json:"uri"` + LanguageID string `json:"languageId,omitempty"` + Version int `json:"version,omitempty"` + Text string `json:"text"` +} + +type VersionedTextDocumentIdentifier struct { + URI string `json:"uri"` + Version int `json:"version,omitempty"` +} + +type TextDocumentIdentifier struct { + URI string `json:"uri"` +} + +type DidOpenTextDocumentParams struct { + TextDocument TextDocumentItem `json:"textDocument"` +} + +type TextDocumentContentChangeEvent struct { + Range any `json:"range,omitempty"` + RangeLength int `json:"rangeLength,omitempty"` + Text string `json:"text"` +} + +type DidChangeTextDocumentParams struct { + TextDocument VersionedTextDocumentIdentifier `json:"textDocument"` + ContentChanges []TextDocumentContentChangeEvent `json:"contentChanges"` +} + +type DidCloseTextDocumentParams struct { + TextDocument TextDocumentIdentifier `json:"textDocument"` +} + +type Position struct { + Line int `json:"line"` + Character int `json:"character"` +} + +type CompletionParams struct { + TextDocument TextDocumentIdentifier `json:"textDocument"` + Position Position `json:"position"` + Context any `json:"context,omitempty"` +} + +// Range defines a text range in a document. +type Range struct { + Start Position `json:"start"` + End Position `json:"end"` +} + +// TextEdit represents a textual edit applicable to a document. +type TextEdit struct { + Range Range `json:"range"` + NewText string `json:"newText"` +} diff --git a/internal/test.go b/internal/test.go deleted file mode 100644 index 586a1bc..0000000 --- a/internal/test.go +++ /dev/null @@ -1,18 +0,0 @@ -package internal - -import "os" - -func fib(i int) int { - if i <= 1 { - return i - } - return fib(i-1) + fib(i-2) -} - -func countFilesInDir(dirPath string) int { - files, err := os.ReadDir(dirPath) - if err != nil { - return 0 - } - return len(files) -} |
