diff options
| -rw-r--r-- | internal/appconfig/config_load.go | 76 | ||||
| -rw-r--r-- | internal/llm/ollama.go | 52 | ||||
| -rw-r--r-- | internal/lsp/handlers.go | 106 | ||||
| -rw-r--r-- | internal/lsp/handlers_completion.go | 67 |
4 files changed, 176 insertions, 125 deletions
diff --git a/internal/appconfig/config_load.go b/internal/appconfig/config_load.go index cb02a2e..4c6214c 100644 --- a/internal/appconfig/config_load.go +++ b/internal/appconfig/config_load.go @@ -610,6 +610,47 @@ func parseSurfaceEntries(raw any, path string, logger *log.Logger) ([]SurfaceCon } } +// decodeModelEntryFromMap decodes a map[string]any entry into a SurfaceConfig. +// It validates that model, provider, and temperature fields have the correct types. +func decodeModelEntryFromMap(v map[string]any, path string, logger *log.Logger) (*SurfaceConfig, bool) { + model := "" + provider := "" + if m, ok := v["model"]; ok { + s, ok := m.(string) + if !ok { + if logger != nil { + logger.Printf("config: %s.model must be a string", path) + } + return nil, false + } + model = strings.TrimSpace(s) + } + if pRaw, ok := v["provider"]; ok { + ps, ok := pRaw.(string) + if !ok { + if logger != nil { + logger.Printf("config: %s.provider must be a string", path) + } + return nil, false + } + provider = strings.TrimSpace(ps) + } + var tempPtr *float64 + if tRaw, ok := v["temperature"]; ok { + parsed, ok := parseTemperatureValue(tRaw, path, logger) + if !ok { + return nil, false + } + tempPtr = parsed + } + if model == "" && tempPtr == nil && provider == "" { + return nil, false + } + return &SurfaceConfig{Provider: provider, Model: model, Temperature: tempPtr}, true +} + +// decodeModelEntry converts a raw TOML value (string or table) into a SurfaceConfig. +// A plain string is treated as a model name; a table may carry model, provider and temperature. func decodeModelEntry(raw any, path string, logger *log.Logger) (*SurfaceConfig, bool) { if raw == nil { return nil, false @@ -622,40 +663,7 @@ func decodeModelEntry(raw any, path string, logger *log.Logger) (*SurfaceConfig, } return &SurfaceConfig{Model: model}, true case map[string]any: - model := "" - provider := "" - if m, ok := v["model"]; ok { - s, ok := m.(string) - if !ok { - if logger != nil { - logger.Printf("config: %s.model must be a string", path) - } - return nil, false - } - model = strings.TrimSpace(s) - } - if pRaw, ok := v["provider"]; ok { - ps, ok := pRaw.(string) - if !ok { - if logger != nil { - logger.Printf("config: %s.provider must be a string", path) - } - return nil, false - } - provider = strings.TrimSpace(ps) - } - var tempPtr *float64 - if tRaw, ok := v["temperature"]; ok { - parsed, ok := parseTemperatureValue(tRaw, path, logger) - if !ok { - return nil, false - } - tempPtr = parsed - } - if model == "" && tempPtr == nil && provider == "" { - return nil, false - } - return &SurfaceConfig{Provider: provider, Model: model, Temperature: tempPtr}, true + return decodeModelEntryFromMap(v, path, logger) default: if logger != nil { logger.Printf("config: %s must be a string or table, got %T", path, raw) diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go index be93ab0..e212466 100644 --- a/internal/llm/ollama.go +++ b/internal/llm/ollama.go @@ -133,7 +133,35 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ func (c ollamaClient) Name() string { return "ollama" } func (c ollamaClient) DefaultModel() string { return c.defaultModel } -// Streaming support (optional) +// parseOllamaStream reads NDJSON streaming events from dec, calling onDelta for each +// non-empty content delta. Returns an error if decoding fails or the server signals +// an error event; returns nil when the done flag is received or the stream ends. +func parseOllamaStream(dec *json.Decoder, start time.Time, onDelta func(string)) error { + for { + var ev ollamaChatResponse + if err := dec.Decode(&ev); err != nil { + if errors.Is(err, io.EOF) { + break + } + logging.Logf("llm/ollama ", "%sdecode stream error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + return err + } + if strings.TrimSpace(ev.Error) != "" { + logging.Logf("llm/ollama ", "%sstream event error: %s%s", logging.AnsiRed, ev.Error, logging.AnsiBase) + return fmt.Errorf("ollama stream error: %s", ev.Error) + } + if s := ev.Message.Content; strings.TrimSpace(s) != "" { + onDelta(s) + } + if ev.Done { + break + } + } + return nil +} + +// ChatStream sends a streaming chat request to Ollama, calling onDelta for each +// received content delta. It blocks until the stream ends or an error occurs. func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error { o := Options{Model: c.defaultModel} for _, opt := range opts { @@ -167,26 +195,8 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt return err } - dec := json.NewDecoder(resp.Body) - for { - var ev ollamaChatResponse - if err := dec.Decode(&ev); err != nil { - if errors.Is(err, io.EOF) { - break - } - logging.Logf("llm/ollama ", "%sdecode stream error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) - return err - } - if strings.TrimSpace(ev.Error) != "" { - logging.Logf("llm/ollama ", "%sstream event error: %s%s", logging.AnsiRed, ev.Error, logging.AnsiBase) - return fmt.Errorf("ollama stream error: %s", ev.Error) - } - if s := ev.Message.Content; strings.TrimSpace(s) != "" { - onDelta(s) - } - if ev.Done { - break - } + if err := parseOllamaStream(json.NewDecoder(resp.Body), start, onDelta); err != nil { + return err } logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start)) return nil diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index 3b3f8e0..0f98715 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -141,61 +141,63 @@ func (s *Server) completionCacheKey(p CompletionParams, above, current, below, f }, "\x1f") // use unit separator to avoid collisions } -// isTriggerEvent returns true when the completion request appears to be caused -// by typing one of our configured trigger characters. It checks the LSP -// CompletionContext if provided and also falls back to inspecting the character -// immediately to the left of the cursor. -func (s *Server) isTriggerEvent(p CompletionParams, current string) bool { - open, _, openChar, closeChar := s.inlineMarkers() - doubleSeqs := doubleOpenSequences(open, openChar, closeChar) - triggerChars := s.triggerCharacters() - // 1) Inspect LSP completion context if present - if p.Context != nil { - var ctx struct { - TriggerKind int `json:"triggerKind"` - TriggerCharacter string `json:"triggerCharacter,omitempty"` - } - if raw, ok := p.Context.(json.RawMessage); ok { - if err := json.Unmarshal(raw, &ctx); err != nil { - logging.Logf("lsp ", "handleCompletion: unmarshal raw context: %v", err) - } - } else { - b, _ := json.Marshal(p.Context) - if err := json.Unmarshal(b, &ctx); err != nil { - logging.Logf("lsp ", "handleCompletion: unmarshal context: %v", err) - } - } - // If configured and the line contains a bare double-open marker (e.g., '>>!' with no '>>!text>'), - // do not treat as a trigger source. - if containsAny(current, doubleSeqs) && !hasDoubleOpenTrigger(current, open, openChar, closeChar) { - return false +// checkTriggerFromContext inspects the LSP CompletionContext (if present) to decide if +// the completion was triggered by one of our configured trigger characters or by a manual +// invoke. Returns (result, decided): decided=true means the caller should use result +// directly; decided=false means the context was absent or inconclusive (TriggerKind 3). +func (s *Server) checkTriggerFromContext(p CompletionParams, current string, open string, openChar, closeChar byte, doubleSeqs, triggerChars []string) (result bool, decided bool) { + if p.Context == nil { + return false, false + } + var ctx struct { + TriggerKind int `json:"triggerKind"` + TriggerCharacter string `json:"triggerCharacter,omitempty"` + } + if raw, ok := p.Context.(json.RawMessage); ok { + if err := json.Unmarshal(raw, &ctx); err != nil { + logging.Logf("lsp ", "handleCompletion: unmarshal raw context: %v", err) } - // TriggerKind 1 = Invoked (manual). Always allow manual invoke. - if ctx.TriggerKind == 1 { - return true + } else { + b, _ := json.Marshal(p.Context) + if err := json.Unmarshal(b, &ctx); err != nil { + logging.Logf("lsp ", "handleCompletion: unmarshal context: %v", err) } - // TriggerKind 2 is TriggerCharacter per LSP spec - if ctx.TriggerKind == 2 { - if ctx.TriggerCharacter != "" { - for _, c := range triggerChars { - if c == ctx.TriggerCharacter { - return true - } + } + // Bare double-open markers must not be treated as a trigger source. + if containsAny(current, doubleSeqs) && !hasDoubleOpenTrigger(current, open, openChar, closeChar) { + return false, true + } + // TriggerKind 1 = Invoked (manual). Always allow. + if ctx.TriggerKind == 1 { + return true, true + } + // TriggerKind 2 = TriggerCharacter per LSP spec. + if ctx.TriggerKind == 2 { + if ctx.TriggerCharacter != "" { + for _, c := range triggerChars { + if c == ctx.TriggerCharacter { + return true, true } - return false } - // No character provided but reported as TriggerCharacter; be conservative - return false + return false, true } - // For TriggerForIncomplete (3), require manual char check below + // No character provided but reported as TriggerCharacter; be conservative. + return false, true } - // 2) Fallback: check the character immediately prior to cursor. + // TriggerKind 3 (TriggerForIncomplete): fall through to cursor-char check. + return false, false +} + +// checkTriggerFromCursorChar is the fallback check that looks at the character +// immediately to the left of the cursor position to decide whether it matches a +// configured trigger character. +func (s *Server) checkTriggerFromCursorChar(p CompletionParams, current string, open string, openChar, closeChar byte, doubleSeqs, triggerChars []string) bool { // Convert UTF-16 offset to byte offset for correct multi-byte handling. byteIdx := utf16OffsetToByteOffset(current, p.Position.Character) if byteIdx <= 0 || byteIdx > len(current) { return false } - // Bare double-open should not trigger via fallback char either (only when configured) + // Bare double-open should not trigger via fallback char check either. if containsAny(current, doubleSeqs) && !hasDoubleOpenTrigger(current, open, openChar, closeChar) { return false } @@ -209,6 +211,22 @@ func (s *Server) isTriggerEvent(p CompletionParams, current string) bool { return false } +// isTriggerEvent returns true when the completion request appears to be caused +// by typing one of our configured trigger characters. It checks the LSP +// CompletionContext if provided and also falls back to inspecting the character +// immediately to the left of the cursor. +func (s *Server) isTriggerEvent(p CompletionParams, current string) bool { + open, _, openChar, closeChar := s.inlineMarkers() + doubleSeqs := doubleOpenSequences(open, openChar, closeChar) + triggerChars := s.triggerCharacters() + // 1) Inspect LSP completion context if present. + if result, decided := s.checkTriggerFromContext(p, current, open, openChar, closeChar, doubleSeqs, triggerChars); decided { + return result + } + // 2) Fallback: check the character immediately prior to cursor. + return s.checkTriggerFromCursorChar(p, current, open, openChar, closeChar, doubleSeqs, triggerChars) +} + func (s *Server) makeCompletionItems(cleaned string, inParams bool, current string, p CompletionParams, docStr string, detail string, sortPrefix string) []CompletionItem { te, filter := computeTextEditAndFilter(cleaned, inParams, current, p) rm := s.collectPromptRemovalEdits(p.TextDocument.URI) diff --git a/internal/lsp/handlers_completion.go b/internal/lsp/handlers_completion.go index 527d020..d6529de 100644 --- a/internal/lsp/handlers_completion.go +++ b/internal/lsp/handlers_completion.go @@ -321,7 +321,6 @@ func (s *Server) executeChatCompletion(ctx context.Context, plan completionPlan, text, err := client.Chat(ctx, messages, spec.options...) if err != nil { logging.Logf("lsp ", "llm completion error: %v", err) - s.logLLMStats("") return nil, false } s.incRecvCounters(len(text)) @@ -426,6 +425,45 @@ func (s *Server) prefixHeuristicAllows(inlinePrompt bool, current string, p Comp return j-start >= min } +// buildNativeCompletionCacheKey constructs the per-provider cache key for native completions. +func buildNativeCompletionCacheKey(planCacheKey, provider, modelUsed string, clientName string) string { + providerKey := provider + if providerKey == "" { + providerKey = llmutils.CanonicalProvider(clientName) + } + return planCacheKey + "|" + providerKey + ":" + modelUsed +} + +// postProcessNativeCompletion strips duplicates and applies indentation to the raw suggestion. +// Returns the cleaned text, or an empty string when the suggestion should be discarded. +func (s *Server) postProcessNativeCompletion(raw, current string, charOffset int) string { + cleaned := strings.TrimSpace(raw) + if cleaned == "" { + return "" + } + openStr, _, openChar, closeChar := s.inlineMarkers() + cByte := utf16OffsetToByteOffset(current, charOffset) + leftOfCursor := current[:cByte] + cleaned = stripDuplicateAssignmentPrefix(leftOfCursor, cleaned) + if cleaned == "" { + return "" + } + cleaned = stripDuplicateGeneralPrefix(leftOfCursor, cleaned) + if cleaned == "" { + return "" + } + if strings.TrimSpace(cleaned) != "" && hasDoubleOpenTrigger(current, openStr, openChar, closeChar) { + if indent := leadingIndent(current); indent != "" { + cleaned = applyIndent(indent, cleaned) + } + } + // Guard against all-whitespace result without stripping intentional indentation. + if strings.TrimSpace(cleaned) == "" { + return "" + } + return cleaned +} + // tryProviderNativeCompletion attempts provider-native completion and returns items when successful. func (s *Server) tryProviderNativeCompletion(ctx context.Context, plan completionPlan, spec requestSpec, client llm.Client, sortPrefix string) ([]CompletionItem, bool) { cc, ok := client.(llm.CodeCompleter) @@ -437,7 +475,6 @@ func (s *Server) tryProviderNativeCompletion(ctx context.Context, plan completio before, after := s.docBeforeAfter(p.TextDocument.URI, p.Position) path := strings.TrimPrefix(p.TextDocument.URI, "file://") cfg := s.currentConfig() - openStr, _, openChar, closeChar := s.inlineMarkers() prompt := renderTemplate(cfg.PromptNativeCompletion, map[string]string{ "path": path, "before": before, @@ -466,34 +503,12 @@ func (s *Server) tryProviderNativeCompletion(ctx context.Context, plan completio s.incRecvCounters(len(suggestions[0])) _ = stats.Update(ctx2, client.Name(), modelUsed, sentBytes, len(suggestions[0])) s.logLLMStats(modelUsed) - cleaned := strings.TrimSpace(suggestions[0]) - if cleaned == "" { - return nil, false - } - cByte := utf16OffsetToByteOffset(current, p.Position.Character) - cleaned = stripDuplicateAssignmentPrefix(current[:cByte], cleaned) + cleaned := s.postProcessNativeCompletion(suggestions[0], current, p.Position.Character) if cleaned == "" { return nil, false } - cleaned = stripDuplicateGeneralPrefix(current[:cByte], cleaned) - if cleaned == "" { - return nil, false - } - if strings.TrimSpace(cleaned) != "" && hasDoubleOpenTrigger(current, openStr, openChar, closeChar) { - indent := leadingIndent(current) - if indent != "" { - cleaned = applyIndent(indent, cleaned) - } - } - if strings.TrimSpace(cleaned) == "" { - return nil, false - } detail := fmt.Sprintf("Hexai %s:%s", client.Name(), modelUsed) - providerKey := provider - if providerKey == "" { - providerKey = llmutils.CanonicalProvider(client.Name()) - } - cacheKey := plan.cacheKey + "|" + providerKey + ":" + modelUsed + cacheKey := buildNativeCompletionCacheKey(plan.cacheKey, provider, modelUsed, client.Name()) s.completionCachePut(cacheKey, cleaned) items := s.makeCompletionItems(cleaned, plan.inParams, current, p, plan.docStr, detail, sortPrefix) return items, true |
