diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-10 19:34:02 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-10 19:34:02 +0200 |
| commit | de5029e6d4a7efffcccfb08d98770b1c1c4f54fe (patch) | |
| tree | 7429b9a72c96393ca74d323faf4f989224e83db0 | |
| parent | ba4b4b340b17450fa86122f227a75ef054e0ad53 (diff) | |
task bf088a70: extract LSP client and completion state
| -rw-r--r-- | internal/lsp/completion_state.go | 167 | ||||
| -rw-r--r-- | internal/lsp/document_test.go | 5 | ||||
| -rw-r--r-- | internal/lsp/handlers.go | 52 | ||||
| -rw-r--r-- | internal/lsp/handlers_completion.go | 34 | ||||
| -rw-r--r-- | internal/lsp/ignore_test.go | 8 | ||||
| -rw-r--r-- | internal/lsp/llm_client_registry.go | 99 | ||||
| -rw-r--r-- | internal/lsp/llm_client_registry_test.go | 65 | ||||
| -rw-r--r-- | internal/lsp/server.go | 139 |
8 files changed, 355 insertions, 214 deletions
diff --git a/internal/lsp/completion_state.go b/internal/lsp/completion_state.go new file mode 100644 index 0000000..9799561 --- /dev/null +++ b/internal/lsp/completion_state.go @@ -0,0 +1,167 @@ +package lsp + +import ( + "context" + "sync" + "time" +) + +type completionState struct { + stateMu sync.RWMutex + compCache map[string]string + compCacheOrder []string + pendingCompletions map[string][]CompletionItem + lastLLMCall time.Time + completionsDisabled bool +} + +func newCompletionState() completionState { + return completionState{ + compCache: make(map[string]string), + pendingCompletions: make(map[string][]CompletionItem), + } +} + +func (s *completionState) storePendingCompletion(key string, items []CompletionItem) { + if len(items) == 0 { + return + } + cpy := make([]CompletionItem, len(items)) + copy(cpy, items) + s.stateMu.Lock() + defer s.stateMu.Unlock() + if s.pendingCompletions == nil { + s.pendingCompletions = make(map[string][]CompletionItem) + } + s.pendingCompletions[key] = cpy +} + +func (s *completionState) setCompletionsDisabled(disabled bool) bool { + s.stateMu.Lock() + defer s.stateMu.Unlock() + prev := s.completionsDisabled + s.completionsDisabled = disabled + return prev +} + +func (s *completionState) completionDisabled() bool { + s.stateMu.RLock() + defer s.stateMu.RUnlock() + return s.completionsDisabled +} + +func (s *completionState) takePendingCompletion(key string) []CompletionItem { + s.stateMu.Lock() + defer s.stateMu.Unlock() + if len(s.pendingCompletions) == 0 { + return nil + } + items, ok := s.pendingCompletions[key] + if !ok { + return nil + } + delete(s.pendingCompletions, key) + cpy := make([]CompletionItem, len(items)) + copy(cpy, items) + return cpy +} + +func (s *completionState) cacheGet(key string) (string, bool) { + s.stateMu.Lock() + defer s.stateMu.Unlock() + v, ok := s.compCache[key] + if !ok { + return "", false + } + s.touchLocked(key) + return v, true +} + +func (s *completionState) cachePut(key, value string) { + s.stateMu.Lock() + defer s.stateMu.Unlock() + if s.compCache == nil { + s.compCache = make(map[string]string) + } + if _, exists := s.compCache[key]; !exists { + s.compCacheOrder = append(s.compCacheOrder, key) + s.compCache[key] = value + if len(s.compCacheOrder) > 10 { + old := s.compCacheOrder[0] + s.compCacheOrder = s.compCacheOrder[1:] + delete(s.compCache, old) + } + return + } + s.compCache[key] = value + s.touchLocked(key) +} + +func (s *completionState) touchLocked(key string) { + idx := -1 + for i, k := range s.compCacheOrder { + if k == key { + idx = i + break + } + } + if idx >= 0 { + s.compCacheOrder = append(append([]string{}, s.compCacheOrder[:idx]...), s.compCacheOrder[idx+1:]...) + } + s.compCacheOrder = append(s.compCacheOrder, key) +} + +func (s *completionState) waitForThrottle(ctx context.Context, interval time.Duration) bool { + if interval <= 0 { + return true + } + var wait time.Duration + for { + s.stateMu.Lock() + next := s.lastLLMCall.Add(interval) + now := time.Now() + if now.Before(next) { + wait = next.Sub(now) + s.stateMu.Unlock() + timer := time.NewTimer(wait) + select { + case <-ctx.Done(): + timer.Stop() + return false + case <-timer.C: + continue + } + } + s.lastLLMCall = now + s.stateMu.Unlock() + return true + } +} + +func (s *Server) storePendingCompletion(key string, items []CompletionItem) { + s.completionState.storePendingCompletion(key, items) +} + +func (s *Server) setCompletionsDisabled(disabled bool) bool { + return s.completionState.setCompletionsDisabled(disabled) +} + +func (s *Server) completionDisabled() bool { + return s.completionState.completionDisabled() +} + +func (s *Server) takePendingCompletion(key string) []CompletionItem { + return s.completionState.takePendingCompletion(key) +} + +func (s *Server) completionCacheGet(key string) (string, bool) { + return s.completionState.cacheGet(key) +} + +func (s *Server) completionCachePut(key, value string) { + s.completionState.cachePut(key, value) +} + +func (s *Server) waitForThrottle(ctx context.Context) bool { + return s.completionState.waitForThrottle(ctx, s.completionThrottle()) +} diff --git a/internal/lsp/document_test.go b/internal/lsp/document_test.go index 1a3f909..3dc970d 100644 --- a/internal/lsp/document_test.go +++ b/internal/lsp/document_test.go @@ -8,7 +8,6 @@ import ( "testing" "codeberg.org/snonux/hexai/internal/appconfig" - "codeberg.org/snonux/hexai/internal/llm" ) func newTestServer() *Server { @@ -40,9 +39,9 @@ func newTestServer() *Server { docs: make(map[string]*document), cfg: cfg, codeActionSubsystem: codeActionSubsystem{ - altClients: make(map[string]llm.Client), - llmProvider: canonicalProvider(cfg.Provider), + llmClientRegistry: llmClientRegistry{llmProvider: canonicalProvider(cfg.Provider)}, }, + completionSubsystem: completionSubsystem{completionState: completionState{}}, } } diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index 7b61970..fe52512 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -187,8 +187,6 @@ func (s *Server) reply(id json.RawMessage, result any, err *RespError) { // that an LLM request is already in flight. // removed: previous single in-flight LLM busy gate and busy item -// --- small completion cache (last ~10 entries) --- - func (s *Server) completionCacheKey(p CompletionParams, above, current, below, funcCtx string, inParams bool, hasExtra bool, extraText string) string { // Normalize left-of-cursor by trimming trailing spaces/tabs idx := p.Position.Character @@ -232,56 +230,6 @@ func (s *Server) completionCacheKey(p CompletionParams, above, current, below, f }, "\x1f") // use unit separator to avoid collisions } -func (s *Server) completionCacheGet(key string) (string, bool) { - s.mu.Lock() - defer s.mu.Unlock() - v, ok := s.compCache[key] - if !ok { - return "", false - } - // move to most-recent - s.compCacheTouchLocked(key) - return v, true -} - -func (s *Server) completionCachePut(key, value string) { - s.mu.Lock() - defer s.mu.Unlock() - if s.compCache == nil { - s.compCache = make(map[string]string) - } - if _, exists := s.compCache[key]; !exists { - s.compCacheOrder = append(s.compCacheOrder, key) - s.compCache[key] = value - if len(s.compCacheOrder) > 10 { - // evict oldest - old := s.compCacheOrder[0] - s.compCacheOrder = s.compCacheOrder[1:] - delete(s.compCache, old) - } - return - } - // update existing and mark most-recent - s.compCache[key] = value - s.compCacheTouchLocked(key) -} - -func (s *Server) compCacheTouchLocked(key string) { - // assumes s.mu is held - // remove any existing occurrence of key in order slice - idx := -1 - for i, k := range s.compCacheOrder { - if k == key { - idx = i - break - } - } - if idx >= 0 { - s.compCacheOrder = append(append([]string{}, s.compCacheOrder[:idx]...), s.compCacheOrder[idx+1:]...) - } - s.compCacheOrder = append(s.compCacheOrder, key) -} - // isTriggerEvent returns true when the completion request appears to be caused // by typing one of our configured trigger characters. It checks the LSP // CompletionContext if provided and also falls back to inspecting the character diff --git a/internal/lsp/handlers_completion.go b/internal/lsp/handlers_completion.go index 3009d50..7c293b0 100644 --- a/internal/lsp/handlers_completion.go +++ b/internal/lsp/handlers_completion.go @@ -63,7 +63,7 @@ func (s *Server) handleCompletion(req Request) { if s.logContext { s.logCompletionContext(p, above, current, below, funcCtx) } - if s.llmClient != nil { + if s.currentLLMClient() != nil { newFunc := s.isDefiningNewFunction(p.TextDocument.URI, p.Position) extra, has := s.buildAdditionalContext(newFunc, p.TextDocument.URI, p.Position) items, ok, incomplete := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, has, extra) @@ -512,38 +512,6 @@ func (s *Server) waitForDebounce(ctx context.Context) { } } -// waitForThrottle enforces a minimum spacing between LLM calls. Returns false -// if the context is canceled while waiting. -func (s *Server) waitForThrottle(ctx context.Context) bool { - interval := s.completionThrottle() - if interval <= 0 { - return true - } - var wait time.Duration - for { - s.mu.Lock() - next := s.lastLLMCall.Add(interval) - now := time.Now() - if now.Before(next) { - wait = next.Sub(now) - s.mu.Unlock() - timer := time.NewTimer(wait) - select { - case <-ctx.Done(): - timer.Stop() - return false - case <-timer.C: - // try again to set the next call time - continue - } - } - // we are allowed to proceed now; record this call as the latest - s.lastLLMCall = now - s.mu.Unlock() - return true - } -} - // buildCompletionMessages constructs the LLM messages for completion. func (s *Server) buildCompletionMessages(inlinePrompt, hasExtra bool, extraText string, inParams bool, p CompletionParams, above, current, below, funcCtx string) []llm.Message { vars := map[string]string{ diff --git a/internal/lsp/ignore_test.go b/internal/lsp/ignore_test.go index f14daf5..7df7428 100644 --- a/internal/lsp/ignore_test.go +++ b/internal/lsp/ignore_test.go @@ -9,7 +9,6 @@ import ( "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/ignore" - "codeberg.org/snonux/hexai/internal/llm" ) // newIgnoreTestServer creates a Server with an ignore checker configured @@ -26,7 +25,8 @@ func newIgnoreTestServer(gitRoot string, useGI bool, extra []string, notifyIgnor logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), cfg: cfg, - codeActionSubsystem: codeActionSubsystem{altClients: make(map[string]llm.Client)}, + codeActionSubsystem: codeActionSubsystem{llmClientRegistry: llmClientRegistry{}}, + completionSubsystem: completionSubsystem{completionState: completionState{}}, ignoreChecker: ignore.New(gitRoot, useGI, extra), } return s @@ -129,7 +129,7 @@ func TestIsFileIgnored_NoChecker(t *testing.T) { s := &Server{ logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), - codeActionSubsystem: codeActionSubsystem{altClients: make(map[string]llm.Client)}, + codeActionSubsystem: codeActionSubsystem{llmClientRegistry: llmClientRegistry{}}, // ignoreChecker is nil } @@ -166,7 +166,7 @@ func TestIgnoreLSPNotifyEnabled_NilConfig(t *testing.T) { s := &Server{ logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), - codeActionSubsystem: codeActionSubsystem{altClients: make(map[string]llm.Client)}, + codeActionSubsystem: codeActionSubsystem{llmClientRegistry: llmClientRegistry{}}, cfg: appconfig.App{}, } if !s.ignoreLSPNotifyEnabled() { diff --git a/internal/lsp/llm_client_registry.go b/internal/lsp/llm_client_registry.go new file mode 100644 index 0000000..53fa25f --- /dev/null +++ b/internal/lsp/llm_client_registry.go @@ -0,0 +1,99 @@ +package lsp + +import ( + "strings" + "sync" + + "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/logging" +) + +type llmClientBuilder func(appconfig.App, string, string) (llm.Client, error) + +type llmClientRegistry struct { + clientsMu sync.RWMutex + llmClient llm.Client + llmProvider string + altClients map[string]llm.Client +} + +func newLLMClientRegistry() llmClientRegistry { + return llmClientRegistry{ + altClients: make(map[string]llm.Client), + } +} + +func (r *llmClientRegistry) applyOptions(client llm.Client, configuredProvider string) { + provider := canonicalProvider(configuredProvider) + if client != nil { + if name := canonicalProvider(client.Name()); name != "" { + provider = name + } + } + r.clientsMu.Lock() + defer r.clientsMu.Unlock() + r.llmClient = client + r.llmProvider = provider + r.altClients = make(map[string]llm.Client) +} + +func (r *llmClientRegistry) current() llm.Client { + r.clientsMu.RLock() + defer r.clientsMu.RUnlock() + return r.llmClient +} + +func (r *llmClientRegistry) clientFor(spec requestSpec, cfg appconfig.App, build llmClientBuilder) llm.Client { + provider := canonicalProvider(spec.provider) + + r.clientsMu.RLock() + baseProvider := r.llmProvider + baseClient := r.llmClient + if baseClient != nil && strings.TrimSpace(baseProvider) == "" { + baseProvider = canonicalProvider(baseClient.Name()) + } + if provider == "" { + provider = baseProvider + } + if provider == baseProvider && baseClient != nil { + r.clientsMu.RUnlock() + return baseClient + } + if cached, ok := r.altClients[provider]; ok { + r.clientsMu.RUnlock() + return cached + } + r.clientsMu.RUnlock() + + modelOverride := strings.TrimSpace(spec.entry.Model) + if modelOverride == "" { + modelOverride = strings.TrimSpace(spec.fallbackModel) + } + client, err := build(cfg, provider, modelOverride) + if err != nil { + logging.Logf("lsp ", "failed to build client for provider=%s: %v", provider, err) + if baseClient != nil { + return baseClient + } + return nil + } + + r.clientsMu.Lock() + defer r.clientsMu.Unlock() + if provider == r.llmProvider { + if r.llmClient == nil { + r.llmClient = client + r.llmProvider = provider + } + return r.llmClient + } + if existing, ok := r.altClients[provider]; ok { + return existing + } + if r.altClients == nil { + r.altClients = make(map[string]llm.Client) + } + r.altClients[provider] = client + return client +} diff --git a/internal/lsp/llm_client_registry_test.go b/internal/lsp/llm_client_registry_test.go new file mode 100644 index 0000000..5700a53 --- /dev/null +++ b/internal/lsp/llm_client_registry_test.go @@ -0,0 +1,65 @@ +package lsp + +import ( + "errors" + "testing" + + "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/llm" +) + +func TestLLMClientRegistryClientFor_CachesAlternateProviders(t *testing.T) { + registry := newLLMClientRegistry() + registry.applyOptions(fakeClient{name: "openai", model: "gpt-5.0"}, "openai") + + cfg := appconfig.App{} + spec := requestSpec{ + provider: "anthropic", + entry: appconfig.SurfaceConfig{Model: "claude-3-7-sonnet"}, + fallbackModel: "claude-3-7-sonnet", + } + + buildCalls := 0 + builder := func(_ appconfig.App, provider, modelOverride string) (llm.Client, error) { + buildCalls++ + return fakeClient{name: provider, model: modelOverride}, nil + } + + first := registry.clientFor(spec, cfg, builder) + second := registry.clientFor(spec, cfg, builder) + if first == nil || second == nil { + t.Fatal("expected alternate provider client") + } + if buildCalls != 1 { + t.Fatalf("expected one build for cached alternate client, got %d", buildCalls) + } + if first.Name() != "anthropic" || second.Name() != "anthropic" { + t.Fatalf("expected anthropic client, got %q and %q", first.Name(), second.Name()) + } +} + +func TestLLMClientRegistryClientFor_FallsBackToBaseClientOnBuildError(t *testing.T) { + base := fakeClient{name: "openai", model: "gpt-5.0"} + registry := newLLMClientRegistry() + registry.applyOptions(base, "openai") + + spec := requestSpec{provider: "anthropic"} + builder := func(appconfig.App, string, string) (llm.Client, error) { + return nil, errors.New("boom") + } + + got := registry.clientFor(spec, appconfig.App{}, builder) + if got != base { + t.Fatalf("expected base client fallback, got %#v", got) + } +} + +func TestLLMClientRegistryCurrent_ReturnsConfiguredClient(t *testing.T) { + registry := newLLMClientRegistry() + client := fakeClient{name: "openrouter", model: "gpt-4.1-mini"} + registry.applyOptions(client, "openrouter") + + if got := registry.current(); got != client { + t.Fatalf("expected configured client, got %#v", got) + } +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index b33147c..4e8a339 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -36,7 +36,6 @@ type Server struct { logContext bool configStore *runtimeconfig.Store cfg appconfig.App - llmClient llm.Client codeActionSubsystem chatSubsystem // LLM request stats @@ -58,12 +57,7 @@ type Server struct { } type completionSubsystem struct { - // Small LRU cache for recent code completion outputs (keyed by context) - compCache map[string]string - compCacheOrder []string // most-recent at end; cap ~10 - pendingCompletions map[string][]CompletionItem - lastLLMCall time.Time - completionsDisabled bool + completionState } type chatSubsystem struct { @@ -71,8 +65,7 @@ type chatSubsystem struct { } type codeActionSubsystem struct { - llmProvider string - altClients map[string]llm.Client + llmClientRegistry } // StatusSink receives status updates from the LSP server. @@ -105,10 +98,14 @@ func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions) configStore: opts.ConfigStore, serverCtx: ctx, serverCancel: cancel, + codeActionSubsystem: codeActionSubsystem{ + llmClientRegistry: llmClientRegistry{}, + }, + completionSubsystem: completionSubsystem{ + completionState: completionState{}, + }, } s.startTime = time.Now() - s.compCache = make(map[string]string) - s.pendingCompletions = make(map[string][]CompletionItem) s.applyOptions(opts) // Initialize dispatch table s.handlers = map[string]func(Request){ @@ -142,19 +139,13 @@ func (s *Server) applyOptions(opts ServerOptions) { } else { s.cfg = appconfig.App{} } - s.llmClient = opts.Client - if opts.Client != nil { - s.llmProvider = canonicalProvider(opts.Client.Name()) - } else { - s.llmProvider = canonicalProvider(s.cfg.Provider) - } - s.altClients = make(map[string]llm.Client) if opts.IgnoreChecker != nil { s.ignoreChecker = opts.IgnoreChecker } if opts.StatusSink != nil { s.statusSink = opts.StatusSink } + s.llmClientRegistry.applyOptions(opts.Client, s.cfg.Provider) } // ApplyOptions updates the server's configuration at runtime. @@ -163,9 +154,7 @@ func (s *Server) ApplyOptions(opts ServerOptions) { } func (s *Server) currentLLMClient() llm.Client { - s.mu.RLock() - defer s.mu.RUnlock() - return s.llmClient + return s.llmClientRegistry.current() } func newClientForProvider(cfg appconfig.App, provider, modelOverride string) (llm.Client, error) { @@ -173,112 +162,18 @@ func newClientForProvider(cfg appconfig.App, provider, modelOverride string) (ll } func (s *Server) clientFor(spec requestSpec) llm.Client { - provider := canonicalProvider(spec.provider) - s.mu.RLock() - baseProvider := s.llmProvider - baseClient := s.llmClient - if baseClient != nil && strings.TrimSpace(baseProvider) == "" { - baseProvider = canonicalProvider(baseClient.Name()) - } - if provider == "" { - provider = baseProvider - } - if provider == baseProvider && baseClient != nil { - s.mu.RUnlock() - return baseClient - } - if c, ok := s.altClients[provider]; ok { - s.mu.RUnlock() - return c - } - cfg := s.cfg - store := s.configStore - s.mu.RUnlock() - if store != nil { - cfg = store.Snapshot() - } - modelOverride := strings.TrimSpace(spec.entry.Model) - if modelOverride == "" { - modelOverride = strings.TrimSpace(spec.fallbackModel) - } - client, err := newClientForProvider(cfg, provider, modelOverride) - if err != nil { - logging.Logf("lsp ", "failed to build client for provider=%s: %v", provider, err) - if baseClient != nil { - return baseClient - } - return nil - } - s.mu.Lock() - defer s.mu.Unlock() - if provider == s.llmProvider { - if s.llmClient == nil { - s.llmClient = client - s.llmProvider = provider - } - return s.llmClient - } - if existing, ok := s.altClients[provider]; ok { - return existing - } - if s.altClients == nil { - s.altClients = make(map[string]llm.Client) - } - s.altClients[provider] = client - return client + return s.llmClientRegistry.clientFor(spec, s.currentConfig(), newClientForProvider) } func (s *Server) currentConfig() appconfig.App { - if s.configStore != nil { - return s.configStore.Snapshot() - } - s.mu.RLock() - defer s.mu.RUnlock() - return s.cfg -} - -func (s *Server) storePendingCompletion(key string, items []CompletionItem) { - if len(items) == 0 { - return - } - cpy := make([]CompletionItem, len(items)) - copy(cpy, items) - s.mu.Lock() - if s.pendingCompletions == nil { - s.pendingCompletions = make(map[string][]CompletionItem) - } - s.pendingCompletions[key] = cpy - s.mu.Unlock() -} - -func (s *Server) setCompletionsDisabled(disabled bool) bool { - s.mu.Lock() - prev := s.completionsDisabled - s.completionsDisabled = disabled - s.mu.Unlock() - return prev -} - -func (s *Server) completionDisabled() bool { s.mu.RLock() - defer s.mu.RUnlock() - return s.completionsDisabled -} - -func (s *Server) takePendingCompletion(key string) []CompletionItem { - s.mu.Lock() - defer s.mu.Unlock() - if len(s.pendingCompletions) == 0 { - return nil - } - items, ok := s.pendingCompletions[key] - if !ok { - return nil + store := s.configStore + cfg := s.cfg + s.mu.RUnlock() + if store != nil { + return store.Snapshot() } - delete(s.pendingCompletions, key) - cpy := make([]CompletionItem, len(items)) - copy(cpy, items) - return cpy + return cfg } func (s *Server) maxTokens() int { |
