diff options
| author | Paul Buetow <paul@buetow.org> | 2025-08-18 09:28:48 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-08-18 09:28:48 +0300 |
| commit | 96ace6c7019a914e21b25fa94ddfc4ee9239c2fb (patch) | |
| tree | 30550bcab30c91e917a4d8b3feccda829a364437 | |
| parent | 6d29ac7e4b2604b5c7df50f33f8ef2357709faf2 (diff) | |
refactor(lsp,llm,hexailsp,appconfig): split long funcs; add tests
- Extract helpers to keep funcs <=50 lines; no behavior changes
- Add tests for prompt removal, code actions, and LLM request builders
- Table-drive TestInParamList; run gofmt
| -rw-r--r-- | internal/appconfig/config.go | 228 | ||||
| -rw-r--r-- | internal/hexailsp/run.go | 94 | ||||
| -rw-r--r-- | internal/llm/copilot.go | 166 | ||||
| -rw-r--r-- | internal/llm/copilot_test.go | 15 | ||||
| -rw-r--r-- | internal/llm/ollama.go | 192 | ||||
| -rw-r--r-- | internal/llm/ollama_test.go | 18 | ||||
| -rw-r--r-- | internal/llm/openai.go | 281 | ||||
| -rw-r--r-- | internal/llm/openai_test.go | 44 | ||||
| -rw-r--r-- | internal/lsp/codeaction_test.go | 63 | ||||
| -rw-r--r-- | internal/lsp/handlers.go | 339 | ||||
| -rw-r--r-- | internal/lsp/handlers_helpers_test.go | 52 | ||||
| -rw-r--r-- | internal/lsp/handlers_test.go | 30 |
12 files changed, 864 insertions, 658 deletions
diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go index 4fa3441..3067dd1 100644 --- a/internal/appconfig/config.go +++ b/internal/appconfig/config.go @@ -13,57 +13,57 @@ import ( // App holds user-configurable settings read from ~/.config/hexai/config.json. type App struct { - MaxTokens int `json:"max_tokens"` + MaxTokens int `json:"max_tokens"` ContextMode string `json:"context_mode"` ContextWindowLines int `json:"context_window_lines"` MaxContextTokens int `json:"max_context_tokens"` - LogPreviewLimit int `json:"log_preview_limit"` - // Single knob for LSP requests; if set, overrides hardcoded temps in LSP. - CodingTemperature *float64 `json:"coding_temperature"` + LogPreviewLimit int `json:"log_preview_limit"` + // Single knob for LSP requests; if set, overrides hardcoded temps in LSP. + CodingTemperature *float64 `json:"coding_temperature"` TriggerCharacters []string `json:"trigger_characters"` Provider string `json:"provider"` - // Provider-specific options - OpenAIBaseURL string `json:"openai_base_url"` - OpenAIModel string `json:"openai_model"` - // Default temperature for OpenAI requests (nil means use provider default) - OpenAITemperature *float64 `json:"openai_temperature"` - OllamaBaseURL string `json:"ollama_base_url"` - OllamaModel string `json:"ollama_model"` - // Default temperature for Ollama requests (nil means use provider default) - OllamaTemperature *float64 `json:"ollama_temperature"` - CopilotBaseURL string `json:"copilot_base_url"` - CopilotModel string `json:"copilot_model"` - // Default temperature for Copilot requests (nil means use provider default) - CopilotTemperature *float64 `json:"copilot_temperature"` + // Provider-specific options + OpenAIBaseURL string `json:"openai_base_url"` + OpenAIModel string `json:"openai_model"` + // Default temperature for OpenAI requests (nil means use provider default) + OpenAITemperature *float64 `json:"openai_temperature"` + OllamaBaseURL string `json:"ollama_base_url"` + OllamaModel string `json:"ollama_model"` + // Default temperature for Ollama requests (nil means use provider default) + OllamaTemperature *float64 `json:"ollama_temperature"` + CopilotBaseURL string `json:"copilot_base_url"` + CopilotModel string `json:"copilot_model"` + // Default temperature for Copilot requests (nil means use provider default) + CopilotTemperature *float64 `json:"copilot_temperature"` } // Constructor: defaults for App (kept first among functions) func newDefaultConfig() App { - // Coding-friendly default temperature across providers - // Users can override per provider in config.json (including 0.0). - t := 0.2 - return App{ - MaxTokens: 4000, - ContextMode: "always-full", - ContextWindowLines: 120, - MaxContextTokens: 4000, - LogPreviewLimit: 100, - CodingTemperature: &t, - OpenAITemperature: &t, - OllamaTemperature: &t, - CopilotTemperature: &t, - } + // Coding-friendly default temperature across providers + // Users can override per provider in config.json (including 0.0). + t := 0.2 + return App{ + MaxTokens: 4000, + ContextMode: "always-full", + ContextWindowLines: 120, + MaxContextTokens: 4000, + LogPreviewLimit: 100, + CodingTemperature: &t, + OpenAITemperature: &t, + OllamaTemperature: &t, + CopilotTemperature: &t, + } } // Load reads configuration from a file and merges with defaults. // It respects the XDG Base Directory Specification. func Load(logger *log.Logger) App { - cfg := newDefaultConfig() - if logger == nil { - return cfg // Return defaults if no logger is provided (e.g. in tests) - } + cfg := newDefaultConfig() + if logger == nil { + return cfg // Return defaults if no logger is provided (e.g. in tests) + } configPath, err := getConfigPath() if err != nil { @@ -76,91 +76,101 @@ func Load(logger *log.Logger) App { return cfg } - cfg.mergeWith(fileCfg) - return cfg + cfg.mergeWith(fileCfg) + return cfg } // Private helpers func loadFromFile(path string, logger *log.Logger) (*App, error) { - f, err := os.Open(path) - if err != nil { - if !os.IsNotExist(err) && logger != nil { - logger.Printf("cannot open config file %s: %v", path, err) - } - return nil, err - } - defer f.Close() + f, err := os.Open(path) + if err != nil { + if !os.IsNotExist(err) && logger != nil { + logger.Printf("cannot open config file %s: %v", path, err) + } + return nil, err + } + defer f.Close() - dec := json.NewDecoder(f) - var fileCfg App - if err := dec.Decode(&fileCfg); err != nil { - if logger != nil { - logger.Printf("invalid config file %s: %v", path, err) - } - return nil, err - } - return &fileCfg, nil + dec := json.NewDecoder(f) + var fileCfg App + if err := dec.Decode(&fileCfg); err != nil { + if logger != nil { + logger.Printf("invalid config file %s: %v", path, err) + } + return nil, err + } + return &fileCfg, nil } func (a *App) mergeWith(other *App) { - if other.MaxTokens > 0 { - a.MaxTokens = other.MaxTokens - } - if strings.TrimSpace(other.ContextMode) != "" { - a.ContextMode = other.ContextMode - } - if other.ContextWindowLines > 0 { - a.ContextWindowLines = other.ContextWindowLines - } - if other.MaxContextTokens > 0 { - a.MaxContextTokens = other.MaxContextTokens - } - if other.LogPreviewLimit >= 0 { - a.LogPreviewLimit = other.LogPreviewLimit - } - if other.CodingTemperature != nil { // allow explicit 0.0 - a.CodingTemperature = other.CodingTemperature - } - if len(other.TriggerCharacters) > 0 { - a.TriggerCharacters = slices.Clone(other.TriggerCharacters) - } - if strings.TrimSpace(other.Provider) != "" { - a.Provider = other.Provider - } - if strings.TrimSpace(other.OpenAIBaseURL) != "" { - a.OpenAIBaseURL = other.OpenAIBaseURL - } - if strings.TrimSpace(other.OpenAIModel) != "" { - a.OpenAIModel = other.OpenAIModel - } - if other.OpenAITemperature != nil { // allow explicit 0.0 - a.OpenAITemperature = other.OpenAITemperature - } - if strings.TrimSpace(other.OllamaBaseURL) != "" { - a.OllamaBaseURL = other.OllamaBaseURL - } - if strings.TrimSpace(other.OllamaModel) != "" { - a.OllamaModel = other.OllamaModel - } - if other.OllamaTemperature != nil { // allow explicit 0.0 - a.OllamaTemperature = other.OllamaTemperature - } - if strings.TrimSpace(other.CopilotBaseURL) != "" { - a.CopilotBaseURL = other.CopilotBaseURL - } - if strings.TrimSpace(other.CopilotModel) != "" { - a.CopilotModel = other.CopilotModel - } - if other.CopilotTemperature != nil { // allow explicit 0.0 - a.CopilotTemperature = other.CopilotTemperature - } + a.mergeBasics(other) + a.mergeProviderFields(other) +} + +// mergeBasics merges general (non-provider) fields. +func (a *App) mergeBasics(other *App) { + if other.MaxTokens > 0 { + a.MaxTokens = other.MaxTokens + } + if s := strings.TrimSpace(other.ContextMode); s != "" { + a.ContextMode = s + } + if other.ContextWindowLines > 0 { + a.ContextWindowLines = other.ContextWindowLines + } + if other.MaxContextTokens > 0 { + a.MaxContextTokens = other.MaxContextTokens + } + if other.LogPreviewLimit >= 0 { + a.LogPreviewLimit = other.LogPreviewLimit + } + if other.CodingTemperature != nil { // allow explicit 0.0 + a.CodingTemperature = other.CodingTemperature + } + if len(other.TriggerCharacters) > 0 { + a.TriggerCharacters = slices.Clone(other.TriggerCharacters) + } + if s := strings.TrimSpace(other.Provider); s != "" { + a.Provider = s + } +} + +// mergeProviderFields merges per-provider configuration. +func (a *App) mergeProviderFields(other *App) { + if s := strings.TrimSpace(other.OpenAIBaseURL); s != "" { + a.OpenAIBaseURL = s + } + if s := strings.TrimSpace(other.OpenAIModel); s != "" { + a.OpenAIModel = s + } + if other.OpenAITemperature != nil { // allow explicit 0.0 + a.OpenAITemperature = other.OpenAITemperature + } + if s := strings.TrimSpace(other.OllamaBaseURL); s != "" { + a.OllamaBaseURL = s + } + if s := strings.TrimSpace(other.OllamaModel); s != "" { + a.OllamaModel = s + } + if other.OllamaTemperature != nil { // allow explicit 0.0 + a.OllamaTemperature = other.OllamaTemperature + } + if s := strings.TrimSpace(other.CopilotBaseURL); s != "" { + a.CopilotBaseURL = s + } + if s := strings.TrimSpace(other.CopilotModel); s != "" { + a.CopilotModel = s + } + if other.CopilotTemperature != nil { // allow explicit 0.0 + a.CopilotTemperature = other.CopilotTemperature + } } func getConfigPath() (string, error) { - var configPath string - if xdgConfigHome := os.Getenv("XDG_CONFIG_HOME"); xdgConfigHome != "" { - configPath = filepath.Join(xdgConfigHome, "hexai", "config.json") - } else { + var configPath string + if xdgConfigHome := os.Getenv("XDG_CONFIG_HOME"); xdgConfigHome != "" { + configPath = filepath.Join(xdgConfigHome, "hexai", "config.json") + } else { home, err := os.UserHomeDir() if err != nil { return "", fmt.Errorf("cannot find user home directory: %v", err) diff --git a/internal/hexailsp/run.go b/internal/hexailsp/run.go index 1beb93a..64607e3 100644 --- a/internal/hexailsp/run.go +++ b/internal/hexailsp/run.go @@ -40,56 +40,72 @@ func Run(logPath string, stdin io.Reader, stdout io.Writer, stderr io.Writer) er // RunWithFactory is the testable entrypoint. When client is nil, it is built from cfg+env. // When factory is nil, lsp.NewServer is used. func RunWithFactory(logPath string, stdin io.Reader, stdout io.Writer, logger *log.Logger, cfg appconfig.App, client llm.Client, factory ServerFactory) error { - // Normalize and apply logging config + normalizeLoggingConfig(&cfg) + client = buildClientIfNil(cfg, client) + factory = ensureFactory(factory) + + opts := makeServerOptions(cfg, strings.TrimSpace(logPath) != "", client) + server := factory(stdin, stdout, logger, opts) + if err := server.Run(); err != nil { + logger.Fatalf("server error: %v", err) + } + return nil +} + +// --- helpers to keep RunWithFactory small --- + +func normalizeLoggingConfig(cfg *appconfig.App) { cfg.ContextMode = strings.ToLower(strings.TrimSpace(cfg.ContextMode)) if cfg.LogPreviewLimit >= 0 { logging.SetLogPreviewLimit(cfg.LogPreviewLimit) } +} - // Build LLM client if not provided - if client == nil { - llmCfg := llm.Config{ - Provider: cfg.Provider, - OpenAIBaseURL: cfg.OpenAIBaseURL, - OpenAIModel: cfg.OpenAIModel, - OpenAITemperature: cfg.OpenAITemperature, - OllamaBaseURL: cfg.OllamaBaseURL, - OllamaModel: cfg.OllamaModel, - OllamaTemperature: cfg.OllamaTemperature, - CopilotBaseURL: cfg.CopilotBaseURL, - CopilotModel: cfg.CopilotModel, - CopilotTemperature: cfg.CopilotTemperature, - } - oaKey := os.Getenv("OPENAI_API_KEY") - cpKey := os.Getenv("COPILOT_API_KEY") - if c, err := llm.NewFromConfig(llmCfg, oaKey, cpKey); err != nil { - logging.Logf("lsp ", "llm disabled: %v", err) - } else { - client = c - logging.Logf("lsp ", "llm enabled provider=%s model=%s", c.Name(), c.DefaultModel()) - } +func buildClientIfNil(cfg appconfig.App, client llm.Client) llm.Client { + if client != nil { + return client } - - if factory == nil { - factory = func(r io.Reader, w io.Writer, logger *log.Logger, opts lsp.ServerOptions) ServerRunner { - return lsp.NewServer(r, w, logger, opts) - } + llmCfg := llm.Config{ + Provider: cfg.Provider, + OpenAIBaseURL: cfg.OpenAIBaseURL, + OpenAIModel: cfg.OpenAIModel, + OpenAITemperature: cfg.OpenAITemperature, + OllamaBaseURL: cfg.OllamaBaseURL, + OllamaModel: cfg.OllamaModel, + OllamaTemperature: cfg.OllamaTemperature, + CopilotBaseURL: cfg.CopilotBaseURL, + CopilotModel: cfg.CopilotModel, + CopilotTemperature: cfg.CopilotTemperature, + } + oaKey := os.Getenv("OPENAI_API_KEY") + cpKey := os.Getenv("COPILOT_API_KEY") + if c, err := llm.NewFromConfig(llmCfg, oaKey, cpKey); err != nil { + logging.Logf("lsp ", "llm disabled: %v", err) + return nil + } else { + logging.Logf("lsp ", "llm enabled provider=%s model=%s", c.Name(), c.DefaultModel()) + return c } +} - server := factory(stdin, stdout, logger, lsp.ServerOptions{ - LogContext: strings.TrimSpace(logPath) != "", - MaxTokens: cfg.MaxTokens, - ContextMode: cfg.ContextMode, - WindowLines: cfg.ContextWindowLines, - MaxContextTokens: cfg.MaxContextTokens, +func ensureFactory(factory ServerFactory) ServerFactory { + if factory != nil { + return factory + } + return func(r io.Reader, w io.Writer, logger *log.Logger, opts lsp.ServerOptions) ServerRunner { + return lsp.NewServer(r, w, logger, opts) + } +} +func makeServerOptions(cfg appconfig.App, logContext bool, client llm.Client) lsp.ServerOptions { + return lsp.ServerOptions{ + LogContext: logContext, + MaxTokens: cfg.MaxTokens, + ContextMode: cfg.ContextMode, + WindowLines: cfg.ContextWindowLines, + MaxContextTokens: cfg.MaxContextTokens, CodingTemperature: cfg.CodingTemperature, - Client: client, TriggerCharacters: cfg.TriggerCharacters, - }) - if err := server.Run(); err != nil { - logger.Fatalf("server error: %v", err) } - return nil } diff --git a/internal/llm/copilot.go b/internal/llm/copilot.go index 47ce11e..67cffc9 100644 --- a/internal/llm/copilot.go +++ b/internal/llm/copilot.go @@ -17,20 +17,20 @@ import ( // copilotClient implements Client against GitHub Copilot's Chat Completions API. type copilotClient struct { - httpClient *http.Client - apiKey string - baseURL string - defaultModel string - chatLogger logging.ChatLogger - defaultTemperature *float64 + httpClient *http.Client + apiKey string + baseURL string + defaultModel string + chatLogger logging.ChatLogger + defaultTemperature *float64 } type copilotChatRequest struct { - Model string `json:"model"` - Messages []copilotMessage `json:"messages"` - Temperature *float64 `json:"temperature,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - Stop []string `json:"stop,omitempty"` + Model string `json:"model"` + Messages []copilotMessage `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Stop []string `json:"stop,omitempty"` } type copilotMessage struct { @@ -57,20 +57,20 @@ type copilotChatResponse struct { // Constructor (kept among the first functions by convention) func newCopilot(baseURL, model, apiKey string, defaultTemp *float64) Client { - if strings.TrimSpace(baseURL) == "" { - baseURL = "https://api.githubcopilot.com" - } - if strings.TrimSpace(model) == "" { - model = "gpt-4.1" - } - return copilotClient{ - httpClient: &http.Client{Timeout: 30 * time.Second}, - apiKey: apiKey, - baseURL: strings.TrimRight(baseURL, "/"), - defaultModel: model, - chatLogger: logging.NewChatLogger("copilot"), - defaultTemperature: defaultTemp, - } + if strings.TrimSpace(baseURL) == "" { + baseURL = "https://api.githubcopilot.com" + } + if strings.TrimSpace(model) == "" { + model = "gpt-4.1" + } + return copilotClient{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + apiKey: apiKey, + baseURL: strings.TrimRight(baseURL, "/"), + defaultModel: model, + chatLogger: logging.NewChatLogger("copilot"), + defaultTemperature: defaultTemp, + } } func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) { @@ -84,38 +84,14 @@ func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...Req if o.Model == "" { o.Model = c.defaultModel } - start := time.Now() - logMessages := make([]struct { - Role string - Content string - }, len(messages)) + logMessages := make([]struct{ Role, Content string }, len(messages)) for i, m := range messages { - logMessages[i] = struct { - Role string - Content string - }{Role: m.Role, Content: m.Content} + logMessages[i] = struct{ Role, Content string }{m.Role, m.Content} } c.chatLogger.LogStart(false, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) - req := copilotChatRequest{Model: o.Model} - req.Messages = make([]copilotMessage, len(messages)) - for i, m := range messages { - req.Messages[i] = copilotMessage{Role: m.Role, Content: m.Content} - } - if o.Temperature != 0 { - req.Temperature = &o.Temperature - } else if c.defaultTemperature != nil { - t := *c.defaultTemperature - req.Temperature = &t - } - if o.MaxTokens > 0 { - req.MaxTokens = &o.MaxTokens - } - if len(o.Stop) > 0 { - req.Stop = o.Stop - } - + req := buildCopilotChatRequest(o, messages, c.defaultTemperature) body, err := json.Marshal(req) if err != nil { logging.Logf("llm/copilot ", "marshal error: %v", err) @@ -124,34 +100,19 @@ func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...Req endpoint := c.baseURL + "/chat/completions" logging.Logf("llm/copilot ", "POST %s", endpoint) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - logging.Logf("llm/copilot ", "new request error: %v", err) - return "", err - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) - - resp, err := c.httpClient.Do(httpReq) + resp, err := c.doJSON(ctx, endpoint, body, map[string]string{ + "Authorization": "Bearer " + c.apiKey, + }) if err != nil { logging.Logf("llm/copilot ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return "", err } defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - var apiErr copilotChatResponse - _ = json.NewDecoder(resp.Body).Decode(&apiErr) - if apiErr.Error != nil && strings.TrimSpace(apiErr.Error.Message) != "" { - logging.Logf("llm/copilot ", "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase) - return "", fmt.Errorf("copilot error: %s (status %d)", apiErr.Error.Message, resp.StatusCode) - } - logging.Logf("llm/copilot ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) - return "", fmt.Errorf("copilot http error: status %d", resp.StatusCode) + if err := handleCopilotNon2xx(resp, start); err != nil { + return "", err } - - var out copilotChatResponse - if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { - logging.Logf("llm/copilot ", "%sdecode error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + out, err := decodeCopilotChat(resp, start) + if err != nil { return "", err } if len(out.Choices) == 0 { @@ -166,3 +127,60 @@ func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...Req // Provider metadata func (c copilotClient) Name() string { return "copilot" } func (c copilotClient) DefaultModel() string { return c.defaultModel } + +// helpers +func buildCopilotChatRequest(o Options, messages []Message, defaultTemp *float64) copilotChatRequest { + req := copilotChatRequest{Model: o.Model} + req.Messages = make([]copilotMessage, len(messages)) + for i, m := range messages { + req.Messages[i] = copilotMessage{Role: m.Role, Content: m.Content} + } + if o.Temperature != 0 { + req.Temperature = &o.Temperature + } else if defaultTemp != nil { + t := *defaultTemp + req.Temperature = &t + } + if o.MaxTokens > 0 { + req.MaxTokens = &o.MaxTokens + } + if len(o.Stop) > 0 { + req.Stop = o.Stop + } + return req +} + +func (c copilotClient) doJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + for k, v := range headers { + req.Header.Set(k, v) + } + return c.httpClient.Do(req) +} + +func handleCopilotNon2xx(resp *http.Response, start time.Time) error { + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return nil + } + var apiErr copilotChatResponse + _ = json.NewDecoder(resp.Body).Decode(&apiErr) + if apiErr.Error != nil && strings.TrimSpace(apiErr.Error.Message) != "" { + logging.Logf("llm/copilot ", "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase) + return fmt.Errorf("copilot error: %s (status %d)", apiErr.Error.Message, resp.StatusCode) + } + logging.Logf("llm/copilot ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) + return fmt.Errorf("copilot http error: status %d", resp.StatusCode) +} + +func decodeCopilotChat(resp *http.Response, start time.Time) (copilotChatResponse, error) { + var out copilotChatResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + logging.Logf("llm/copilot ", "%sdecode error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + return copilotChatResponse{}, err + } + return out, nil +} diff --git a/internal/llm/copilot_test.go b/internal/llm/copilot_test.go new file mode 100644 index 0000000..5492713 --- /dev/null +++ b/internal/llm/copilot_test.go @@ -0,0 +1,15 @@ +package llm + +import "testing" + +func TestBuildCopilotChatRequest_FieldsAndDefaults(t *testing.T) { + o := Options{Model: "gpt-x", Temperature: 0, MaxTokens: 123, Stop: []string{"X"}} + msgs := []Message{{Role: "user", Content: "q"}} + req := buildCopilotChatRequest(o, msgs, f64p(0.5)) + if req.Model != "gpt-x" { t.Fatalf("model mismatch: %q", req.Model) } + if req.Temperature == nil || *req.Temperature != 0.5 { t.Fatalf("default temp not applied") } + if req.MaxTokens == nil || *req.MaxTokens != 123 { t.Fatalf("max_tokens not applied") } + if len(req.Stop) != 1 || req.Stop[0] != "X" { t.Fatalf("stop not applied") } + if len(req.Messages) != 1 || req.Messages[0].Content != "q" { t.Fatalf("messages not copied") } +} + diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go index 20dfe2a..50e9837 100644 --- a/internal/llm/ollama.go +++ b/internal/llm/ollama.go @@ -1,5 +1,4 @@ // Summary: Ollama client against a local server; supports chat responses and streaming via /api/chat. -// Not yet reviewed by a human package llm import ( @@ -18,11 +17,11 @@ import ( // ollamaClient implements Client against a local Ollama server. type ollamaClient struct { - httpClient *http.Client - baseURL string - defaultModel string - chatLogger logging.ChatLogger - defaultTemperature *float64 + httpClient *http.Client + baseURL string + defaultModel string + chatLogger logging.ChatLogger + defaultTemperature *float64 } type ollamaChatRequest struct { @@ -49,13 +48,13 @@ func newOllama(baseURL, model string, defaultTemp *float64) Client { if strings.TrimSpace(model) == "" { model = "qwen3-coder:30b-a3b-q4_K_M`" } - return ollamaClient{ - httpClient: &http.Client{Timeout: 30 * time.Second}, - baseURL: strings.TrimRight(baseURL, "/"), - defaultModel: model, - chatLogger: logging.NewChatLogger("ollama"), - defaultTemperature: defaultTemp, - } + return ollamaClient{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + baseURL: strings.TrimRight(baseURL, "/"), + defaultModel: model, + chatLogger: logging.NewChatLogger("ollama"), + defaultTemperature: defaultTemp, + } } // TODO: This function is too long and should be refactored for readability and maintainability. @@ -69,41 +68,8 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ } start := time.Now() - logMessages := make([]struct { - Role string - Content string - }, len(messages)) - for i, m := range messages { - logMessages[i] = struct { - Role string - Content string - }{Role: m.Role, Content: m.Content} - } - c.chatLogger.LogStart(false, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) - - req := ollamaChatRequest{Model: o.Model, Stream: false} - req.Messages = make([]oaMessage, len(messages)) - for i, m := range messages { - req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} - } - - // Build options map only if any option is set - optsMap := map[string]any{} - if o.Temperature != 0 { - optsMap["temperature"] = o.Temperature - } else if c.defaultTemperature != nil { - optsMap["temperature"] = *c.defaultTemperature - } - if o.MaxTokens > 0 { - optsMap["num_predict"] = o.MaxTokens - } - if len(o.Stop) > 0 { - optsMap["stop"] = o.Stop - } - if len(optsMap) > 0 { - req.Options = optsMap - } - + c.logStart(false, o, messages) + req := buildOllamaRequest(o, messages, c.defaultTemperature, false) body, err := json.Marshal(req) if err != nil { return "", err @@ -111,27 +77,14 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ endpoint := c.baseURL + "/api/chat" logging.Logf("llm/ollama ", "POST %s", endpoint) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - return "", err - } - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.httpClient.Do(httpReq) + resp, err := c.doJSON(ctx, endpoint, body) if err != nil { logging.Logf("llm/ollama ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return "", err } defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - var apiErr ollamaChatResponse - _ = json.NewDecoder(resp.Body).Decode(&apiErr) - if strings.TrimSpace(apiErr.Error) != "" { - logging.Logf("llm/ollama ", "%sapi error status=%d msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error, time.Since(start), logging.AnsiBase) - return "", fmt.Errorf("ollama error: %s (status %d)", apiErr.Error, resp.StatusCode) - } - logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) - return "", fmt.Errorf("ollama http error: status %d", resp.StatusCode) + if err := handleOllamaNon2xx(resp, start); err != nil { + return "", err } var out ollamaChatResponse @@ -163,40 +116,8 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt } start := time.Now() - logMessages := make([]struct { - Role string - Content string - }, len(messages)) - for i, m := range messages { - logMessages[i] = struct { - Role string - Content string - }{Role: m.Role, Content: m.Content} - } - c.chatLogger.LogStart(true, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) - - req := ollamaChatRequest{Model: o.Model, Stream: true} - req.Messages = make([]oaMessage, len(messages)) - for i, m := range messages { - req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} - } - // Build options map - optsMap := map[string]any{} - if o.Temperature != 0 { - optsMap["temperature"] = o.Temperature - } else if c.defaultTemperature != nil { - optsMap["temperature"] = *c.defaultTemperature - } - if o.MaxTokens > 0 { - optsMap["num_predict"] = o.MaxTokens - } - if len(o.Stop) > 0 { - optsMap["stop"] = o.Stop - } - if len(optsMap) > 0 { - req.Options = optsMap - } - + c.logStart(true, o, messages) + req := buildOllamaRequest(o, messages, c.defaultTemperature, true) body, err := json.Marshal(req) if err != nil { return err @@ -204,27 +125,14 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt endpoint := c.baseURL + "/api/chat" logging.Logf("llm/ollama ", "POST %s (stream)", endpoint) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - return err - } - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.httpClient.Do(httpReq) + resp, err := c.doJSON(ctx, endpoint, body) if err != nil { logging.Logf("llm/ollama ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return err } defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - var apiErr ollamaChatResponse - _ = json.NewDecoder(resp.Body).Decode(&apiErr) - if strings.TrimSpace(apiErr.Error) != "" { - logging.Logf("llm/ollama ", "%sapi error status=%d msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error, time.Since(start), logging.AnsiBase) - return fmt.Errorf("ollama error: %s (status %d)", apiErr.Error, resp.StatusCode) - } - logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) - return fmt.Errorf("ollama http error: status %d", resp.StatusCode) + if err := handleOllamaNon2xx(resp, start); err != nil { + return err } dec := json.NewDecoder(resp.Body) @@ -251,3 +159,59 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start)) return nil } + +// helpers to keep methods small +func (c ollamaClient) logStart(stream bool, o Options, messages []Message) { + logMessages := make([]struct{ Role, Content string }, len(messages)) + for i, m := range messages { + logMessages[i] = struct{ Role, Content string }{m.Role, m.Content} + } + c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) +} + +func buildOllamaRequest(o Options, messages []Message, defaultTemp *float64, stream bool) ollamaChatRequest { + req := ollamaChatRequest{Model: o.Model, Stream: stream} + req.Messages = make([]oaMessage, len(messages)) + for i, m := range messages { + req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} + } + optsMap := map[string]any{} + if o.Temperature != 0 { + optsMap["temperature"] = o.Temperature + } else if defaultTemp != nil { + optsMap["temperature"] = *defaultTemp + } + if o.MaxTokens > 0 { + optsMap["num_predict"] = o.MaxTokens + } + if len(o.Stop) > 0 { + optsMap["stop"] = o.Stop + } + if len(optsMap) > 0 { + req.Options = optsMap + } + return req +} + +func (c ollamaClient) doJSON(ctx context.Context, url string, body []byte) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + return c.httpClient.Do(req) +} + +func handleOllamaNon2xx(resp *http.Response, start time.Time) error { + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return nil + } + var apiErr ollamaChatResponse + _ = json.NewDecoder(resp.Body).Decode(&apiErr) + if strings.TrimSpace(apiErr.Error) != "" { + logging.Logf("llm/ollama ", "%sapi error status=%d msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error, time.Since(start), logging.AnsiBase) + return fmt.Errorf("ollama error: %s (status %d)", apiErr.Error, resp.StatusCode) + } + logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) + return fmt.Errorf("ollama http error: status %d", resp.StatusCode) +} diff --git a/internal/llm/ollama_test.go b/internal/llm/ollama_test.go new file mode 100644 index 0000000..4ad6fdf --- /dev/null +++ b/internal/llm/ollama_test.go @@ -0,0 +1,18 @@ +package llm + +import "testing" + +func TestBuildOllamaRequest_OptionsAndStream(t *testing.T) { + o := Options{Model: "codemodel", Temperature: 0, MaxTokens: 256, Stop: []string{"STOP"}} + msgs := []Message{{Role: "user", Content: "hello"}} + req := buildOllamaRequest(o, msgs, f64p(0.2), false) + if req.Model != "codemodel" || req.Stream { t.Fatalf("model/stream mismatch: %+v", req) } + if req.Options == nil { t.Fatalf("expected options map") } + if req.Options.(map[string]any)["temperature"].(float64) != 0.2 { t.Fatalf("default temp not applied") } + if req.Options.(map[string]any)["num_predict"].(int) != 256 { t.Fatalf("num_predict not applied") } + if req.Options.(map[string]any)["stop"].([]string)[0] != "STOP" { t.Fatalf("stop not applied") } + + req2 := buildOllamaRequest(o, msgs, f64p(0.2), true) + if !req2.Stream { t.Fatalf("expected stream=true") } +} + diff --git a/internal/llm/openai.go b/internal/llm/openai.go index 5348def..69c0cfc 100644 --- a/internal/llm/openai.go +++ b/internal/llm/openai.go @@ -13,26 +13,26 @@ import ( "strings" "time" - "hexai/internal/logging" + "hexai/internal/logging" ) // openAIClient implements Client against OpenAI's Chat Completions API. type openAIClient struct { - httpClient *http.Client - apiKey string - baseURL string - defaultModel string - chatLogger logging.ChatLogger - defaultTemperature *float64 + httpClient *http.Client + apiKey string + baseURL string + defaultModel string + chatLogger logging.ChatLogger + defaultTemperature *float64 } type oaChatRequest struct { - Model string `json:"model"` - Messages []oaMessage `json:"messages"` - Temperature *float64 `json:"temperature,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - Stop []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` + Model string `json:"model"` + Messages []oaMessage `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` } type oaMessage struct { @@ -54,43 +54,43 @@ type oaChatResponse struct { Type string `json:"type"` Param any `json:"param"` Code any `json:"code"` - } `json:"error,omitempty"` + } `json:"error,omitempty"` } // Streaming response chunk type (SSE) type oaStreamChunk struct { - Choices []struct { - Delta struct { - Content string `json:"content"` - } `json:"delta"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Error *struct { - Message string `json:"message"` - Type string `json:"type"` - Param any `json:"param"` - Code any `json:"code"` - } `json:"error,omitempty"` + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Error *struct { + Message string `json:"message"` + Type string `json:"type"` + Param any `json:"param"` + Code any `json:"code"` + } `json:"error,omitempty"` } // Constructor (kept among the first functions by convention) // newOpenAI constructs an OpenAI client using explicit configuration values. // The apiKey may be empty; calls will fail until a valid key is supplied. func newOpenAI(baseURL, model, apiKey string, defaultTemp *float64) Client { - if strings.TrimSpace(baseURL) == "" { - baseURL = "https://api.openai.com/v1" - } - if strings.TrimSpace(model) == "" { - model = "gpt-4.1" - } - return openAIClient{ - httpClient: &http.Client{Timeout: 30 * time.Second}, - apiKey: apiKey, - baseURL: baseURL, - defaultModel: model, - chatLogger: logging.NewChatLogger("openai"), - defaultTemperature: defaultTemp, - } + if strings.TrimSpace(baseURL) == "" { + baseURL = "https://api.openai.com/v1" + } + if strings.TrimSpace(model) == "" { + model = "gpt-4.1" + } + return openAIClient{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + apiKey: apiKey, + baseURL: baseURL, + defaultModel: model, + chatLogger: logging.NewChatLogger("openai"), + defaultTemperature: defaultTemp, + } } func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) { @@ -105,37 +105,8 @@ func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...Requ o.Model = c.defaultModel } start := time.Now() - logMessages := make([]struct { - Role string - Content string - }, len(messages)) - for i, m := range messages { - logMessages[i] = struct { - Role string - Content string - }{Role: m.Role, Content: m.Content} - } - c.chatLogger.LogStart(false, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) - - req := oaChatRequest{Model: o.Model} - req.Messages = make([]oaMessage, len(messages)) - for i, m := range messages { - req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} - } - // Decide temperature: request option overrides config default. - if o.Temperature != 0 { - req.Temperature = &o.Temperature - } else if c.defaultTemperature != nil { - t := *c.defaultTemperature - req.Temperature = &t - } - if o.MaxTokens > 0 { - req.MaxTokens = &o.MaxTokens - } - if len(o.Stop) > 0 { - req.Stop = o.Stop - } - + c.logStart(false, o, messages) + req := buildOAChatRequest(o, messages, c.defaultTemperature, false) body, err := json.Marshal(req) if err != nil { c.logf("marshal error: %v", err) @@ -143,33 +114,19 @@ func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...Requ } endpoint := c.baseURL + "/chat/completions" logging.Logf("llm/openai ", "POST %s", endpoint) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - c.logf("new request error: %v", err) - return "", err - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) - - resp, err := c.httpClient.Do(httpReq) + resp, err := c.doJSON(ctx, endpoint, body, map[string]string{ + "Authorization": "Bearer " + c.apiKey, + }) if err != nil { logging.Logf("llm/openai ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return "", err } defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - var apiErr oaChatResponse - _ = json.NewDecoder(resp.Body).Decode(&apiErr) - if apiErr.Error != nil && apiErr.Error.Message != "" { - logging.Logf("llm/openai ", "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase) - return "", fmt.Errorf("openai error: %s (status %d)", apiErr.Error.Message, resp.StatusCode) - } - logging.Logf("llm/openai ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) - return "", fmt.Errorf("openai http error: status %d", resp.StatusCode) + if err := handleOpenAINon2xx(resp, start); err != nil { + return "", err } - var out oaChatResponse - if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { - logging.Logf("llm/openai ", "%sdecode error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + out, err := decodeOpenAIChat(resp, start) + if err != nil { return "", err } if len(out.Choices) == 0 { @@ -177,7 +134,6 @@ func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...Requ return "", errors.New("openai: no choices returned") } content := out.Choices[0].Message.Content - // Received context (green) logging.Logf("llm/openai ", "success choice=0 finish=%s size=%d preview=%s%s%s duration=%s", out.Choices[0].FinishReason, len(content), logging.AnsiGreen, logging.PreviewForLog(content), logging.AnsiBase, time.Since(start)) return content, nil } @@ -187,7 +143,6 @@ func (c openAIClient) Name() string { return "openai" } func (c openAIClient) DefaultModel() string { return c.defaultModel } // Streaming support (optional) - func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error { if c.apiKey == "" { @@ -201,74 +156,118 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt o.Model = c.defaultModel } start := time.Now() - logMessages := make([]struct { - Role string - Content string - }, len(messages)) + c.logStart(true, o, messages) + req := buildOAChatRequest(o, messages, c.defaultTemperature, true) + body, err := json.Marshal(req) + if err != nil { + c.logf("marshal error: %v", err) + return err + } + endpoint := c.baseURL + "/chat/completions" + logging.Logf("llm/openai ", "POST %s (stream)", endpoint) + resp, err := c.doJSONWithAccept(ctx, endpoint, body, map[string]string{ + "Authorization": "Bearer " + c.apiKey, + }, "text/event-stream") + if err != nil { + logging.Logf("llm/openai ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + return err + } + defer resp.Body.Close() + if err := handleOpenAINon2xx(resp, start); err != nil { + return err + } + + if err := parseOpenAIStream(resp, start, onDelta); err != nil { + return err + } + logging.Logf("llm/openai ", "stream end duration=%s", time.Since(start)) + return nil +} + +// Private helpers +func (c openAIClient) logf(format string, args ...any) { logging.Logf("llm/openai ", format, args...) } + +// helpers extracted to keep methods small +func (c openAIClient) logStart(stream bool, o Options, messages []Message) { + logMessages := make([]struct{ Role, Content string }, len(messages)) for i, m := range messages { - logMessages[i] = struct { - Role string - Content string - }{Role: m.Role, Content: m.Content} + logMessages[i] = struct{ Role, Content string }{m.Role, m.Content} } - c.chatLogger.LogStart(true, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) + c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) +} - req := oaChatRequest{Model: o.Model, Stream: true} +func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, stream bool) oaChatRequest { + req := oaChatRequest{Model: o.Model, Stream: stream} req.Messages = make([]oaMessage, len(messages)) for i, m := range messages { req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} } - if o.Temperature != 0 { - req.Temperature = &o.Temperature - } else if c.defaultTemperature != nil { - t := *c.defaultTemperature - req.Temperature = &t - } + if o.Temperature != 0 { + req.Temperature = &o.Temperature + } else if defaultTemp != nil { + t := *defaultTemp + req.Temperature = &t + } if o.MaxTokens > 0 { req.MaxTokens = &o.MaxTokens } if len(o.Stop) > 0 { req.Stop = o.Stop } + return req +} - body, err := json.Marshal(req) +func (c openAIClient) doJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { - c.logf("marshal error: %v", err) - return err + return nil, err } - endpoint := c.baseURL + "/chat/completions" - logging.Logf("llm/openai ", "POST %s (stream)", endpoint) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - c.logf("new request error: %v", err) - return err + req.Header.Set("Content-Type", "application/json") + for k, v := range headers { + req.Header.Set(k, v) } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) - // Streaming uses SSE-style data lines - httpReq.Header.Set("Accept", "text/event-stream") + return c.httpClient.Do(req) +} - resp, err := c.httpClient.Do(httpReq) +func (c openAIClient) doJSONWithAccept(ctx context.Context, url string, body []byte, headers map[string]string, accept string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { - logging.Logf("llm/openai ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) - return err + return nil, err } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - // try to decode body to surface message - var apiErr oaChatResponse - _ = json.NewDecoder(resp.Body).Decode(&apiErr) - if apiErr.Error != nil && apiErr.Error.Message != "" { - logging.Logf("llm/openai ", "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase) - return fmt.Errorf("openai error: %s (status %d)", apiErr.Error.Message, resp.StatusCode) - } - logging.Logf("llm/openai ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) - return fmt.Errorf("openai http error: status %d", resp.StatusCode) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", accept) + for k, v := range headers { + req.Header.Set(k, v) + } + return c.httpClient.Do(req) +} + +func handleOpenAINon2xx(resp *http.Response, start time.Time) error { + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return nil + } + var apiErr oaChatResponse + _ = json.NewDecoder(resp.Body).Decode(&apiErr) + if apiErr.Error != nil && apiErr.Error.Message != "" { + logging.Logf("llm/openai ", "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase) + return fmt.Errorf("openai error: %s (status %d)", apiErr.Error.Message, resp.StatusCode) + } + logging.Logf("llm/openai ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) + return fmt.Errorf("openai http error: status %d", resp.StatusCode) +} + +func decodeOpenAIChat(resp *http.Response, start time.Time) (oaChatResponse, error) { + var out oaChatResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + logging.Logf("llm/openai ", "%sdecode error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + return oaChatResponse{}, err } + return out, nil +} +func parseOpenAIStream(resp *http.Response, start time.Time, onDelta func(string)) error { // Parse SSE: lines starting with "data: " containing JSON or [DONE] scanner := bufio.NewScanner(resp.Body) - // Increase buffer for long lines const maxBuf = 1024 * 1024 buf := make([]byte, 0, 64*1024) scanner.Buffer(buf, maxBuf) @@ -283,7 +282,7 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt } var chunk oaStreamChunk if err := json.Unmarshal([]byte(payload), &chunk); err != nil { - continue // skip malformed lines + continue } if chunk.Error != nil && chunk.Error.Message != "" { logging.Logf("llm/openai ", "%sstream error: %s%s", logging.AnsiRed, chunk.Error.Message, logging.AnsiBase) @@ -299,9 +298,5 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt logging.Logf("llm/openai ", "%sstream read error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return err } - logging.Logf("llm/openai ", "stream end duration=%s", time.Since(start)) return nil } - -// Private helpers -func (c openAIClient) logf(format string, args ...any) { logging.Logf("llm/openai ", format, args...) } diff --git a/internal/llm/openai_test.go b/internal/llm/openai_test.go new file mode 100644 index 0000000..f50b171 --- /dev/null +++ b/internal/llm/openai_test.go @@ -0,0 +1,44 @@ +package llm + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + "time" +) + +func f64p(v float64) *float64 { return &v } + +func TestBuildOAChatRequest_TempFallbackAndFields(t *testing.T) { + o := Options{Model: "m1", Temperature: 0, MaxTokens: 42, Stop: []string{"END"}} + msgs := []Message{{Role: "user", Content: "hi"}} + req := buildOAChatRequest(o, msgs, f64p(0.3), false) + if req.Model != "m1" || req.Stream { t.Fatalf("model/stream mismatch: %+v", req) } + if req.Temperature == nil || *req.Temperature != 0.3 { t.Fatalf("expected default temp 0.3, got %#v", req.Temperature) } + if req.MaxTokens == nil || *req.MaxTokens != 42 { t.Fatalf("expected max tokens 42") } + if len(req.Stop) != 1 || req.Stop[0] != "END" { t.Fatalf("stop not propagated: %#v", req.Stop) } + if len(req.Messages) != 1 || req.Messages[0].Content != "hi" { t.Fatalf("messages not copied") } + + // stream on + req2 := buildOAChatRequest(o, msgs, f64p(0.3), true) + if !req2.Stream { t.Fatalf("expected stream=true") } +} + +func TestHandleOpenAINon2xx_WithAPIError(t *testing.T) { + api := oaChatResponse{Error: &struct{ Message string `json:"message"`; Type string `json:"type"`; Param any `json:"param"`; Code any `json:"code"` }{Message: "bad", Type: "invalid"}} + b, _ := json.Marshal(api) + resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytes.NewReader(b))} + if err := handleOpenAINon2xx(resp, time.Now()); err == nil { t.Fatalf("expected error for non-2xx with body") } +} + +func TestParseOpenAIStream_DeliversChunks(t *testing.T) { + stream := "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n" + + "data: [DONE]\n" + resp := &http.Response{Body: io.NopCloser(strings.NewReader(stream))} + var got strings.Builder + if err := parseOpenAIStream(resp, time.Now(), func(s string){ got.WriteString(s) }); err != nil { t.Fatalf("unexpected error: %v", err) } + if got.String() != "Hi" { t.Fatalf("got %q want %q", got.String(), "Hi") } +} diff --git a/internal/lsp/codeaction_test.go b/internal/lsp/codeaction_test.go new file mode 100644 index 0000000..e9abbb8 --- /dev/null +++ b/internal/lsp/codeaction_test.go @@ -0,0 +1,63 @@ +package lsp + +import ( + "context" + "encoding/json" + "testing" + "hexai/internal/llm" +) + +type fakeLLM struct{ resp string; err error } + +func (f fakeLLM) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { + return f.resp, f.err +} +func (f fakeLLM) Name() string { return "fake" } +func (f fakeLLM) DefaultModel() string { return "fake-model" } + +func TestBuildRewriteCodeAction_ReturnsEdit(t *testing.T) { + s := newTestServer() + s.llmClient = fakeLLM{resp: "REWRITTEN"} + p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: "file:///t.go"}, Range: Range{Start: Position{Line: 1, Character: 2}, End: Position{Line: 3, Character: 4}}} + sel := ";rewrite;\nold code" + ca := s.buildRewriteCodeAction(p, sel) + if ca == nil { t.Fatalf("expected code action") } + if ca.Edit == nil || len(ca.Edit.Changes) == 0 { t.Fatalf("expected workspace edit with changes") } + edits := ca.Edit.Changes[p.TextDocument.URI] + if len(edits) != 1 { t.Fatalf("expected 1 edit, got %d", len(edits)) } + if edits[0].Range != p.Range { t.Fatalf("edit range mismatch: got %+v want %+v", edits[0].Range, p.Range) } + if edits[0].NewText == "" { t.Fatalf("expected non-empty replacement text") } +} + +func TestBuildRewriteCodeAction_NoInstruction(t *testing.T) { + s := newTestServer() + s.llmClient = fakeLLM{resp: "IGNORED"} + p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: "file:///t.go"}, Range: Range{}} + sel := "no instruction here" + if ca := s.buildRewriteCodeAction(p, sel); ca != nil { t.Fatalf("expected nil action when no instruction present") } +} + +func TestBuildDiagnosticsCodeAction_ReturnsEdit(t *testing.T) { + s := newTestServer() + s.llmClient = fakeLLM{resp: "FIXED"} + p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: "file:///t.go"}, Range: Range{Start: Position{Line: 10}, End: Position{Line: 12, Character: 5}}} + ctx := CodeActionContext{Diagnostics: []Diagnostic{ + {Range: Range{Start: Position{Line: 11}, End: Position{Line: 11, Character: 10}}, Message: "inside"}, + {Range: Range{Start: Position{Line: 2}, End: Position{Line: 3}}, Message: "outside"}, + }} + raw, _ := json.Marshal(ctx) + p.Context = json.RawMessage(raw) + sel := "some selected code" + ca := s.buildDiagnosticsCodeAction(p, sel) + if ca == nil { t.Fatalf("expected diagnostics code action") } + if ca.Edit == nil || len(ca.Edit.Changes) == 0 { t.Fatalf("expected workspace edit") } +} + +func TestBuildDiagnosticsCodeAction_NoDiagnostics(t *testing.T) { + s := newTestServer() + s.llmClient = fakeLLM{resp: "FIXED"} + p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: "file:///t.go"}, Range: Range{}} + // empty context + p.Context = json.RawMessage(nil) + if ca := s.buildDiagnosticsCodeAction(p, "sel"); ca != nil { t.Fatalf("expected nil action when no diagnostics") } +} diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index d21c5b3..43d42c8 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -68,16 +68,15 @@ func (s *Server) handleCodeAction(req Request) { } return } - // Extract selected text d := s.getDocument(p.TextDocument.URI) - if d == nil || len(d.lines) == 0 { + if d == nil || len(d.lines) == 0 || s.llmClient == nil { if len(req.ID) != 0 { s.reply(req.ID, []CodeAction{}, nil) } return } sel := extractRangeText(d, p.Range) - if strings.TrimSpace(sel) == "" || s.llmClient == nil { + if strings.TrimSpace(sel) == "" { if len(req.ID) != 0 { s.reply(req.ID, []CodeAction{}, nil) } @@ -85,67 +84,77 @@ func (s *Server) handleCodeAction(req Request) { } actions := make([]CodeAction, 0, 2) + if a := s.buildRewriteCodeAction(p, sel); a != nil { + actions = append(actions, *a) + } + if a := s.buildDiagnosticsCodeAction(p, sel); a != nil { + actions = append(actions, *a) + } + if len(req.ID) != 0 { + s.reply(req.ID, actions, nil) + } +} - // Action 1: Rewrite selection based on first instruction in selection +func (s *Server) buildRewriteCodeAction(p CodeActionParams, sel string) *CodeAction { if instr, cleaned := instructionFromSelection(sel); strings.TrimSpace(instr) != "" { sys := "You are a precise code refactoring engine. Rewrite the given code strictly according to the instruction. Return only the updated code with no prose or backticks. Preserve formatting where reasonable." user := fmt.Sprintf("Instruction: %s\n\nSelected code to transform:\n%s", instr, cleaned) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} - // Build request options from server settings - opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} - if s.codingTemperature != nil { - opts = append(opts, llm.WithTemperature(*s.codingTemperature)) - } - if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { - out := strings.TrimSpace(text) - if out != "" { + opts := s.llmRequestOpts() + if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { + if out := strings.TrimSpace(text); out != "" { edit := WorkspaceEdit{Changes: map[string][]TextEdit{p.TextDocument.URI: {{Range: p.Range, NewText: out}}}} - actions = append(actions, CodeAction{Title: "Hexai: rewrite selection", Kind: "refactor.rewrite", Edit: &edit}) + ca := CodeAction{Title: "Hexai: rewrite selection", Kind: "refactor.rewrite", Edit: &edit} + return &ca } } else { logging.Logf("lsp ", "codeAction rewrite llm error: %v", err) } } + return nil +} - // Action 2: Resolve diagnostics within selection - if diags := s.diagnosticsInRange(p.Context, p.Range); len(diags) > 0 { - // Compose a prompt listing diagnostics relevant to the selected code - sys := "You are a precise code fixer. Resolve the given diagnostics by editing only the selected code. Return only the corrected code with no prose or backticks. Keep behavior and style, and avoid unrelated changes." - var b strings.Builder - b.WriteString("Diagnostics to resolve (selection only):\n") - for i, dgn := range diags { - // Minimal, user-facing summary; include source if present - if dgn.Source != "" { - fmt.Fprintf(&b, "%d. [%s] %s\n", i+1, dgn.Source, dgn.Message) - } else { - fmt.Fprintf(&b, "%d. %s\n", i+1, dgn.Message) - } - } - b.WriteString("\nSelected code:\n") - b.WriteString(sel) - ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) - defer cancel() - messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: b.String()}} - opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} - if s.codingTemperature != nil { - opts = append(opts, llm.WithTemperature(*s.codingTemperature)) - } - if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { - out := strings.TrimSpace(text) - if out != "" { - edit := WorkspaceEdit{Changes: map[string][]TextEdit{p.TextDocument.URI: {{Range: p.Range, NewText: out}}}} - actions = append(actions, CodeAction{Title: "Hexai: resolve diagnostics", Kind: "quickfix", Edit: &edit}) - } +func (s *Server) buildDiagnosticsCodeAction(p CodeActionParams, sel string) *CodeAction { + diags := s.diagnosticsInRange(p.Context, p.Range) + if len(diags) == 0 { + return nil + } + sys := "You are a precise code fixer. Resolve the given diagnostics by editing only the selected code. Return only the corrected code with no prose or backticks. Keep behavior and style, and avoid unrelated changes." + var b strings.Builder + b.WriteString("Diagnostics to resolve (selection only):\n") + for i, dgn := range diags { + if dgn.Source != "" { + fmt.Fprintf(&b, "%d. [%s] %s\n", i+1, dgn.Source, dgn.Message) } else { - logging.Logf("lsp ", "codeAction diagnostics llm error: %v", err) + fmt.Fprintf(&b, "%d. %s\n", i+1, dgn.Message) } } + b.WriteString("\nSelected code:\n") + b.WriteString(sel) + ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) + defer cancel() + messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: b.String()}} + opts := s.llmRequestOpts() + if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { + if out := strings.TrimSpace(text); out != "" { + edit := WorkspaceEdit{Changes: map[string][]TextEdit{p.TextDocument.URI: {{Range: p.Range, NewText: out}}}} + ca := CodeAction{Title: "Hexai: resolve diagnostics", Kind: "quickfix", Edit: &edit} + return &ca + } + } else { + logging.Logf("lsp ", "codeAction diagnostics llm error: %v", err) + } + return nil +} - if len(req.ID) != 0 { - s.reply(req.ID, actions, nil) +func (s *Server) llmRequestOpts() []llm.RequestOption { + opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} + if s.codingTemperature != nil { + opts = append(opts, llm.WithTemperature(*s.codingTemperature)) } + return opts } // instructionFromSelection extracts the first instruction from selection text. @@ -457,64 +466,22 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun for _, m := range messages { sentSize += len(m.Content) } - // Update request counters (sent) - s.mu.Lock() - s.llmReqTotal++ - s.llmSentBytesTotal += int64(sentSize) - s.mu.Unlock() + s.incSentCounters(sentSize) - opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} - if s.codingTemperature != nil { - opts = append(opts, llm.WithTemperature(*s.codingTemperature)) - } - text, err := s.llmClient.Chat(ctx, messages, opts...) + opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} + if s.codingTemperature != nil { + opts = append(opts, llm.WithTemperature(*s.codingTemperature)) + } + text, err := s.llmClient.Chat(ctx, messages, opts...) if err != nil { logging.Logf("lsp ", "llm completion error: %v", err) // Log updated averages after this request (even if failed) - s.mu.RLock() - avgSent := int64(0) - if s.llmReqTotal > 0 { - avgSent = s.llmSentBytesTotal / s.llmReqTotal - } - avgRecv := int64(0) - if s.llmRespTotal > 0 { - avgRecv = s.llmRespBytesTotal / s.llmRespTotal - } - reqs, sentTot, recvTot := s.llmReqTotal, s.llmSentBytesTotal, s.llmRespBytesTotal - s.mu.RUnlock() - mins := time.Since(s.startTime).Minutes() - if mins <= 0 { - mins = 0.001 - } - rpm := float64(reqs) / mins - sentPerMin := float64(sentTot) / mins - recvPerMin := float64(recvTot) / mins - logging.Logf("lsp ", "llm stats reqs=%d avg_sent=%d avg_recv=%d sent_total=%d recv_total=%d rpm=%.2f sent_per_min=%.0f recv_per_min=%.0f", reqs, avgSent, avgRecv, sentTot, recvTot, rpm, sentPerMin, recvPerMin) + s.logLLMStats() return nil, false } // Update response counters (received) - recvSize := len(text) - s.mu.Lock() - s.llmRespTotal++ - s.llmRespBytesTotal += int64(recvSize) - avgSent := int64(0) - if s.llmReqTotal > 0 { - avgSent = s.llmSentBytesTotal / s.llmReqTotal - } - avgRecv := int64(0) - if s.llmRespTotal > 0 { - avgRecv = s.llmRespBytesTotal / s.llmRespTotal - } - reqs, sentTot, recvTot := s.llmReqTotal, s.llmSentBytesTotal, s.llmRespBytesTotal - s.mu.Unlock() - mins := time.Since(s.startTime).Minutes() - if mins <= 0 { - mins = 0.001 - } - rpm := float64(reqs) / mins - sentPerMin := float64(sentTot) / mins - recvPerMin := float64(recvTot) / mins - logging.Logf("lsp ", "llm stats reqs=%d avg_sent=%d avg_recv=%d sent_total=%d recv_total=%d rpm=%.2f sent_per_min=%.0f recv_per_min=%.0f", reqs, avgSent, avgRecv, sentTot, recvTot, rpm, sentPerMin, recvPerMin) + s.incRecvCounters(len(text)) + s.logLLMStats() cleaned := strings.TrimSpace(text) if cleaned != "" { cleaned = stripDuplicateAssignmentPrefix(current[:p.Position.Character], cleaned) @@ -523,15 +490,18 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun return nil, false } + return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true +} + +func (s *Server) makeCompletionItems(cleaned string, inParams bool, current string, p CompletionParams, docStr string) []CompletionItem { te, filter := computeTextEditAndFilter(cleaned, inParams, current, p) rm := s.collectPromptRemovalEdits(p.TextDocument.URI) label := labelForCompletion(cleaned, filter) - // Detail shows provider/model for visibility in client UI detail := "Hexai LLM completion" if s.llmClient != nil { detail = "Hexai " + s.llmClient.Name() + ":" + s.llmClient.DefaultModel() } - items := []CompletionItem{{ + return []CompletionItem{{ Label: label, Kind: 1, Detail: detail, @@ -542,7 +512,43 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun SortText: "0000", Documentation: docStr, }} - return items, true +} + +// small helpers to keep tryLLMCompletion short +func (s *Server) incSentCounters(n int) { + s.mu.Lock() + s.llmReqTotal++ + s.llmSentBytesTotal += int64(n) + s.mu.Unlock() +} + +func (s *Server) incRecvCounters(n int) { + s.mu.Lock() + s.llmRespTotal++ + s.llmRespBytesTotal += int64(n) + s.mu.Unlock() +} + +func (s *Server) logLLMStats() { + s.mu.RLock() + avgSent := int64(0) + if s.llmReqTotal > 0 { + avgSent = s.llmSentBytesTotal / s.llmReqTotal + } + avgRecv := int64(0) + if s.llmRespTotal > 0 { + avgRecv = s.llmRespBytesTotal / s.llmRespTotal + } + reqs, sentTot, recvTot := s.llmReqTotal, s.llmSentBytesTotal, s.llmRespBytesTotal + s.mu.RUnlock() + mins := time.Since(s.startTime).Minutes() + if mins <= 0 { + mins = 0.001 + } + rpm := float64(reqs) / mins + sentPerMin := float64(sentTot) / mins + recvPerMin := float64(recvTot) / mins + logging.Logf("lsp ", "llm stats reqs=%d avg_sent=%d avg_recv=%d sent_total=%d recv_total=%d rpm=%.2f sent_per_min=%.0f recv_per_min=%.0f", reqs, avgSent, avgRecv, sentTot, recvTot, rpm, sentPerMin, recvPerMin) } // collectPromptRemovalEdits returns edits to remove all inline prompt markers. @@ -559,83 +565,78 @@ func (s *Server) collectPromptRemovalEdits(uri string) []TextEdit { } var edits []TextEdit for i, line := range d.lines { - // If the line contains a double-semicolon trigger of the form - // ";;text;" (no space after the ";;" and no space before the closing ';'), - // remove the entire line. - removeWholeLine := false - { - pos := 0 - for pos < len(line) { - j := strings.Index(line[pos:], ";;") - if j < 0 { - break - } - j += pos - // ensure there's a non-space after the two semicolons - if j+2 >= len(line) || line[j+2] == ' ' { - pos = j + 2 - continue - } - // find closing ';' after the content - k := strings.Index(line[j+2:], ";") - if k < 0 { - break - } + edits = append(edits, promptRemovalEditsForLine(line, i)...) + } + return edits +} + +func promptRemovalEditsForLine(line string, lineNum int) []TextEdit { + if hasDoubleSemicolonTrigger(line) { + return []TextEdit{{Range: Range{Start: Position{Line: lineNum, Character: 0}, End: Position{Line: lineNum, Character: len(line)}}, NewText: ""}} + } + return collectSemicolonMarkers(line, lineNum) +} + +func hasDoubleSemicolonTrigger(line string) bool { + pos := 0 + for pos < len(line) { + j := strings.Index(line[pos:], ";;") + if j < 0 { + return false + } + j += pos + if j+2 < len(line) && line[j+2] != ' ' { + if k := strings.Index(line[j+2:], ";"); k >= 0 { closeIdx := j + 2 + k - // ensure char before closing ';' is not a space - if closeIdx-1 < 0 || line[closeIdx-1] == ' ' { - pos = closeIdx + 1 - continue + if closeIdx-1 >= 0 && line[closeIdx-1] != ' ' { + return true } - removeWholeLine = true - break + pos = closeIdx + 1 + continue } + return false + } + pos = j + 2 + } + return false +} + +func collectSemicolonMarkers(line string, lineNum int) []TextEdit { + var edits []TextEdit + startSemi := 0 + for startSemi < len(line) { + j := strings.Index(line[startSemi:], ";") + if j < 0 { + break } - if removeWholeLine { - edits = append(edits, TextEdit{Range: Range{Start: Position{Line: i, Character: 0}, End: Position{Line: i, Character: len(line)}}, NewText: ""}) + j += startSemi + k := strings.Index(line[j+1:], ";") + if k < 0 { + break + } + if j+1 >= len(line) || line[j+1] == ' ' { + startSemi = j + 1 continue } - // Scan for ;...; markers that have no spaces directly inside the semicolons - startSemi := 0 - for startSemi < len(line) { - j := strings.Index(line[startSemi:], ";") - if j < 0 { - break - } - j += startSemi - k := strings.Index(line[j+1:], ";") - if k < 0 { - break - } - // Require no space immediately after the first ';' - if j+1 >= len(line) || line[j+1] == ' ' { - startSemi = j + 1 - continue - } - // Ignore patterns that start with double semicolon here; handled above - if line[j+1] == ';' { - startSemi = j + 2 - continue - } - // Index of the closing ';' - closeIdx := j + 1 + k - // Require no space immediately before the closing ';' - if closeIdx-1 < 0 || line[closeIdx-1] == ' ' { - startSemi = closeIdx + 1 - continue - } - // Require at least one character between the semicolons - if closeIdx-(j+1) < 1 { - startSemi = closeIdx + 1 - continue - } - endChar := closeIdx + 1 // include trailing ';' - if endChar < len(line) && line[endChar] == ' ' { - endChar++ - } - edits = append(edits, TextEdit{Range: Range{Start: Position{Line: i, Character: j}, End: Position{Line: i, Character: endChar}}, NewText: ""}) - startSemi = endChar + if line[j+1] == ';' { + startSemi = j + 2 + continue + } + closeIdx := j + 1 + k + if closeIdx-1 < 0 || line[closeIdx-1] == ' ' { + startSemi = closeIdx + 1 + continue + } + if closeIdx-(j+1) < 1 { + startSemi = closeIdx + 1 + continue + } + endChar := closeIdx + 1 + if endChar < len(line) && line[endChar] == ' ' { + endChar++ } + edits = append(edits, TextEdit{Range: Range{Start: Position{Line: lineNum, Character: j}, End: Position{Line: lineNum, Character: endChar}}, NewText: ""}) + startSemi = endChar } return edits } diff --git a/internal/lsp/handlers_helpers_test.go b/internal/lsp/handlers_helpers_test.go new file mode 100644 index 0000000..84dce77 --- /dev/null +++ b/internal/lsp/handlers_helpers_test.go @@ -0,0 +1,52 @@ +package lsp + +import ( + "strings" + "testing" +) + +func TestHasDoubleSemicolonTrigger(t *testing.T) { + cases := []struct{ + line string + want bool + }{ + {";;todo; remove this", true}, + {"prefix ;;x; suffix", true}, + {";; spaced ;", false}, + {"no markers", false}, + {";;x ; space before close", false}, + } + for _, tc := range cases { + got := hasDoubleSemicolonTrigger(tc.line) + if got != tc.want { + t.Fatalf("hasDoubleSemicolonTrigger(%q)=%v want %v", tc.line, got, tc.want) + } + } +} + +func TestCollectSemicolonMarkers(t *testing.T) { + line := "keep ;ok; this and ;another; that" + edits := collectSemicolonMarkers(line, 7) + if len(edits) != 2 { + t.Fatalf("expected 2 edits, got %d", len(edits)) + } + // Validate the first edit aligns with ;ok; + start := strings.Index(line, ";ok;") + if start < 0 { t.Fatalf("test setup: missing ;ok;") } + if edits[0].Range.Start.Line != 7 || edits[0].Range.Start.Character != start { + t.Fatalf("first edit start got line=%d char=%d want line=7 char=%d", edits[0].Range.Start.Line, edits[0].Range.Start.Character, start) + } +} + +func TestPromptRemovalEditsForLine_WholeLine(t *testing.T) { + line := ";;todo; remove this whole line" + edits := promptRemovalEditsForLine(line, 3) + if len(edits) != 1 { + t.Fatalf("expected 1 whole-line edit, got %d", len(edits)) + } + e := edits[0] + if e.Range.Start.Line != 3 || e.Range.End.Line != 3 || e.Range.Start.Character != 0 || e.Range.End.Character != len(line) { + t.Fatalf("unexpected range for whole-line removal: %+v", e.Range) + } +} + diff --git a/internal/lsp/handlers_test.go b/internal/lsp/handlers_test.go index 9a490e3..10b704b 100644 --- a/internal/lsp/handlers_test.go +++ b/internal/lsp/handlers_test.go @@ -9,16 +9,26 @@ import ( ) func TestInParamList(t *testing.T) { - line := "func foo(a int, b string) int {" - if !inParamList(line, 15) { // inside params - t.Fatalf("expected inParamList true for cursor inside params") - } - if inParamList(line, 2) { // before 'func' - t.Fatalf("expected inParamList false for cursor before params") - } - if inParamList(line, len(line)) { // after ')' - t.Fatalf("expected inParamList false for cursor after params") - } + line := "func foo(a int, b string) int {" + cases := []struct{ + name string + cursor int + want bool + }{ + {"inside-params", 15, true}, + {"before-func", 2, false}, + {"after-paren", len(line), false}, + {"at-open-paren", strings.Index(line, "(")+1, true}, + {"at-close-paren", strings.Index(line, ")"), true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := inParamList(line, tc.cursor) + if got != tc.want { + t.Fatalf("cursor=%d got %v want %v", tc.cursor, got, tc.want) + } + }) + } } func TestComputeWordStart(t *testing.T) { |
