diff options
Diffstat (limited to 'internal/lsp/handlers_utils.go')
| -rw-r--r-- | internal/lsp/handlers_utils.go | 166 |
1 files changed, 142 insertions, 24 deletions
diff --git a/internal/lsp/handlers_utils.go b/internal/lsp/handlers_utils.go index 5d5ca27..3bd13ee 100644 --- a/internal/lsp/handlers_utils.go +++ b/internal/lsp/handlers_utils.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/llm" "codeberg.org/snonux/hexai/internal/logging" "codeberg.org/snonux/hexai/internal/stats" @@ -14,24 +15,134 @@ import ( tmx "codeberg.org/snonux/hexai/internal/tmux" ) -// llmRequestOpts builds request options from server settings. -func (s *Server) llmRequestOpts() []llm.RequestOption { +type surfaceKind string + +const ( + surfaceCompletion surfaceKind = "completion" + surfaceCodeAction surfaceKind = "code_action" + surfaceChat surfaceKind = "chat" +) + +type requestSpec struct { + provider string + modelOverride string + fallbackModel string + options []llm.RequestOption +} + +func (r requestSpec) effectiveModel() string { + if s := strings.TrimSpace(r.modelOverride); s != "" { + return s + } + return strings.TrimSpace(r.fallbackModel) +} + +func (s *Server) buildRequestSpec(surface surfaceKind) requestSpec { + cfg := s.currentConfig() + providerOverride := strings.TrimSpace(surfaceProviderFromConfig(cfg, surface)) + provider := canonicalProvider(cfg.Provider) + if providerOverride != "" { + provider = canonicalProvider(providerOverride) + } + fallbackModel := strings.TrimSpace(resolveDefaultModel(cfg, provider)) + modelOverride := strings.TrimSpace(surfaceModelFromConfig(cfg, surface)) maxTokens := s.maxTokens() - client := s.currentLLMClient() - tempPtr := s.codingTemperature() opts := []llm.RequestOption{llm.WithMaxTokens(maxTokens)} - if tempPtr != nil { - temp := *tempPtr - if client != nil { - prov := strings.ToLower(strings.TrimSpace(client.Name())) - model := strings.ToLower(strings.TrimSpace(client.DefaultModel())) - if prov == "openai" && strings.HasPrefix(model, "gpt-5") { - temp = 1.0 - } + if tempVal, ok := chooseSurfaceTemperature(surface, cfg, provider, modelOverride, fallbackModel); ok { + opts = append(opts, llm.WithTemperature(tempVal)) + } + if modelOverride != "" { + opts = append(opts, llm.WithModel(modelOverride)) + } + return requestSpec{ + provider: provider, + modelOverride: modelOverride, + fallbackModel: fallbackModel, + options: opts, + } +} + +func canonicalProvider(name string) string { + p := strings.ToLower(strings.TrimSpace(name)) + if p == "" { + return "openai" + } + return p +} + +func resolveDefaultModel(cfg appconfig.App, provider string) string { + switch provider { + case "ollama": + return strings.TrimSpace(cfg.OllamaModel) + case "copilot": + return strings.TrimSpace(cfg.CopilotModel) + default: + return strings.TrimSpace(cfg.OpenAIModel) + } +} + +func surfaceModelFromConfig(cfg appconfig.App, surface surfaceKind) string { + switch surface { + case surfaceCompletion: + return cfg.CompletionModel + case surfaceCodeAction: + return cfg.CodeActionModel + case surfaceChat: + return cfg.ChatModel + default: + return "" + } +} + +func surfaceProviderFromConfig(cfg appconfig.App, surface surfaceKind) string { + switch surface { + case surfaceCompletion: + return cfg.CompletionProvider + case surfaceCodeAction: + return cfg.CodeActionProvider + case surfaceChat: + return cfg.ChatProvider + default: + return "" + } +} + +func surfaceTemperatureFromConfig(cfg appconfig.App, surface surfaceKind) *float64 { + switch surface { + case surfaceCompletion: + return cfg.CompletionTemperature + case surfaceCodeAction: + return cfg.CodeActionTemperature + case surfaceChat: + return cfg.ChatTemperature + default: + return nil + } +} + +func chooseSurfaceTemperature(surface surfaceKind, cfg appconfig.App, provider string, overrideModel, fallbackModel string) (float64, bool) { + if t := surfaceTemperatureFromConfig(cfg, surface); t != nil { + return *t, true + } + if cfg.CodingTemperature != nil { + temp := *cfg.CodingTemperature + effectiveModel := strings.TrimSpace(overrideModel) + if effectiveModel == "" { + effectiveModel = strings.TrimSpace(fallbackModel) + } + if provider == "openai" && strings.HasPrefix(strings.ToLower(effectiveModel), "gpt-5") && temp == 0.2 { + temp = 1.0 } - opts = append(opts, llm.WithTemperature(temp)) + return temp, true } - return opts + effectiveModel := strings.TrimSpace(overrideModel) + if effectiveModel == "" { + effectiveModel = strings.TrimSpace(fallbackModel) + } + if provider == "openai" && strings.HasPrefix(strings.ToLower(effectiveModel), "gpt-5") { + return 1.0, true + } + return 0, false } // small helpers for LLM traffic stats @@ -49,7 +160,7 @@ func (s *Server) incRecvCounters(n int) { s.mu.Unlock() } -func (s *Server) logLLMStats() { +func (s *Server) logLLMStats(model string) { s.mu.RLock() avgSent := int64(0) if s.llmReqTotal > 0 { @@ -75,11 +186,14 @@ func (s *Server) logLLMStats() { if err == nil { if client := s.currentLLMClient(); client != nil { provider := client.Name() - model := client.DefaultModel() + modelName := strings.TrimSpace(model) + if modelName == "" { + modelName = client.DefaultModel() + } // Per-scope rpm estimated from window scopeReqs := int64(0) if pe, ok := snap.Providers[provider]; ok { - if mc, ok2 := pe.Models[model]; ok2 { + if mc, ok2 := pe.Models[modelName]; ok2 { scopeReqs = mc.Reqs } } @@ -88,7 +202,7 @@ func (s *Server) logLLMStats() { minsWin = 0.001 } scopeRPM := float64(scopeReqs) / minsWin - status := tmx.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, provider, model, scopeRPM, scopeReqs, snap.Window) + status := tmx.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, provider, modelName, scopeRPM, scopeReqs, snap.Window) _ = tmx.SetStatus(status) } } @@ -154,7 +268,7 @@ func isIdentChar(ch byte) bool { } // chatWithStats wraps llmClient.Chat to increment counters and emit a tmux heartbeat. -func (s *Server) chatWithStats(ctx context.Context, msgs []llm.Message, opts ...llm.RequestOption) (string, error) { +func (s *Server) chatWithStats(ctx context.Context, surface surfaceKind, spec requestSpec, msgs []llm.Message) (string, error) { // Count bytes sent sent := 0 for _, m := range msgs { @@ -167,19 +281,23 @@ func (s *Server) chatWithStats(ctx context.Context, msgs []llm.Message, opts ... return "", context.Canceled } // Perform request - client := s.currentLLMClient() + client := s.clientFor(spec) if client == nil { return "", fmt.Errorf("llm client unavailable") } - txt, err := client.Chat(ctx, msgs, opts...) + txt, err := client.Chat(ctx, msgs, spec.options...) if err != nil { - s.logLLMStats() + s.logLLMStats(spec.effectiveModel()) return "", err } s.incRecvCounters(len(txt)) // Update global stats cache - _ = stats.Update(ctx, client.Name(), client.DefaultModel(), sent, len(txt)) - s.logLLMStats() + model := spec.effectiveModel() + if model == "" { + model = client.DefaultModel() + } + _ = stats.Update(ctx, client.Name(), model, sent, len(txt)) + s.logLLMStats(model) return txt, nil } |
