summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2025-09-06 10:56:27 +0300
committerPaul Buetow <paul@buetow.org>2025-09-06 10:56:27 +0300
commit320de746f7a2985b60c8564a0e65bdf231e840b7 (patch)
treee70bcf50813dba411afa2934e774383124bbc99e
parent06247527d5170f329b454b42f59a3e4434ab1f4b (diff)
use gofumpt
-rw-r--r--AGENTS.md8
-rw-r--r--internal/appconfig/config.go413
-rw-r--r--internal/appconfig/config_test.go294
-rw-r--r--internal/hexaicli/run.go78
-rw-r--r--internal/hexaicli/run_test.go196
-rw-r--r--internal/hexaicli/testhelpers_test.go40
-rw-r--r--internal/hexailsp/run.go56
-rw-r--r--internal/llm/copilot.go297
-rw-r--r--internal/llm/copilot_http_test.go392
-rw-r--r--internal/llm/ollama_test.go296
-rw-r--r--internal/llm/openai_http_test.go250
-rw-r--r--internal/llm/openai_sse_negative_test.go46
-rw-r--r--internal/llm/openai_test.go79
-rw-r--r--internal/llm/provider.go108
-rw-r--r--internal/llm/provider_more_test.go37
-rw-r--r--internal/llm/provider_test.go30
-rw-r--r--internal/llm/util_test.go7
-rw-r--r--internal/logging/chatlogger.go3
-rw-r--r--internal/logging/logging.go14
-rw-r--r--internal/logging/logging_test.go72
-rw-r--r--internal/lsp/build_prompts_table_test.go24
-rw-r--r--internal/lsp/chat_history_test.go46
-rw-r--r--internal/lsp/chat_no_double_answer_test.go29
-rw-r--r--internal/lsp/code_fences_table_test.go45
-rw-r--r--internal/lsp/codeaction_more_test.go151
-rw-r--r--internal/lsp/codeaction_test.go11
-rw-r--r--internal/lsp/codegen_helpers_test.go19
-rw-r--r--internal/lsp/completion_cache_test.go10
-rw-r--r--internal/lsp/completion_codex_path_test.go8
-rw-r--r--internal/lsp/completion_helpers_more_test.go60
-rw-r--r--internal/lsp/completion_messages_test.go116
-rw-r--r--internal/lsp/completion_prefix_strip_test.go97
-rw-r--r--internal/lsp/completion_provider_fallback_test.go59
-rw-r--r--internal/lsp/compute_textedit_table_test.go53
-rw-r--r--internal/lsp/context.go3
-rw-r--r--internal/lsp/debounce_throttle_more_test.go51
-rw-r--r--internal/lsp/debounce_throttle_test.go123
-rw-r--r--internal/lsp/diagnostics_action_test.go49
-rw-r--r--internal/lsp/document.go2
-rw-r--r--internal/lsp/document_handlers_test.go102
-rw-r--r--internal/lsp/document_test.go42
-rw-r--r--internal/lsp/fallback_items_test.go11
-rw-r--r--internal/lsp/gotest_append_test.go51
-rw-r--r--internal/lsp/handlers.go28
-rw-r--r--internal/lsp/handlers_codeaction.go575
-rw-r--r--internal/lsp/handlers_completion.go219
-rw-r--r--internal/lsp/handlers_document.go160
-rw-r--r--internal/lsp/handlers_end_to_end_test.go454
-rw-r--r--internal/lsp/handlers_execute.go53
-rw-r--r--internal/lsp/handlers_helpers_test.go56
-rw-r--r--internal/lsp/handlers_init.go3
-rw-r--r--internal/lsp/handlers_test.go66
-rw-r--r--internal/lsp/handlers_utils.go265
-rw-r--r--internal/lsp/helpers_inline_prompt_test.go82
-rw-r--r--internal/lsp/helpers_more_test.go188
-rw-r--r--internal/lsp/init_and_trigger_test.go104
-rw-r--r--internal/lsp/init_shutdown_test.go27
-rw-r--r--internal/lsp/instruction_table_test.go36
-rw-r--r--internal/lsp/label_filter_table_test.go21
-rw-r--r--internal/lsp/llm_stats_test.go9
-rw-r--r--internal/lsp/log_context_test.go15
-rw-r--r--internal/lsp/postprocess_indent_test.go14
-rw-r--r--internal/lsp/prefix_table_test.go35
-rw-r--r--internal/lsp/provider_native_success_test.go66
-rw-r--r--internal/lsp/rewrite_diagnostics_realism_test.go113
-rw-r--r--internal/lsp/server.go163
-rw-r--r--internal/lsp/testfakes_test.go5
-rw-r--r--internal/lsp/transport.go3
-rw-r--r--internal/lsp/transport_test.go71
-rw-r--r--internal/lsp/triggers_config_test.go118
-rw-r--r--internal/lsp/types.go32
-rw-r--r--internal/testutil/fixtures.go11
72 files changed, 3769 insertions, 3101 deletions
diff --git a/AGENTS.md b/AGENTS.md
index fe3f8ca..0729682 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -9,16 +9,10 @@
- `tests/`: Future test suites mirroring `src/` paths.
- `scripts/`: Helper tools and maintenance scripts.
-## Build, Test, and Development Commands
-
-- Lint Markdown: `markdownlint **/*.md` — checks heading/style rules.
-- Spellcheck: `codespell` — catches common typos.
-- Optimize images: `pngquant --quality=70-85 input.png -o assets/input.png`.
-- No build step required for docs-only changes.
-
## Coding Style & Naming Conventions
- Aim for at least 85% unit test coverage of all source code.
+- Always run the gofumpt code reformater on all go files modified.
- Ensure that all unit tests pass before merging any changes.
- If possible, construct individual methods so that they can be unit tested. But only if it doesn't add too much boilerplate to the code base.
- There should be no source code file larger than 1000 lines. If so, split it up into multiple.
diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go
index d19ea18..92fdf19 100644
--- a/internal/appconfig/config.go
+++ b/internal/appconfig/config.go
@@ -2,14 +2,14 @@
package appconfig
import (
- "encoding/json"
- "fmt"
- "log"
- "os"
- "path/filepath"
- "slices"
- "strconv"
- "strings"
+ "encoding/json"
+ "fmt"
+ "log"
+ "os"
+ "path/filepath"
+ "slices"
+ "strconv"
+ "strings"
)
// App holds user-configurable settings read from ~/.config/hexai/config.json.
@@ -20,25 +20,25 @@ type App struct {
MaxContextTokens int `json:"max_context_tokens"`
LogPreviewLimit int `json:"log_preview_limit"`
// Single knob for LSP requests; if set, overrides hardcoded temps in LSP.
- CodingTemperature *float64 `json:"coding_temperature"`
- // Minimum identifier characters required for manual (TriggerKind=1) invoke
- // to proceed without structural triggers. 0 means always allow.
- ManualInvokeMinPrefix int `json:"manual_invoke_min_prefix"`
+ CodingTemperature *float64 `json:"coding_temperature"`
+ // Minimum identifier characters required for manual (TriggerKind=1) invoke
+ // to proceed without structural triggers. 0 means always allow.
+ ManualInvokeMinPrefix int `json:"manual_invoke_min_prefix"`
- // Completion debounce in milliseconds. When > 0, the server waits until
- // there has been no text change for at least this duration before sending
- // an LLM completion request.
- CompletionDebounceMs int `json:"completion_debounce_ms"`
- // Completion throttle in milliseconds. When > 0, caps the minimum spacing
- // between LLM requests (both chat and code-completer paths).
- CompletionThrottleMs int `json:"completion_throttle_ms"`
+ // Completion debounce in milliseconds. When > 0, the server waits until
+ // there has been no text change for at least this duration before sending
+ // an LLM completion request.
+ CompletionDebounceMs int `json:"completion_debounce_ms"`
+ // Completion throttle in milliseconds. When > 0, caps the minimum spacing
+ // between LLM requests (both chat and code-completer paths).
+ CompletionThrottleMs int `json:"completion_throttle_ms"`
TriggerCharacters []string `json:"trigger_characters"`
Provider string `json:"provider"`
// Inline prompt trigger characters (default: >text> and >>text>)
- InlineOpen string `json:"inline_open"`
- InlineClose string `json:"inline_close"`
+ InlineOpen string `json:"inline_open"`
+ InlineClose string `json:"inline_close"`
// In-editor chat triggers (default: suffix ">" after one of [?, !, :, ;])
ChatSuffix string `json:"chat_suffix"`
ChatPrefixes []string `json:"chat_prefixes"`
@@ -64,51 +64,51 @@ func newDefaultConfig() App {
// Users can override per provider in config.json (including 0.0).
t := 0.2
return App{
- MaxTokens: 4000,
- ContextMode: "always-full",
- ContextWindowLines: 120,
- MaxContextTokens: 4000,
- LogPreviewLimit: 100,
- CodingTemperature: &t,
- OpenAITemperature: &t,
- OllamaTemperature: &t,
- CopilotTemperature: &t,
- ManualInvokeMinPrefix: 0,
- CompletionDebounceMs: 200,
- CompletionThrottleMs: 0,
- // Inline/chat trigger defaults
- InlineOpen: ">",
- InlineClose: ">",
- ChatSuffix: ">",
- ChatPrefixes: []string{"?", "!", ":", ";"},
- }
+ MaxTokens: 4000,
+ ContextMode: "always-full",
+ ContextWindowLines: 120,
+ MaxContextTokens: 4000,
+ LogPreviewLimit: 100,
+ CodingTemperature: &t,
+ OpenAITemperature: &t,
+ OllamaTemperature: &t,
+ CopilotTemperature: &t,
+ ManualInvokeMinPrefix: 0,
+ CompletionDebounceMs: 200,
+ CompletionThrottleMs: 0,
+ // Inline/chat trigger defaults
+ InlineOpen: ">",
+ InlineClose: ">",
+ ChatSuffix: ">",
+ ChatPrefixes: []string{"?", "!", ":", ";"},
+ }
}
// Load reads configuration from a file and merges with defaults.
// It respects the XDG Base Directory Specification.
func Load(logger *log.Logger) App {
- cfg := newDefaultConfig()
- if logger == nil {
- return cfg // Return defaults if no logger is provided (e.g. in tests)
- }
+ cfg := newDefaultConfig()
+ if logger == nil {
+ return cfg // Return defaults if no logger is provided (e.g. in tests)
+ }
- configPath, err := getConfigPath()
- if err != nil {
- logger.Printf("%v", err)
- // Even if config path cannot be resolved, still allow env overrides below.
- } else {
- if fileCfg, err := loadFromFile(configPath, logger); err == nil && fileCfg != nil {
- cfg.mergeWith(fileCfg)
- }
- // When the config file is missing or invalid, we keep defaults and still
- // apply any environment overrides below.
- }
+ configPath, err := getConfigPath()
+ if err != nil {
+ logger.Printf("%v", err)
+ // Even if config path cannot be resolved, still allow env overrides below.
+ } else {
+ if fileCfg, err := loadFromFile(configPath, logger); err == nil && fileCfg != nil {
+ cfg.mergeWith(fileCfg)
+ }
+ // When the config file is missing or invalid, we keep defaults and still
+ // apply any environment overrides below.
+ }
- // Environment overrides (take precedence over file)
- if envCfg := loadFromEnv(logger); envCfg != nil {
- cfg.mergeWith(envCfg)
- }
- return cfg
+ // Environment overrides (take precedence over file)
+ if envCfg := loadFromEnv(logger); envCfg != nil {
+ cfg.mergeWith(envCfg)
+ }
+ return cfg
}
// Private helpers
@@ -134,8 +134,8 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) {
}
func (a *App) mergeWith(other *App) {
- a.mergeBasics(other)
- a.mergeProviderFields(other)
+ a.mergeBasics(other)
+ a.mergeProviderFields(other)
}
// mergeBasics merges general (non-provider) fields.
@@ -155,32 +155,36 @@ func (a *App) mergeBasics(other *App) {
if other.LogPreviewLimit >= 0 {
a.LogPreviewLimit = other.LogPreviewLimit
}
- if other.CodingTemperature != nil { // allow explicit 0.0
- a.CodingTemperature = other.CodingTemperature
- }
- if other.ManualInvokeMinPrefix >= 0 {
- a.ManualInvokeMinPrefix = other.ManualInvokeMinPrefix
- }
- if other.CompletionDebounceMs > 0 { a.CompletionDebounceMs = other.CompletionDebounceMs }
- if other.CompletionThrottleMs > 0 { a.CompletionThrottleMs = other.CompletionThrottleMs }
- if len(other.TriggerCharacters) > 0 {
- a.TriggerCharacters = slices.Clone(other.TriggerCharacters)
- }
- if s := strings.TrimSpace(other.InlineOpen); s != "" {
- a.InlineOpen = s
- }
- if s := strings.TrimSpace(other.InlineClose); s != "" {
- a.InlineClose = s
- }
- if s := strings.TrimSpace(other.ChatSuffix); s != "" {
- a.ChatSuffix = s
- }
- if len(other.ChatPrefixes) > 0 {
- a.ChatPrefixes = slices.Clone(other.ChatPrefixes)
- }
- if s := strings.TrimSpace(other.Provider); s != "" {
- a.Provider = s
- }
+ if other.CodingTemperature != nil { // allow explicit 0.0
+ a.CodingTemperature = other.CodingTemperature
+ }
+ if other.ManualInvokeMinPrefix >= 0 {
+ a.ManualInvokeMinPrefix = other.ManualInvokeMinPrefix
+ }
+ if other.CompletionDebounceMs > 0 {
+ a.CompletionDebounceMs = other.CompletionDebounceMs
+ }
+ if other.CompletionThrottleMs > 0 {
+ a.CompletionThrottleMs = other.CompletionThrottleMs
+ }
+ if len(other.TriggerCharacters) > 0 {
+ a.TriggerCharacters = slices.Clone(other.TriggerCharacters)
+ }
+ if s := strings.TrimSpace(other.InlineOpen); s != "" {
+ a.InlineOpen = s
+ }
+ if s := strings.TrimSpace(other.InlineClose); s != "" {
+ a.InlineClose = s
+ }
+ if s := strings.TrimSpace(other.ChatSuffix); s != "" {
+ a.ChatSuffix = s
+ }
+ if len(other.ChatPrefixes) > 0 {
+ a.ChatPrefixes = slices.Clone(other.ChatPrefixes)
+ }
+ if s := strings.TrimSpace(other.Provider); s != "" {
+ a.Provider = s
+ }
}
// mergeProviderFields merges per-provider configuration.
@@ -225,7 +229,7 @@ func getConfigPath() (string, error) {
}
configPath = filepath.Join(home, ".config", "hexai", "config.json")
}
- return configPath, nil
+ return configPath, nil
}
// --- Environment overrides ---
@@ -233,98 +237,155 @@ func getConfigPath() (string, error) {
// loadFromEnv constructs an App containing only fields set via HEXAI_* env vars.
// These values should take precedence over file config when merged.
func loadFromEnv(logger *log.Logger) *App {
- var out App
- var any bool
+ var out App
+ var any bool
- // helpers
- getenv := func(k string) string { return strings.TrimSpace(os.Getenv(k)) }
- parseInt := func(k string) (int, bool) {
- v := getenv(k)
- if v == "" { return 0, false }
- n, err := strconv.Atoi(v)
- if err != nil { if logger != nil { logger.Printf("invalid %s: %v", k, err) } ; return 0, false }
- return n, true
- }
- parseFloatPtr := func(k string) (*float64, bool) {
- v := getenv(k)
- if v == "" { return nil, false }
- f, err := strconv.ParseFloat(v, 64)
- if err != nil {
- if logger != nil { logger.Printf("invalid %s: %v", k, err) }
- return nil, false
- }
- return &f, true
- }
+ // helpers
+ getenv := func(k string) string { return strings.TrimSpace(os.Getenv(k)) }
+ parseInt := func(k string) (int, bool) {
+ v := getenv(k)
+ if v == "" {
+ return 0, false
+ }
+ n, err := strconv.Atoi(v)
+ if err != nil {
+ if logger != nil {
+ logger.Printf("invalid %s: %v", k, err)
+ }
+ return 0, false
+ }
+ return n, true
+ }
+ parseFloatPtr := func(k string) (*float64, bool) {
+ v := getenv(k)
+ if v == "" {
+ return nil, false
+ }
+ f, err := strconv.ParseFloat(v, 64)
+ if err != nil {
+ if logger != nil {
+ logger.Printf("invalid %s: %v", k, err)
+ }
+ return nil, false
+ }
+ return &f, true
+ }
- if n, ok := parseInt("HEXAI_MAX_TOKENS"); ok {
- out.MaxTokens = n; any = true
- }
- if s := getenv("HEXAI_CONTEXT_MODE"); s != "" {
- out.ContextMode = s; any = true
- }
- if n, ok := parseInt("HEXAI_CONTEXT_WINDOW_LINES"); ok {
- out.ContextWindowLines = n; any = true
- }
- if n, ok := parseInt("HEXAI_MAX_CONTEXT_TOKENS"); ok {
- out.MaxContextTokens = n; any = true
- }
- if n, ok := parseInt("HEXAI_LOG_PREVIEW_LIMIT"); ok {
- out.LogPreviewLimit = n; any = true
- }
- if n, ok := parseInt("HEXAI_MANUAL_INVOKE_MIN_PREFIX"); ok {
- out.ManualInvokeMinPrefix = n; any = true
- }
- if n, ok := parseInt("HEXAI_COMPLETION_DEBOUNCE_MS"); ok {
- out.CompletionDebounceMs = n; any = true
- }
- if n, ok := parseInt("HEXAI_COMPLETION_THROTTLE_MS"); ok {
- out.CompletionThrottleMs = n; any = true
- }
- if f, ok := parseFloatPtr("HEXAI_CODING_TEMPERATURE"); ok {
- out.CodingTemperature = f; any = true
- }
- if s := getenv("HEXAI_TRIGGER_CHARACTERS"); s != "" {
- parts := strings.Split(s, ",")
- out.TriggerCharacters = nil
- for _, p := range parts {
- if t := strings.TrimSpace(p); t != "" {
- out.TriggerCharacters = append(out.TriggerCharacters, t)
- }
- }
- any = true
- }
- if s := getenv("HEXAI_INLINE_OPEN"); s != "" { out.InlineOpen = s; any = true }
- if s := getenv("HEXAI_INLINE_CLOSE"); s != "" { out.InlineClose = s; any = true }
- if s := getenv("HEXAI_CHAT_SUFFIX"); s != "" { out.ChatSuffix = s; any = true }
- if s := getenv("HEXAI_CHAT_PREFIXES"); s != "" {
- parts := strings.Split(s, ",")
- out.ChatPrefixes = nil
- for _, p := range parts {
- if t := strings.TrimSpace(p); t != "" {
- out.ChatPrefixes = append(out.ChatPrefixes, t)
- }
- }
- any = true
- }
- if s := getenv("HEXAI_PROVIDER"); s != "" {
- out.Provider = s; any = true
- }
+ if n, ok := parseInt("HEXAI_MAX_TOKENS"); ok {
+ out.MaxTokens = n
+ any = true
+ }
+ if s := getenv("HEXAI_CONTEXT_MODE"); s != "" {
+ out.ContextMode = s
+ any = true
+ }
+ if n, ok := parseInt("HEXAI_CONTEXT_WINDOW_LINES"); ok {
+ out.ContextWindowLines = n
+ any = true
+ }
+ if n, ok := parseInt("HEXAI_MAX_CONTEXT_TOKENS"); ok {
+ out.MaxContextTokens = n
+ any = true
+ }
+ if n, ok := parseInt("HEXAI_LOG_PREVIEW_LIMIT"); ok {
+ out.LogPreviewLimit = n
+ any = true
+ }
+ if n, ok := parseInt("HEXAI_MANUAL_INVOKE_MIN_PREFIX"); ok {
+ out.ManualInvokeMinPrefix = n
+ any = true
+ }
+ if n, ok := parseInt("HEXAI_COMPLETION_DEBOUNCE_MS"); ok {
+ out.CompletionDebounceMs = n
+ any = true
+ }
+ if n, ok := parseInt("HEXAI_COMPLETION_THROTTLE_MS"); ok {
+ out.CompletionThrottleMs = n
+ any = true
+ }
+ if f, ok := parseFloatPtr("HEXAI_CODING_TEMPERATURE"); ok {
+ out.CodingTemperature = f
+ any = true
+ }
+ if s := getenv("HEXAI_TRIGGER_CHARACTERS"); s != "" {
+ parts := strings.Split(s, ",")
+ out.TriggerCharacters = nil
+ for _, p := range parts {
+ if t := strings.TrimSpace(p); t != "" {
+ out.TriggerCharacters = append(out.TriggerCharacters, t)
+ }
+ }
+ any = true
+ }
+ if s := getenv("HEXAI_INLINE_OPEN"); s != "" {
+ out.InlineOpen = s
+ any = true
+ }
+ if s := getenv("HEXAI_INLINE_CLOSE"); s != "" {
+ out.InlineClose = s
+ any = true
+ }
+ if s := getenv("HEXAI_CHAT_SUFFIX"); s != "" {
+ out.ChatSuffix = s
+ any = true
+ }
+ if s := getenv("HEXAI_CHAT_PREFIXES"); s != "" {
+ parts := strings.Split(s, ",")
+ out.ChatPrefixes = nil
+ for _, p := range parts {
+ if t := strings.TrimSpace(p); t != "" {
+ out.ChatPrefixes = append(out.ChatPrefixes, t)
+ }
+ }
+ any = true
+ }
+ if s := getenv("HEXAI_PROVIDER"); s != "" {
+ out.Provider = s
+ any = true
+ }
- // Provider-specific
- if s := getenv("HEXAI_OPENAI_BASE_URL"); s != "" { out.OpenAIBaseURL = s; any = true }
- if s := getenv("HEXAI_OPENAI_MODEL"); s != "" { out.OpenAIModel = s; any = true }
- if f, ok := parseFloatPtr("HEXAI_OPENAI_TEMPERATURE"); ok { out.OpenAITemperature = f; any = true }
+ // Provider-specific
+ if s := getenv("HEXAI_OPENAI_BASE_URL"); s != "" {
+ out.OpenAIBaseURL = s
+ any = true
+ }
+ if s := getenv("HEXAI_OPENAI_MODEL"); s != "" {
+ out.OpenAIModel = s
+ any = true
+ }
+ if f, ok := parseFloatPtr("HEXAI_OPENAI_TEMPERATURE"); ok {
+ out.OpenAITemperature = f
+ any = true
+ }
- if s := getenv("HEXAI_OLLAMA_BASE_URL"); s != "" { out.OllamaBaseURL = s; any = true }
- if s := getenv("HEXAI_OLLAMA_MODEL"); s != "" { out.OllamaModel = s; any = true }
- if f, ok := parseFloatPtr("HEXAI_OLLAMA_TEMPERATURE"); ok { out.OllamaTemperature = f; any = true }
+ if s := getenv("HEXAI_OLLAMA_BASE_URL"); s != "" {
+ out.OllamaBaseURL = s
+ any = true
+ }
+ if s := getenv("HEXAI_OLLAMA_MODEL"); s != "" {
+ out.OllamaModel = s
+ any = true
+ }
+ if f, ok := parseFloatPtr("HEXAI_OLLAMA_TEMPERATURE"); ok {
+ out.OllamaTemperature = f
+ any = true
+ }
- if s := getenv("HEXAI_COPILOT_BASE_URL"); s != "" { out.CopilotBaseURL = s; any = true }
- if s := getenv("HEXAI_COPILOT_MODEL"); s != "" { out.CopilotModel = s; any = true }
- if f, ok := parseFloatPtr("HEXAI_COPILOT_TEMPERATURE"); ok { out.CopilotTemperature = f; any = true }
+ if s := getenv("HEXAI_COPILOT_BASE_URL"); s != "" {
+ out.CopilotBaseURL = s
+ any = true
+ }
+ if s := getenv("HEXAI_COPILOT_MODEL"); s != "" {
+ out.CopilotModel = s
+ any = true
+ }
+ if f, ok := parseFloatPtr("HEXAI_COPILOT_TEMPERATURE"); ok {
+ out.CopilotTemperature = f
+ any = true
+ }
- if !any {
- return nil
- }
- return &out
+ if !any {
+ return nil
+ }
+ return &out
}
diff --git a/internal/appconfig/config_test.go b/internal/appconfig/config_test.go
index 30898a6..f2e3f7a 100644
--- a/internal/appconfig/config_test.go
+++ b/internal/appconfig/config_test.go
@@ -1,167 +1,185 @@
package appconfig
import (
- "encoding/json"
- "io"
- "log"
- "os"
- "path/filepath"
- "reflect"
- "strings"
- "testing"
+ "encoding/json"
+ "io"
+ "log"
+ "os"
+ "path/filepath"
+ "reflect"
+ "strings"
+ "testing"
)
func newLogger() *log.Logger { return log.New(io.Discard, "", 0) }
func writeJSON(t *testing.T, path string, v any) {
- t.Helper()
- if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
- t.Fatalf("mkdir: %v", err)
- }
- f, err := os.Create(path)
- if err != nil { t.Fatalf("create: %v", err) }
- defer f.Close()
- enc := json.NewEncoder(f)
- if err := enc.Encode(v); err != nil {
- t.Fatalf("encode json: %v", err)
- }
+ t.Helper()
+ if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+ t.Fatalf("mkdir: %v", err)
+ }
+ f, err := os.Create(path)
+ if err != nil {
+ t.Fatalf("create: %v", err)
+ }
+ defer f.Close()
+ enc := json.NewEncoder(f)
+ if err := enc.Encode(v); err != nil {
+ t.Fatalf("encode json: %v", err)
+ }
}
-func withEnv(t *testing.T, k, v string) { t.Helper(); old := os.Getenv(k); _ = os.Setenv(k, v); t.Cleanup(func(){ _ = os.Setenv(k, old) }) }
+func withEnv(t *testing.T, k, v string) {
+ t.Helper()
+ old := os.Getenv(k)
+ _ = os.Setenv(k, v)
+ t.Cleanup(func() { _ = os.Setenv(k, old) })
+}
func TestLoad_Defaults_NoLogger(t *testing.T) {
- cfg := Load(nil)
- if cfg.MaxTokens == 0 || cfg.ContextMode == "" || cfg.ContextWindowLines == 0 || cfg.MaxContextTokens == 0 {
- t.Fatalf("expected defaults populated, got %+v", cfg)
- }
- if cfg.CodingTemperature == nil { t.Fatalf("expected default CodingTemperature") }
+ cfg := Load(nil)
+ if cfg.MaxTokens == 0 || cfg.ContextMode == "" || cfg.ContextWindowLines == 0 || cfg.MaxContextTokens == 0 {
+ t.Fatalf("expected defaults populated, got %+v", cfg)
+ }
+ if cfg.CodingTemperature == nil {
+ t.Fatalf("expected default CodingTemperature")
+ }
}
func TestLoad_Defaults_WithLogger_NoFile_NoEnv(t *testing.T) {
- t.Setenv("XDG_CONFIG_HOME", t.TempDir())
- logger := newLogger()
- cfg := Load(logger)
- def := newDefaultConfig()
- if cfg.MaxTokens != def.MaxTokens || cfg.ContextMode != def.ContextMode || cfg.ContextWindowLines != def.ContextWindowLines {
- t.Fatalf("expected defaults; got %+v want %+v", cfg, def)
- }
+ t.Setenv("XDG_CONFIG_HOME", t.TempDir())
+ logger := newLogger()
+ cfg := Load(logger)
+ def := newDefaultConfig()
+ if cfg.MaxTokens != def.MaxTokens || cfg.ContextMode != def.ContextMode || cfg.ContextWindowLines != def.ContextWindowLines {
+ t.Fatalf("expected defaults; got %+v want %+v", cfg, def)
+ }
}
func TestLoad_FileMerge_And_EnvOverride(t *testing.T) {
- dir := t.TempDir()
- t.Setenv("XDG_CONFIG_HOME", dir)
- cfgPath := filepath.Join(dir, "hexai", "config.json")
- temp0 := 0.0
- fileCfg := App{
- MaxTokens: 123,
- ContextMode: "file-on-new-func",
- ContextWindowLines: 50,
- MaxContextTokens: 999,
- LogPreviewLimit: 0,
- CodingTemperature: &temp0,
- ManualInvokeMinPrefix: 2,
- CompletionDebounceMs: 150,
- CompletionThrottleMs: 300,
- TriggerCharacters: []string{".", ":"},
- Provider: "openai",
- OpenAIBaseURL: "https://api.example",
- OpenAIModel: "gpt-x",
- OpenAITemperature: &temp0,
- OllamaBaseURL: "http://ollama",
- OllamaModel: "llama",
- OllamaTemperature: &temp0,
- CopilotBaseURL: "http://copilot",
- CopilotModel: "ghost",
- CopilotTemperature: &temp0,
- }
- writeJSON(t, cfgPath, fileCfg)
+ dir := t.TempDir()
+ t.Setenv("XDG_CONFIG_HOME", dir)
+ cfgPath := filepath.Join(dir, "hexai", "config.json")
+ temp0 := 0.0
+ fileCfg := App{
+ MaxTokens: 123,
+ ContextMode: "file-on-new-func",
+ ContextWindowLines: 50,
+ MaxContextTokens: 999,
+ LogPreviewLimit: 0,
+ CodingTemperature: &temp0,
+ ManualInvokeMinPrefix: 2,
+ CompletionDebounceMs: 150,
+ CompletionThrottleMs: 300,
+ TriggerCharacters: []string{".", ":"},
+ Provider: "openai",
+ OpenAIBaseURL: "https://api.example",
+ OpenAIModel: "gpt-x",
+ OpenAITemperature: &temp0,
+ OllamaBaseURL: "http://ollama",
+ OllamaModel: "llama",
+ OllamaTemperature: &temp0,
+ CopilotBaseURL: "http://copilot",
+ CopilotModel: "ghost",
+ CopilotTemperature: &temp0,
+ }
+ writeJSON(t, cfgPath, fileCfg)
- // Env overrides take precedence
- withEnv(t, "HEXAI_MAX_TOKENS", "321")
- withEnv(t, "HEXAI_CONTEXT_MODE", "always-full")
- withEnv(t, "HEXAI_CONTEXT_WINDOW_LINES", "77")
- withEnv(t, "HEXAI_MAX_CONTEXT_TOKENS", "888")
- withEnv(t, "HEXAI_LOG_PREVIEW_LIMIT", "7")
- withEnv(t, "HEXAI_CODING_TEMPERATURE", "0.7")
- withEnv(t, "HEXAI_MANUAL_INVOKE_MIN_PREFIX", "5")
- withEnv(t, "HEXAI_COMPLETION_DEBOUNCE_MS", "333")
- withEnv(t, "HEXAI_COMPLETION_THROTTLE_MS", "444")
- withEnv(t, "HEXAI_TRIGGER_CHARACTERS", "., / ,_")
- withEnv(t, "HEXAI_PROVIDER", "ollama")
- withEnv(t, "HEXAI_OPENAI_BASE_URL", "https://override")
- withEnv(t, "HEXAI_OPENAI_MODEL", "gpt-override")
- withEnv(t, "HEXAI_OPENAI_TEMPERATURE", "0.4")
- withEnv(t, "HEXAI_OLLAMA_BASE_URL", "http://ollama-override")
- withEnv(t, "HEXAI_OLLAMA_MODEL", "mistral")
- withEnv(t, "HEXAI_OLLAMA_TEMPERATURE", "0.6")
- withEnv(t, "HEXAI_COPILOT_BASE_URL", "http://copilot-override")
- withEnv(t, "HEXAI_COPILOT_MODEL", "ghost-override")
- withEnv(t, "HEXAI_COPILOT_TEMPERATURE", "0.3")
+ // Env overrides take precedence
+ withEnv(t, "HEXAI_MAX_TOKENS", "321")
+ withEnv(t, "HEXAI_CONTEXT_MODE", "always-full")
+ withEnv(t, "HEXAI_CONTEXT_WINDOW_LINES", "77")
+ withEnv(t, "HEXAI_MAX_CONTEXT_TOKENS", "888")
+ withEnv(t, "HEXAI_LOG_PREVIEW_LIMIT", "7")
+ withEnv(t, "HEXAI_CODING_TEMPERATURE", "0.7")
+ withEnv(t, "HEXAI_MANUAL_INVOKE_MIN_PREFIX", "5")
+ withEnv(t, "HEXAI_COMPLETION_DEBOUNCE_MS", "333")
+ withEnv(t, "HEXAI_COMPLETION_THROTTLE_MS", "444")
+ withEnv(t, "HEXAI_TRIGGER_CHARACTERS", "., / ,_")
+ withEnv(t, "HEXAI_PROVIDER", "ollama")
+ withEnv(t, "HEXAI_OPENAI_BASE_URL", "https://override")
+ withEnv(t, "HEXAI_OPENAI_MODEL", "gpt-override")
+ withEnv(t, "HEXAI_OPENAI_TEMPERATURE", "0.4")
+ withEnv(t, "HEXAI_OLLAMA_BASE_URL", "http://ollama-override")
+ withEnv(t, "HEXAI_OLLAMA_MODEL", "mistral")
+ withEnv(t, "HEXAI_OLLAMA_TEMPERATURE", "0.6")
+ withEnv(t, "HEXAI_COPILOT_BASE_URL", "http://copilot-override")
+ withEnv(t, "HEXAI_COPILOT_MODEL", "ghost-override")
+ withEnv(t, "HEXAI_COPILOT_TEMPERATURE", "0.3")
- logger := newLogger()
- cfg := Load(logger)
+ logger := newLogger()
+ cfg := Load(logger)
- // Check overrides
- if cfg.MaxTokens != 321 || cfg.ContextMode != "always-full" || cfg.ContextWindowLines != 77 || cfg.MaxContextTokens != 888 {
- t.Fatalf("env overrides (basic) not applied: %+v", cfg)
- }
- if cfg.LogPreviewLimit != 7 || cfg.ManualInvokeMinPrefix != 5 || cfg.CompletionDebounceMs != 333 || cfg.CompletionThrottleMs != 444 {
- t.Fatalf("env overrides (ints) not applied: %+v", cfg)
- }
- if cfg.CodingTemperature == nil || *cfg.CodingTemperature != 0.7 {
- t.Fatalf("env override (CodingTemperature) not applied: %+v", cfg.CodingTemperature)
- }
- if want := []string{".", "/", "_"}; !reflect.DeepEqual(cfg.TriggerCharacters, want) {
- t.Fatalf("env override (TriggerCharacters), got %v want %v", cfg.TriggerCharacters, want)
- }
- if cfg.Provider != "ollama" {
- t.Fatalf("provider override failed: %q", cfg.Provider)
- }
- // Provider-specific
- if cfg.OpenAIBaseURL != "https://override" || cfg.OpenAIModel != "gpt-override" || cfg.OpenAITemperature == nil || *cfg.OpenAITemperature != 0.4 {
- t.Fatalf("openai overrides not applied: %+v", cfg)
- }
- if cfg.OllamaBaseURL != "http://ollama-override" || cfg.OllamaModel != "mistral" || cfg.OllamaTemperature == nil || *cfg.OllamaTemperature != 0.6 {
- t.Fatalf("ollama overrides not applied: %+v", cfg)
- }
- if cfg.CopilotBaseURL != "http://copilot-override" || cfg.CopilotModel != "ghost-override" || cfg.CopilotTemperature == nil || *cfg.CopilotTemperature != 0.3 {
- t.Fatalf("copilot overrides not applied: %+v", cfg)
- }
+ // Check overrides
+ if cfg.MaxTokens != 321 || cfg.ContextMode != "always-full" || cfg.ContextWindowLines != 77 || cfg.MaxContextTokens != 888 {
+ t.Fatalf("env overrides (basic) not applied: %+v", cfg)
+ }
+ if cfg.LogPreviewLimit != 7 || cfg.ManualInvokeMinPrefix != 5 || cfg.CompletionDebounceMs != 333 || cfg.CompletionThrottleMs != 444 {
+ t.Fatalf("env overrides (ints) not applied: %+v", cfg)
+ }
+ if cfg.CodingTemperature == nil || *cfg.CodingTemperature != 0.7 {
+ t.Fatalf("env override (CodingTemperature) not applied: %+v", cfg.CodingTemperature)
+ }
+ if want := []string{".", "/", "_"}; !reflect.DeepEqual(cfg.TriggerCharacters, want) {
+ t.Fatalf("env override (TriggerCharacters), got %v want %v", cfg.TriggerCharacters, want)
+ }
+ if cfg.Provider != "ollama" {
+ t.Fatalf("provider override failed: %q", cfg.Provider)
+ }
+ // Provider-specific
+ if cfg.OpenAIBaseURL != "https://override" || cfg.OpenAIModel != "gpt-override" || cfg.OpenAITemperature == nil || *cfg.OpenAITemperature != 0.4 {
+ t.Fatalf("openai overrides not applied: %+v", cfg)
+ }
+ if cfg.OllamaBaseURL != "http://ollama-override" || cfg.OllamaModel != "mistral" || cfg.OllamaTemperature == nil || *cfg.OllamaTemperature != 0.6 {
+ t.Fatalf("ollama overrides not applied: %+v", cfg)
+ }
+ if cfg.CopilotBaseURL != "http://copilot-override" || cfg.CopilotModel != "ghost-override" || cfg.CopilotTemperature == nil || *cfg.CopilotTemperature != 0.3 {
+ t.Fatalf("copilot overrides not applied: %+v", cfg)
+ }
- // Ensure file values would have applied absent env
- // Spot-check: reset env and reload
- for _, k := range []string{
- "HEXAI_MAX_TOKENS","HEXAI_CONTEXT_MODE","HEXAI_CONTEXT_WINDOW_LINES","HEXAI_MAX_CONTEXT_TOKENS","HEXAI_LOG_PREVIEW_LIMIT","HEXAI_CODING_TEMPERATURE","HEXAI_MANUAL_INVOKE_MIN_PREFIX","HEXAI_COMPLETION_DEBOUNCE_MS","HEXAI_COMPLETION_THROTTLE_MS","HEXAI_TRIGGER_CHARACTERS","HEXAI_PROVIDER","HEXAI_OPENAI_BASE_URL","HEXAI_OPENAI_MODEL","HEXAI_OPENAI_TEMPERATURE","HEXAI_OLLAMA_BASE_URL","HEXAI_OLLAMA_MODEL","HEXAI_OLLAMA_TEMPERATURE","HEXAI_COPILOT_BASE_URL","HEXAI_COPILOT_MODEL","HEXAI_COPILOT_TEMPERATURE",
- } { t.Setenv(k, "") }
- cfg2 := Load(logger)
- if cfg2.MaxTokens != 123 || cfg2.ContextMode != "file-on-new-func" || cfg2.ContextWindowLines != 50 || cfg2.MaxContextTokens != 999 || cfg2.LogPreviewLimit != 0 {
- t.Fatalf("file merge not applied: %+v", cfg2)
- }
- if cfg2.CodingTemperature == nil || *cfg2.CodingTemperature != 0.0 {
- t.Fatalf("file merge (CodingTemperature) not applied: %+v", cfg2.CodingTemperature)
- }
- if cfg2.OpenAIBaseURL != "https://api.example" || cfg2.OpenAIModel != "gpt-x" || cfg2.OpenAITemperature == nil || *cfg2.OpenAITemperature != 0.0 {
- t.Fatalf("file merge (openai) not applied: %+v", cfg2)
- }
+ // Ensure file values would have applied absent env
+ // Spot-check: reset env and reload
+ for _, k := range []string{
+ "HEXAI_MAX_TOKENS", "HEXAI_CONTEXT_MODE", "HEXAI_CONTEXT_WINDOW_LINES", "HEXAI_MAX_CONTEXT_TOKENS", "HEXAI_LOG_PREVIEW_LIMIT", "HEXAI_CODING_TEMPERATURE", "HEXAI_MANUAL_INVOKE_MIN_PREFIX", "HEXAI_COMPLETION_DEBOUNCE_MS", "HEXAI_COMPLETION_THROTTLE_MS", "HEXAI_TRIGGER_CHARACTERS", "HEXAI_PROVIDER", "HEXAI_OPENAI_BASE_URL", "HEXAI_OPENAI_MODEL", "HEXAI_OPENAI_TEMPERATURE", "HEXAI_OLLAMA_BASE_URL", "HEXAI_OLLAMA_MODEL", "HEXAI_OLLAMA_TEMPERATURE", "HEXAI_COPILOT_BASE_URL", "HEXAI_COPILOT_MODEL", "HEXAI_COPILOT_TEMPERATURE",
+ } {
+ t.Setenv(k, "")
+ }
+ cfg2 := Load(logger)
+ if cfg2.MaxTokens != 123 || cfg2.ContextMode != "file-on-new-func" || cfg2.ContextWindowLines != 50 || cfg2.MaxContextTokens != 999 || cfg2.LogPreviewLimit != 0 {
+ t.Fatalf("file merge not applied: %+v", cfg2)
+ }
+ if cfg2.CodingTemperature == nil || *cfg2.CodingTemperature != 0.0 {
+ t.Fatalf("file merge (CodingTemperature) not applied: %+v", cfg2.CodingTemperature)
+ }
+ if cfg2.OpenAIBaseURL != "https://api.example" || cfg2.OpenAIModel != "gpt-x" || cfg2.OpenAITemperature == nil || *cfg2.OpenAITemperature != 0.0 {
+ t.Fatalf("file merge (openai) not applied: %+v", cfg2)
+ }
}
func TestGetConfigPath_XDG(t *testing.T) {
- dir := t.TempDir()
- t.Setenv("XDG_CONFIG_HOME", dir)
- path, err := getConfigPath()
- if err != nil { t.Fatalf("getConfigPath: %v", err) }
- if !strings.HasPrefix(path, filepath.Join(dir, "hexai")) || !strings.HasSuffix(path, "config.json") {
- t.Fatalf("unexpected path: %s", path)
- }
+ dir := t.TempDir()
+ t.Setenv("XDG_CONFIG_HOME", dir)
+ path, err := getConfigPath()
+ if err != nil {
+ t.Fatalf("getConfigPath: %v", err)
+ }
+ if !strings.HasPrefix(path, filepath.Join(dir, "hexai")) || !strings.HasSuffix(path, "config.json") {
+ t.Fatalf("unexpected path: %s", path)
+ }
}
func TestLoadFromFile_InvalidJSON(t *testing.T) {
- dir := t.TempDir()
- t.Setenv("XDG_CONFIG_HOME", dir)
- cfgPath := filepath.Join(dir, "hexai", "config.json")
- if err := os.MkdirAll(filepath.Dir(cfgPath), 0o755); err != nil { t.Fatal(err) }
- if err := os.WriteFile(cfgPath, []byte("{ invalid"), 0o644); err != nil { t.Fatal(err) }
- _, err := loadFromFile(cfgPath, newLogger())
- if err == nil { t.Fatalf("expected error for invalid JSON") }
+ dir := t.TempDir()
+ t.Setenv("XDG_CONFIG_HOME", dir)
+ cfgPath := filepath.Join(dir, "hexai", "config.json")
+ if err := os.MkdirAll(filepath.Dir(cfgPath), 0o755); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.WriteFile(cfgPath, []byte("{ invalid"), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ _, err := loadFromFile(cfgPath, newLogger())
+ if err == nil {
+ t.Fatalf("expected error for invalid JSON")
+ }
}
-
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index 7471816..54cb3ff 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -3,14 +3,14 @@
package hexaicli
import (
- "bufio"
- "context"
- "fmt"
- "io"
- "log"
- "os"
- "strings"
- "time"
+ "bufio"
+ "context"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "strings"
+ "time"
"codeberg.org/snonux/hexai/internal/appconfig"
"codeberg.org/snonux/hexai/internal/llm"
@@ -20,14 +20,14 @@ import (
// Run executes the Hexai CLI behavior given arguments and I/O streams.
// It assumes flags have already been parsed by the caller.
func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
- // Load configuration with a logger so file-based config is respected.
- logger := log.New(stderr, "hexai ", log.LstdFlags|log.Lmsgprefix)
- cfg := appconfig.Load(logger)
- client, err := newClientFromConfig(cfg)
- if err != nil {
- fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err)
- return err
- }
+ // Load configuration with a logger so file-based config is respected.
+ logger := log.New(stderr, "hexai ", log.LstdFlags|log.Lmsgprefix)
+ cfg := appconfig.Load(logger)
+ client, err := newClientFromConfig(cfg)
+ if err != nil {
+ fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err)
+ return err
+ }
return RunWithClient(ctx, args, stdin, stdout, stderr, client)
}
@@ -71,29 +71,29 @@ func readInput(stdin io.Reader, args []string) (string, error) {
// newClientFromConfig builds an LLM client from the app config and env keys.
func newClientFromConfig(cfg appconfig.App) (llm.Client, error) {
- llmCfg := llm.Config{
- Provider: cfg.Provider,
- OpenAIBaseURL: cfg.OpenAIBaseURL,
- OpenAIModel: cfg.OpenAIModel,
- OpenAITemperature: cfg.OpenAITemperature,
- OllamaBaseURL: cfg.OllamaBaseURL,
- OllamaModel: cfg.OllamaModel,
- OllamaTemperature: cfg.OllamaTemperature,
- CopilotBaseURL: cfg.CopilotBaseURL,
- CopilotModel: cfg.CopilotModel,
- CopilotTemperature: cfg.CopilotTemperature,
- }
- // Prefer HEXAI_OPENAI_API_KEY; fall back to OPENAI_API_KEY
- oaKey := os.Getenv("HEXAI_OPENAI_API_KEY")
- if strings.TrimSpace(oaKey) == "" {
- oaKey = os.Getenv("OPENAI_API_KEY")
- }
- // Prefer HEXAI_COPILOT_API_KEY; fall back to COPILOT_API_KEY
- cpKey := os.Getenv("HEXAI_COPILOT_API_KEY")
- if strings.TrimSpace(cpKey) == "" {
- cpKey = os.Getenv("COPILOT_API_KEY")
- }
- return llm.NewFromConfig(llmCfg, oaKey, cpKey)
+ llmCfg := llm.Config{
+ Provider: cfg.Provider,
+ OpenAIBaseURL: cfg.OpenAIBaseURL,
+ OpenAIModel: cfg.OpenAIModel,
+ OpenAITemperature: cfg.OpenAITemperature,
+ OllamaBaseURL: cfg.OllamaBaseURL,
+ OllamaModel: cfg.OllamaModel,
+ OllamaTemperature: cfg.OllamaTemperature,
+ CopilotBaseURL: cfg.CopilotBaseURL,
+ CopilotModel: cfg.CopilotModel,
+ CopilotTemperature: cfg.CopilotTemperature,
+ }
+ // Prefer HEXAI_OPENAI_API_KEY; fall back to OPENAI_API_KEY
+ oaKey := os.Getenv("HEXAI_OPENAI_API_KEY")
+ if strings.TrimSpace(oaKey) == "" {
+ oaKey = os.Getenv("OPENAI_API_KEY")
+ }
+ // Prefer HEXAI_COPILOT_API_KEY; fall back to COPILOT_API_KEY
+ cpKey := os.Getenv("HEXAI_COPILOT_API_KEY")
+ if strings.TrimSpace(cpKey) == "" {
+ cpKey = os.Getenv("COPILOT_API_KEY")
+ }
+ return llm.NewFromConfig(llmCfg, oaKey, cpKey)
}
// buildMessages creates system and user messages based on input content.
diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go
index 0d77e19..77daa8b 100644
--- a/internal/hexaicli/run_test.go
+++ b/internal/hexaicli/run_test.go
@@ -1,122 +1,150 @@
package hexaicli
import (
- "bytes"
- "context"
- "io"
- "path/filepath"
- "strings"
- "testing"
+ "bytes"
+ "context"
+ "io"
+ "path/filepath"
+ "strings"
+ "testing"
- "codeberg.org/snonux/hexai/internal/appconfig"
- "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/appconfig"
+ "codeberg.org/snonux/hexai/internal/llm"
)
func TestReadInput_Combinations(t *testing.T) {
- // stdin + arg
- restore, f := setStdin(t, "from-stdin")
- defer restore()
- s, err := readInput(f, []string{"from-arg"})
- if err != nil || !strings.HasPrefix(s, "from-arg:\n\nfrom-stdin") { t.Fatalf("stdin+arg failed: %q %v", s, err) }
- // stdin only
- restore2, f2 := setStdin(t, "from-stdin")
- defer restore2()
- s, err = readInput(f2, nil)
- if err != nil || s != "from-stdin" { t.Fatalf("stdin only failed: %q %v", s, err) }
- // arg only
- s, err = readInput(strings.NewReader(""), []string{"arg1","arg2"})
- if err != nil || s != "arg1 arg2" { t.Fatalf("arg only failed: %q %v", s, err) }
- // no input
- restore3, f3 := setStdin(t, "")
- defer restore3()
- _, err = readInput(f3, nil)
- if err == nil { t.Fatalf("expected error for no input") }
+ // stdin + arg
+ restore, f := setStdin(t, "from-stdin")
+ defer restore()
+ s, err := readInput(f, []string{"from-arg"})
+ if err != nil || !strings.HasPrefix(s, "from-arg:\n\nfrom-stdin") {
+ t.Fatalf("stdin+arg failed: %q %v", s, err)
+ }
+ // stdin only
+ restore2, f2 := setStdin(t, "from-stdin")
+ defer restore2()
+ s, err = readInput(f2, nil)
+ if err != nil || s != "from-stdin" {
+ t.Fatalf("stdin only failed: %q %v", s, err)
+ }
+ // arg only
+ s, err = readInput(strings.NewReader(""), []string{"arg1", "arg2"})
+ if err != nil || s != "arg1 arg2" {
+ t.Fatalf("arg only failed: %q %v", s, err)
+ }
+ // no input
+ restore3, f3 := setStdin(t, "")
+ defer restore3()
+ _, err = readInput(f3, nil)
+ if err == nil {
+ t.Fatalf("expected error for no input")
+ }
}
func TestBuildMessages_Explain(t *testing.T) {
- msgs := buildMessages("please explain this")
- if len(msgs) != 2 || msgs[0].Role != "system" || !strings.Contains(strings.ToLower(msgs[0].Content), "explanation") {
- t.Fatalf("unexpected system prompt: %#v", msgs)
- }
+ msgs := buildMessages("please explain this")
+ if len(msgs) != 2 || msgs[0].Role != "system" || !strings.Contains(strings.ToLower(msgs[0].Content), "explanation") {
+ t.Fatalf("unexpected system prompt: %#v", msgs)
+ }
}
func TestBuildMessages_Default(t *testing.T) {
- msgs := buildMessages("just do it")
- if len(msgs) != 2 || msgs[0].Role != "system" || strings.Contains(msgs[0].Content, "requested an explanation") {
- t.Fatalf("unexpected system prompt: %#v", msgs)
- }
+ msgs := buildMessages("just do it")
+ if len(msgs) != 2 || msgs[0].Role != "system" || strings.Contains(msgs[0].Content, "requested an explanation") {
+ t.Fatalf("unexpected system prompt: %#v", msgs)
+ }
}
func TestRunChat_StreamAndNonStream(t *testing.T) {
- // stream path
- fc := &fakeStreamer{fakeClient: fakeClient{name: "p", model: "m"}, chunks: []string{"H","i","!"}}
- var out, errb bytes.Buffer
- if err := runChat(context.Background(), fc, buildMessages("hello"), "hello", &out, &errb); err != nil { t.Fatalf("stream: %v", err) }
- if out.String() != "Hi!" || !strings.Contains(errb.String(), "provider=p model=m") { t.Fatalf("bad output or summary: %q %q", out.String(), errb.String()) }
- // non-stream path
- fc2 := &fakeClient{name: "p2", model: "m2", resp: "Yo"}
- out.Reset(); errb.Reset()
- if err := runChat(context.Background(), fc2, buildMessages("hello"), "hello", &out, &errb); err != nil { t.Fatalf("non-stream: %v", err) }
- if out.String() != "Yo" || !strings.Contains(errb.String(), "provider=p2 model=m2") { t.Fatalf("bad output or summary (non-stream)") }
+ // stream path
+ fc := &fakeStreamer{fakeClient: fakeClient{name: "p", model: "m"}, chunks: []string{"H", "i", "!"}}
+ var out, errb bytes.Buffer
+ if err := runChat(context.Background(), fc, buildMessages("hello"), "hello", &out, &errb); err != nil {
+ t.Fatalf("stream: %v", err)
+ }
+ if out.String() != "Hi!" || !strings.Contains(errb.String(), "provider=p model=m") {
+ t.Fatalf("bad output or summary: %q %q", out.String(), errb.String())
+ }
+ // non-stream path
+ fc2 := &fakeClient{name: "p2", model: "m2", resp: "Yo"}
+ out.Reset()
+ errb.Reset()
+ if err := runChat(context.Background(), fc2, buildMessages("hello"), "hello", &out, &errb); err != nil {
+ t.Fatalf("non-stream: %v", err)
+ }
+ if out.String() != "Yo" || !strings.Contains(errb.String(), "provider=p2 model=m2") {
+ t.Fatalf("bad output or summary (non-stream)")
+ }
}
type clientErr struct{ name, model string }
-func (c clientErr) Chat(context.Context, []llm.Message, ...llm.RequestOption) (string, error) { return "", io.EOF }
-func (c clientErr) Name() string { return c.name }
+
+func (c clientErr) Chat(context.Context, []llm.Message, ...llm.RequestOption) (string, error) {
+ return "", io.EOF
+}
+func (c clientErr) Name() string { return c.name }
func (c clientErr) DefaultModel() string { return c.model }
func TestRunChat_ErrorPaths(t *testing.T) {
- ctx := context.Background()
- out, errb := &bytes.Buffer{}, &bytes.Buffer{}
- if err := runChat(ctx, clientErr{"p","m"}, buildMessages("hi"), "hi", out, errb); err == nil {
- t.Fatalf("expected error from Chat")
- }
+ ctx := context.Background()
+ out, errb := &bytes.Buffer{}, &bytes.Buffer{}
+ if err := runChat(ctx, clientErr{"p", "m"}, buildMessages("hi"), "hi", out, errb); err == nil {
+ t.Fatalf("expected error from Chat")
+ }
}
func TestRunWithClient_ErrorPrint(t *testing.T) {
- var out, errb bytes.Buffer
- err := RunWithClient(context.Background(), []string{"hi"}, strings.NewReader(""), &out, &errb, clientErr{"p","m"})
- if err == nil { t.Fatalf("expected error") }
- if !strings.Contains(errb.String(), "hexai: error:") {
- t.Fatalf("expected error line, got %q", errb.String())
- }
+ var out, errb bytes.Buffer
+ err := RunWithClient(context.Background(), []string{"hi"}, strings.NewReader(""), &out, &errb, clientErr{"p", "m"})
+ if err == nil {
+ t.Fatalf("expected error")
+ }
+ if !strings.Contains(errb.String(), "hexai: error:") {
+ t.Fatalf("expected error line, got %q", errb.String())
+ }
}
func TestRun_OpenAI_NoKey_ShowsError(t *testing.T) {
- dir := testingTempDir(t)
- // write config with provider=openai
- writeJSON(t, filepath.Join(dir, "hexai", "config.json"), map[string]any{"provider":"openai", "openai_model":"gpt-x"})
- t.Setenv("XDG_CONFIG_HOME", dir)
- // Ensure no OpenAI API key is present in environment
- t.Setenv("HEXAI_OPENAI_API_KEY", "")
- t.Setenv("OPENAI_API_KEY", "")
- var out, errb bytes.Buffer
- // Run expects parsed flags; here args irrelevant
- err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb)
- if err == nil { t.Fatalf("expected error due to missing API key") }
- // Accept either explicit "LLM disabled" or a generic provider error emitted by Run.
- if !(strings.Contains(errb.String(), "LLM disabled") || strings.Contains(errb.String(), "openai error") || strings.Contains(errb.String(), "hexai: error:")) {
- t.Fatalf("expected disabled-or-error message, got %q", errb.String())
- }
+ dir := testingTempDir(t)
+ // write config with provider=openai
+ writeJSON(t, filepath.Join(dir, "hexai", "config.json"), map[string]any{"provider": "openai", "openai_model": "gpt-x"})
+ t.Setenv("XDG_CONFIG_HOME", dir)
+ // Ensure no OpenAI API key is present in environment
+ t.Setenv("HEXAI_OPENAI_API_KEY", "")
+ t.Setenv("OPENAI_API_KEY", "")
+ var out, errb bytes.Buffer
+ // Run expects parsed flags; here args irrelevant
+ err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb)
+ if err == nil {
+ t.Fatalf("expected error due to missing API key")
+ }
+ // Accept either explicit "LLM disabled" or a generic provider error emitted by Run.
+ if !(strings.Contains(errb.String(), "LLM disabled") || strings.Contains(errb.String(), "openai error") || strings.Contains(errb.String(), "hexai: error:")) {
+ t.Fatalf("expected disabled-or-error message, got %q", errb.String())
+ }
}
func TestPrintProviderInfo(t *testing.T) {
- var b bytes.Buffer
- printProviderInfo(&b, &fakeClient{name:"x", model:"y"})
- if !strings.Contains(b.String(), "provider=x model=y") { t.Fatalf("missing provider line: %q", b.String()) }
+ var b bytes.Buffer
+ printProviderInfo(&b, &fakeClient{name: "x", model: "y"})
+ if !strings.Contains(b.String(), "provider=x model=y") {
+ t.Fatalf("missing provider line: %q", b.String())
+ }
}
func TestNewClientFromConfig_Ollama(t *testing.T) {
- cfg := appconfig.App{ Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m" }
- c, err := newClientFromConfig(cfg)
- if err != nil || c == nil { t.Fatalf("expected client: %v %v", c, err) }
+ cfg := appconfig.App{Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m"}
+ c, err := newClientFromConfig(cfg)
+ if err != nil || c == nil {
+ t.Fatalf("expected client: %v %v", c, err)
+ }
}
func TestNewClientFromConfig_OpenAI_MissingKey(t *testing.T) {
- cfg := appconfig.App{ Provider: "openai", OpenAIBaseURL: "https://api", OpenAIModel: "gpt" }
- t.Setenv("HEXAI_OPENAI_API_KEY", "")
- t.Setenv("OPENAI_API_KEY", "")
- if _, err := newClientFromConfig(cfg); err == nil {
- t.Fatalf("expected error for missing openai key")
- }
+ cfg := appconfig.App{Provider: "openai", OpenAIBaseURL: "https://api", OpenAIModel: "gpt"}
+ t.Setenv("HEXAI_OPENAI_API_KEY", "")
+ t.Setenv("OPENAI_API_KEY", "")
+ if _, err := newClientFromConfig(cfg); err == nil {
+ t.Fatalf("expected error for missing openai key")
+ }
}
diff --git a/internal/hexaicli/testhelpers_test.go b/internal/hexaicli/testhelpers_test.go
index 1f75916..512a3ba 100644
--- a/internal/hexaicli/testhelpers_test.go
+++ b/internal/hexaicli/testhelpers_test.go
@@ -2,13 +2,13 @@
package hexaicli
import (
- "context"
- "encoding/json"
- "os"
- "path/filepath"
- "testing"
+ "context"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "testing"
- "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/llm"
)
// setStdin sets os.Stdin from a string and returns a restore func and reader.
@@ -55,21 +55,27 @@ type fakeStreamer struct {
}
func (s *fakeStreamer) ChatStream(ctx context.Context, messages []llm.Message, onDelta func(string), opts ...llm.RequestOption) error {
- s.sMsgs = append([]llm.Message{}, messages...)
- for _, c := range s.chunks {
- onDelta(c)
- }
- return nil
+ s.sMsgs = append([]llm.Message{}, messages...)
+ for _, c := range s.chunks {
+ onDelta(c)
+ }
+ return nil
}
// small JSON writer for tests
func writeJSON(t *testing.T, path string, v any) {
- t.Helper()
- if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { t.Fatalf("mkdir: %v", err) }
- f, err := os.Create(path)
- if err != nil { t.Fatalf("create: %v", err) }
- defer f.Close()
- if err := json.NewEncoder(f).Encode(v); err != nil { t.Fatalf("encode: %v", err) }
+ t.Helper()
+ if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+ t.Fatalf("mkdir: %v", err)
+ }
+ f, err := os.Create(path)
+ if err != nil {
+ t.Fatalf("create: %v", err)
+ }
+ defer f.Close()
+ if err := json.NewEncoder(f).Encode(v); err != nil {
+ t.Fatalf("encode: %v", err)
+ }
}
func testingTempDir(t *testing.T) string { t.Helper(); return t.TempDir() }
diff --git a/internal/hexailsp/run.go b/internal/hexailsp/run.go
index c12018f..a1be5aa 100644
--- a/internal/hexailsp/run.go
+++ b/internal/hexailsp/run.go
@@ -25,7 +25,7 @@ type ServerFactory func(r io.Reader, w io.Writer, logger *log.Logger, opts lsp.S
func Run(logPath string, stdin io.Reader, stdout io.Writer, stderr io.Writer) error {
logger := log.New(stderr, "hexai-lsp ", log.LstdFlags|log.Lmsgprefix)
if strings.TrimSpace(logPath) != "" {
- f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
+ f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
logger.Fatalf("failed to open log file: %v", err)
}
@@ -77,16 +77,16 @@ func buildClientIfNil(cfg appconfig.App, client llm.Client) llm.Client {
CopilotModel: cfg.CopilotModel,
CopilotTemperature: cfg.CopilotTemperature,
}
- // Prefer HEXAI_OPENAI_API_KEY; fall back to OPENAI_API_KEY
- oaKey := os.Getenv("HEXAI_OPENAI_API_KEY")
- if strings.TrimSpace(oaKey) == "" {
- oaKey = os.Getenv("OPENAI_API_KEY")
- }
- // Prefer HEXAI_COPILOT_API_KEY; fall back to COPILOT_API_KEY
- cpKey := os.Getenv("HEXAI_COPILOT_API_KEY")
- if strings.TrimSpace(cpKey) == "" {
- cpKey = os.Getenv("COPILOT_API_KEY")
- }
+ // Prefer HEXAI_OPENAI_API_KEY; fall back to OPENAI_API_KEY
+ oaKey := os.Getenv("HEXAI_OPENAI_API_KEY")
+ if strings.TrimSpace(oaKey) == "" {
+ oaKey = os.Getenv("OPENAI_API_KEY")
+ }
+ // Prefer HEXAI_COPILOT_API_KEY; fall back to COPILOT_API_KEY
+ cpKey := os.Getenv("HEXAI_COPILOT_API_KEY")
+ if strings.TrimSpace(cpKey) == "" {
+ cpKey = os.Getenv("COPILOT_API_KEY")
+ }
if c, err := llm.NewFromConfig(llmCfg, oaKey, cpKey); err != nil {
logging.Logf("lsp ", "llm disabled: %v", err)
return nil
@@ -106,21 +106,21 @@ func ensureFactory(factory ServerFactory) ServerFactory {
}
func makeServerOptions(cfg appconfig.App, logContext bool, client llm.Client) lsp.ServerOptions {
- return lsp.ServerOptions{
- LogContext: logContext,
- MaxTokens: cfg.MaxTokens,
- ContextMode: cfg.ContextMode,
- WindowLines: cfg.ContextWindowLines,
- MaxContextTokens: cfg.MaxContextTokens,
- CodingTemperature: cfg.CodingTemperature,
- Client: client,
- TriggerCharacters: cfg.TriggerCharacters,
- ManualInvokeMinPrefix: cfg.ManualInvokeMinPrefix,
- CompletionDebounceMs: cfg.CompletionDebounceMs,
- CompletionThrottleMs: cfg.CompletionThrottleMs,
- InlineOpen: cfg.InlineOpen,
- InlineClose: cfg.InlineClose,
- ChatSuffix: cfg.ChatSuffix,
- ChatPrefixes: cfg.ChatPrefixes,
- }
+ return lsp.ServerOptions{
+ LogContext: logContext,
+ MaxTokens: cfg.MaxTokens,
+ ContextMode: cfg.ContextMode,
+ WindowLines: cfg.ContextWindowLines,
+ MaxContextTokens: cfg.MaxContextTokens,
+ CodingTemperature: cfg.CodingTemperature,
+ Client: client,
+ TriggerCharacters: cfg.TriggerCharacters,
+ ManualInvokeMinPrefix: cfg.ManualInvokeMinPrefix,
+ CompletionDebounceMs: cfg.CompletionDebounceMs,
+ CompletionThrottleMs: cfg.CompletionThrottleMs,
+ InlineOpen: cfg.InlineOpen,
+ InlineClose: cfg.InlineClose,
+ ChatSuffix: cfg.ChatSuffix,
+ ChatPrefixes: cfg.ChatPrefixes,
+ }
}
diff --git a/internal/llm/copilot.go b/internal/llm/copilot.go
index 16eeda6..d3b1a9d 100644
--- a/internal/llm/copilot.go
+++ b/internal/llm/copilot.go
@@ -4,6 +4,7 @@ package llm
import (
"bytes"
"context"
+ "encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -13,7 +14,6 @@ import (
"strings"
"time"
- "encoding/base64"
appver "codeberg.org/snonux/hexai/internal"
"codeberg.org/snonux/hexai/internal/logging"
)
@@ -162,10 +162,14 @@ func buildCopilotChatRequest(o Options, messages []Message, defaultTemp *float64
}
func (c copilotClient) postJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
- if err != nil { return nil, err }
- for k, v := range headers { req.Header.Set(k, v) }
- return c.httpClient.Do(req)
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+ return c.httpClient.Do(req)
}
func handleCopilotNon2xx(resp *http.Response, start time.Time) error {
@@ -194,55 +198,73 @@ func decodeCopilotChat(resp *http.Response, start time.Time) (copilotChatRespons
// --- Copilot session token management ---
type ghCopilotTokenResp struct {
- Token string `json:"token"`
+ Token string `json:"token"`
}
func (c *copilotClient) ensureSession(ctx context.Context) error {
- // If token valid for >60s, reuse
- if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) {
- return nil
- }
- if strings.TrimSpace(c.apiKey) == "" {
- return errors.New("missing Copilot API key")
- }
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil)
- if err != nil { return err }
- req.Header.Set("Authorization", "Bearer "+c.apiKey)
- req.Header.Set("Accept", "application/json")
- req.Header.Set("User-Agent", "hexai/"+appver.Version)
- resp, err := c.httpClient.Do(req)
- if err != nil { return err }
- defer resp.Body.Close()
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- return fmt.Errorf("copilot token http error: %d", resp.StatusCode)
- }
- var out ghCopilotTokenResp
- if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { return err }
- if strings.TrimSpace(out.Token) == "" { return errors.New("empty copilot session token") }
- // Parse JWT exp
- exp := parseJWTExp(out.Token)
- if exp.IsZero() { exp = time.Now().Add(10 * time.Minute) }
- c.sessionToken = out.Token
- c.tokenExpiry = exp
- return nil
+ // If token valid for >60s, reuse
+ if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) {
+ return nil
+ }
+ if strings.TrimSpace(c.apiKey) == "" {
+ return errors.New("missing Copilot API key")
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil)
+ if err != nil {
+ return err
+ }
+ req.Header.Set("Authorization", "Bearer "+c.apiKey)
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("User-Agent", "hexai/"+appver.Version)
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return fmt.Errorf("copilot token http error: %d", resp.StatusCode)
+ }
+ var out ghCopilotTokenResp
+ if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
+ return err
+ }
+ if strings.TrimSpace(out.Token) == "" {
+ return errors.New("empty copilot session token")
+ }
+ // Parse JWT exp
+ exp := parseJWTExp(out.Token)
+ if exp.IsZero() {
+ exp = time.Now().Add(10 * time.Minute)
+ }
+ c.sessionToken = out.Token
+ c.tokenExpiry = exp
+ return nil
}
var jwtExpRe = regexp.MustCompile(`"exp"\s*:\s*([0-9]+)`) // fallback if we can't base64 decode
func parseJWTExp(token string) time.Time {
- parts := strings.Split(token, ".")
- if len(parts) < 2 { return time.Time{} }
- b, err := base64.RawURLEncoding.DecodeString(parts[1])
- if err != nil {
- if m := jwtExpRe.FindStringSubmatch(token); len(m) == 2 {
- if n, err2 := parseInt64(m[1]); err2 == nil { return time.Unix(n, 0) }
- }
- return time.Time{}
- }
- var payload struct{ Exp int64 `json:"exp"` }
- _ = json.Unmarshal(b, &payload)
- if payload.Exp == 0 { return time.Time{} }
- return time.Unix(payload.Exp, 0)
+ parts := strings.Split(token, ".")
+ if len(parts) < 2 {
+ return time.Time{}
+ }
+ b, err := base64.RawURLEncoding.DecodeString(parts[1])
+ if err != nil {
+ if m := jwtExpRe.FindStringSubmatch(token); len(m) == 2 {
+ if n, err2 := parseInt64(m[1]); err2 == nil {
+ return time.Unix(n, 0)
+ }
+ }
+ return time.Time{}
+ }
+ var payload struct {
+ Exp int64 `json:"exp"`
+ }
+ _ = json.Unmarshal(b, &payload)
+ if payload.Exp == 0 {
+ return time.Time{}
+ }
+ return time.Unix(payload.Exp, 0)
}
func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, &n); return n, err }
@@ -250,99 +272,120 @@ func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, &
// --- Copilot headers ---
func (c *copilotClient) headersChat() map[string]string {
- _ = c.ensureSession(context.Background())
- h := map[string]string{
- "Content-Type": "application/json; charset=utf-8",
- "Accept": "application/json",
- "Authorization": "Bearer " + c.sessionToken,
- "User-Agent": "GitHubCopilotChat/0.8.0",
- "Editor-Plugin-Version": "copilot-chat/0.8.0",
- "Editor-Version": "vscode/1.85.1",
- "Openai-Intent": "conversation-panel",
- "Openai-Organization": "github-copilot",
- "VScode-MachineId": randHex(64),
- "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
- "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
- }
- return h
+ _ = c.ensureSession(context.Background())
+ h := map[string]string{
+ "Content-Type": "application/json; charset=utf-8",
+ "Accept": "application/json",
+ "Authorization": "Bearer " + c.sessionToken,
+ "User-Agent": "GitHubCopilotChat/0.8.0",
+ "Editor-Plugin-Version": "copilot-chat/0.8.0",
+ "Editor-Version": "vscode/1.85.1",
+ "Openai-Intent": "conversation-panel",
+ "Openai-Organization": "github-copilot",
+ "VScode-MachineId": randHex(64),
+ "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ }
+ return h
}
func (c *copilotClient) headersGhost() map[string]string {
- _ = c.ensureSession(context.Background())
- h := map[string]string{
- "Content-Type": "application/json; charset=utf-8",
- "Accept": "*/*",
- "Authorization": "Bearer " + c.sessionToken,
- "User-Agent": "GithubCopilot/1.155.0",
- "Editor-Plugin-Version": "copilot/1.155.0",
- "Editor-Version": "vscode/1.85.1",
- "Openai-Intent": "copilot-ghost",
- "Openai-Organization": "github-copilot",
- "VScode-MachineId": randHex(64),
- "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
- "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
- }
- return h
+ _ = c.ensureSession(context.Background())
+ h := map[string]string{
+ "Content-Type": "application/json; charset=utf-8",
+ "Accept": "*/*",
+ "Authorization": "Bearer " + c.sessionToken,
+ "User-Agent": "GithubCopilot/1.155.0",
+ "Editor-Plugin-Version": "copilot/1.155.0",
+ "Editor-Version": "vscode/1.85.1",
+ "Openai-Intent": "copilot-ghost",
+ "Openai-Organization": "github-copilot",
+ "VScode-MachineId": randHex(64),
+ "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ }
+ return h
}
func randHex(n int) string {
- const hex = "0123456789abcdef"
- b := make([]byte, n)
- for i := range b {
- b[i] = hex[int(time.Now().UnixNano()+int64(i))%len(hex)]
- }
- return string(b)
+ const hex = "0123456789abcdef"
+ b := make([]byte, n)
+ for i := range b {
+ b[i] = hex[int(time.Now().UnixNano()+int64(i))%len(hex)]
+ }
+ return string(b)
}
// --- Codex-style code completion ---
// CodeCompletion implements CodeCompleter; returns up to n suggestions.
func (c copilotClient) CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error) {
- if strings.TrimSpace(c.apiKey) == "" { return nil, errors.New("missing Copilot API key") }
- if err := c.ensureSession(ctx); err != nil { return nil, err }
- if n <= 0 { n = 1 }
- maxTokens := 500
- body := map[string]any{
- "extra": map[string]any{
- "language": language,
- "next_indent": 0,
- "prompt_tokens": 500,
- "suffix_tokens": 400,
- "trim_by_indentation": true,
- },
- "max_tokens": maxTokens,
- "n": n,
- "nwo": "hexai",
- "prompt": prompt,
- "stop": []string{"\n\n"},
- "stream": true,
- "suffix": suffix,
- "temperature": temperature,
- "top_p": 1,
- }
- buf, _ := json.Marshal(body)
- url := "https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions"
- resp, err := c.postJSON(ctx, url, buf, c.headersGhost())
- if err != nil { return nil, err }
- defer resp.Body.Close()
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- return nil, fmt.Errorf("copilot codex http error: %d", resp.StatusCode)
- }
- // Read all and parse lines that start with "data: " accumulating by index
- raw, _ := io.ReadAll(resp.Body)
- byIndex := make(map[int]string)
- lines := strings.Split(string(raw), "\n")
- for _, ln := range lines {
- if !strings.HasPrefix(ln, "data: ") { continue }
- var evt struct{ Choices []struct{ Index int `json:"index"`; Text string `json:"text"` } `json:"choices"` }
- if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil { continue }
- for _, ch := range evt.Choices { byIndex[ch.Index] += ch.Text }
- }
- out := make([]string, 0, len(byIndex))
- for i := 0; i < n; i++ {
- if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" { out = append(out, s) }
- }
- return out, nil
+ if strings.TrimSpace(c.apiKey) == "" {
+ return nil, errors.New("missing Copilot API key")
+ }
+ if err := c.ensureSession(ctx); err != nil {
+ return nil, err
+ }
+ if n <= 0 {
+ n = 1
+ }
+ maxTokens := 500
+ body := map[string]any{
+ "extra": map[string]any{
+ "language": language,
+ "next_indent": 0,
+ "prompt_tokens": 500,
+ "suffix_tokens": 400,
+ "trim_by_indentation": true,
+ },
+ "max_tokens": maxTokens,
+ "n": n,
+ "nwo": "hexai",
+ "prompt": prompt,
+ "stop": []string{"\n\n"},
+ "stream": true,
+ "suffix": suffix,
+ "temperature": temperature,
+ "top_p": 1,
+ }
+ buf, _ := json.Marshal(body)
+ url := "https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions"
+ resp, err := c.postJSON(ctx, url, buf, c.headersGhost())
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return nil, fmt.Errorf("copilot codex http error: %d", resp.StatusCode)
+ }
+ // Read all and parse lines that start with "data: " accumulating by index
+ raw, _ := io.ReadAll(resp.Body)
+ byIndex := make(map[int]string)
+ lines := strings.Split(string(raw), "\n")
+ for _, ln := range lines {
+ if !strings.HasPrefix(ln, "data: ") {
+ continue
+ }
+ var evt struct {
+ Choices []struct {
+ Index int `json:"index"`
+ Text string `json:"text"`
+ } `json:"choices"`
+ }
+ if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil {
+ continue
+ }
+ for _, ch := range evt.Choices {
+ byIndex[ch.Index] += ch.Text
+ }
+ }
+ out := make([]string, 0, len(byIndex))
+ for i := 0; i < n; i++ {
+ if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" {
+ out = append(out, s)
+ }
+ }
+ return out, nil
}
// newLineDataReader wraps a streaming body and exposes a JSON decoder that
diff --git a/internal/llm/copilot_http_test.go b/internal/llm/copilot_http_test.go
index 180e43e..d66311c 100644
--- a/internal/llm/copilot_http_test.go
+++ b/internal/llm/copilot_http_test.go
@@ -1,205 +1,261 @@
package llm
import (
- "context"
- "encoding/json"
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
- "encoding/base64"
- "os"
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+ "time"
)
type rtFunc2 func(*http.Request) (*http.Response, error)
+
func (f rtFunc2) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
func TestCopilot_EnsureSession_AndChat_Success(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Mock chat endpoint
- chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path != "/chat/completions" { t.Fatalf("unexpected path: %s", r.URL.Path) }
- _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index":0, "message": map[string]string{"role":"assistant","content":"OK"}}}})
- }))
- defer chatSrv.Close()
- c := newCopilot(chatSrv.URL, "gpt-4o-mini", "APIKEY", f64p(0.1)).(copilotClient)
- // Intercept token endpoint to return a session token
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder()
- _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"})
- res := rw.Result()
- res.StatusCode = 200
- return res, nil
- }
- // Fallback to default transport for chatSrv
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- out, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}})
- if err != nil || out != "OK" { t.Fatalf("copilot chat failed: %v %q", err, out) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Mock chat endpoint
+ chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/chat/completions" {
+ t.Fatalf("unexpected path: %s", r.URL.Path)
+ }
+ _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index": 0, "message": map[string]string{"role": "assistant", "content": "OK"}}}})
+ }))
+ defer chatSrv.Close()
+ c := newCopilot(chatSrv.URL, "gpt-4o-mini", "APIKEY", f64p(0.1)).(copilotClient)
+ // Intercept token endpoint to return a session token
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ // Fallback to default transport for chatSrv
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
+ if err != nil || out != "OK" {
+ t.Fatalf("copilot chat failed: %v %q", err, out)
+ }
}
func TestCopilot_HandleNon2xx(t *testing.T) {
- b, _ := json.Marshal(map[string]any{"error": map[string]any{"message":"bad","type":"invalid"}})
- resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytesReader(b))}
- if err := handleCopilotNon2xx(resp, time.Now()); err == nil { t.Fatalf("expected error") }
+ b, _ := json.Marshal(map[string]any{"error": map[string]any{"message": "bad", "type": "invalid"}})
+ resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytesReader(b))}
+ if err := handleCopilotNon2xx(resp, time.Now()); err == nil {
+ t.Fatalf("expected error")
+ }
}
func TestCopilot_CodeCompletion_Success(t *testing.T) {
- c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient)
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- // Token endpoint
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder()
- _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"})
- res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- // Codex completion endpoint
- if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
- rw := httptest.NewRecorder()
- // two choices for index 0 and 1
- rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"A\"}]}\n")
- rw.WriteString("data: {\"choices\":[{\"index\":1,\"text\":\"B\"}]}\n")
- res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- out, err := c.CodeCompletion(context.Background(), "p", "s", 2, "go", 0.1)
- if err != nil || len(out) != 2 || out[0] != "A" || out[1] != "B" {
- t.Fatalf("codex: %v %#v", err, out)
- }
+ c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient)
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ // Token endpoint
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ // Codex completion endpoint
+ if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
+ rw := httptest.NewRecorder()
+ // two choices for index 0 and 1
+ rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"A\"}]}\n")
+ rw.WriteString("data: {\"choices\":[{\"index\":1,\"text\":\"B\"}]}\n")
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ out, err := c.CodeCompletion(context.Background(), "p", "s", 2, "go", 0.1)
+ if err != nil || len(out) != 2 || out[0] != "A" || out[1] != "B" {
+ t.Fatalf("codex: %v %#v", err, out)
+ }
}
func TestCopilot_Chat_MultiChoice_And_ErrorBody(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Chat multi-choice: return two choices; client returns first content
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{
- "choices": []map[string]any{
- {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}},
- {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}},
- },
- })
- }))
- defer srv.Close()
- c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
- // Token success
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
- if err != nil || out != "FIRST" { t.Fatalf("copilot multi-choice: %v %q", err, out) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Chat multi-choice: return two choices; client returns first content
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "choices": []map[string]any{
+ {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}},
+ {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}},
+ },
+ })
+ }))
+ defer srv.Close()
+ c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
+ // Token success
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
+ if err != nil || out != "FIRST" {
+ t.Fatalf("copilot multi-choice: %v %q", err, out)
+ }
- // Non-2xx with error body
- srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(403)
- _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message":"denied","type":"forbidden"}})
- }))
- defer srv2.Close()
- c2 := newCopilot(srv2.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
- c2.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- if _, err := c2.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil {
- t.Fatalf("expected error for copilot non-2xx with error body")
- }
+ // Non-2xx with error body
+ srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(403)
+ _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message": "denied", "type": "forbidden"}})
+ }))
+ defer srv2.Close()
+ c2 := newCopilot(srv2.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
+ c2.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected error for copilot non-2xx with error body")
+ }
}
func TestCopilot_Chat_NoChoices_Error(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
- }))
- defer srv.Close()
- c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil {
- t.Fatalf("expected error when no choices returned")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
+ }))
+ defer srv.Close()
+ c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected error when no choices returned")
+ }
}
func TestCopilot_Chat_DecodeError_StatusOK(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Chat returns 200 but invalid JSON; expect decode error
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- io.WriteString(w, "{invalid")
- }))
- defer srv.Close()
- c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil {
- t.Fatalf("expected decode error for invalid body")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Chat returns 200 but invalid JSON; expect decode error
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "{invalid")
+ }))
+ defer srv.Close()
+ c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected decode error for invalid body")
+ }
}
func TestCopilot_CodeCompletion_MalformedAndEmpty(t *testing.T) {
- c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient)
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
- rw := httptest.NewRecorder()
- // malformed line
- rw.WriteString("data: {bad}\n")
- // done; should produce empty suggestions
- rw.WriteString("data: [DONE]\n")
- res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- out, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1)
- if err != nil { t.Fatalf("unexpected error: %v", err) }
- if len(out) != 0 { t.Fatalf("expected empty suggestions, got %#v", out) }
+ c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient)
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
+ rw := httptest.NewRecorder()
+ // malformed line
+ rw.WriteString("data: {bad}\n")
+ // done; should produce empty suggestions
+ rw.WriteString("data: [DONE]\n")
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ out, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(out) != 0 {
+ t.Fatalf("expected empty suggestions, got %#v", out)
+ }
- // Now include one good chunk after malformed
- tr2 := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
- rw := httptest.NewRecorder()
- rw.WriteString("data: {bad}\n")
- rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"OK\"}]}\n")
- rw.WriteString("data: [DONE]\n")
- res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr2, Timeout: 5 * time.Second}
- out2, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1)
- if err != nil || len(out2) != 1 || out2[0] != "OK" { t.Fatalf("unexpected: %v %#v", err, out2) }
+ // Now include one good chunk after malformed
+ tr2 := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
+ rw := httptest.NewRecorder()
+ rw.WriteString("data: {bad}\n")
+ rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"OK\"}]}\n")
+ rw.WriteString("data: [DONE]\n")
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr2, Timeout: 5 * time.Second}
+ out2, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1)
+ if err != nil || len(out2) != 1 || out2[0] != "OK" {
+ t.Fatalf("unexpected: %v %#v", err, out2)
+ }
}
func TestParseJWTExp_AndParseInt64(t *testing.T) {
- // Valid base64 payload
- payload := `{"exp": 1700000000}`
- b := base64.RawURLEncoding.EncodeToString([]byte(payload))
- tok := "x." + b + ".y"
- if tm := parseJWTExp(tok); tm.IsZero() { t.Fatalf("expected non-zero time") }
- if n, err := parseInt64("123"); err != nil || n != 123 { t.Fatalf("parseInt64: %v %d", err, n) }
+ // Valid base64 payload
+ payload := `{"exp": 1700000000}`
+ b := base64.RawURLEncoding.EncodeToString([]byte(payload))
+ tok := "x." + b + ".y"
+ if tm := parseJWTExp(tok); tm.IsZero() {
+ t.Fatalf("expected non-zero time")
+ }
+ if n, err := parseInt64("123"); err != nil || n != 123 {
+ t.Fatalf("parseInt64: %v %d", err, n)
+ }
}
// bytesReader wraps a byte slice with an io.ReadCloser without importing extra.
type bytesReader []byte
+
func (b bytesReader) Read(p []byte) (int, error) { n := copy(p, b); return n, io.EOF }
-func (b bytesReader) Close() error { return nil }
+func (b bytesReader) Close() error { return nil }
diff --git a/internal/llm/ollama_test.go b/internal/llm/ollama_test.go
index 15f9cff..8bd33ca 100644
--- a/internal/llm/ollama_test.go
+++ b/internal/llm/ollama_test.go
@@ -1,173 +1,217 @@
package llm
import (
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
- "os"
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+ "time"
)
func TestBuildOllamaRequest_OptionsAndStream(t *testing.T) {
- o := Options{Model: "codemodel", Temperature: 0, MaxTokens: 256, Stop: []string{"STOP"}}
- msgs := []Message{{Role: "user", Content: "hello"}}
- req := buildOllamaRequest(o, msgs, f64p(0.2), false)
- if req.Model != "codemodel" || req.Stream { t.Fatalf("model/stream mismatch: %+v", req) }
- if req.Options == nil { t.Fatalf("expected options map") }
- if req.Options.(map[string]any)["temperature"].(float64) != 0.2 { t.Fatalf("default temp not applied") }
- if req.Options.(map[string]any)["num_predict"].(int) != 256 { t.Fatalf("num_predict not applied") }
- if req.Options.(map[string]any)["stop"].([]string)[0] != "STOP" { t.Fatalf("stop not applied") }
-
- req2 := buildOllamaRequest(o, msgs, f64p(0.2), true)
- if !req2.Stream { t.Fatalf("expected stream=true") }
+ o := Options{Model: "codemodel", Temperature: 0, MaxTokens: 256, Stop: []string{"STOP"}}
+ msgs := []Message{{Role: "user", Content: "hello"}}
+ req := buildOllamaRequest(o, msgs, f64p(0.2), false)
+ if req.Model != "codemodel" || req.Stream {
+ t.Fatalf("model/stream mismatch: %+v", req)
+ }
+ if req.Options == nil {
+ t.Fatalf("expected options map")
+ }
+ if req.Options.(map[string]any)["temperature"].(float64) != 0.2 {
+ t.Fatalf("default temp not applied")
+ }
+ if req.Options.(map[string]any)["num_predict"].(int) != 256 {
+ t.Fatalf("num_predict not applied")
+ }
+ if req.Options.(map[string]any)["stop"].([]string)[0] != "STOP" {
+ t.Fatalf("stop not applied")
+ }
+
+ req2 := buildOllamaRequest(o, msgs, f64p(0.2), true)
+ if !req2.Stream {
+ t.Fatalf("expected stream=true")
+ }
}
func TestBuildOllamaRequest_TempOverride(t *testing.T) {
- o := Options{Model: "m", Temperature: 0.9}
- msgs := []Message{{Role: "user", Content: "hi"}}
- req := buildOllamaRequest(o, msgs, f64p(0.2), false)
- m := req.Options.(map[string]any)
- if m["temperature"].(float64) != 0.9 { t.Fatalf("explicit temp should override default") }
+ o := Options{Model: "m", Temperature: 0.9}
+ msgs := []Message{{Role: "user", Content: "hi"}}
+ req := buildOllamaRequest(o, msgs, f64p(0.2), false)
+ m := req.Options.(map[string]any)
+ if m["temperature"].(float64) != 0.9 {
+ t.Fatalf("explicit temp should override default")
+ }
}
func TestOllama_NameAndModel(t *testing.T) {
- c := newOllama("http://x", "model-x", nil).(ollamaClient)
- if c.Name() != "ollama" { t.Fatalf("name: %q", c.Name()) }
- if c.DefaultModel() != "model-x" { t.Fatalf("default model: %q", c.DefaultModel()) }
+ c := newOllama("http://x", "model-x", nil).(ollamaClient)
+ if c.Name() != "ollama" {
+ t.Fatalf("name: %q", c.Name())
+ }
+ if c.DefaultModel() != "model-x" {
+ t.Fatalf("default model: %q", c.DefaultModel())
+ }
}
func TestOllamaChat_Success(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodPost || r.URL.Path != "/api/chat" { t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path) }
- w.Header().Set("Content-Type", "application/json")
- _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role":"assistant","content":"Hello"}, "done": true})
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", f64p(0.1)).(ollamaClient)
- c.httpClient = ts.Client()
- out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
- if err != nil { t.Fatalf("unexpected err: %v", err) }
- if out != "Hello" { t.Fatalf("got %q", out) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost || r.URL.Path != "/api/chat" {
+ t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path)
+ }
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role": "assistant", "content": "Hello"}, "done": true})
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", f64p(0.1)).(ollamaClient)
+ c.httpClient = ts.Client()
+ out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
+ if err != nil {
+ t.Fatalf("unexpected err: %v", err)
+ }
+ if out != "Hello" {
+ t.Fatalf("got %q", out)
+ }
}
func TestOllamaChat_EmptyContent(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role":"assistant","content":""}, "done": true})
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", nil).(ollamaClient)
- c.httpClient = ts.Client()
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil {
- t.Fatalf("expected error for empty content")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role": "assistant", "content": ""}, "done": true})
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", nil).(ollamaClient)
+ c.httpClient = ts.Client()
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil {
+ t.Fatalf("expected error for empty content")
+ }
}
func TestOllamaChat_Non2xx(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // API error string
- ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(400)
- _ = json.NewEncoder(w).Encode(map[string]any{"error":"bad"})
- }))
- defer ts1.Close()
- c1 := newOllama(ts1.URL, "m", nil).(ollamaClient)
- c1.httpClient = ts1.Client()
- if _, err := c1.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil {
- t.Fatalf("expected error for 400 with api body")
- }
- // Plain HTTP error without api message
- ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(500)
- _, _ = w.Write([]byte("{}"))
- }))
- defer ts2.Close()
- c2 := newOllama(ts2.URL, "m", nil).(ollamaClient)
- c2.httpClient = ts2.Client()
- if _, err := c2.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil {
- t.Fatalf("expected error for 500")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // API error string
+ ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(400)
+ _ = json.NewEncoder(w).Encode(map[string]any{"error": "bad"})
+ }))
+ defer ts1.Close()
+ c1 := newOllama(ts1.URL, "m", nil).(ollamaClient)
+ c1.httpClient = ts1.Client()
+ if _, err := c1.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil {
+ t.Fatalf("expected error for 400 with api body")
+ }
+ // Plain HTTP error without api message
+ ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(500)
+ _, _ = w.Write([]byte("{}"))
+ }))
+ defer ts2.Close()
+ c2 := newOllama(ts2.URL, "m", nil).(ollamaClient)
+ c2.httpClient = ts2.Client()
+ if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil {
+ t.Fatalf("expected error for 500")
+ }
}
type rtFunc func(*http.Request) (*http.Response, error)
+
func (f rtFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
func TestOllamaChat_HTTPError(t *testing.T) {
- c := newOllama("http://127.0.0.1:0", "m", nil).(ollamaClient)
- c.httpClient = &http.Client{Transport: rtFunc(func(*http.Request)(*http.Response,error){ return nil, fmt.Errorf("boom") })}
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil {
- t.Fatalf("expected http error path")
- }
+ c := newOllama("http://127.0.0.1:0", "m", nil).(ollamaClient)
+ c.httpClient = &http.Client{Transport: rtFunc(func(*http.Request) (*http.Response, error) { return nil, fmt.Errorf("boom") })}
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil {
+ t.Fatalf("expected http error path")
+ }
}
func TestOllamaChat_DecodeError(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _, _ = w.Write([]byte("{bad json}"))
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", nil).(ollamaClient)
- c.httpClient = ts.Client()
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil {
- t.Fatalf("expected decode error")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte("{bad json}"))
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", nil).(ollamaClient)
+ c.httpClient = ts.Client()
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil {
+ t.Fatalf("expected decode error")
+ }
}
func TestHandleOllamaNon2xx_OK(t *testing.T) {
- resp := &http.Response{StatusCode: 200, Body: ioNopCloser(strings.NewReader(""))}
- if err := handleOllamaNon2xx(resp, time.Now()); err != nil { t.Fatalf("unexpected: %v", err) }
+ resp := &http.Response{StatusCode: 200, Body: ioNopCloser(strings.NewReader(""))}
+ if err := handleOllamaNon2xx(resp, time.Now()); err != nil {
+ t.Fatalf("unexpected: %v", err)
+ }
}
func TestOllamaChatStream_Success(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "application/json")
- // two JSON objects back-to-back
- _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"Hi"},"done":false}`))
- _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"!"},"done":true}`))
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", nil).(ollamaClient)
- c.httpClient = ts.Client()
- var got strings.Builder
- if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"x"}}, func(s string){ got.WriteString(s) }); err != nil {
- t.Fatalf("unexpected: %v", err)
- }
- if got.String() != "Hi!" { t.Fatalf("got %q", got.String()) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ // two JSON objects back-to-back
+ _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"Hi"},"done":false}`))
+ _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"!"},"done":true}`))
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", nil).(ollamaClient)
+ c.httpClient = ts.Client()
+ var got strings.Builder
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(s string) { got.WriteString(s) }); err != nil {
+ t.Fatalf("unexpected: %v", err)
+ }
+ if got.String() != "Hi!" {
+ t.Fatalf("got %q", got.String())
+ }
}
func TestOllamaChatStream_ErrorEvent(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{"error":"oops"})
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", nil).(ollamaClient)
- c.httpClient = ts.Client()
- if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"x"}}, func(string){}); err == nil {
- t.Fatalf("expected stream error")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{"error": "oops"})
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", nil).(ollamaClient)
+ c.httpClient = ts.Client()
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(string) {}); err == nil {
+ t.Fatalf("expected stream error")
+ }
}
func TestOllamaChatStream_DecodeError(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _, _ = w.Write([]byte("{not json}"))
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", nil).(ollamaClient)
- c.httpClient = ts.Client()
- if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"x"}}, func(string){}); err == nil {
- t.Fatalf("expected decode error")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte("{not json}"))
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", nil).(ollamaClient)
+ c.httpClient = ts.Client()
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(string) {}); err == nil {
+ t.Fatalf("expected decode error")
+ }
}
// small helper to construct an io.ReadCloser without importing extra packages
type readCloser struct{ *strings.Reader }
-func (readCloser) Close() error { return nil }
+
+func (readCloser) Close() error { return nil }
func ioNopCloser(r *strings.Reader) *readCloser { return &readCloser{r} }
diff --git a/internal/llm/openai_http_test.go b/internal/llm/openai_http_test.go
index ac7b897..cb4bfcb 100644
--- a/internal/llm/openai_http_test.go
+++ b/internal/llm/openai_http_test.go
@@ -1,143 +1,171 @@
package llm
import (
- "context"
- "encoding/json"
- "io"
- "net/http"
- "net/http/httptest"
- "testing"
- "strings"
- "time"
- "os"
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+ "time"
)
func TestOpenAI_Chat_Success(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path != "/chat/completions" { t.Fatalf("unexpected path: %s", r.URL.Path) }
- _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index":0, "message": map[string]string{"role":"assistant","content":"OK"}}}})
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- out, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}})
- if err != nil || out != "OK" { t.Fatalf("openai chat: %v %q", err, out) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/chat/completions" {
+ t.Fatalf("unexpected path: %s", r.URL.Path)
+ }
+ _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index": 0, "message": map[string]string{"role": "assistant", "content": "OK"}}}})
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
+ if err != nil || out != "OK" {
+ t.Fatalf("openai chat: %v %q", err, out)
+ }
}
func TestOpenAI_Chat_MissingKey(t *testing.T) {
- c := newOpenAI("http://x", "g", "", f64p(0.2)).(openAIClient)
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil { t.Fatalf("expected error for missing key") }
+ c := newOpenAI("http://x", "g", "", f64p(0.2)).(openAIClient)
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected error for missing key")
+ }
}
func TestOpenAI_ChatStream_SSE(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // Return SSE-like stream
- w.Header().Set("Content-Type", "text/event-stream")
- io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n")
- io.WriteString(w, "data: [DONE]\n")
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- var got string
- err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"hi"}}, func(s string){ got += s })
- if err != nil || got != "Hi" { t.Fatalf("chat stream: %v %q", err, got) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Return SSE-like stream
+ w.Header().Set("Content-Type", "text/event-stream")
+ io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n")
+ io.WriteString(w, "data: [DONE]\n")
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ var got string
+ err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s })
+ if err != nil || got != "Hi" {
+ t.Fatalf("chat stream: %v %q", err, got)
+ }
}
func TestHandleOpenAINon2xx_NoErrorBody(t *testing.T) {
- resp := &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("{}"))}
- if err := handleOpenAINon2xx(resp, time.Now()); err == nil { t.Fatalf("expected http error") }
+ resp := &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("{}"))}
+ if err := handleOpenAINon2xx(resp, time.Now()); err == nil {
+ t.Fatalf("expected http error")
+ }
}
func TestOpenAI_ChatStream_SSE_ErrorChunk(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
- io.WriteString(w, "data: {\"error\":{\"message\":\"oops\"}}\n\n")
- io.WriteString(w, "data: [DONE]\n")
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- var got string
- if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"hi"}}, func(s string){ got += s }); err == nil {
- t.Fatalf("expected error due to error chunk")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ io.WriteString(w, "data: {\"error\":{\"message\":\"oops\"}}\n\n")
+ io.WriteString(w, "data: [DONE]\n")
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ var got string
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s }); err == nil {
+ t.Fatalf("expected error due to error chunk")
+ }
}
func TestOpenAI_Chat_NoChoices_Error(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil {
- t.Fatalf("expected error when choices empty")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected error when choices empty")
+ }
}
func TestOpenAI_ChatStream_SSE_EmptyDelta_NoError(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
- io.WriteString(w, "data: {\\\"choices\\\":[{\\\"delta\\\":{\\\"content\\\":\\\"\\\"}}]}\\n\\n")
- io.WriteString(w, "data: [DONE]\\n")
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- var got string
- if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"hi"}}, func(s string){ got += s }); err != nil {
- t.Fatalf("unexpected error for empty delta: %v", err)
- }
- if got != "" { t.Fatalf("expected no output for empty delta, got %q", got) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ io.WriteString(w, "data: {\\\"choices\\\":[{\\\"delta\\\":{\\\"content\\\":\\\"\\\"}}]}\\n\\n")
+ io.WriteString(w, "data: [DONE]\\n")
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ var got string
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s }); err != nil {
+ t.Fatalf("unexpected error for empty delta: %v", err)
+ }
+ if got != "" {
+ t.Fatalf("expected no output for empty delta, got %q", got)
+ }
}
func TestOpenAI_Chat_DecodeError_StatusOK(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Return status 200 but invalid JSON body; Chat should return an error
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(200)
- io.WriteString(w, "{invalid")
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
- t.Fatalf("expected decode error for invalid JSON body")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Return status 200 but invalid JSON body; Chat should return an error
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200)
+ io.WriteString(w, "{invalid")
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected decode error for invalid JSON body")
+ }
}
func TestOpenAI_Chat_MultiChoiceAndErrorBody(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Multi-choice success: return two choices with different finish reasons
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{
- "choices": []map[string]any{
- {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}},
- {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}},
- },
- })
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
- if err != nil || out != "FIRST" { t.Fatalf("openai multi-choice: %v %q", err, out) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Multi-choice success: return two choices with different finish reasons
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "choices": []map[string]any{
+ {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}},
+ {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}},
+ },
+ })
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
+ if err != nil || out != "FIRST" {
+ t.Fatalf("openai multi-choice: %v %q", err, out)
+ }
- // Error body case: non-2xx with error message
- srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(400)
- _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message": "bad", "type": "invalid"}})
- }))
- defer srv2.Close()
- c2 := newOpenAI(srv2.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c2.httpClient = srv2.Client()
- if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
- t.Fatalf("expected error from non-2xx with error body")
- }
+ // Error body case: non-2xx with error message
+ srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(400)
+ _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message": "bad", "type": "invalid"}})
+ }))
+ defer srv2.Close()
+ c2 := newOpenAI(srv2.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c2.httpClient = srv2.Client()
+ if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected error from non-2xx with error body")
+ }
}
diff --git a/internal/llm/openai_sse_negative_test.go b/internal/llm/openai_sse_negative_test.go
index 8da5526..de2ff71 100644
--- a/internal/llm/openai_sse_negative_test.go
+++ b/internal/llm/openai_sse_negative_test.go
@@ -1,28 +1,32 @@
package llm
import (
- "context"
- "io"
- "net/http"
- "net/http/httptest"
- "testing"
- "os"
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
)
func TestOpenAI_ChatStream_SSE_MalformedChunk(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Malformed JSON chunk should be skipped; no onDelta calls; no error.
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
- io.WriteString(w, "data: {not json}\n\n")
- io.WriteString(w, "data: [DONE]\n")
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- var got string
- if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string){ got += s }); err != nil {
- t.Fatalf("unexpected error for malformed chunk: %v", err)
- }
- if got != "" { t.Fatalf("expected no deltas for malformed chunk, got %q", got) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Malformed JSON chunk should be skipped; no onDelta calls; no error.
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ io.WriteString(w, "data: {not json}\n\n")
+ io.WriteString(w, "data: [DONE]\n")
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ var got string
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s }); err != nil {
+ t.Fatalf("unexpected error for malformed chunk: %v", err)
+ }
+ if got != "" {
+ t.Fatalf("expected no deltas for malformed chunk, got %q", got)
+ }
}
diff --git a/internal/llm/openai_test.go b/internal/llm/openai_test.go
index f50b171..f7ce080 100644
--- a/internal/llm/openai_test.go
+++ b/internal/llm/openai_test.go
@@ -1,44 +1,67 @@
package llm
import (
- "bytes"
- "encoding/json"
- "io"
- "net/http"
- "strings"
- "testing"
- "time"
+ "bytes"
+ "encoding/json"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
)
func f64p(v float64) *float64 { return &v }
func TestBuildOAChatRequest_TempFallbackAndFields(t *testing.T) {
- o := Options{Model: "m1", Temperature: 0, MaxTokens: 42, Stop: []string{"END"}}
- msgs := []Message{{Role: "user", Content: "hi"}}
- req := buildOAChatRequest(o, msgs, f64p(0.3), false)
- if req.Model != "m1" || req.Stream { t.Fatalf("model/stream mismatch: %+v", req) }
- if req.Temperature == nil || *req.Temperature != 0.3 { t.Fatalf("expected default temp 0.3, got %#v", req.Temperature) }
- if req.MaxTokens == nil || *req.MaxTokens != 42 { t.Fatalf("expected max tokens 42") }
- if len(req.Stop) != 1 || req.Stop[0] != "END" { t.Fatalf("stop not propagated: %#v", req.Stop) }
- if len(req.Messages) != 1 || req.Messages[0].Content != "hi" { t.Fatalf("messages not copied") }
+ o := Options{Model: "m1", Temperature: 0, MaxTokens: 42, Stop: []string{"END"}}
+ msgs := []Message{{Role: "user", Content: "hi"}}
+ req := buildOAChatRequest(o, msgs, f64p(0.3), false)
+ if req.Model != "m1" || req.Stream {
+ t.Fatalf("model/stream mismatch: %+v", req)
+ }
+ if req.Temperature == nil || *req.Temperature != 0.3 {
+ t.Fatalf("expected default temp 0.3, got %#v", req.Temperature)
+ }
+ if req.MaxTokens == nil || *req.MaxTokens != 42 {
+ t.Fatalf("expected max tokens 42")
+ }
+ if len(req.Stop) != 1 || req.Stop[0] != "END" {
+ t.Fatalf("stop not propagated: %#v", req.Stop)
+ }
+ if len(req.Messages) != 1 || req.Messages[0].Content != "hi" {
+ t.Fatalf("messages not copied")
+ }
- // stream on
- req2 := buildOAChatRequest(o, msgs, f64p(0.3), true)
- if !req2.Stream { t.Fatalf("expected stream=true") }
+ // stream on
+ req2 := buildOAChatRequest(o, msgs, f64p(0.3), true)
+ if !req2.Stream {
+ t.Fatalf("expected stream=true")
+ }
}
func TestHandleOpenAINon2xx_WithAPIError(t *testing.T) {
- api := oaChatResponse{Error: &struct{ Message string `json:"message"`; Type string `json:"type"`; Param any `json:"param"`; Code any `json:"code"` }{Message: "bad", Type: "invalid"}}
- b, _ := json.Marshal(api)
- resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytes.NewReader(b))}
- if err := handleOpenAINon2xx(resp, time.Now()); err == nil { t.Fatalf("expected error for non-2xx with body") }
+ api := oaChatResponse{Error: &struct {
+ Message string `json:"message"`
+ Type string `json:"type"`
+ Param any `json:"param"`
+ Code any `json:"code"`
+ }{Message: "bad", Type: "invalid"}}
+ b, _ := json.Marshal(api)
+ resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytes.NewReader(b))}
+ if err := handleOpenAINon2xx(resp, time.Now()); err == nil {
+ t.Fatalf("expected error for non-2xx with body")
+ }
}
func TestParseOpenAIStream_DeliversChunks(t *testing.T) {
- stream := "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n" +
- "data: [DONE]\n"
- resp := &http.Response{Body: io.NopCloser(strings.NewReader(stream))}
- var got strings.Builder
- if err := parseOpenAIStream(resp, time.Now(), func(s string){ got.WriteString(s) }); err != nil { t.Fatalf("unexpected error: %v", err) }
- if got.String() != "Hi" { t.Fatalf("got %q want %q", got.String(), "Hi") }
+ stream := "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n" +
+ "data: [DONE]\n"
+ resp := &http.Response{Body: io.NopCloser(strings.NewReader(stream))}
+ var got strings.Builder
+ if err := parseOpenAIStream(resp, time.Now(), func(s string) { got.WriteString(s) }); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got.String() != "Hi" {
+ t.Fatalf("got %q want %q", got.String(), "Hi")
+ }
}
diff --git a/internal/llm/provider.go b/internal/llm/provider.go
index 7ab58c6..88c280c 100644
--- a/internal/llm/provider.go
+++ b/internal/llm/provider.go
@@ -28,20 +28,20 @@ type Client interface {
// token-by-token streaming responses. Callers can type-assert to Streamer and
// fall back to Client.Chat when not implemented.
type Streamer interface {
- // ChatStream sends chat messages and invokes onDelta with incremental text
- // chunks as they are produced by the model. Implementations should call
- // onDelta with empty strings sparingly (prefer only non-empty chunks).
- ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error
+ // ChatStream sends chat messages and invokes onDelta with incremental text
+ // chunks as they are produced by the model. Implementations should call
+ // onDelta with empty strings sparingly (prefer only non-empty chunks).
+ ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error
}
// CodeCompleter is an optional interface for providers that support a
// prompt/suffix code-completion API (e.g., Copilot Codex endpoint). Clients
// can type-assert to this and prefer it over chat when available.
type CodeCompleter interface {
- // CodeCompletion requests up to n suggestions given a left-hand prompt and
- // right-hand suffix around the cursor. Language is advisory and may be
- // ignored. Temperature applies when provider supports it.
- CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error)
+ // CodeCompletion requests up to n suggestions given a left-hand prompt and
+ // right-hand suffix around the cursor. Language is advisory and may be
+ // ignored. Temperature applies when provider supports it.
+ CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error)
}
// Options for a request. Providers may ignore unsupported fields.
@@ -64,56 +64,56 @@ func WithStop(stop ...string) RequestOption {
// Config defines provider configuration read from the Hexai config file.
type Config struct {
- Provider string
- // OpenAI options
- OpenAIBaseURL string
- OpenAIModel string
- OpenAITemperature *float64
- // Ollama options
- OllamaBaseURL string
- OllamaModel string
- OllamaTemperature *float64
- // Copilot options
- CopilotBaseURL string
- CopilotModel string
- CopilotTemperature *float64
+ Provider string
+ // OpenAI options
+ OpenAIBaseURL string
+ OpenAIModel string
+ OpenAITemperature *float64
+ // Ollama options
+ OllamaBaseURL string
+ OllamaModel string
+ OllamaTemperature *float64
+ // Copilot options
+ CopilotBaseURL string
+ CopilotModel string
+ CopilotTemperature *float64
}
// NewFromConfig creates an LLM client using only the supplied configuration.
// The OpenAI API key is supplied separately and may be read from the environment
// by the caller; other environment-based configuration is not used.
func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, error) {
- p := strings.ToLower(strings.TrimSpace(cfg.Provider))
- if p == "" {
- p = "openai"
- }
- switch p {
- case "openai":
- if strings.TrimSpace(openAIAPIKey) == "" {
- return nil, errors.New("missing OPENAI_API_KEY for provider openai")
- }
- // Set coding-friendly default temperature if none provided
- if cfg.OpenAITemperature == nil {
- t := 0.2
- cfg.OpenAITemperature = &t
- }
- return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil
- case "ollama":
- if cfg.OllamaTemperature == nil {
- t := 0.2
- cfg.OllamaTemperature = &t
- }
- return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature), nil
- case "copilot":
- if strings.TrimSpace(copilotAPIKey) == "" {
- return nil, errors.New("missing COPILOT_API_KEY for provider copilot")
- }
- if cfg.CopilotTemperature == nil {
- t := 0.2
- cfg.CopilotTemperature = &t
- }
- return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey, cfg.CopilotTemperature), nil
- default:
- return nil, errors.New("unknown LLM provider: " + p)
- }
+ p := strings.ToLower(strings.TrimSpace(cfg.Provider))
+ if p == "" {
+ p = "openai"
+ }
+ switch p {
+ case "openai":
+ if strings.TrimSpace(openAIAPIKey) == "" {
+ return nil, errors.New("missing OPENAI_API_KEY for provider openai")
+ }
+ // Set coding-friendly default temperature if none provided
+ if cfg.OpenAITemperature == nil {
+ t := 0.2
+ cfg.OpenAITemperature = &t
+ }
+ return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil
+ case "ollama":
+ if cfg.OllamaTemperature == nil {
+ t := 0.2
+ cfg.OllamaTemperature = &t
+ }
+ return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature), nil
+ case "copilot":
+ if strings.TrimSpace(copilotAPIKey) == "" {
+ return nil, errors.New("missing COPILOT_API_KEY for provider copilot")
+ }
+ if cfg.CopilotTemperature == nil {
+ t := 0.2
+ cfg.CopilotTemperature = &t
+ }
+ return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey, cfg.CopilotTemperature), nil
+ default:
+ return nil, errors.New("unknown LLM provider: " + p)
+ }
}
diff --git a/internal/llm/provider_more_test.go b/internal/llm/provider_more_test.go
index bd08552..d7469af 100644
--- a/internal/llm/provider_more_test.go
+++ b/internal/llm/provider_more_test.go
@@ -3,24 +3,27 @@ package llm
import "testing"
func TestWithOptions_Apply(t *testing.T) {
- o := Options{}
- WithModel("m")(&o)
- WithTemperature(0.7)(&o)
- WithMaxTokens(123)(&o)
- WithStop("END")(&o)
- if o.Model != "m" || o.Temperature != 0.7 || o.MaxTokens != 123 || len(o.Stop) != 1 || o.Stop[0] != "END" {
- t.Fatalf("options not applied correctly: %+v", o)
- }
+ o := Options{}
+ WithModel("m")(&o)
+ WithTemperature(0.7)(&o)
+ WithMaxTokens(123)(&o)
+ WithStop("END")(&o)
+ if o.Model != "m" || o.Temperature != 0.7 || o.MaxTokens != 123 || len(o.Stop) != 1 || o.Stop[0] != "END" {
+ t.Fatalf("options not applied correctly: %+v", o)
+ }
}
func TestNewFromConfig_Success_OpenAI_And_Copilot(t *testing.T) {
- // OpenAI success
- oc := Config{Provider: "openai", OpenAIBaseURL: "http://x", OpenAIModel: "gpt"}
- c, err := NewFromConfig(oc, "KEY", "")
- if err != nil || c == nil || c.Name() != "openai" || c.DefaultModel() == "" { t.Fatalf("openai new: %v %v", c, err) }
- // Copilot success
- cc := Config{Provider: "copilot", CopilotBaseURL: "http://x", CopilotModel: "gpt-4o-mini"}
- c2, err := NewFromConfig(cc, "", "KEY")
- if err != nil || c2 == nil || c2.Name() != "copilot" || c2.DefaultModel() == "" { t.Fatalf("copilot new: %v %v", c2, err) }
+ // OpenAI success
+ oc := Config{Provider: "openai", OpenAIBaseURL: "http://x", OpenAIModel: "gpt"}
+ c, err := NewFromConfig(oc, "KEY", "")
+ if err != nil || c == nil || c.Name() != "openai" || c.DefaultModel() == "" {
+ t.Fatalf("openai new: %v %v", c, err)
+ }
+ // Copilot success
+ cc := Config{Provider: "copilot", CopilotBaseURL: "http://x", CopilotModel: "gpt-4o-mini"}
+ c2, err := NewFromConfig(cc, "", "KEY")
+ if err != nil || c2 == nil || c2.Name() != "copilot" || c2.DefaultModel() == "" {
+ t.Fatalf("copilot new: %v %v", c2, err)
+ }
}
-
diff --git a/internal/llm/provider_test.go b/internal/llm/provider_test.go
index 1412b3c..29e2514 100644
--- a/internal/llm/provider_test.go
+++ b/internal/llm/provider_test.go
@@ -1,21 +1,29 @@
package llm
import (
- "context"
- "testing"
+ "context"
+ "testing"
)
func TestNewFromConfig_DefaultsAndErrors(t *testing.T) {
- // Unknown provider
- if _, err := NewFromConfig(Config{Provider:"bogus"}, "", ""); err == nil { t.Fatalf("expected error for unknown provider") }
- // OpenAI missing key
- if _, err := NewFromConfig(Config{Provider:"openai", OpenAIModel:"g"}, "", ""); err == nil { t.Fatalf("expected key error") }
- // Copilot missing key
- if _, err := NewFromConfig(Config{Provider:"copilot", CopilotModel:"m"}, "", ""); err == nil { t.Fatalf("expected key error") }
+ // Unknown provider
+ if _, err := NewFromConfig(Config{Provider: "bogus"}, "", ""); err == nil {
+ t.Fatalf("expected error for unknown provider")
+ }
+ // OpenAI missing key
+ if _, err := NewFromConfig(Config{Provider: "openai", OpenAIModel: "g"}, "", ""); err == nil {
+ t.Fatalf("expected key error")
+ }
+ // Copilot missing key
+ if _, err := NewFromConfig(Config{Provider: "copilot", CopilotModel: "m"}, "", ""); err == nil {
+ t.Fatalf("expected key error")
+ }
}
type fakeClientMin struct{}
-func (fakeClientMin) Chat(context.Context, []Message, ...RequestOption) (string, error) { return "", nil }
-func (fakeClientMin) Name() string { return "x" }
-func (fakeClientMin) DefaultModel() string { return "m" }
+func (fakeClientMin) Chat(context.Context, []Message, ...RequestOption) (string, error) {
+ return "", nil
+}
+func (fakeClientMin) Name() string { return "x" }
+func (fakeClientMin) DefaultModel() string { return "m" }
diff --git a/internal/llm/util_test.go b/internal/llm/util_test.go
index acffe5a..137e149 100644
--- a/internal/llm/util_test.go
+++ b/internal/llm/util_test.go
@@ -3,7 +3,8 @@ package llm
import "testing"
func TestNilStringErr(t *testing.T) {
- s, err := nilStringErr("boom")
- if s != "" || err == nil { t.Fatalf("expected empty string and error") }
+ s, err := nilStringErr("boom")
+ if s != "" || err == nil {
+ t.Fatalf("expected empty string and error")
+ }
}
-
diff --git a/internal/logging/chatlogger.go b/internal/logging/chatlogger.go
index 2f2fc99..b2d8684 100644
--- a/internal/logging/chatlogger.go
+++ b/internal/logging/chatlogger.go
@@ -14,7 +14,8 @@ func NewChatLogger(provider string) ChatLogger {
func (cl ChatLogger) LogStart(stream bool, model string, temp float64, maxTokens int, stop []string, messages []struct {
Role string
Content string
-}) {
+},
+) {
chatOrStream := "chat"
if stream {
chatOrStream = "stream"
diff --git a/internal/logging/logging.go b/internal/logging/logging.go
index f90562f..259ad68 100644
--- a/internal/logging/logging.go
+++ b/internal/logging/logging.go
@@ -8,13 +8,13 @@ import (
// ANSI color utilities shared across Hexai.
const (
- AnsiBgBlack = "\x1b[40m"
- AnsiGrey = "\x1b[90m"
- AnsiCyan = "\x1b[36m"
- AnsiGreen = "\x1b[32m"
- AnsiYellow = "\x1b[33m"
- AnsiRed = "\x1b[31m"
- AnsiReset = "\x1b[0m"
+ AnsiBgBlack = "\x1b[40m"
+ AnsiGrey = "\x1b[90m"
+ AnsiCyan = "\x1b[36m"
+ AnsiGreen = "\x1b[32m"
+ AnsiYellow = "\x1b[33m"
+ AnsiRed = "\x1b[31m"
+ AnsiReset = "\x1b[0m"
)
// AnsiBase is the default style: black background + grey foreground.
diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go
index be8c93f..adeefde 100644
--- a/internal/logging/logging_test.go
+++ b/internal/logging/logging_test.go
@@ -1,48 +1,48 @@
package logging
import (
- "bytes"
- "log"
- "strings"
- "testing"
+ "bytes"
+ "log"
+ "strings"
+ "testing"
)
func TestLogf_WithBindAndPreview(t *testing.T) {
- var buf bytes.Buffer
- l := log.New(&buf, "", 0)
- Bind(l)
+ var buf bytes.Buffer
+ l := log.New(&buf, "", 0)
+ Bind(l)
- SetLogPreviewLimit(5)
- if got := PreviewForLog("abcdef"); got != "abcde…" {
- t.Fatalf("preview truncation failed: %q", got)
- }
- if got := PreviewForLog("abcd"); got != "abcd" {
- t.Fatalf("preview (no trunc) failed: %q", got)
- }
- SetLogPreviewLimit(0)
- if got := PreviewForLog("abcdef"); got != "abcdef" {
- t.Fatalf("preview unlimited failed: %q", got)
- }
+ SetLogPreviewLimit(5)
+ if got := PreviewForLog("abcdef"); got != "abcde…" {
+ t.Fatalf("preview truncation failed: %q", got)
+ }
+ if got := PreviewForLog("abcd"); got != "abcd" {
+ t.Fatalf("preview (no trunc) failed: %q", got)
+ }
+ SetLogPreviewLimit(0)
+ if got := PreviewForLog("abcdef"); got != "abcdef" {
+ t.Fatalf("preview unlimited failed: %q", got)
+ }
- Logf("mod ", "hello %s", "world")
- out := buf.String()
- if !strings.Contains(out, "hello world") || !strings.Contains(out, AnsiBase) || !strings.Contains(out, AnsiReset) {
- t.Fatalf("log output missing parts: %q", out)
- }
+ Logf("mod ", "hello %s", "world")
+ out := buf.String()
+ if !strings.Contains(out, "hello world") || !strings.Contains(out, AnsiBase) || !strings.Contains(out, AnsiReset) {
+ t.Fatalf("log output missing parts: %q", out)
+ }
}
func TestChatLogger_LogStart(t *testing.T) {
- var buf bytes.Buffer
- Bind(log.New(&buf, "", 0))
- SetLogPreviewLimit(3)
- cl := NewChatLogger("prov")
- msgs := []struct{ Role, Content string }{{"user", "abcdef"}, {"assistant", "xyz"}}
- cl.LogStart(true, "m", 0.2, 128, []string{"END"}, msgs)
- out := buf.String()
- if !strings.Contains(out, "stream start model=m") || !strings.Contains(out, "messages=2") {
- t.Fatalf("missing header log: %q", out)
- }
- if !strings.Contains(out, "msg[0] role=user") || !strings.Contains(out, "preview=") {
- t.Fatalf("missing message logs: %q", out)
- }
+ var buf bytes.Buffer
+ Bind(log.New(&buf, "", 0))
+ SetLogPreviewLimit(3)
+ cl := NewChatLogger("prov")
+ msgs := []struct{ Role, Content string }{{"user", "abcdef"}, {"assistant", "xyz"}}
+ cl.LogStart(true, "m", 0.2, 128, []string{"END"}, msgs)
+ out := buf.String()
+ if !strings.Contains(out, "stream start model=m") || !strings.Contains(out, "messages=2") {
+ t.Fatalf("missing header log: %q", out)
+ }
+ if !strings.Contains(out, "msg[0] role=user") || !strings.Contains(out, "preview=") {
+ t.Fatalf("missing message logs: %q", out)
+ }
}
diff --git a/internal/lsp/build_prompts_table_test.go b/internal/lsp/build_prompts_table_test.go
index b0092e2..7e8e5e7 100644
--- a/internal/lsp/build_prompts_table_test.go
+++ b/internal/lsp/build_prompts_table_test.go
@@ -3,14 +3,18 @@ package lsp
import "testing"
func TestBuildPrompts_Table(t *testing.T) {
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line:5, Character:7}}
- cases := []struct{ name string; inParams bool }{
- {"generic", false},
- {"in_params", true},
- }
- for _, c := range cases {
- sys, user := buildPrompts(c.inParams, p, "above", "current", "below", "func ctx")
- if sys == "" || user == "" { t.Fatalf("%s: prompts empty", c.name) }
- }
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 5, Character: 7}}
+ cases := []struct {
+ name string
+ inParams bool
+ }{
+ {"generic", false},
+ {"in_params", true},
+ }
+ for _, c := range cases {
+ sys, user := buildPrompts(c.inParams, p, "above", "current", "below", "func ctx")
+ if sys == "" || user == "" {
+ t.Fatalf("%s: prompts empty", c.name)
+ }
+ }
}
-
diff --git a/internal/lsp/chat_history_test.go b/internal/lsp/chat_history_test.go
index 0e9fed5..b1cae80 100644
--- a/internal/lsp/chat_history_test.go
+++ b/internal/lsp/chat_history_test.go
@@ -3,25 +3,35 @@ package lsp
import "testing"
func TestStripTrailingTrigger(t *testing.T) {
- if got := stripTrailingTrigger("what?"); got != "what" { t.Fatalf("should remove trailing ?") }
- if got := stripTrailingTrigger("what?>"); got != "what?" { t.Fatalf("should drop trailing > when preceded by ?") }
- if got := stripTrailingTrigger("ok!>"); got != "ok!" { t.Fatalf("should drop > after !") }
- if got := stripTrailingTrigger("note:>"); got != "note:" { t.Fatalf("should drop > after :") }
- if got := stripTrailingTrigger("go;>"); got != "go;" { t.Fatalf("should drop > after ;") }
+ if got := stripTrailingTrigger("what?"); got != "what" {
+ t.Fatalf("should remove trailing ?")
+ }
+ if got := stripTrailingTrigger("what?>"); got != "what?" {
+ t.Fatalf("should drop trailing > when preceded by ?")
+ }
+ if got := stripTrailingTrigger("ok!>"); got != "ok!" {
+ t.Fatalf("should drop > after !")
+ }
+ if got := stripTrailingTrigger("note:>"); got != "note:" {
+ t.Fatalf("should drop > after :")
+ }
+ if got := stripTrailingTrigger("go;>"); got != "go;" {
+ t.Fatalf("should drop > after ;")
+ }
}
func TestBuildChatHistory_OrderAndLimit(t *testing.T) {
- s := newTestServer()
- uri := "file:///chat.txt"
- // Conversation: q1, > a1, blank, q2, > a2 lines, then current prompt
- doc := "q1\n> a1\n\nq2\n> a2\n\n"
- s.setDocument(uri, doc)
- msgs := s.buildChatHistory(uri, 5, "q3")
- // Expect: user q1, assistant a1, user q2, assistant a2, user q3
- if len(msgs) != 5 || msgs[0].Role != "user" || msgs[1].Role != "assistant" || msgs[2].Role != "user" || msgs[3].Role != "assistant" || msgs[4].Role != "user" {
- t.Fatalf("unexpected roles: %+v", msgs)
- }
- if msgs[0].Content != "q1" || msgs[1].Content != "a1" || msgs[2].Content != "q2" || msgs[3].Content != "a2" || msgs[4].Content != "q3" {
- t.Fatalf("unexpected contents: %+v", msgs)
- }
+ s := newTestServer()
+ uri := "file:///chat.txt"
+ // Conversation: q1, > a1, blank, q2, > a2 lines, then current prompt
+ doc := "q1\n> a1\n\nq2\n> a2\n\n"
+ s.setDocument(uri, doc)
+ msgs := s.buildChatHistory(uri, 5, "q3")
+ // Expect: user q1, assistant a1, user q2, assistant a2, user q3
+ if len(msgs) != 5 || msgs[0].Role != "user" || msgs[1].Role != "assistant" || msgs[2].Role != "user" || msgs[3].Role != "assistant" || msgs[4].Role != "user" {
+ t.Fatalf("unexpected roles: %+v", msgs)
+ }
+ if msgs[0].Content != "q1" || msgs[1].Content != "a1" || msgs[2].Content != "q2" || msgs[3].Content != "a2" || msgs[4].Content != "q3" {
+ t.Fatalf("unexpected contents: %+v", msgs)
+ }
}
diff --git a/internal/lsp/chat_no_double_answer_test.go b/internal/lsp/chat_no_double_answer_test.go
index 9898ad9..8821cd0 100644
--- a/internal/lsp/chat_no_double_answer_test.go
+++ b/internal/lsp/chat_no_double_answer_test.go
@@ -1,22 +1,21 @@
package lsp
import (
- "bytes"
- "io"
- "log"
- "testing"
+ "bytes"
+ "io"
+ "log"
+ "testing"
)
func TestDetectAndHandleChat_NoDoubleAnswer(t *testing.T) {
- var out bytes.Buffer
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
- s.llmClient = fakeLLM{resp: "IGNORED"}
- uri := "file:///x.go"
- // Question line with trigger, followed by an existing answer line starting with '>'
- s.setDocument(uri, "What?>\n> already answered\n")
- s.detectAndHandleChat(uri)
- if out.Len() != 0 {
- t.Fatalf("expected no applyEdit request when answer exists; got %d bytes", out.Len())
- }
+ var out bytes.Buffer
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
+ s.llmClient = fakeLLM{resp: "IGNORED"}
+ uri := "file:///x.go"
+ // Question line with trigger, followed by an existing answer line starting with '>'
+ s.setDocument(uri, "What?>\n> already answered\n")
+ s.detectAndHandleChat(uri)
+ if out.Len() != 0 {
+ t.Fatalf("expected no applyEdit request when answer exists; got %d bytes", out.Len())
+ }
}
-
diff --git a/internal/lsp/code_fences_table_test.go b/internal/lsp/code_fences_table_test.go
index c217bce..340ed61 100644
--- a/internal/lsp/code_fences_table_test.go
+++ b/internal/lsp/code_fences_table_test.go
@@ -3,30 +3,29 @@ package lsp
import "testing"
func TestStripCodeFences_Table(t *testing.T) {
- cases := []struct{ name, in, want string }{
- {"no_fence", "return x", "return x"},
- {"plain_fence", "```\nA\nB\n```", "A\nB"},
- {"lang_fence", "```go\nfmt.Println()\n```", "fmt.Println()"},
- {"spaces", " \n```python\nprint('x')\n```\n ", "print('x')"},
- }
- for _, c := range cases {
- if got := stripCodeFences(c.in); got != c.want {
- t.Fatalf("%s: got %q want %q", c.name, got, c.want)
- }
- }
+ cases := []struct{ name, in, want string }{
+ {"no_fence", "return x", "return x"},
+ {"plain_fence", "```\nA\nB\n```", "A\nB"},
+ {"lang_fence", "```go\nfmt.Println()\n```", "fmt.Println()"},
+ {"spaces", " \n```python\nprint('x')\n```\n ", "print('x')"},
+ }
+ for _, c := range cases {
+ if got := stripCodeFences(c.in); got != c.want {
+ t.Fatalf("%s: got %q want %q", c.name, got, c.want)
+ }
+ }
}
func TestStripInlineCodeSpan_Table(t *testing.T) {
- cases := []struct{ name, in, want string }{
- {"no_ticks", "text", "text"},
- {"single_span", "Use `foo()` here", "foo()"},
- {"multiple", "`a` + `b`", "a"},
- {"unmatched", "`missing end", "`missing end"},
- }
- for _, c := range cases {
- if got := stripInlineCodeSpan(c.in); got != c.want {
- t.Fatalf("%s: got %q want %q", c.name, got, c.want)
- }
- }
+ cases := []struct{ name, in, want string }{
+ {"no_ticks", "text", "text"},
+ {"single_span", "Use `foo()` here", "foo()"},
+ {"multiple", "`a` + `b`", "a"},
+ {"unmatched", "`missing end", "`missing end"},
+ }
+ for _, c := range cases {
+ if got := stripInlineCodeSpan(c.in); got != c.want {
+ t.Fatalf("%s: got %q want %q", c.name, got, c.want)
+ }
+ }
}
-
diff --git a/internal/lsp/codeaction_more_test.go b/internal/lsp/codeaction_more_test.go
index 412d988..82972d8 100644
--- a/internal/lsp/codeaction_more_test.go
+++ b/internal/lsp/codeaction_more_test.go
@@ -1,86 +1,109 @@
package lsp
import (
- "encoding/json"
- "path/filepath"
- "strings"
- "testing"
- tut "codeberg.org/snonux/hexai/internal/testutil"
+ "encoding/json"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ tut "codeberg.org/snonux/hexai/internal/testutil"
)
func TestBuildDocumentCodeAction_AndResolve(t *testing.T) {
- s := newTestServer()
- s.llmClient = fakeLLM{resp: tut.MultilineDocBlock()+"\n"+"func add(a,b int) int { return a+b }"}
- uri := "file:///doc.go"
- s.setDocument(uri, "package x\nfunc add(a,b int) int {return a+b}")
- p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: Range{Start: Position{Line:1, Character:0}, End: Position{Line:1, Character:10}}}
- sel := "func add(a,b int) int {return a+b}"
- ca := s.buildDocumentCodeAction(p, sel)
- if ca == nil { t.Fatalf("expected document code action") }
- resolved, ok := s.resolveCodeAction(*ca)
- if !ok || resolved.Edit == nil { t.Fatalf("expected resolved edit") }
- edits := resolved.Edit.Changes[uri]
- if len(edits) != 1 || strings.TrimSpace(edits[0].NewText) == "" { t.Fatalf("expected replacement text") }
+ s := newTestServer()
+ s.llmClient = fakeLLM{resp: tut.MultilineDocBlock() + "\n" + "func add(a,b int) int { return a+b }"}
+ uri := "file:///doc.go"
+ s.setDocument(uri, "package x\nfunc add(a,b int) int {return a+b}")
+ p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: Range{Start: Position{Line: 1, Character: 0}, End: Position{Line: 1, Character: 10}}}
+ sel := "func add(a,b int) int {return a+b}"
+ ca := s.buildDocumentCodeAction(p, sel)
+ if ca == nil {
+ t.Fatalf("expected document code action")
+ }
+ resolved, ok := s.resolveCodeAction(*ca)
+ if !ok || resolved.Edit == nil {
+ t.Fatalf("expected resolved edit")
+ }
+ edits := resolved.Edit.Changes[uri]
+ if len(edits) != 1 || strings.TrimSpace(edits[0].NewText) == "" {
+ t.Fatalf("expected replacement text")
+ }
}
func TestResolveCodeAction_Rewrite(t *testing.T) {
- s := newTestServer()
- s.llmClient = fakeLLM{resp: "rewritten"}
- uri := "file:///x.go"
- s.setDocument(uri, "package p\nvar a=1\n")
- payload := struct {
- Type string `json:"type"`
- URI string `json:"uri"`
- Range Range `json:"range"`
- Instruction string `json:"instruction"`
- Selection string `json:"selection"`
- }{Type: "rewrite", URI: uri, Range: Range{Start: Position{Line:1}, End: Position{Line:1, Character: 5}}, Instruction: "do it", Selection: "var a"}
- raw, _ := json.Marshal(payload)
- ca := CodeAction{Title: "Hexai: rewrite selection", Data: raw}
- if resolved, ok := s.resolveCodeAction(ca); !ok || resolved.Edit == nil { t.Fatalf("expected resolved rewrite edit") }
+ s := newTestServer()
+ s.llmClient = fakeLLM{resp: "rewritten"}
+ uri := "file:///x.go"
+ s.setDocument(uri, "package p\nvar a=1\n")
+ payload := struct {
+ Type string `json:"type"`
+ URI string `json:"uri"`
+ Range Range `json:"range"`
+ Instruction string `json:"instruction"`
+ Selection string `json:"selection"`
+ }{Type: "rewrite", URI: uri, Range: Range{Start: Position{Line: 1}, End: Position{Line: 1, Character: 5}}, Instruction: "do it", Selection: "var a"}
+ raw, _ := json.Marshal(payload)
+ ca := CodeAction{Title: "Hexai: rewrite selection", Data: raw}
+ if resolved, ok := s.resolveCodeAction(ca); !ok || resolved.Edit == nil {
+ t.Fatalf("expected resolved rewrite edit")
+ }
}
func TestBuildGoUnitTestCodeAction_AndResolveCreate(t *testing.T) {
- s := newTestServer()
- // place files under a temp dir to avoid collisions
- dir := t.TempDir()
- srcPath := filepath.Join(dir, "calc.go")
- uri := "file://" + srcPath
- src := "package calc\n\nfunc Sum(a, b int) int { return a+b }\n"
- s.setDocument(uri, src)
- // Offer action (not a _test.go)
- p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: Range{Start: Position{Line:2}}}
- if a := s.buildGoUnitTestCodeAction(p); a == nil { t.Fatalf("expected go unit test action") }
- // Resolve should create new test file with package+import and a test function
- we, testURI, _, ok := s.resolveGoTest(uri, Position{Line:2})
- if !ok { t.Fatalf("resolveGoTest failed") }
- if len(we.DocumentChanges) != 2 { t.Fatalf("expected create + edits, got %d", len(we.DocumentChanges)) }
- if !strings.HasSuffix(testURI, "_test.go") { t.Fatalf("unexpected test URI: %s", testURI) }
+ s := newTestServer()
+ // place files under a temp dir to avoid collisions
+ dir := t.TempDir()
+ srcPath := filepath.Join(dir, "calc.go")
+ uri := "file://" + srcPath
+ src := "package calc\n\nfunc Sum(a, b int) int { return a+b }\n"
+ s.setDocument(uri, src)
+ // Offer action (not a _test.go)
+ p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: Range{Start: Position{Line: 2}}}
+ if a := s.buildGoUnitTestCodeAction(p); a == nil {
+ t.Fatalf("expected go unit test action")
+ }
+ // Resolve should create new test file with package+import and a test function
+ we, testURI, _, ok := s.resolveGoTest(uri, Position{Line: 2})
+ if !ok {
+ t.Fatalf("resolveGoTest failed")
+ }
+ if len(we.DocumentChanges) != 2 {
+ t.Fatalf("expected create + edits, got %d", len(we.DocumentChanges))
+ }
+ if !strings.HasSuffix(testURI, "_test.go") {
+ t.Fatalf("unexpected test URI: %s", testURI)
+ }
}
func TestBuildGoUnitTestCodeAction_SkipOnTestFile(t *testing.T) {
- s := newTestServer()
- uri := "file:///tmp/x_test.go"
- s.setDocument(uri, "package p\nfunc T(){}")
- p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}}
- if a := s.buildGoUnitTestCodeAction(p); a != nil { t.Fatalf("expected no action on _test.go") }
+ s := newTestServer()
+ uri := "file:///tmp/x_test.go"
+ s.setDocument(uri, "package p\nfunc T(){}")
+ p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}}
+ if a := s.buildGoUnitTestCodeAction(p); a != nil {
+ t.Fatalf("expected no action on _test.go")
+ }
}
func TestDiagnosticsInRange(t *testing.T) {
- s := newTestServer()
- ctx := CodeActionContext{Diagnostics: []Diagnostic{
- {Range: Range{Start: Position{Line: 3}, End: Position{Line: 3, Character: 5}}, Message: "in"},
- {Range: Range{Start: Position{Line: 10}, End: Position{Line: 11}}, Message: "out"},
- }}
- raw, _ := json.Marshal(ctx)
- got := s.diagnosticsInRange(json.RawMessage(raw), Range{Start: Position{Line:2}, End: Position{Line:4}})
- if len(got) != 1 || got[0].Message != "in" { t.Fatalf("unexpected diags: %+v", got) }
+ s := newTestServer()
+ ctx := CodeActionContext{Diagnostics: []Diagnostic{
+ {Range: Range{Start: Position{Line: 3}, End: Position{Line: 3, Character: 5}}, Message: "in"},
+ {Range: Range{Start: Position{Line: 10}, End: Position{Line: 11}}, Message: "out"},
+ }}
+ raw, _ := json.Marshal(ctx)
+ got := s.diagnosticsInRange(json.RawMessage(raw), Range{Start: Position{Line: 2}, End: Position{Line: 4}})
+ if len(got) != 1 || got[0].Message != "in" {
+ t.Fatalf("unexpected diags: %+v", got)
+ }
}
func TestDocBeforeAfter(t *testing.T) {
- s := newTestServer()
- uri := "file:///d.go"
- s.setDocument(uri, "ab\ncd\nef")
- before, after := s.docBeforeAfter(uri, Position{Line:1, Character:1})
- if before != "ab\nc" || after != "d\nef" { t.Fatalf("before=%q after=%q", before, after) }
+ s := newTestServer()
+ uri := "file:///d.go"
+ s.setDocument(uri, "ab\ncd\nef")
+ before, after := s.docBeforeAfter(uri, Position{Line: 1, Character: 1})
+ if before != "ab\nc" || after != "d\nef" {
+ t.Fatalf("before=%q after=%q", before, after)
+ }
}
diff --git a/internal/lsp/codeaction_test.go b/internal/lsp/codeaction_test.go
index 4de0790..29cb416 100644
--- a/internal/lsp/codeaction_test.go
+++ b/internal/lsp/codeaction_test.go
@@ -1,10 +1,11 @@
package lsp
import (
- "context"
- "encoding/json"
- "codeberg.org/snonux/hexai/internal/llm"
- "testing"
+ "context"
+ "encoding/json"
+ "testing"
+
+ "codeberg.org/snonux/hexai/internal/llm"
)
type fakeLLM struct {
@@ -22,7 +23,7 @@ func TestBuildRewriteCodeAction_LazyAndResolves(t *testing.T) {
s := newTestServer()
s.llmClient = fakeLLM{resp: "REWRITTEN"}
p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: "file:///t.go"}, Range: Range{Start: Position{Line: 1, Character: 2}, End: Position{Line: 3, Character: 4}}}
- sel := ">rewrite>\nold code"
+ sel := ">rewrite>\nold code"
ca := s.buildRewriteCodeAction(p, sel)
if ca == nil {
t.Fatalf("expected code action")
diff --git a/internal/lsp/codegen_helpers_test.go b/internal/lsp/codegen_helpers_test.go
index d897953..de43b7d 100644
--- a/internal/lsp/codegen_helpers_test.go
+++ b/internal/lsp/codegen_helpers_test.go
@@ -3,13 +3,20 @@ package lsp
import "testing"
func TestParseGoPackageName(t *testing.T) {
- lines := []string{"// comment", "package mypkg // trailing"}
- if got := parseGoPackageName(lines); got != "mypkg" { t.Fatalf("got %q", got) }
- if got := parseGoPackageName([]string{"no package"}); got != "" { t.Fatalf("expected empty") }
+ lines := []string{"// comment", "package mypkg // trailing"}
+ if got := parseGoPackageName(lines); got != "mypkg" {
+ t.Fatalf("got %q", got)
+ }
+ if got := parseGoPackageName([]string{"no package"}); got != "" {
+ t.Fatalf("expected empty")
+ }
}
func TestDeriveGoFuncName(t *testing.T) {
- if got := deriveGoFuncName("func Sum(a int) int { return a }"); got != "Sum" { t.Fatalf("got %q", got) }
- if got := deriveGoFuncName("func (t *Type) Method(x int) {}"); got != "Method" { t.Fatalf("got %q", got) }
+ if got := deriveGoFuncName("func Sum(a int) int { return a }"); got != "Sum" {
+ t.Fatalf("got %q", got)
+ }
+ if got := deriveGoFuncName("func (t *Type) Method(x int) {}"); got != "Method" {
+ t.Fatalf("got %q", got)
+ }
}
-
diff --git a/internal/lsp/completion_cache_test.go b/internal/lsp/completion_cache_test.go
index 9ef0f00..65631f9 100644
--- a/internal/lsp/completion_cache_test.go
+++ b/internal/lsp/completion_cache_test.go
@@ -1,12 +1,12 @@
package lsp
import (
- "bytes"
- "log"
- "strings"
- "testing"
+ "bytes"
+ "log"
+ "strings"
+ "testing"
- "codeberg.org/snonux/hexai/internal/logging"
+ "codeberg.org/snonux/hexai/internal/logging"
)
func TestCompletionCache_IgnoresWhitespaceBeforeCursor(t *testing.T) {
diff --git a/internal/lsp/completion_codex_path_test.go b/internal/lsp/completion_codex_path_test.go
index 6030d92..bd3b3f4 100644
--- a/internal/lsp/completion_codex_path_test.go
+++ b/internal/lsp/completion_codex_path_test.go
@@ -1,11 +1,11 @@
package lsp
import (
- "context"
- "errors"
- "testing"
+ "context"
+ "errors"
+ "testing"
- "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/llm"
)
// fakeCodeLLM implements both llm.Client and llm.CodeCompleter.
diff --git a/internal/lsp/completion_helpers_more_test.go b/internal/lsp/completion_helpers_more_test.go
index 02fe9f3..79d2523 100644
--- a/internal/lsp/completion_helpers_more_test.go
+++ b/internal/lsp/completion_helpers_more_test.go
@@ -1,35 +1,51 @@
package lsp
import (
- "encoding/json"
- "testing"
+ "encoding/json"
+ "testing"
)
func TestExtractTriggerInfo_ParseManualInvoke(t *testing.T) {
- // Compose a CompletionParams with a raw JSON context
- ctx := struct{ TriggerKind int `json:"triggerKind"`; TriggerCharacter string `json:"triggerCharacter"` }{TriggerKind: 1, TriggerCharacter: "."}
- raw, _ := json.Marshal(ctx)
- p := CompletionParams{Context: json.RawMessage(raw)}
- kind, ch := extractTriggerInfo(p)
- if kind != 1 || ch != "." { t.Fatalf("unexpected trigger info: %d %q", kind, ch) }
- if !parseManualInvoke(json.RawMessage(raw)) { t.Fatalf("expected manual invoke true") }
+ // Compose a CompletionParams with a raw JSON context
+ ctx := struct {
+ TriggerKind int `json:"triggerKind"`
+ TriggerCharacter string `json:"triggerCharacter"`
+ }{TriggerKind: 1, TriggerCharacter: "."}
+ raw, _ := json.Marshal(ctx)
+ p := CompletionParams{Context: json.RawMessage(raw)}
+ kind, ch := extractTriggerInfo(p)
+ if kind != 1 || ch != "." {
+ t.Fatalf("unexpected trigger info: %d %q", kind, ch)
+ }
+ if !parseManualInvoke(json.RawMessage(raw)) {
+ t.Fatalf("expected manual invoke true")
+ }
}
func TestShouldSuppressForChatTriggerEOL(t *testing.T) {
- s := newTestServer()
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line:0, Character:10}}
- line := "say hi;>"
- if !s.shouldSuppressForChatTriggerEOL(line, p) { t.Fatalf("expected suppression when ;> at EOL") }
- if s.shouldSuppressForChatTriggerEOL("plain>", p) { t.Fatalf("should not suppress for plain >") }
+ s := newTestServer()
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line: 0, Character: 10}}
+ line := "say hi;>"
+ if !s.shouldSuppressForChatTriggerEOL(line, p) {
+ t.Fatalf("expected suppression when ;> at EOL")
+ }
+ if s.shouldSuppressForChatTriggerEOL("plain>", p) {
+ t.Fatalf("should not suppress for plain >")
+ }
}
func TestPrefixHeuristicAllows(t *testing.T) {
- s := newTestServer()
- // inline prompt allows zero prefix
- if !s.prefixHeuristicAllows(true, "", CompletionParams{Position: Position{Line:0, Character:0}}, false) { t.Fatalf("inline prompt should allow") }
- // structural triggers like '.' allow without prefix
- if !s.prefixHeuristicAllows(false, "fmt.", CompletionParams{Position: Position{Line:0, Character:4}}, false) { t.Fatalf("dot trigger should allow") }
- // otherwise need at least minimal prefix (default min=1)
- if s.prefixHeuristicAllows(false, " ", CompletionParams{Position: Position{Line:0, Character:0}}, false) { t.Fatalf("should not allow with no prefix") }
+ s := newTestServer()
+ // inline prompt allows zero prefix
+ if !s.prefixHeuristicAllows(true, "", CompletionParams{Position: Position{Line: 0, Character: 0}}, false) {
+ t.Fatalf("inline prompt should allow")
+ }
+ // structural triggers like '.' allow without prefix
+ if !s.prefixHeuristicAllows(false, "fmt.", CompletionParams{Position: Position{Line: 0, Character: 4}}, false) {
+ t.Fatalf("dot trigger should allow")
+ }
+ // otherwise need at least minimal prefix (default min=1)
+ if s.prefixHeuristicAllows(false, " ", CompletionParams{Position: Position{Line: 0, Character: 0}}, false) {
+ t.Fatalf("should not allow with no prefix")
+ }
}
-
diff --git a/internal/lsp/completion_messages_test.go b/internal/lsp/completion_messages_test.go
index e9ec3e5..28908d5 100644
--- a/internal/lsp/completion_messages_test.go
+++ b/internal/lsp/completion_messages_test.go
@@ -1,73 +1,99 @@
package lsp
import (
- "testing"
+ "testing"
)
func TestBuildCompletionMessages_InlinePromptOverridesSys(t *testing.T) {
- s := newTestServer()
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line:0, Character:1}}
- msgs := s.buildCompletionMessages(true, false, "", false, p, "above", "current", "below", "func f")
- if len(msgs) < 2 { t.Fatalf("expected messages") }
- if msgs[0].Role != "system" || msgs[1].Role != "user" { t.Fatalf("unexpected roles") }
- if want := "precise code completion/refactoring engine"; !contains(msgs[0].Content, want) {
- t.Fatalf("inline sys not applied")
- }
+ s := newTestServer()
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line: 0, Character: 1}}
+ msgs := s.buildCompletionMessages(true, false, "", false, p, "above", "current", "below", "func f")
+ if len(msgs) < 2 {
+ t.Fatalf("expected messages")
+ }
+ if msgs[0].Role != "system" || msgs[1].Role != "user" {
+ t.Fatalf("unexpected roles")
+ }
+ if want := "precise code completion/refactoring engine"; !contains(msgs[0].Content, want) {
+ t.Fatalf("inline sys not applied")
+ }
}
func TestBuildCompletionMessages_ExtraContextIncluded(t *testing.T) {
- s := newTestServer()
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line:0, Character:1}}
- msgs := s.buildCompletionMessages(false, true, "EXTRA", false, p, "a", "b", "c", "f")
- found := false
- for _, m := range msgs { if m.Role == "user" && contains(m.Content, "Additional context:") { found = true } }
- if !found { t.Fatalf("missing extra context message") }
+ s := newTestServer()
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line: 0, Character: 1}}
+ msgs := s.buildCompletionMessages(false, true, "EXTRA", false, p, "a", "b", "c", "f")
+ found := false
+ for _, m := range msgs {
+ if m.Role == "user" && contains(m.Content, "Additional context:") {
+ found = true
+ }
+ }
+ if !found {
+ t.Fatalf("missing extra context message")
+ }
}
func TestPrefixHeuristic_AllVariants(t *testing.T) {
- s := newTestServer()
- // manual invoke requires at least min prefix; set to 2
- s.manualInvokeMinPrefix = 2
- cur := "a"
- p := CompletionParams{Position: Position{Line:0, Character:1}}
- if s.prefixHeuristicAllows(false, cur, p, true) { t.Fatalf("should require >=2 prefix on manual invoke") }
- // structural triggers allow without prefix
- if !s.prefixHeuristicAllows(false, "fmt.", CompletionParams{Position: Position{Line:0, Character:4}}, false) { t.Fatalf("dot trigger should allow") }
+ s := newTestServer()
+ // manual invoke requires at least min prefix; set to 2
+ s.manualInvokeMinPrefix = 2
+ cur := "a"
+ p := CompletionParams{Position: Position{Line: 0, Character: 1}}
+ if s.prefixHeuristicAllows(false, cur, p, true) {
+ t.Fatalf("should require >=2 prefix on manual invoke")
+ }
+ // structural triggers allow without prefix
+ if !s.prefixHeuristicAllows(false, "fmt.", CompletionParams{Position: Position{Line: 0, Character: 4}}, false) {
+ t.Fatalf("dot trigger should allow")
+ }
}
func TestBuildDocString_Contents(t *testing.T) {
- s := newTestServer()
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line:3, Character:7}}
- got := s.buildDocString(p, "above", "current", "below", "func ctx")
- if !contains(got, "file: file:///x") || !contains(got, "line: 3") || !contains(got, "function: func ctx") {
- t.Fatalf("unexpected doc string: %q", got)
- }
+ s := newTestServer()
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line: 3, Character: 7}}
+ got := s.buildDocString(p, "above", "current", "below", "func ctx")
+ if !contains(got, "file: file:///x") || !contains(got, "line: 3") || !contains(got, "function: func ctx") {
+ t.Fatalf("unexpected doc string: %q", got)
+ }
}
func TestBuildPrompts_InParams(t *testing.T) {
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line:0, Character:5}}
- sys, user := buildPrompts(true, p, "a", "func f(x)", "c", "func f(x)")
- if !contains(sys, "function signatures") || !contains(user, "parameter list") { t.Fatalf("unexpected in-params prompts") }
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line: 0, Character: 5}}
+ sys, user := buildPrompts(true, p, "a", "func f(x)", "c", "func f(x)")
+ if !contains(sys, "function signatures") || !contains(user, "parameter list") {
+ t.Fatalf("unexpected in-params prompts")
+ }
}
func TestPostProcessCompletion_CodeFencesAndDuplicates(t *testing.T) {
- s := newTestServer()
- // code fences
- cleaned := s.postProcessCompletion("```go\nname := value\n```", "", "")
- if cleaned == "" { t.Fatalf("expected non-empty after fence removal") }
- // duplicate assignment prefix strip
- cleaned2 := s.postProcessCompletion("name := other", "name := ", "name := ")
- if cleaned2 == "" || cleaned2 == "name := other" { t.Fatalf("expected duplicate assignment prefix stripped: %q", cleaned2) }
+ s := newTestServer()
+ // code fences
+ cleaned := s.postProcessCompletion("```go\nname := value\n```", "", "")
+ if cleaned == "" {
+ t.Fatalf("expected non-empty after fence removal")
+ }
+ // duplicate assignment prefix strip
+ cleaned2 := s.postProcessCompletion("name := other", "name := ", "name := ")
+ if cleaned2 == "" || cleaned2 == "name := other" {
+ t.Fatalf("expected duplicate assignment prefix stripped: %q", cleaned2)
+ }
}
-func contains(s, sub string) bool { return len(s) >= len(sub) && (s == sub || (len(sub) > 0 && (stringIndex(s, sub) >= 0))) }
-func stringIndex(s, sub string) int { return len([]rune(s[:])) - len([]rune(s[:])) + (func() int { return intIndex(s, sub) })() }
+func contains(s, sub string) bool {
+ return len(s) >= len(sub) && (s == sub || (len(sub) > 0 && (stringIndex(s, sub) >= 0)))
+}
+func stringIndex(s, sub string) int {
+ return len([]rune(s[:])) - len([]rune(s[:])) + (func() int { return intIndex(s, sub) })()
+}
func intIndex(s, sub string) int { return Index(s, sub) }
// Go's strings.Index is fine; wrapped to avoid extra imports in this small test.
func Index(s, sub string) int {
- for i := 0; i+len(sub) <= len(s); i++ {
- if s[i:i+len(sub)] == sub { return i }
- }
- return -1
+ for i := 0; i+len(sub) <= len(s); i++ {
+ if s[i:i+len(sub)] == sub {
+ return i
+ }
+ }
+ return -1
}
diff --git a/internal/lsp/completion_prefix_strip_test.go b/internal/lsp/completion_prefix_strip_test.go
index e8e70f5..6af87a0 100644
--- a/internal/lsp/completion_prefix_strip_test.go
+++ b/internal/lsp/completion_prefix_strip_test.go
@@ -1,9 +1,10 @@
package lsp
import (
- "encoding/json"
- "testing"
- tut "codeberg.org/snonux/hexai/internal/testutil"
+ "encoding/json"
+ "testing"
+
+ tut "codeberg.org/snonux/hexai/internal/testutil"
)
func TestStripDuplicateGeneralPrefix_ExactOverlap(t *testing.T) {
@@ -41,7 +42,7 @@ func TestStripDuplicateAssignmentPrefix_AssignAndWalrus(t *testing.T) {
func TestTryLLMCompletion_ManualInvokeAfterWhitespace_Allows(t *testing.T) {
s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)}
- s.llmClient = fakeLLM{resp: tut.MultilineFunctionSuggestion()}
+ s.llmClient = fakeLLM{resp: tut.MultilineFunctionSuggestion()}
line := "func fib(i int) " // cursor after space
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://x.go"}}
// Simulate manual user invocation (TriggerKind=1)
@@ -56,15 +57,15 @@ func TestTryLLMCompletion_ManualInvokeAfterWhitespace_Allows(t *testing.T) {
}
func TestTryLLMCompletion_InlinePromptAlwaysTriggers(t *testing.T) {
- s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)}
- s.llmClient = fakeLLM{resp: "replacement"}
- line := "prefix >do something> suffix"
- // No trigger char immediately before cursor; place cursor at end
- p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://inline.go"}}
- items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
- if !ok || len(items) == 0 {
- t.Fatalf("expected completion to trigger on inline >text> prompt")
- }
+ s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)}
+ s.llmClient = fakeLLM{resp: "replacement"}
+ line := "prefix >do something> suffix"
+ // No trigger char immediately before cursor; place cursor at end
+ p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://inline.go"}}
+ items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ if !ok || len(items) == 0 {
+ t.Fatalf("expected completion to trigger on inline >text> prompt")
+ }
}
func TestTryLLMCompletion_DoubleOpenEmpty_DoesNotAutoTrigger(t *testing.T) {
@@ -86,63 +87,63 @@ func TestTryLLMCompletion_DoubleOpenEmpty_DoesNotAutoTrigger(t *testing.T) {
}
func TestHasDoubleSemicolonTrigger_Variants(t *testing.T) {
- if hasDoubleOpenTrigger(">>") {
- t.Fatalf("bare double-open should not trigger")
- }
- if hasDoubleOpenTrigger(">> ") {
- t.Fatalf("double-open followed by space should not trigger")
- }
- if hasDoubleOpenTrigger(">>>") {
- t.Fatalf("';;;' should not trigger (no content)")
- }
- if !hasDoubleOpenTrigger(">>x>") {
- t.Fatalf("expected trigger for ';;x;' pattern")
- }
+ if hasDoubleOpenTrigger(">>") {
+ t.Fatalf("bare double-open should not trigger")
+ }
+ if hasDoubleOpenTrigger(">> ") {
+ t.Fatalf("double-open followed by space should not trigger")
+ }
+ if hasDoubleOpenTrigger(">>>") {
+ t.Fatalf("';;;' should not trigger (no content)")
+ }
+ if !hasDoubleOpenTrigger(">>x>") {
+ t.Fatalf("expected trigger for ';;x;' pattern")
+ }
}
func TestBareDoubleOpenPreventsAutoTriggerEvenWithOtherTriggers(t *testing.T) {
- s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)}
- fake := &countingLLM{}
- s.llmClient = fake
- // Place a '.' earlier but also include bare double-open at end; should not auto-trigger
- line := "obj. call >>"
+ s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)}
+ fake := &countingLLM{}
+ s.llmClient = fake
+ // Place a '.' earlier but also include bare double-open at end; should not auto-trigger
+ line := "obj. call >>"
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://bare-ds.go"}}
items, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
if !ok {
t.Fatalf("expected ok=true (handled), but not auto-triggering")
}
- if len(items) != 0 {
- t.Fatalf("expected no items due to bare double-open")
- }
+ if len(items) != 0 {
+ t.Fatalf("expected no items due to bare double-open")
+ }
if fake.calls != 0 {
t.Fatalf("LLM should not be called; calls=%d", fake.calls)
}
}
func TestBareDoubleOpenOnNextLine_PreventsAutoTrigger(t *testing.T) {
- s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)}
- fake := &countingLLM{}
- s.llmClient = fake
- current := "expression := flag.String(\"expression\", \"\", \"Expression to evaluate\")"
- below := ">>"
+ s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)}
+ fake := &countingLLM{}
+ s.llmClient = fake
+ current := "expression := flag.String(\"expression\", \"\", \"Expression to evaluate\")"
+ below := ">>"
p := CompletionParams{Position: Position{Line: 0, Character: len(current)}, TextDocument: TextDocumentIdentifier{URI: "file://nextline.go"}}
items, ok := s.tryLLMCompletion(p, "", current, below, "", "", false, "")
if !ok {
t.Fatalf("expected ok=true handled")
}
- if len(items) != 0 {
- t.Fatalf("expected no items due to bare double-open on next line")
- }
+ if len(items) != 0 {
+ t.Fatalf("expected no items due to bare double-open on next line")
+ }
if fake.calls != 0 {
t.Fatalf("LLM should not be called; calls=%d", fake.calls)
}
}
func TestBareDoubleOpenPreventsManualInvoke(t *testing.T) {
- s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)}
- fake := &countingLLM{}
- s.llmClient = fake
- line := ">>"
+ s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)}
+ fake := &countingLLM{}
+ s.llmClient = fake
+ line := ">>"
p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://bare-ds-manual.go"}}
// Simulate manual invoke
p.Context = json.RawMessage([]byte(`{"triggerKind":1}`))
@@ -150,9 +151,9 @@ func TestBareDoubleOpenPreventsManualInvoke(t *testing.T) {
if !ok {
t.Fatalf("expected ok=true (handled)")
}
- if len(items) != 0 {
- t.Fatalf("expected no items for bare double-open even with manual invoke")
- }
+ if len(items) != 0 {
+ t.Fatalf("expected no items for bare double-open even with manual invoke")
+ }
if fake.calls != 0 {
t.Fatalf("LLM should not be called; calls=%d", fake.calls)
}
diff --git a/internal/lsp/completion_provider_fallback_test.go b/internal/lsp/completion_provider_fallback_test.go
index 04ca7a4..67dc78b 100644
--- a/internal/lsp/completion_provider_fallback_test.go
+++ b/internal/lsp/completion_provider_fallback_test.go
@@ -1,41 +1,50 @@
package lsp
import (
- "context"
- "encoding/json"
- "io"
- "testing"
+ "context"
+ "encoding/json"
+ "io"
+ "testing"
- "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/llm"
)
// fakeCompleterErr implements both Client and CodeCompleter; CodeCompletion errors,
// forcing tryProviderNativeCompletion to take the error path and fall back to chat.
type fakeCompleterErr struct{}
-func (fakeCompleterErr) Chat(context.Context, []llm.Message, ...llm.RequestOption) (string, error) { return "X", nil }
-func (fakeCompleterErr) Name() string { return "prov" }
+
+func (fakeCompleterErr) Chat(context.Context, []llm.Message, ...llm.RequestOption) (string, error) {
+ return "X", nil
+}
+func (fakeCompleterErr) Name() string { return "prov" }
func (fakeCompleterErr) DefaultModel() string { return "m" }
-func (fakeCompleterErr) CodeCompletion(context.Context, string, string, int, string, float64) ([]string, error) { return nil, io.EOF }
+func (fakeCompleterErr) CodeCompletion(context.Context, string, string, int, string, float64) ([]string, error) {
+ return nil, io.EOF
+}
func TestCompletion_FallbackOnProviderError(t *testing.T) {
- s := newTestServer()
- s.llmClient = fakeCompleterErr{}
- // Provide simple document
- uri := "file:///x.go"
- s.setDocument(uri, "package p\nfunc f(){\nfmt.\n}\n")
- // Position after 'fmt.' to satisfy prefix heuristics
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Position: Position{Line:2, Character:4}}
- // Build context for trigger character '.'
- ctx := struct{ TriggerKind int `json:"triggerKind"`; TriggerCharacter string `json:"triggerCharacter"` }{TriggerKind: 2, TriggerCharacter: "."}
- bctx, _ := json.Marshal(ctx)
- p.Context = json.RawMessage(bctx)
-
- // Call handleCompletion and ensure it returns at least one item from chat fallback
- var buf nopWriter
- s.out = &buf
- s.handleCompletion(Request{JSONRPC: "2.0", ID: json.RawMessage("6"), Method: "textDocument/completion", Params: mustJSON(p)})
- // No panic implies path executed; detailed decode not needed here
+ s := newTestServer()
+ s.llmClient = fakeCompleterErr{}
+ // Provide simple document
+ uri := "file:///x.go"
+ s.setDocument(uri, "package p\nfunc f(){\nfmt.\n}\n")
+ // Position after 'fmt.' to satisfy prefix heuristics
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Position: Position{Line: 2, Character: 4}}
+ // Build context for trigger character '.'
+ ctx := struct {
+ TriggerKind int `json:"triggerKind"`
+ TriggerCharacter string `json:"triggerCharacter"`
+ }{TriggerKind: 2, TriggerCharacter: "."}
+ bctx, _ := json.Marshal(ctx)
+ p.Context = json.RawMessage(bctx)
+
+ // Call handleCompletion and ensure it returns at least one item from chat fallback
+ var buf nopWriter
+ s.out = &buf
+ s.handleCompletion(Request{JSONRPC: "2.0", ID: json.RawMessage("6"), Method: "textDocument/completion", Params: mustJSON(p)})
+ // No panic implies path executed; detailed decode not needed here
}
type nopWriter struct{}
+
func (nopWriter) Write(p []byte) (int, error) { return len(p), nil }
diff --git a/internal/lsp/compute_textedit_table_test.go b/internal/lsp/compute_textedit_table_test.go
index d82e91d..6ed5330 100644
--- a/internal/lsp/compute_textedit_table_test.go
+++ b/internal/lsp/compute_textedit_table_test.go
@@ -3,31 +3,30 @@ package lsp
import "testing"
func TestComputeTextEditAndFilter_Table(t *testing.T) {
- cases := []struct{
- name string
- inParams bool
- current string
- pos Position
- cleaned string
- }{
- {"ident_replace", false, "ab cd", Position{Line:1, Character:4}, "X"},
- {"params_inside", true, "func add(a int, b string)", Position{Line:0, Character:15}, "c bool"},
- {"params_at_close", true, "func add(a int)", Position{Line:0, Character:len("func add(a int)")}, "b string"},
- }
- for _, c := range cases {
- te, filter := computeTextEditAndFilter(c.cleaned, c.inParams, c.current, CompletionParams{Position: c.pos})
- if te == nil {
- t.Fatalf("%s: expected edit", c.name)
- }
- if c.inParams && te.Range.Start.Character == 0 {
- t.Fatalf("%s: expected param range (non-zero start)", c.name)
- }
- if filter == "" && c.current != "" {
- // For ident_replace, filter may be non-empty; for params, it can be empty when replacing entire segment
- }
- if te.NewText != c.cleaned {
- t.Fatalf("%s: newText got %q want %q", c.name, te.NewText, c.cleaned)
- }
- }
+ cases := []struct {
+ name string
+ inParams bool
+ current string
+ pos Position
+ cleaned string
+ }{
+ {"ident_replace", false, "ab cd", Position{Line: 1, Character: 4}, "X"},
+ {"params_inside", true, "func add(a int, b string)", Position{Line: 0, Character: 15}, "c bool"},
+ {"params_at_close", true, "func add(a int)", Position{Line: 0, Character: len("func add(a int)")}, "b string"},
+ }
+ for _, c := range cases {
+ te, filter := computeTextEditAndFilter(c.cleaned, c.inParams, c.current, CompletionParams{Position: c.pos})
+ if te == nil {
+ t.Fatalf("%s: expected edit", c.name)
+ }
+ if c.inParams && te.Range.Start.Character == 0 {
+ t.Fatalf("%s: expected param range (non-zero start)", c.name)
+ }
+ if filter == "" && c.current != "" {
+ // For ident_replace, filter may be non-empty; for params, it can be empty when replacing entire segment
+ }
+ if te.NewText != c.cleaned {
+ t.Fatalf("%s: newText got %q want %q", c.name, te.NewText, c.cleaned)
+ }
+ }
}
-
diff --git a/internal/lsp/context.go b/internal/lsp/context.go
index 72331a8..5a4983c 100644
--- a/internal/lsp/context.go
+++ b/internal/lsp/context.go
@@ -2,8 +2,9 @@
package lsp
import (
- "codeberg.org/snonux/hexai/internal/logging"
"strings"
+
+ "codeberg.org/snonux/hexai/internal/logging"
)
// buildAdditionalContext builds extra context messages based on the configured mode.
diff --git a/internal/lsp/debounce_throttle_more_test.go b/internal/lsp/debounce_throttle_more_test.go
index cb11ea4..ed61336 100644
--- a/internal/lsp/debounce_throttle_more_test.go
+++ b/internal/lsp/debounce_throttle_more_test.go
@@ -1,36 +1,35 @@
package lsp
import (
- "context"
- "testing"
- "time"
+ "context"
+ "testing"
+ "time"
)
func TestWaitForDebounce_WaitsRoughlyDebounce(t *testing.T) {
- s := newTestServer()
- s.completionDebounce = 20 * time.Millisecond
- s.mu.Lock()
- s.lastInput = time.Now()
- s.mu.Unlock()
- start := time.Now()
- s.waitForDebounce(context.Background())
- if elapsed := time.Since(start); elapsed < 15*time.Millisecond {
- t.Fatalf("debounce did not wait long enough: %v", elapsed)
- }
+ s := newTestServer()
+ s.completionDebounce = 20 * time.Millisecond
+ s.mu.Lock()
+ s.lastInput = time.Now()
+ s.mu.Unlock()
+ start := time.Now()
+ s.waitForDebounce(context.Background())
+ if elapsed := time.Since(start); elapsed < 15*time.Millisecond {
+ t.Fatalf("debounce did not wait long enough: %v", elapsed)
+ }
}
func TestWaitForThrottle_WaitsRoughlyInterval(t *testing.T) {
- s := newTestServer()
- s.throttleInterval = 20 * time.Millisecond
- s.mu.Lock()
- s.lastLLMCall = time.Now()
- s.mu.Unlock()
- start := time.Now()
- if !s.waitForThrottle(context.Background()) {
- t.Fatalf("waitForThrottle returned false")
- }
- if elapsed := time.Since(start); elapsed < 15*time.Millisecond {
- t.Fatalf("throttle did not wait long enough: %v", elapsed)
- }
+ s := newTestServer()
+ s.throttleInterval = 20 * time.Millisecond
+ s.mu.Lock()
+ s.lastLLMCall = time.Now()
+ s.mu.Unlock()
+ start := time.Now()
+ if !s.waitForThrottle(context.Background()) {
+ t.Fatalf("waitForThrottle returned false")
+ }
+ if elapsed := time.Since(start); elapsed < 15*time.Millisecond {
+ t.Fatalf("throttle did not wait long enough: %v", elapsed)
+ }
}
-
diff --git a/internal/lsp/debounce_throttle_test.go b/internal/lsp/debounce_throttle_test.go
index 012ec68..0b49b1b 100644
--- a/internal/lsp/debounce_throttle_test.go
+++ b/internal/lsp/debounce_throttle_test.go
@@ -1,84 +1,85 @@
package lsp
import (
- "context"
- "encoding/json"
- "testing"
- "time"
- "codeberg.org/snonux/hexai/internal/llm"
+ "context"
+ "encoding/json"
+ "testing"
+ "time"
+
+ "codeberg.org/snonux/hexai/internal/llm"
)
// timeLLM records the time when Chat is invoked.
type timeLLM struct{ t time.Time }
func (t *timeLLM) Chat(ctx context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) {
- t.t = time.Now()
- return "ok", nil
+ t.t = time.Now()
+ return "ok", nil
}
func (t *timeLLM) Name() string { return "fake" }
func (t *timeLLM) DefaultModel() string { return "m" }
func TestCompletionDebounce_WaitsUntilQuiet(t *testing.T) {
- s := newTestServer()
- s.compCache = make(map[string]string)
- s.triggerChars = []string{".", ":", "/", "_"}
- s.maxTokens = 32
- s.completionDebounce = 30 * time.Millisecond
- s.markActivity() // simulate recent input
+ s := newTestServer()
+ s.compCache = make(map[string]string)
+ s.triggerChars = []string{".", ":", "/", "_"}
+ s.maxTokens = 32
+ s.completionDebounce = 30 * time.Millisecond
+ s.markActivity() // simulate recent input
- f := &timeLLM{}
- s.llmClient = f
+ f := &timeLLM{}
+ s.llmClient = f
- line := "func f(i int) "
- p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://debounce.go"}}
- p.Context = json.RawMessage([]byte(`{"triggerKind":1}`))
+ line := "func f(i int) "
+ p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://debounce.go"}}
+ p.Context = json.RawMessage([]byte(`{"triggerKind":1}`))
- start := time.Now()
- _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
- if !ok {
- t.Fatalf("expected ok=true")
- }
- if f.t.IsZero() {
- t.Fatalf("expected LLM to be called")
- }
- if f.t.Sub(start) < 25*time.Millisecond { // allow minor timing noise
- t.Fatalf("expected debounce delay, got %s", f.t.Sub(start))
- }
+ start := time.Now()
+ _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, "")
+ if !ok {
+ t.Fatalf("expected ok=true")
+ }
+ if f.t.IsZero() {
+ t.Fatalf("expected LLM to be called")
+ }
+ if f.t.Sub(start) < 25*time.Millisecond { // allow minor timing noise
+ t.Fatalf("expected debounce delay, got %s", f.t.Sub(start))
+ }
}
func TestCompletionThrottle_SerializesCalls(t *testing.T) {
- s := newTestServer()
- s.compCache = make(map[string]string)
- s.triggerChars = []string{".", ":", "/", "_"}
- s.maxTokens = 32
- s.throttleInterval = 25 * time.Millisecond
+ s := newTestServer()
+ s.compCache = make(map[string]string)
+ s.triggerChars = []string{".", ":", "/", "_"}
+ s.maxTokens = 32
+ s.throttleInterval = 25 * time.Millisecond
- // first call uses timeLLM to record time
- f1 := &timeLLM{}
- s.llmClient = f1
- line := "func f(i int) "
- p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://throttle.go"}}
- p.Context = json.RawMessage([]byte(`{"triggerKind":1}`))
- start := time.Now()
- if _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok {
- t.Fatalf("first call expected ok=true")
- }
- if f1.t.IsZero() {
- t.Fatalf("expected first call time recorded")
- }
+ // first call uses timeLLM to record time
+ f1 := &timeLLM{}
+ s.llmClient = f1
+ line := "func f(i int) "
+ p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://throttle.go"}}
+ p.Context = json.RawMessage([]byte(`{"triggerKind":1}`))
+ start := time.Now()
+ if _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok {
+ t.Fatalf("first call expected ok=true")
+ }
+ if f1.t.IsZero() {
+ t.Fatalf("expected first call time recorded")
+ }
- // second call immediately after; should be delayed by ~interval.
- // Clear cache to ensure we actually call the LLM again.
- s.compCache = make(map[string]string)
- f2 := &timeLLM{}
- s.llmClient = f2
- if _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok {
- t.Fatalf("second call expected ok=true")
- }
- if f2.t.IsZero() {
- t.Fatalf("expected second call time recorded")
- }
- if f2.t.Sub(start) < s.throttleInterval {
- t.Fatalf("expected throttle spacing >= %s, got %s", s.throttleInterval, f2.t.Sub(start))
- }
+ // second call immediately after; should be delayed by ~interval.
+ // Clear cache to ensure we actually call the LLM again.
+ s.compCache = make(map[string]string)
+ f2 := &timeLLM{}
+ s.llmClient = f2
+ if _, ok := s.tryLLMCompletion(p, "", line, "", "", "", false, ""); !ok {
+ t.Fatalf("second call expected ok=true")
+ }
+ if f2.t.IsZero() {
+ t.Fatalf("expected second call time recorded")
+ }
+ if f2.t.Sub(start) < s.throttleInterval {
+ t.Fatalf("expected throttle spacing >= %s, got %s", s.throttleInterval, f2.t.Sub(start))
+ }
}
diff --git a/internal/lsp/diagnostics_action_test.go b/internal/lsp/diagnostics_action_test.go
index 1a9201f..a607b86 100644
--- a/internal/lsp/diagnostics_action_test.go
+++ b/internal/lsp/diagnostics_action_test.go
@@ -1,30 +1,33 @@
package lsp
import (
- "encoding/json"
- "io"
- "log"
- "testing"
+ "encoding/json"
+ "io"
+ "log"
+ "testing"
)
func TestHandleCodeAction_ListsDiagnosticsActionWhenOverlap(t *testing.T) {
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document)}
- s.llmClient = fakeLLM{resp: "fixed"}
- uri := "file:///x.go"
- s.setDocument(uri, "package p\nvar a=1\n")
- // Selection overlaps line 1
- sel := Range{Start: Position{Line:1, Character:0}, End: Position{Line:1, Character:5}}
- // Provide diagnostics in the action context with one overlapping
- ctx := CodeActionContext{Diagnostics: []Diagnostic{
- {Range: Range{Start: Position{Line:1, Character:0}, End: Position{Line:1, Character:3}}, Message: "in"},
- {Range: Range{Start: Position{Line:0, Character:0}, End: Position{Line:0, Character:1}}, Message: "out"},
- }}
- rawCtx, _ := json.Marshal(ctx)
- p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: sel, Context: json.RawMessage(rawCtx)}
- ca := s.buildDiagnosticsCodeAction(p, "var a=1")
- if ca == nil { t.Fatalf("expected diagnostics action") }
- // Resolve should produce an edit
- resolved, ok := s.resolveCodeAction(*ca)
- if !ok || resolved.Edit == nil { t.Fatalf("expected resolved edit from diagnostics") }
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document)}
+ s.llmClient = fakeLLM{resp: "fixed"}
+ uri := "file:///x.go"
+ s.setDocument(uri, "package p\nvar a=1\n")
+ // Selection overlaps line 1
+ sel := Range{Start: Position{Line: 1, Character: 0}, End: Position{Line: 1, Character: 5}}
+ // Provide diagnostics in the action context with one overlapping
+ ctx := CodeActionContext{Diagnostics: []Diagnostic{
+ {Range: Range{Start: Position{Line: 1, Character: 0}, End: Position{Line: 1, Character: 3}}, Message: "in"},
+ {Range: Range{Start: Position{Line: 0, Character: 0}, End: Position{Line: 0, Character: 1}}, Message: "out"},
+ }}
+ rawCtx, _ := json.Marshal(ctx)
+ p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: sel, Context: json.RawMessage(rawCtx)}
+ ca := s.buildDiagnosticsCodeAction(p, "var a=1")
+ if ca == nil {
+ t.Fatalf("expected diagnostics action")
+ }
+ // Resolve should produce an edit
+ resolved, ok := s.resolveCodeAction(*ca)
+ if !ok || resolved.Edit == nil {
+ t.Fatalf("expected resolved edit from diagnostics")
+ }
}
-
diff --git a/internal/lsp/document.go b/internal/lsp/document.go
index a5ece7e..1ef1a5b 100644
--- a/internal/lsp/document.go
+++ b/internal/lsp/document.go
@@ -68,7 +68,7 @@ func (s *Server) lineContext(uri string, pos Position) (above, current, below, f
break
}
}
- return
+ return above, current, below, funcCtx
}
// isDefiningNewFunction returns true when the cursor appears to be within
diff --git a/internal/lsp/document_handlers_test.go b/internal/lsp/document_handlers_test.go
index bb12dd2..eae5020 100644
--- a/internal/lsp/document_handlers_test.go
+++ b/internal/lsp/document_handlers_test.go
@@ -1,61 +1,73 @@
package lsp
import (
- "bytes"
- "encoding/json"
- "io"
- "log"
- "testing"
- "time"
+ "bytes"
+ "encoding/json"
+ "io"
+ "log"
+ "testing"
+ "time"
)
func TestDidOpenChangeClose_UpdateDocs(t *testing.T) {
- s := newTestServer()
- uri := "file:///x.go"
- // didOpen
- open := DidOpenTextDocumentParams{TextDocument: TextDocumentItem{URI: uri, Text: "a\n"}}
- s.handleDidOpen(Request{JSONRPC: "2.0", Method: "textDocument/didOpen", Params: mustJSON(open)})
- if s.getDocument(uri) == nil { t.Fatalf("doc not opened") }
- // didChange
- ch := DidChangeTextDocumentParams{TextDocument: VersionedTextDocumentIdentifier{URI: uri}, ContentChanges: []TextDocumentContentChangeEvent{{Text: "b\n"}}}
- s.handleDidChange(Request{JSONRPC: "2.0", Method: "textDocument/didChange", Params: mustJSON(ch)})
- if d := s.getDocument(uri); d == nil || d.text != "b\n" { t.Fatalf("doc not changed") }
- // didClose
- s.handleDidClose(Request{JSONRPC: "2.0", Method: "textDocument/didClose", Params: mustJSON(DidCloseTextDocumentParams{TextDocument: TextDocumentIdentifier{URI: uri}})})
- if s.getDocument(uri) != nil { t.Fatalf("doc not closed") }
+ s := newTestServer()
+ uri := "file:///x.go"
+ // didOpen
+ open := DidOpenTextDocumentParams{TextDocument: TextDocumentItem{URI: uri, Text: "a\n"}}
+ s.handleDidOpen(Request{JSONRPC: "2.0", Method: "textDocument/didOpen", Params: mustJSON(open)})
+ if s.getDocument(uri) == nil {
+ t.Fatalf("doc not opened")
+ }
+ // didChange
+ ch := DidChangeTextDocumentParams{TextDocument: VersionedTextDocumentIdentifier{URI: uri}, ContentChanges: []TextDocumentContentChangeEvent{{Text: "b\n"}}}
+ s.handleDidChange(Request{JSONRPC: "2.0", Method: "textDocument/didChange", Params: mustJSON(ch)})
+ if d := s.getDocument(uri); d == nil || d.text != "b\n" {
+ t.Fatalf("doc not changed")
+ }
+ // didClose
+ s.handleDidClose(Request{JSONRPC: "2.0", Method: "textDocument/didClose", Params: mustJSON(DidCloseTextDocumentParams{TextDocument: TextDocumentIdentifier{URI: uri}})})
+ if s.getDocument(uri) != nil {
+ t.Fatalf("doc not closed")
+ }
}
func TestClientShowDocument_WritesRequest(t *testing.T) {
- var out bytes.Buffer
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
- uri := "file:///x.go"
- sel := Range{Start: Position{Line: 1}, End: Position{Line: 2}}
- out.Reset()
- s.clientShowDocument(uri, &sel)
- req := captureRequest(t, &out)
- if req.Method != "window/showDocument" { t.Fatalf("got %s", req.Method) }
+ var out bytes.Buffer
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
+ uri := "file:///x.go"
+ sel := Range{Start: Position{Line: 1}, End: Position{Line: 2}}
+ out.Reset()
+ s.clientShowDocument(uri, &sel)
+ req := captureRequest(t, &out)
+ if req.Method != "window/showDocument" {
+ t.Fatalf("got %s", req.Method)
+ }
}
func TestHandleExecuteCommand_ShowDocument(t *testing.T) {
- var out bytes.Buffer
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
- uri := "file:///x.go"
- r := Range{Start: Position{Line:0}, End: Position{Line:0}}
- args := []any{uri, r}
- params := ExecuteCommandParams{Command: "hexai.showDocument", Arguments: args}
- s.handleExecuteCommand(Request{JSONRPC: "2.0", ID: json.RawMessage("11"), Method: "workspace/executeCommand", Params: mustJSON(params)})
- req := captureRequest(t, &out)
- if req.Method != "window/showDocument" { t.Fatalf("expected showDocument after executeCommand, got %s", req.Method) }
+ var out bytes.Buffer
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
+ uri := "file:///x.go"
+ r := Range{Start: Position{Line: 0}, End: Position{Line: 0}}
+ args := []any{uri, r}
+ params := ExecuteCommandParams{Command: "hexai.showDocument", Arguments: args}
+ s.handleExecuteCommand(Request{JSONRPC: "2.0", ID: json.RawMessage("11"), Method: "workspace/executeCommand", Params: mustJSON(params)})
+ req := captureRequest(t, &out)
+ if req.Method != "window/showDocument" {
+ t.Fatalf("expected showDocument after executeCommand, got %s", req.Method)
+ }
}
func TestDeferShowDocument_WritesLater(t *testing.T) {
- var out bytes.Buffer
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
- uri := "file:///x.go"
- out.Reset()
- s.deferShowDocument(uri, Range{Start: Position{Line:0}, End: Position{Line:0}})
- // wait >120ms per implementation
- time.Sleep(160 * time.Millisecond)
- req := captureRequest(t, &out)
- if req.Method != "window/showDocument" { t.Fatalf("expected showDocument, got %s", req.Method) }
+ var out bytes.Buffer
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
+ uri := "file:///x.go"
+ out.Reset()
+ s.deferShowDocument(uri, Range{Start: Position{Line: 0}, End: Position{Line: 0}})
+ // wait >120ms per implementation
+ time.Sleep(160 * time.Millisecond)
+ req := captureRequest(t, &out)
+ if req.Method != "window/showDocument" {
+ t.Fatalf("expected showDocument, got %s", req.Method)
+ }
}
diff --git a/internal/lsp/document_test.go b/internal/lsp/document_test.go
index 5fee18b..00e4548 100644
--- a/internal/lsp/document_test.go
+++ b/internal/lsp/document_test.go
@@ -9,20 +9,20 @@ import (
)
func newTestServer() *Server {
- s := &Server{
- logger: log.New(io.Discard, "", 0),
- docs: make(map[string]*document),
- inlineOpen: ">",
- inlineClose: ">",
- chatSuffix: ">",
- chatPrefixes: []string{"?","!",":",";"},
- }
- // Keep package-level helpers in sync for tests using free functions
- inlineOpenChar = '>'
- inlineCloseChar = '>'
- chatSuffixChar = '>'
- chatPrefixSingles = []string{"?","!",":",";"}
- return s
+ s := &Server{
+ logger: log.New(io.Discard, "", 0),
+ docs: make(map[string]*document),
+ inlineOpen: ">",
+ inlineClose: ">",
+ chatSuffix: ">",
+ chatPrefixes: []string{"?", "!", ":", ";"},
+ }
+ // Keep package-level helpers in sync for tests using free functions
+ inlineOpenChar = '>'
+ inlineCloseChar = '>'
+ chatSuffixChar = '>'
+ chatPrefixSingles = []string{"?", "!", ":", ";"}
+ return s
}
func TestSplitLines(t *testing.T) {
@@ -71,12 +71,14 @@ func TestLineContext_EmptyDoc(t *testing.T) {
}
func TestDocBeforeAfter_ClampsIndices(t *testing.T) {
- s := newTestServer()
- uri := "file:///clamp.go"
- s.setDocument(uri, "abc\nxyz")
- // Position beyond document length should be clamped safely
- before, after := s.docBeforeAfter(uri, Position{Line: 99, Character: 99})
- if before == "" && after == "" { t.Fatalf("expected some text with clamped indices") }
+ s := newTestServer()
+ uri := "file:///clamp.go"
+ s.setDocument(uri, "abc\nxyz")
+ // Position beyond document length should be clamped safely
+ before, after := s.docBeforeAfter(uri, Position{Line: 99, Character: 99})
+ if before == "" && after == "" {
+ t.Fatalf("expected some text with clamped indices")
+ }
}
func TestTrimLen(t *testing.T) {
diff --git a/internal/lsp/fallback_items_test.go b/internal/lsp/fallback_items_test.go
index 0ce3542..0a08a0d 100644
--- a/internal/lsp/fallback_items_test.go
+++ b/internal/lsp/fallback_items_test.go
@@ -3,10 +3,9 @@ package lsp
import "testing"
func TestFallbackCompletionItems(t *testing.T) {
- s := newTestServer()
- items := s.fallbackCompletionItems("doc")
- if len(items) != 1 || items[0].Label != "hexai-complete" || items[0].InsertText != "hexai" {
- t.Fatalf("unexpected fallback items: %+v", items)
- }
+ s := newTestServer()
+ items := s.fallbackCompletionItems("doc")
+ if len(items) != 1 || items[0].Label != "hexai-complete" || items[0].InsertText != "hexai" {
+ t.Fatalf("unexpected fallback items: %+v", items)
+ }
}
-
diff --git a/internal/lsp/gotest_append_test.go b/internal/lsp/gotest_append_test.go
index 4fff684..7ceb9e6 100644
--- a/internal/lsp/gotest_append_test.go
+++ b/internal/lsp/gotest_append_test.go
@@ -1,28 +1,37 @@
package lsp
import (
- "os"
- "path/filepath"
- "strings"
- "testing"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
)
func TestResolveGoTest_AppendsToExisting(t *testing.T) {
- s := newTestServer()
- dir := t.TempDir()
- src := filepath.Join(dir, "m.go")
- uri := "file://" + src
- s.setDocument(uri, "package m\n\nfunc F(){}\n")
- // Create existing test file
- testPath := filepath.Join(dir, "m_test.go")
- if err := os.WriteFile(testPath, []byte("package m\n\nimport \"testing\"\n\n"), 0o644); err != nil { t.Fatal(err) }
- // LLM path to increase generateGoTestFunction coverage
- s.llmClient = fakeLLM{resp: "func TestF(t *testing.T) {}"}
- we, testURI, jump, ok := s.resolveGoTest(uri, Position{Line:2})
- if !ok || len(we.Changes) == 0 { t.Fatalf("expected append edit") }
- if !strings.HasSuffix(testURI, "_test.go") { t.Fatalf("unexpected uri: %s", testURI) }
- edits := we.Changes[testURI]
- if len(edits) != 1 || !strings.Contains(edits[0].NewText, "TestF") { t.Fatalf("expected append with TestF") }
- if jump.Start.Line < 0 { t.Fatalf("expected non-negative jump line") }
+ s := newTestServer()
+ dir := t.TempDir()
+ src := filepath.Join(dir, "m.go")
+ uri := "file://" + src
+ s.setDocument(uri, "package m\n\nfunc F(){}\n")
+ // Create existing test file
+ testPath := filepath.Join(dir, "m_test.go")
+ if err := os.WriteFile(testPath, []byte("package m\n\nimport \"testing\"\n\n"), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ // LLM path to increase generateGoTestFunction coverage
+ s.llmClient = fakeLLM{resp: "func TestF(t *testing.T) {}"}
+ we, testURI, jump, ok := s.resolveGoTest(uri, Position{Line: 2})
+ if !ok || len(we.Changes) == 0 {
+ t.Fatalf("expected append edit")
+ }
+ if !strings.HasSuffix(testURI, "_test.go") {
+ t.Fatalf("unexpected uri: %s", testURI)
+ }
+ edits := we.Changes[testURI]
+ if len(edits) != 1 || !strings.Contains(edits[0].NewText, "TestF") {
+ t.Fatalf("expected append with TestF")
+ }
+ if jump.Start.Line < 0 {
+ t.Fatalf("expected non-negative jump line")
+ }
}
-
diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go
index 5e7d86d..e85065b 100644
--- a/internal/lsp/handlers.go
+++ b/internal/lsp/handlers.go
@@ -51,7 +51,7 @@ func findFirstInstructionInLine(line string) (instr string, cleaned string, ok b
text string
}
cands := []cand{}
- if t, l, r, ok := findStrictInlineTag(line); ok {
+ if t, l, r, ok := findStrictInlineTag(line); ok {
cands = append(cands, cand{start: l, end: r, text: t})
}
if i := strings.Index(line, "/*"); i >= 0 {
@@ -298,15 +298,15 @@ func (s *Server) isTriggerEvent(p CompletionParams, current string) bool {
b, _ := json.Marshal(p.Context)
_ = json.Unmarshal(b, &ctx)
}
- // If configured and the line contains a bare double-open marker (e.g., '>>' with no '>>text>'),
- // do not treat as a trigger source.
- if s.inlineOpen != "" && strings.Contains(current, s.inlineOpen+s.inlineOpen) && !hasDoubleOpenTrigger(current) {
- return false
- }
- // TriggerKind 1 = Invoked (manual). Always allow manual invoke.
- if ctx.TriggerKind == 1 {
- return true
- }
+ // If configured and the line contains a bare double-open marker (e.g., '>>' with no '>>text>'),
+ // do not treat as a trigger source.
+ if s.inlineOpen != "" && strings.Contains(current, s.inlineOpen+s.inlineOpen) && !hasDoubleOpenTrigger(current) {
+ return false
+ }
+ // TriggerKind 1 = Invoked (manual). Always allow manual invoke.
+ if ctx.TriggerKind == 1 {
+ return true
+ }
// TriggerKind 2 is TriggerCharacter per LSP spec
if ctx.TriggerKind == 2 {
if ctx.TriggerCharacter != "" {
@@ -327,10 +327,10 @@ func (s *Server) isTriggerEvent(p CompletionParams, current string) bool {
if idx <= 0 || idx > len(current) {
return false
}
- // Bare double-open should not trigger via fallback char either (only when configured)
- if s.inlineOpen != "" && strings.Contains(current, s.inlineOpen+s.inlineOpen) && !hasDoubleOpenTrigger(current) {
- return false
- }
+ // Bare double-open should not trigger via fallback char either (only when configured)
+ if s.inlineOpen != "" && strings.Contains(current, s.inlineOpen+s.inlineOpen) && !hasDoubleOpenTrigger(current) {
+ return false
+ }
ch := string(current[idx-1])
for _, c := range s.triggerChars {
if c == ch {
diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go
index 5740264..27020a0 100644
--- a/internal/lsp/handlers_codeaction.go
+++ b/internal/lsp/handlers_codeaction.go
@@ -2,15 +2,16 @@
package lsp
import (
- "context"
- "encoding/json"
- "fmt"
- "codeberg.org/snonux/hexai/internal/llm"
- "codeberg.org/snonux/hexai/internal/logging"
- "strings"
- "time"
- "os"
- "path/filepath"
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/logging"
)
func (s *Server) handleCodeAction(req Request) {
@@ -28,24 +29,24 @@ func (s *Server) handleCodeAction(req Request) {
}
return
}
- sel := extractRangeText(d, p.Range)
+ sel := extractRangeText(d, p.Range)
- actions := make([]CodeAction, 0, 4)
- if a := s.buildRewriteCodeAction(p, sel); a != nil {
- actions = append(actions, *a)
- }
- if a := s.buildDiagnosticsCodeAction(p, sel); a != nil {
- actions = append(actions, *a)
- }
- if a := s.buildDocumentCodeAction(p, sel); a != nil {
- actions = append(actions, *a)
- }
- if a := s.buildGoUnitTestCodeAction(p); a != nil {
- actions = append(actions, *a)
- }
- if len(req.ID) != 0 {
- s.reply(req.ID, actions, nil)
- }
+ actions := make([]CodeAction, 0, 4)
+ if a := s.buildRewriteCodeAction(p, sel); a != nil {
+ actions = append(actions, *a)
+ }
+ if a := s.buildDiagnosticsCodeAction(p, sel); a != nil {
+ actions = append(actions, *a)
+ }
+ if a := s.buildDocumentCodeAction(p, sel); a != nil {
+ actions = append(actions, *a)
+ }
+ if a := s.buildGoUnitTestCodeAction(p); a != nil {
+ actions = append(actions, *a)
+ }
+ if len(req.ID) != 0 {
+ s.reply(req.ID, actions, nil)
+ }
}
func (s *Server) buildRewriteCodeAction(p CodeActionParams, sel string) *CodeAction {
@@ -96,8 +97,8 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) {
if err := json.Unmarshal(ca.Data, &payload); err != nil {
return ca, false
}
- switch payload.Type {
- case "rewrite":
+ switch payload.Type {
+ case "rewrite":
sys := "You are a precise code refactoring engine. Rewrite the given code strictly according to the instruction. Return only the updated code with no prose or backticks. Preserve formatting where reasonable."
user := fmt.Sprintf("Instruction: %s\n\nSelected code to transform:\n%s", payload.Instruction, payload.Selection)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -113,7 +114,7 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) {
} else {
logging.Logf("lsp ", "codeAction rewrite llm error: %v", err)
}
- case "diagnostics":
+ case "diagnostics":
sys := "You are a precise code fixer. Resolve the given diagnostics by editing only the selected code. Return only the corrected code with no prose or backticks. Keep behavior and style, and avoid unrelated changes."
var b strings.Builder
b.WriteString("Diagnostics to resolve (selection only):\n")
@@ -139,34 +140,34 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) {
} else {
logging.Logf("lsp ", "codeAction diagnostics llm error: %v", err)
}
- case "document":
- sys := "You are a precise code documentation engine. Add idiomatic documentation comments to the given code. Preserve exact behavior and formatting as much as possible. Return only the updated code with comments, no prose or backticks."
- user := "Add documentation comments to this code:\n" + payload.Selection
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}}
- opts := s.llmRequestOpts()
- if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil {
- if out := stripCodeFences(strings.TrimSpace(text)); out != "" {
- edit := WorkspaceEdit{Changes: map[string][]TextEdit{payload.URI: {{Range: payload.Range, NewText: out}}}}
- ca.Edit = &edit
- return ca, true
- }
- } else {
- logging.Logf("lsp ", "codeAction document llm error: %v", err)
- }
- case "go_test":
- if edit, jumpURI, jumpRange, ok := s.resolveGoTest(payload.URI, payload.Range.Start); ok {
- ca.Edit = &edit
- // After edit is applied, ask client to jump to new test function
- ca.Command = &Command{Title: "Jump to generated test", Command: "hexai.showDocument", Arguments: []any{jumpURI, jumpRange}}
- // Also send a server-initiated showDocument shortly after resolve to cover
- // clients that do not execute commands from code actions.
- s.deferShowDocument(jumpURI, jumpRange)
- return ca, true
- }
- }
- return ca, false
+ case "document":
+ sys := "You are a precise code documentation engine. Add idiomatic documentation comments to the given code. Preserve exact behavior and formatting as much as possible. Return only the updated code with comments, no prose or backticks."
+ user := "Add documentation comments to this code:\n" + payload.Selection
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}}
+ opts := s.llmRequestOpts()
+ if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil {
+ if out := stripCodeFences(strings.TrimSpace(text)); out != "" {
+ edit := WorkspaceEdit{Changes: map[string][]TextEdit{payload.URI: {{Range: payload.Range, NewText: out}}}}
+ ca.Edit = &edit
+ return ca, true
+ }
+ } else {
+ logging.Logf("lsp ", "codeAction document llm error: %v", err)
+ }
+ case "go_test":
+ if edit, jumpURI, jumpRange, ok := s.resolveGoTest(payload.URI, payload.Range.Start); ok {
+ ca.Edit = &edit
+ // After edit is applied, ask client to jump to new test function
+ ca.Command = &Command{Title: "Jump to generated test", Command: "hexai.showDocument", Arguments: []any{jumpURI, jumpRange}}
+ // Also send a server-initiated showDocument shortly after resolve to cover
+ // clients that do not execute commands from code actions.
+ s.deferShowDocument(jumpURI, jumpRange)
+ return ca, true
+ }
+ }
+ return ca, false
}
func (s *Server) handleCodeActionResolve(req Request) {
@@ -244,254 +245,282 @@ func greaterPos(p, q Position) bool {
// --- Go unit test code action ---
func (s *Server) buildGoUnitTestCodeAction(p CodeActionParams) *CodeAction {
- uri := p.TextDocument.URI
- if uri == "" || !strings.HasSuffix(strings.TrimPrefix(uri, "file://"), ".go") {
- return nil
- }
- // Skip if already a _test.go file
- if strings.HasSuffix(strings.TrimPrefix(uri, "file://"), "_test.go") {
- return nil
- }
- // Heuristic: only offer when a function context is found above the cursor
- _, _, _, funcCtx := s.lineContext(uri, p.Range.Start)
- if !strings.Contains(funcCtx, "func ") {
- return nil
- }
- payload := struct {
- Type string `json:"type"`
- URI string `json:"uri"`
- Range Range `json:"range"`
- }{Type: "go_test", URI: uri, Range: p.Range}
- raw, _ := json.Marshal(payload)
- ca := CodeAction{Title: "Hexai: implement unit test", Kind: "quickfix", Data: raw}
- return &ca
+ uri := p.TextDocument.URI
+ if uri == "" || !strings.HasSuffix(strings.TrimPrefix(uri, "file://"), ".go") {
+ return nil
+ }
+ // Skip if already a _test.go file
+ if strings.HasSuffix(strings.TrimPrefix(uri, "file://"), "_test.go") {
+ return nil
+ }
+ // Heuristic: only offer when a function context is found above the cursor
+ _, _, _, funcCtx := s.lineContext(uri, p.Range.Start)
+ if !strings.Contains(funcCtx, "func ") {
+ return nil
+ }
+ payload := struct {
+ Type string `json:"type"`
+ URI string `json:"uri"`
+ Range Range `json:"range"`
+ }{Type: "go_test", URI: uri, Range: p.Range}
+ raw, _ := json.Marshal(payload)
+ ca := CodeAction{Title: "Hexai: implement unit test", Kind: "quickfix", Data: raw}
+ return &ca
}
// buildDocumentCodeAction offers to document the selected code by injecting comments.
func (s *Server) buildDocumentCodeAction(p CodeActionParams, sel string) *CodeAction {
- if s.llmClient == nil {
- return nil
- }
- if strings.TrimSpace(sel) == "" {
- return nil
- }
- payload := struct {
- Type string `json:"type"`
- URI string `json:"uri"`
- Range Range `json:"range"`
- Selection string `json:"selection"`
- }{Type: "document", URI: p.TextDocument.URI, Range: p.Range, Selection: sel}
- raw, _ := json.Marshal(payload)
- ca := CodeAction{Title: "Hexai: document code", Kind: "refactor.rewrite", Data: raw}
- return &ca
+ if s.llmClient == nil {
+ return nil
+ }
+ if strings.TrimSpace(sel) == "" {
+ return nil
+ }
+ payload := struct {
+ Type string `json:"type"`
+ URI string `json:"uri"`
+ Range Range `json:"range"`
+ Selection string `json:"selection"`
+ }{Type: "document", URI: p.TextDocument.URI, Range: p.Range, Selection: sel}
+ raw, _ := json.Marshal(payload)
+ ca := CodeAction{Title: "Hexai: document code", Kind: "refactor.rewrite", Data: raw}
+ return &ca
}
func (s *Server) resolveGoTest(uri string, pos Position) (WorkspaceEdit, string, Range, bool) {
- path := strings.TrimPrefix(uri, "file://")
- if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") {
- return WorkspaceEdit{}, "", Range{}, false
- }
- // Load source text
- _, lines := s.loadFileText(uri)
- if len(lines) == 0 {
- return WorkspaceEdit{}, "", Range{}, false
- }
- pkg := parseGoPackageName(lines)
- fnStart, fnEnd := findGoFunctionAtLine(lines, pos.Line)
- if fnStart < 0 || fnEnd < fnStart {
- return WorkspaceEdit{}, "", Range{}, false
- }
- funcCode := strings.Join(lines[fnStart:fnEnd+1], "\n")
- testFunc := s.generateGoTestFunction(funcCode)
- if strings.TrimSpace(testFunc) == "" {
- return WorkspaceEdit{}, "", Range{}, false
- }
- // Determine test file target
- testPath := strings.TrimSuffix(path, ".go") + "_test.go"
- testURI := "file://" + testPath
+ path := strings.TrimPrefix(uri, "file://")
+ if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") {
+ return WorkspaceEdit{}, "", Range{}, false
+ }
+ // Load source text
+ _, lines := s.loadFileText(uri)
+ if len(lines) == 0 {
+ return WorkspaceEdit{}, "", Range{}, false
+ }
+ pkg := parseGoPackageName(lines)
+ fnStart, fnEnd := findGoFunctionAtLine(lines, pos.Line)
+ if fnStart < 0 || fnEnd < fnStart {
+ return WorkspaceEdit{}, "", Range{}, false
+ }
+ funcCode := strings.Join(lines[fnStart:fnEnd+1], "\n")
+ testFunc := s.generateGoTestFunction(funcCode)
+ if strings.TrimSpace(testFunc) == "" {
+ return WorkspaceEdit{}, "", Range{}, false
+ }
+ // Determine test file target
+ testPath := strings.TrimSuffix(path, ".go") + "_test.go"
+ testURI := "file://" + testPath
- // If test file exists, append test at EOF; otherwise, create a new file with package+import
- if fileExists(testPath) {
- // Build an insertion at end of file
- _, tLines := s.loadFileText(testURI)
- // Fallback when not open and cannot read: still insert at line 0
- lineIdx := 0
- col := 0
- if len(tLines) > 0 {
- lineIdx = len(tLines) - 1
- col = len(tLines[lineIdx])
- }
- var b strings.Builder
- // Ensure at least two newlines before the new test
- if len(tLines) == 0 || (len(tLines) > 0 && !strings.HasSuffix(strings.Join(tLines, "\n"), "\n\n")) {
- b.WriteString("\n\n")
- }
- b.WriteString(testFunc)
- insert := b.String()
- edit := TextEdit{Range: Range{Start: Position{Line: lineIdx, Character: col}, End: Position{Line: lineIdx, Character: col}}, NewText: insert}
- we := WorkspaceEdit{Changes: map[string][]TextEdit{testURI: {edit}}}
- // Compute jump range start
- // Count how many prefix newlines added before the test function
- prefixNL := 0
- if strings.HasPrefix(insert, "\n\n") { prefixNL = 2 }
- startLine := lineIdx + prefixNL
- // If we inserted with two newlines and last line wasn't blank, first newline moves to next line
- if prefixNL > 0 { startLine = lineIdx + prefixNL }
- jump := Range{Start: Position{Line: startLine, Character: 0}, End: Position{Line: startLine, Character: 0}}
- return we, testURI, jump, true
- }
- // Create new file content
- var content strings.Builder
- if pkg == "" { pkg = filepath.Base(filepath.Dir(path)) }
- content.WriteString("package ")
- content.WriteString(pkg)
- content.WriteString("\n\n")
- content.WriteString("import (\n\t\"testing\"\n)\n\n")
- content.WriteString(testFunc)
- full := content.String()
- // Use documentChanges with create + full content insert
- create := CreateFile{Kind: "create", URI: testURI}
- tde := TextDocumentEdit{TextDocument: VersionedTextDocumentIdentifier{URI: testURI}, Edits: []TextEdit{{Range: Range{Start: Position{Line: 0, Character: 0}, End: Position{Line: 0, Character: 0}}, NewText: full}}}
- we := WorkspaceEdit{DocumentChanges: []any{create, tde}}
- // Find start line of first test function
- // Count lines before the substring "func Test"
- pre := content.String()
- idx := strings.Index(pre, "func Test")
- startLine := 0
- if idx > 0 {
- before := pre[:idx]
- startLine = strings.Count(before, "\n")
- }
- jump := Range{Start: Position{Line: startLine, Character: 0}, End: Position{Line: startLine, Character: 0}}
- return we, testURI, jump, true
+ // If test file exists, append test at EOF; otherwise, create a new file with package+import
+ if fileExists(testPath) {
+ // Build an insertion at end of file
+ _, tLines := s.loadFileText(testURI)
+ // Fallback when not open and cannot read: still insert at line 0
+ lineIdx := 0
+ col := 0
+ if len(tLines) > 0 {
+ lineIdx = len(tLines) - 1
+ col = len(tLines[lineIdx])
+ }
+ var b strings.Builder
+ // Ensure at least two newlines before the new test
+ if len(tLines) == 0 || (len(tLines) > 0 && !strings.HasSuffix(strings.Join(tLines, "\n"), "\n\n")) {
+ b.WriteString("\n\n")
+ }
+ b.WriteString(testFunc)
+ insert := b.String()
+ edit := TextEdit{Range: Range{Start: Position{Line: lineIdx, Character: col}, End: Position{Line: lineIdx, Character: col}}, NewText: insert}
+ we := WorkspaceEdit{Changes: map[string][]TextEdit{testURI: {edit}}}
+ // Compute jump range start
+ // Count how many prefix newlines added before the test function
+ prefixNL := 0
+ if strings.HasPrefix(insert, "\n\n") {
+ prefixNL = 2
+ }
+ startLine := lineIdx + prefixNL
+ // If we inserted with two newlines and last line wasn't blank, first newline moves to next line
+ if prefixNL > 0 {
+ startLine = lineIdx + prefixNL
+ }
+ jump := Range{Start: Position{Line: startLine, Character: 0}, End: Position{Line: startLine, Character: 0}}
+ return we, testURI, jump, true
+ }
+ // Create new file content
+ var content strings.Builder
+ if pkg == "" {
+ pkg = filepath.Base(filepath.Dir(path))
+ }
+ content.WriteString("package ")
+ content.WriteString(pkg)
+ content.WriteString("\n\n")
+ content.WriteString("import (\n\t\"testing\"\n)\n\n")
+ content.WriteString(testFunc)
+ full := content.String()
+ // Use documentChanges with create + full content insert
+ create := CreateFile{Kind: "create", URI: testURI}
+ tde := TextDocumentEdit{TextDocument: VersionedTextDocumentIdentifier{URI: testURI}, Edits: []TextEdit{{Range: Range{Start: Position{Line: 0, Character: 0}, End: Position{Line: 0, Character: 0}}, NewText: full}}}
+ we := WorkspaceEdit{DocumentChanges: []any{create, tde}}
+ // Find start line of first test function
+ // Count lines before the substring "func Test"
+ pre := content.String()
+ idx := strings.Index(pre, "func Test")
+ startLine := 0
+ if idx > 0 {
+ before := pre[:idx]
+ startLine = strings.Count(before, "\n")
+ }
+ jump := Range{Start: Position{Line: startLine, Character: 0}, End: Position{Line: startLine, Character: 0}}
+ return we, testURI, jump, true
}
// loadFileText returns the file content and lines. It prefers the open document; otherwise reads from disk.
func (s *Server) loadFileText(uri string) (string, []string) {
- if d := s.getDocument(uri); d != nil {
- return d.text, append([]string{}, d.lines...)
- }
- path := strings.TrimPrefix(uri, "file://")
- b, err := os.ReadFile(path)
- if err != nil {
- return "", nil
- }
- txt := string(b)
- return txt, splitLines(txt)
+ if d := s.getDocument(uri); d != nil {
+ return d.text, append([]string{}, d.lines...)
+ }
+ path := strings.TrimPrefix(uri, "file://")
+ b, err := os.ReadFile(path)
+ if err != nil {
+ return "", nil
+ }
+ txt := string(b)
+ return txt, splitLines(txt)
}
func fileExists(path string) bool {
- if _, err := os.Stat(path); err == nil {
- return true
- }
- return false
+ if _, err := os.Stat(path); err == nil {
+ return true
+ }
+ return false
}
// parseGoPackageName returns the package name from file lines, or empty if not found.
func parseGoPackageName(lines []string) string {
- for _, ln := range lines {
- t := strings.TrimSpace(ln)
- if strings.HasPrefix(t, "package ") {
- name := strings.TrimSpace(strings.TrimPrefix(t, "package "))
- // strip inline comments
- if i := strings.Index(name, " "); i >= 0 { name = name[:i] }
- if i := strings.Index(name, "\t"); i >= 0 { name = name[:i] }
- if i := strings.Index(name, "//"); i >= 0 { name = strings.TrimSpace(name[:i]) }
- return name
- }
- }
- return ""
+ for _, ln := range lines {
+ t := strings.TrimSpace(ln)
+ if strings.HasPrefix(t, "package ") {
+ name := strings.TrimSpace(strings.TrimPrefix(t, "package "))
+ // strip inline comments
+ if i := strings.Index(name, " "); i >= 0 {
+ name = name[:i]
+ }
+ if i := strings.Index(name, "\t"); i >= 0 {
+ name = name[:i]
+ }
+ if i := strings.Index(name, "//"); i >= 0 {
+ name = strings.TrimSpace(name[:i])
+ }
+ return name
+ }
+ }
+ return ""
}
// findGoFunctionAtLine finds the function enclosing or preceding line idx. Returns start and end line indexes.
func findGoFunctionAtLine(lines []string, idx int) (int, int) {
- if idx < 0 { idx = 0 }
- if idx >= len(lines) { idx = len(lines)-1 }
- // find signature start
- start := -1
- for i := idx; i >= 0; i-- {
- if strings.Contains(lines[i], "func ") {
- start = i
- break
- }
- if strings.Contains(lines[i], "}") {
- break
- }
- }
- if start == -1 { return -1, -1 }
- // find first '{'
- depth := 0
- seenOpen := false
- for i := start; i < len(lines); i++ {
- ln := lines[i]
- for j := 0; j < len(ln); j++ {
- switch ln[j] {
- case '{':
- depth++
- seenOpen = true
- case '}':
- if depth > 0 { depth-- }
- if seenOpen && depth == 0 {
- return start, i
- }
- }
- }
- }
- // if never saw '{', assume single-line prototype; return that line
- if !seenOpen {
- return start, start
- }
- return start, -1
+ if idx < 0 {
+ idx = 0
+ }
+ if idx >= len(lines) {
+ idx = len(lines) - 1
+ }
+ // find signature start
+ start := -1
+ for i := idx; i >= 0; i-- {
+ if strings.Contains(lines[i], "func ") {
+ start = i
+ break
+ }
+ if strings.Contains(lines[i], "}") {
+ break
+ }
+ }
+ if start == -1 {
+ return -1, -1
+ }
+ // find first '{'
+ depth := 0
+ seenOpen := false
+ for i := start; i < len(lines); i++ {
+ ln := lines[i]
+ for j := 0; j < len(ln); j++ {
+ switch ln[j] {
+ case '{':
+ depth++
+ seenOpen = true
+ case '}':
+ if depth > 0 {
+ depth--
+ }
+ if seenOpen && depth == 0 {
+ return start, i
+ }
+ }
+ }
+ }
+ // if never saw '{', assume single-line prototype; return that line
+ if !seenOpen {
+ return start, start
+ }
+ return start, -1
}
// generateGoTestFunction uses LLM to produce a test function; falls back to a stub when unavailable.
func (s *Server) generateGoTestFunction(funcCode string) string {
- if s.llmClient != nil {
- sys := "You are a precise Go unit test generator. Given a Go function, write one or more Test* functions using the testing package. Do NOT include package or imports, only the test function(s). Prefer table-driven tests. Keep it minimal and idiomatic."
- user := "Function under test:\n" + funcCode
- ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second)
- defer cancel()
- messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}}
- opts := s.llmRequestOpts()
- if out, err := s.llmClient.Chat(ctx, messages, opts...); err == nil {
- cleaned := strings.TrimSpace(stripCodeFences(out))
- if cleaned != "" { return cleaned }
- } else {
- logging.Logf("lsp ", "codeAction go_test llm error: %v", err)
- }
- }
- // Fallback stub
- name := deriveGoFuncName(funcCode)
- if name == "" { name = "Function" }
- return fmt.Sprintf("func Test%s(t *testing.T) {\n\t// TODO: implement tests for %s\n}\n", exportName(name), name)
+ if s.llmClient != nil {
+ sys := "You are a precise Go unit test generator. Given a Go function, write one or more Test* functions using the testing package. Do NOT include package or imports, only the test function(s). Prefer table-driven tests. Keep it minimal and idiomatic."
+ user := "Function under test:\n" + funcCode
+ ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second)
+ defer cancel()
+ messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}}
+ opts := s.llmRequestOpts()
+ if out, err := s.llmClient.Chat(ctx, messages, opts...); err == nil {
+ cleaned := strings.TrimSpace(stripCodeFences(out))
+ if cleaned != "" {
+ return cleaned
+ }
+ } else {
+ logging.Logf("lsp ", "codeAction go_test llm error: %v", err)
+ }
+ }
+ // Fallback stub
+ name := deriveGoFuncName(funcCode)
+ if name == "" {
+ name = "Function"
+ }
+ return fmt.Sprintf("func Test%s(t *testing.T) {\n\t// TODO: implement tests for %s\n}\n", exportName(name), name)
}
// deriveGoFuncName extracts function or method name from code.
func deriveGoFuncName(code string) string {
- // look for line starting with func
- line := firstLine(code)
- line = strings.TrimSpace(line)
- if !strings.HasPrefix(line, "func ") { return "" }
- rest := strings.TrimSpace(strings.TrimPrefix(line, "func "))
- // method receiver
- if strings.HasPrefix(rest, "(") {
- // find ")"
- if i := strings.Index(rest, ")"); i >= 0 && i+1 < len(rest) {
- rest = strings.TrimSpace(rest[i+1:])
- }
- }
- // now rest should start with Name(
- if i := strings.Index(rest, "("); i > 0 {
- return strings.TrimSpace(rest[:i])
- }
- return ""
+ // look for line starting with func
+ line := firstLine(code)
+ line = strings.TrimSpace(line)
+ if !strings.HasPrefix(line, "func ") {
+ return ""
+ }
+ rest := strings.TrimSpace(strings.TrimPrefix(line, "func "))
+ // method receiver
+ if strings.HasPrefix(rest, "(") {
+ // find ")"
+ if i := strings.Index(rest, ")"); i >= 0 && i+1 < len(rest) {
+ rest = strings.TrimSpace(rest[i+1:])
+ }
+ }
+ // now rest should start with Name(
+ if i := strings.Index(rest, "("); i > 0 {
+ return strings.TrimSpace(rest[:i])
+ }
+ return ""
}
func exportName(name string) string {
- if name == "" { return name }
- r := []rune(name)
- if r[0] >= 'a' && r[0] <= 'z' {
- r[0] = r[0] - ('a' - 'A')
- }
- return string(r)
+ if name == "" {
+ return name
+ }
+ r := []rune(name)
+ if r[0] >= 'a' && r[0] <= 'z' {
+ r[0] = r[0] - ('a' - 'A')
+ }
+ return string(r)
}
diff --git a/internal/lsp/handlers_completion.go b/internal/lsp/handlers_completion.go
index 036e591..c6b7d3d 100644
--- a/internal/lsp/handlers_completion.go
+++ b/internal/lsp/handlers_completion.go
@@ -2,13 +2,14 @@
package lsp
import (
- "context"
- "encoding/json"
- "fmt"
- "codeberg.org/snonux/hexai/internal/llm"
- "codeberg.org/snonux/hexai/internal/logging"
- "strings"
- "time"
+ "context"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/logging"
)
func (s *Server) handleCompletion(req Request) {
@@ -70,8 +71,8 @@ func (s *Server) logCompletionContext(p CompletionParams, above, current, below,
}
func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) ([]CompletionItem, bool) {
- ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second)
- defer cancel()
+ ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second)
+ defer cancel()
inlinePrompt := lineHasInlinePrompt(current)
if !inlinePrompt && !s.isTriggerEvent(p, current) {
@@ -93,20 +94,20 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun
logging.AnsiGreen, logging.PreviewForLog(cleaned), logging.AnsiBase)
return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true
}
- if (isBareDoubleOpen(current) || isBareDoubleOpen(below)) {
- logging.Logf("lsp ", "%scompletion skip=empty-double-semicolon line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase)
- return []CompletionItem{}, true
- }
+ if isBareDoubleOpen(current) || isBareDoubleOpen(below) {
+ logging.Logf("lsp ", "%scompletion skip=empty-double-semicolon line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase)
+ return []CompletionItem{}, true
+ }
if !inParams && !s.prefixHeuristicAllows(inlinePrompt, current, p, manualInvoke) {
logging.Logf("lsp ", "%scompletion skip=short-prefix line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase)
return []CompletionItem{}, true
}
- // Provider-native path
- if items, ok := s.tryProviderNativeCompletion(current, p, above, below, funcCtx, docStr, hasExtra, extraText, inParams); ok {
- return items, true
- }
+ // Provider-native path
+ if items, ok := s.tryProviderNativeCompletion(current, p, above, below, funcCtx, docStr, hasExtra, extraText, inParams); ok {
+ return items, true
+ }
// Chat path
messages := s.buildCompletionMessages(inlinePrompt, hasExtra, extraText, inParams, p, above, current, below, funcCtx)
@@ -120,12 +121,12 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun
if s.codingTemperature != nil {
opts = append(opts, llm.WithTemperature(*s.codingTemperature))
}
- // Debounce and throttle before making the LLM call
- s.waitForDebounce(ctx)
- if !s.waitForThrottle(ctx) {
- return nil, false
- }
- logging.Logf("lsp ", "completion llm=requesting model=%s", s.llmClient.DefaultModel())
+ // Debounce and throttle before making the LLM call
+ s.waitForDebounce(ctx)
+ if !s.waitForThrottle(ctx) {
+ return nil, false
+ }
+ logging.Logf("lsp ", "completion llm=requesting model=%s", s.llmClient.DefaultModel())
text, err := s.llmClient.Chat(ctx, messages, opts...)
if err != nil {
@@ -163,19 +164,23 @@ func parseManualInvoke(ctx any) bool {
// shouldSuppressForChatTriggerEOL returns true when a chat trigger like ">" follows ?, !, :, or ; at EOL.
func (s *Server) shouldSuppressForChatTriggerEOL(current string, p CompletionParams) bool {
- t := strings.TrimRight(current, " \t")
- if s.chatSuffix == "" { return false }
- if strings.HasSuffix(t, s.chatSuffix) {
- if len(t) < len(s.chatSuffix)+1 { return false }
- prev := string(t[len(t)-len(s.chatSuffix)-1])
- for _, pf := range s.chatPrefixes {
- if prev == pf {
- logging.Logf("lsp ", "completion skip=chat-trigger-eol uri=%s line=%d", p.TextDocument.URI, p.Position.Line)
- return true
- }
- }
- }
- return false
+ t := strings.TrimRight(current, " \t")
+ if s.chatSuffix == "" {
+ return false
+ }
+ if strings.HasSuffix(t, s.chatSuffix) {
+ if len(t) < len(s.chatSuffix)+1 {
+ return false
+ }
+ prev := string(t[len(t)-len(s.chatSuffix)-1])
+ for _, pf := range s.chatPrefixes {
+ if prev == pf {
+ logging.Logf("lsp ", "completion skip=chat-trigger-eol uri=%s line=%d", p.TextDocument.URI, p.Position.Line)
+ return true
+ }
+ }
+ }
+ return false
}
// prefixHeuristicAllows applies minimal prefix rules unless inlinePrompt or structural triggers apply.
@@ -233,15 +238,15 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams,
prov = s.llmClient.Name()
}
logging.Logf("lsp ", "completion path=codex provider=%s uri=%s", prov, path)
- ctx2, cancel2 := context.WithTimeout(context.Background(), 8*time.Second)
- defer cancel2()
+ ctx2, cancel2 := context.WithTimeout(context.Background(), 8*time.Second)
+ defer cancel2()
- // Debounce and throttle prior to provider-native call
- s.waitForDebounce(ctx2)
- if !s.waitForThrottle(ctx2) {
- return nil, false
- }
- suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, temp)
+ // Debounce and throttle prior to provider-native call
+ s.waitForDebounce(ctx2)
+ if !s.waitForThrottle(ctx2) {
+ return nil, false
+ }
+ suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, temp)
if err == nil && len(suggestions) > 0 {
cleaned := strings.TrimSpace(suggestions[0])
if cleaned != "" {
@@ -249,12 +254,12 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams,
if cleaned != "" {
cleaned = stripDuplicateGeneralPrefix(current[:p.Position.Character], cleaned)
}
- if cleaned != "" && hasDoubleOpenTrigger(current) {
- indent := leadingIndent(current)
- if indent != "" {
- cleaned = applyIndent(indent, cleaned)
- }
- }
+ if cleaned != "" && hasDoubleOpenTrigger(current) {
+ indent := leadingIndent(current)
+ if indent != "" {
+ cleaned = applyIndent(indent, cleaned)
+ }
+ }
if strings.TrimSpace(cleaned) != "" {
key := s.completionCacheKey(p, above, current, below, funcCtx, inParams, hasExtra, extraText)
s.completionCachePut(key, cleaned)
@@ -270,63 +275,63 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams,
// waitForDebounce sleeps until there has been no input activity for at least
// completionDebounce. If debounce is zero or ctx is done, it returns promptly.
func (s *Server) waitForDebounce(ctx context.Context) {
- d := s.completionDebounce
- if d <= 0 {
- return
- }
- for {
- s.mu.RLock()
- last := s.lastInput
- s.mu.RUnlock()
- if last.IsZero() {
- return
- }
- since := time.Since(last)
- if since >= d {
- return
- }
- rem := d - since
- timer := time.NewTimer(rem)
- select {
- case <-ctx.Done():
- timer.Stop()
- return
- case <-timer.C:
- // loop and re-evaluate in case input occurred during sleep
- }
- }
+ d := s.completionDebounce
+ if d <= 0 {
+ return
+ }
+ for {
+ s.mu.RLock()
+ last := s.lastInput
+ s.mu.RUnlock()
+ if last.IsZero() {
+ return
+ }
+ since := time.Since(last)
+ if since >= d {
+ return
+ }
+ rem := d - since
+ timer := time.NewTimer(rem)
+ select {
+ case <-ctx.Done():
+ timer.Stop()
+ return
+ case <-timer.C:
+ // loop and re-evaluate in case input occurred during sleep
+ }
+ }
}
// waitForThrottle enforces a minimum spacing between LLM calls. Returns false
// if the context is canceled while waiting.
func (s *Server) waitForThrottle(ctx context.Context) bool {
- interval := s.throttleInterval
- if interval <= 0 {
- return true
- }
- var wait time.Duration
- for {
- s.mu.Lock()
- next := s.lastLLMCall.Add(interval)
- now := time.Now()
- if now.Before(next) {
- wait = next.Sub(now)
- s.mu.Unlock()
- timer := time.NewTimer(wait)
- select {
- case <-ctx.Done():
- timer.Stop()
- return false
- case <-timer.C:
- // try again to set the next call time
- continue
- }
- }
- // we are allowed to proceed now; record this call as the latest
- s.lastLLMCall = now
- s.mu.Unlock()
- return true
- }
+ interval := s.throttleInterval
+ if interval <= 0 {
+ return true
+ }
+ var wait time.Duration
+ for {
+ s.mu.Lock()
+ next := s.lastLLMCall.Add(interval)
+ now := time.Now()
+ if now.Before(next) {
+ wait = next.Sub(now)
+ s.mu.Unlock()
+ timer := time.NewTimer(wait)
+ select {
+ case <-ctx.Done():
+ timer.Stop()
+ return false
+ case <-timer.C:
+ // try again to set the next call time
+ continue
+ }
+ }
+ // we are allowed to proceed now; record this call as the latest
+ s.lastLLMCall = now
+ s.mu.Unlock()
+ return true
+ }
}
// buildCompletionMessages constructs the LLM messages for completion.
@@ -359,10 +364,10 @@ func (s *Server) postProcessCompletion(text string, leftOfCursor string, current
if cleaned != "" {
cleaned = stripDuplicateGeneralPrefix(leftOfCursor, cleaned)
}
- if cleaned != "" && hasDoubleOpenTrigger(currentLine) {
- if indent := leadingIndent(currentLine); indent != "" {
- cleaned = applyIndent(indent, cleaned)
- }
- }
+ if cleaned != "" && hasDoubleOpenTrigger(currentLine) {
+ if indent := leadingIndent(currentLine); indent != "" {
+ cleaned = applyIndent(indent, cleaned)
+ }
+ }
return cleaned
}
diff --git a/internal/lsp/handlers_document.go b/internal/lsp/handlers_document.go
index 3f9d4b0..6a90919 100644
--- a/internal/lsp/handlers_document.go
+++ b/internal/lsp/handlers_document.go
@@ -2,18 +2,21 @@
package lsp
import (
- "context"
- "encoding/json"
- "codeberg.org/snonux/hexai/internal/llm"
- "codeberg.org/snonux/hexai/internal/logging"
- "strings"
- "time"
+ "context"
+ "encoding/json"
+ "strings"
+ "time"
+
+ "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/logging"
)
// Package-level chat trigger vars for helpers without Server receiver.
// NewServer assigns these from configuration on startup.
-var chatSuffixChar byte = '>'
-var chatPrefixSingles = []string{"?", "!", ":", ";"}
+var (
+ chatSuffixChar byte = '>'
+ chatPrefixSingles = []string{"?", "!", ":", ";"}
+)
func (s *Server) handleDidOpen(req Request) {
var p DidOpenTextDocumentParams
@@ -97,7 +100,7 @@ func (s *Server) detectAndHandleChat(uri string) {
if d == nil || len(d.lines) == 0 {
return
}
- for i, raw := range d.lines {
+ for i, raw := range d.lines {
// Find last non-space character index
j := len(raw) - 1
for j >= 0 {
@@ -107,25 +110,32 @@ func (s *Server) detectAndHandleChat(uri string) {
}
break
}
- if j < 0 {
- continue
- }
- // Check suffix/prefix according to configuration
- if s.chatSuffix == "" {
- continue
- }
- // Last non-space must equal suffix
- if string(raw[j]) != s.chatSuffix {
- continue
- }
- // Require at least one char before suffix and that char must be in chatPrefixes
- if j < 1 { continue }
- prev := string(raw[j-1])
- isTrigger := false
- for _, pfx := range s.chatPrefixes {
- if prev == pfx { isTrigger = true; break }
- }
- if !isTrigger { continue }
+ if j < 0 {
+ continue
+ }
+ // Check suffix/prefix according to configuration
+ if s.chatSuffix == "" {
+ continue
+ }
+ // Last non-space must equal suffix
+ if string(raw[j]) != s.chatSuffix {
+ continue
+ }
+ // Require at least one char before suffix and that char must be in chatPrefixes
+ if j < 1 {
+ continue
+ }
+ prev := string(raw[j-1])
+ isTrigger := false
+ for _, pfx := range s.chatPrefixes {
+ if prev == pfx {
+ isTrigger = true
+ break
+ }
+ }
+ if !isTrigger {
+ continue
+ }
// Avoid double-answering: if the next non-empty line starts with '>' we skip.
k := i + 1
for k < len(d.lines) && strings.TrimSpace(d.lines[k]) == "" {
@@ -135,9 +145,9 @@ func (s *Server) detectAndHandleChat(uri string) {
continue
}
// Derive prompt by removing only the trailing '>'
- removeCount := len(s.chatSuffix)
+ removeCount := len(s.chatSuffix)
base := raw[:j+1-removeCount]
- prompt := strings.TrimSpace(base)
+ prompt := strings.TrimSpace(base)
if prompt == "" {
continue
}
@@ -246,37 +256,37 @@ func (s *Server) buildChatHistory(uri string, lineIdx int, currentPrompt string)
// stripTrailingTrigger removes the trailing chat trigger punctuation from a line if present.
func stripTrailingTrigger(sx string) string {
- s := strings.TrimRight(sx, " \t")
- if len(s) == 0 {
- return sx
- }
- // Configurable suffix removal when preceded by configured prefixes
- if len(s) >= 2 && s[len(s)-1] == chatSuffixChar {
- prev := string(s[len(s)-2])
- for _, pf := range chatPrefixSingles {
- if prev == pf {
- return strings.TrimRight(s[:len(s)-1], " \t")
- }
- }
- }
- // Legacy: remove one trailing punctuation (?, !, :) to build history nicely
- last := s[len(s)-1]
- switch last {
- case '?', '!', ':':
- return strings.TrimRight(s[:len(s)-1], " \t")
- default:
- return sx
- }
+ s := strings.TrimRight(sx, " \t")
+ if len(s) == 0 {
+ return sx
+ }
+ // Configurable suffix removal when preceded by configured prefixes
+ if len(s) >= 2 && s[len(s)-1] == chatSuffixChar {
+ prev := string(s[len(s)-2])
+ for _, pf := range chatPrefixSingles {
+ if prev == pf {
+ return strings.TrimRight(s[:len(s)-1], " \t")
+ }
+ }
+ }
+ // Legacy: remove one trailing punctuation (?, !, :) to build history nicely
+ last := s[len(s)-1]
+ switch last {
+ case '?', '!', ':':
+ return strings.TrimRight(s[:len(s)-1], " \t")
+ default:
+ return sx
+ }
}
// clientApplyEdit sends a workspace/applyEdit request to the client.
func (s *Server) clientApplyEdit(label string, edit WorkspaceEdit) {
- params := ApplyWorkspaceEditParams{Label: label, Edit: edit}
- id := s.nextReqID()
- req := Request{JSONRPC: "2.0", ID: id, Method: "workspace/applyEdit"}
- b, _ := json.Marshal(params)
- req.Params = b
- s.writeMessage(req)
+ params := ApplyWorkspaceEditParams{Label: label, Edit: edit}
+ id := s.nextReqID()
+ req := Request{JSONRPC: "2.0", ID: id, Method: "workspace/applyEdit"}
+ b, _ := json.Marshal(params)
+ req.Params = b
+ s.writeMessage(req)
}
// nextReqID returns a unique json.RawMessage id for server-initiated requests.
@@ -291,27 +301,27 @@ func (s *Server) nextReqID() json.RawMessage {
// clientShowDocument asks the client to open/focus a document and select a range.
func (s *Server) clientShowDocument(uri string, sel *Range) {
- var params struct {
- URI string `json:"uri"`
- External bool `json:"external,omitempty"`
- TakeFocus bool `json:"takeFocus,omitempty"`
- Selection *Range `json:"selection,omitempty"`
- }
- params.URI = uri
- params.TakeFocus = true
- params.Selection = sel
- id := s.nextReqID()
- req := Request{JSONRPC: "2.0", ID: id, Method: "window/showDocument"}
- b, _ := json.Marshal(params)
- req.Params = b
- s.writeMessage(req)
+ var params struct {
+ URI string `json:"uri"`
+ External bool `json:"external,omitempty"`
+ TakeFocus bool `json:"takeFocus,omitempty"`
+ Selection *Range `json:"selection,omitempty"`
+ }
+ params.URI = uri
+ params.TakeFocus = true
+ params.Selection = sel
+ id := s.nextReqID()
+ req := Request{JSONRPC: "2.0", ID: id, Method: "window/showDocument"}
+ b, _ := json.Marshal(params)
+ req.Params = b
+ s.writeMessage(req)
}
// deferShowDocument schedules a showDocument after a short delay to allow the client
// time to apply any pending edits (e.g., create the file before focusing it).
func (s *Server) deferShowDocument(uri string, sel Range) {
- go func() {
- time.Sleep(120 * time.Millisecond)
- s.clientShowDocument(uri, &sel)
- }()
+ go func() {
+ time.Sleep(120 * time.Millisecond)
+ s.clientShowDocument(uri, &sel)
+ }()
}
diff --git a/internal/lsp/handlers_end_to_end_test.go b/internal/lsp/handlers_end_to_end_test.go
index fd66a3c..32cb488 100644
--- a/internal/lsp/handlers_end_to_end_test.go
+++ b/internal/lsp/handlers_end_to_end_test.go
@@ -1,243 +1,281 @@
package lsp
import (
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "log"
- "strings"
- "testing"
- "time"
- tut "codeberg.org/snonux/hexai/internal/testutil"
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "strings"
+ "testing"
+ "time"
+
+ tut "codeberg.org/snonux/hexai/internal/testutil"
)
// captureResponse decodes a single LSP Response from the server's output buffer.
func captureResponse(t *testing.T, buf *bytes.Buffer) Response {
- t.Helper()
- raw := buf.String()
- // strip Content-Length header framing
- idx := strings.Index(raw, "\r\n\r\n")
- if idx < 0 { t.Fatalf("no header/body separator in %q", raw) }
- body := raw[idx+4:]
- var resp Response
- if err := json.Unmarshal([]byte(body), &resp); err != nil {
- t.Fatalf("unmarshal response: %v", err)
- }
- return resp
+ t.Helper()
+ raw := buf.String()
+ // strip Content-Length header framing
+ idx := strings.Index(raw, "\r\n\r\n")
+ if idx < 0 {
+ t.Fatalf("no header/body separator in %q", raw)
+ }
+ body := raw[idx+4:]
+ var resp Response
+ if err := json.Unmarshal([]byte(body), &resp); err != nil {
+ t.Fatalf("unmarshal response: %v", err)
+ }
+ return resp
}
// captureRequest decodes a single JSON-RPC Request from the server's output buffer.
func captureRequest(t *testing.T, buf *bytes.Buffer) Request {
- t.Helper()
- raw := buf.String()
- // There may be multiple framed messages concatenated; scan for each
- off := 0
- for off < len(raw) {
- rest := raw[off:]
- idx := strings.Index(rest, "\r\n\r\n")
- if idx < 0 { break }
- body := rest[idx+4:]
- // Content-Length header indicates body length; parse length from header
- hdr := rest[:idx]
- clen := 0
- for _, line := range strings.Split(hdr, "\r\n") {
- if strings.HasPrefix(strings.ToLower(line), "content-length:") {
- var n int
- _, _ = fmt.Sscanf(line, "Content-Length: %d", &n)
- clen = n
- break
- }
- }
- if clen <= 0 || clen > len(body) { clen = len(body) }
- piece := body[:clen]
- var req Request
- _ = json.Unmarshal([]byte(piece), &req)
- if req.Method != "" {
- return req
- }
- off += idx + 4 + clen
- }
- t.Fatalf("no request found in output")
- return Request{}
+ t.Helper()
+ raw := buf.String()
+ // There may be multiple framed messages concatenated; scan for each
+ off := 0
+ for off < len(raw) {
+ rest := raw[off:]
+ idx := strings.Index(rest, "\r\n\r\n")
+ if idx < 0 {
+ break
+ }
+ body := rest[idx+4:]
+ // Content-Length header indicates body length; parse length from header
+ hdr := rest[:idx]
+ clen := 0
+ for _, line := range strings.Split(hdr, "\r\n") {
+ if strings.HasPrefix(strings.ToLower(line), "content-length:") {
+ var n int
+ _, _ = fmt.Sscanf(line, "Content-Length: %d", &n)
+ clen = n
+ break
+ }
+ }
+ if clen <= 0 || clen > len(body) {
+ clen = len(body)
+ }
+ piece := body[:clen]
+ var req Request
+ _ = json.Unmarshal([]byte(piece), &req)
+ if req.Method != "" {
+ return req
+ }
+ off += idx + 4 + clen
+ }
+ t.Fatalf("no request found in output")
+ return Request{}
}
func TestHandleCodeAction_ListsHexaiActions(t *testing.T) {
- // Prepare server
- var out bytes.Buffer
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
- s.chatSuffix = ">"
- s.chatPrefixes = []string{"?","!",":",";"}
- s.llmClient = fakeLLM{resp: "// doc\nfunc add(a,b int) int { return a+b }"}
-
- // Document with a function
- uri := "file:///x.go"
- src := "package p\n\nfunc add(a,b int) int { return a+b }\n"
- s.setDocument(uri, src)
-
- // Select the function line
- p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: Range{Start: Position{Line:2, Character:0}, End: Position{Line:2, Character:len("func add(a,b int) int { return a+b }")}}}
- b, _ := json.Marshal(p)
- req := Request{JSONRPC: "2.0", ID: json.RawMessage("1"), Method: "textDocument/codeAction", Params: b}
-
- // Invoke directly
- out.Reset()
- s.handleCodeAction(req)
- resp := captureResponse(t, &out)
- // Decode result into []CodeAction
- var actions []CodeAction
- rb, _ := json.Marshal(resp.Result)
- if err := json.Unmarshal(rb, &actions); err != nil {
- t.Fatalf("decode actions: %v", err)
- }
- if len(actions) == 0 { t.Fatalf("expected some actions") }
- // Ensure our Hexai actions are present
- hasDoc := false
- hasGoTest := false
- for _, a := range actions {
- if strings.Contains(strings.ToLower(a.Title), "hexai:") {
- if strings.Contains(a.Title, "document code") { hasDoc = true }
- if strings.Contains(a.Title, "implement unit test") { hasGoTest = true }
- }
- }
- if !hasDoc || !hasGoTest {
- t.Fatalf("expected both Hexai actions, got %+v", actions)
- }
+ // Prepare server
+ var out bytes.Buffer
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
+ s.chatSuffix = ">"
+ s.chatPrefixes = []string{"?", "!", ":", ";"}
+ s.llmClient = fakeLLM{resp: "// doc\nfunc add(a,b int) int { return a+b }"}
+
+ // Document with a function
+ uri := "file:///x.go"
+ src := "package p\n\nfunc add(a,b int) int { return a+b }\n"
+ s.setDocument(uri, src)
+
+ // Select the function line
+ p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: Range{Start: Position{Line: 2, Character: 0}, End: Position{Line: 2, Character: len("func add(a,b int) int { return a+b }")}}}
+ b, _ := json.Marshal(p)
+ req := Request{JSONRPC: "2.0", ID: json.RawMessage("1"), Method: "textDocument/codeAction", Params: b}
+
+ // Invoke directly
+ out.Reset()
+ s.handleCodeAction(req)
+ resp := captureResponse(t, &out)
+ // Decode result into []CodeAction
+ var actions []CodeAction
+ rb, _ := json.Marshal(resp.Result)
+ if err := json.Unmarshal(rb, &actions); err != nil {
+ t.Fatalf("decode actions: %v", err)
+ }
+ if len(actions) == 0 {
+ t.Fatalf("expected some actions")
+ }
+ // Ensure our Hexai actions are present
+ hasDoc := false
+ hasGoTest := false
+ for _, a := range actions {
+ if strings.Contains(strings.ToLower(a.Title), "hexai:") {
+ if strings.Contains(a.Title, "document code") {
+ hasDoc = true
+ }
+ if strings.Contains(a.Title, "implement unit test") {
+ hasGoTest = true
+ }
+ }
+ }
+ if !hasDoc || !hasGoTest {
+ t.Fatalf("expected both Hexai actions, got %+v", actions)
+ }
}
func TestHandleCodeActionResolve_Document(t *testing.T) {
- var out bytes.Buffer
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
- s.llmClient = fakeLLM{resp: "// doc\nfunc f(){}"}
- uri := "file:///x.go"
- s.setDocument(uri, "package p\nfunc f(){}\n")
- // Build a document code action payload
- payload := struct {
- Type string `json:"type"`
- URI string `json:"uri"`
- Range Range `json:"range"`
- Selection string `json:"selection"`
- }{Type: "document", URI: uri, Range: Range{Start: Position{Line:1}, End: Position{Line:1, Character: 10}}, Selection: "func f(){}"}
- raw, _ := json.Marshal(payload)
- ca := CodeAction{Title: "Hexai: document code", Data: raw}
- b, _ := json.Marshal(ca)
- req := Request{JSONRPC: "2.0", ID: json.RawMessage("2"), Method: "codeAction/resolve", Params: b}
-
- out.Reset()
- s.handleCodeActionResolve(req)
- resp := captureResponse(t, &out)
- var resolved CodeAction
- rb, _ := json.Marshal(resp.Result)
- if err := json.Unmarshal(rb, &resolved); err != nil {
- t.Fatalf("decode resolved: %v", err)
- }
- if resolved.Edit == nil { t.Fatalf("expected resolved edit") }
+ var out bytes.Buffer
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
+ s.llmClient = fakeLLM{resp: "// doc\nfunc f(){}"}
+ uri := "file:///x.go"
+ s.setDocument(uri, "package p\nfunc f(){}\n")
+ // Build a document code action payload
+ payload := struct {
+ Type string `json:"type"`
+ URI string `json:"uri"`
+ Range Range `json:"range"`
+ Selection string `json:"selection"`
+ }{Type: "document", URI: uri, Range: Range{Start: Position{Line: 1}, End: Position{Line: 1, Character: 10}}, Selection: "func f(){}"}
+ raw, _ := json.Marshal(payload)
+ ca := CodeAction{Title: "Hexai: document code", Data: raw}
+ b, _ := json.Marshal(ca)
+ req := Request{JSONRPC: "2.0", ID: json.RawMessage("2"), Method: "codeAction/resolve", Params: b}
+
+ out.Reset()
+ s.handleCodeActionResolve(req)
+ resp := captureResponse(t, &out)
+ var resolved CodeAction
+ rb, _ := json.Marshal(resp.Result)
+ if err := json.Unmarshal(rb, &resolved); err != nil {
+ t.Fatalf("decode resolved: %v", err)
+ }
+ if resolved.Edit == nil {
+ t.Fatalf("expected resolved edit")
+ }
}
func TestHandleCodeAction_NoLLMOrEmptySelection_ReturnsEmpty(t *testing.T) {
- var out bytes.Buffer
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
- uri := "file:///x.go"
- s.setDocument(uri, "package p\n\n")
- // Empty selection
- p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: Range{Start: Position{Line:1}, End: Position{Line:1}}}
- b, _ := json.Marshal(p)
- req := Request{JSONRPC: "2.0", ID: json.RawMessage("4"), Method: "textDocument/codeAction", Params: b}
- out.Reset()
- s.handleCodeAction(req)
- resp := captureResponse(t, &out)
- var actions []CodeAction
- rb, _ := json.Marshal(resp.Result)
- _ = json.Unmarshal(rb, &actions)
- if len(actions) != 0 { t.Fatalf("expected no actions for empty selection, got %d", len(actions)) }
-
- // No llm client: should also return empty even if selection non-empty
- p2 := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: Range{Start: Position{Line:0}, End: Position{Line:0, Character:7}}}
- out.Reset()
- req2 := Request{JSONRPC: "2.0", ID: json.RawMessage("5"), Method: "textDocument/codeAction", Params: mustJSON(p2)}
- s.handleCodeAction(req2)
- resp2 := captureResponse(t, &out)
- var actions2 []CodeAction
- rb2, _ := json.Marshal(resp2.Result)
- _ = json.Unmarshal(rb2, &actions2)
- if len(actions2) != 0 { t.Fatalf("expected no actions when llm is nil") }
+ var out bytes.Buffer
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
+ uri := "file:///x.go"
+ s.setDocument(uri, "package p\n\n")
+ // Empty selection
+ p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: Range{Start: Position{Line: 1}, End: Position{Line: 1}}}
+ b, _ := json.Marshal(p)
+ req := Request{JSONRPC: "2.0", ID: json.RawMessage("4"), Method: "textDocument/codeAction", Params: b}
+ out.Reset()
+ s.handleCodeAction(req)
+ resp := captureResponse(t, &out)
+ var actions []CodeAction
+ rb, _ := json.Marshal(resp.Result)
+ _ = json.Unmarshal(rb, &actions)
+ if len(actions) != 0 {
+ t.Fatalf("expected no actions for empty selection, got %d", len(actions))
+ }
+
+ // No llm client: should also return empty even if selection non-empty
+ p2 := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Range: Range{Start: Position{Line: 0}, End: Position{Line: 0, Character: 7}}}
+ out.Reset()
+ req2 := Request{JSONRPC: "2.0", ID: json.RawMessage("5"), Method: "textDocument/codeAction", Params: mustJSON(p2)}
+ s.handleCodeAction(req2)
+ resp2 := captureResponse(t, &out)
+ var actions2 []CodeAction
+ rb2, _ := json.Marshal(resp2.Result)
+ _ = json.Unmarshal(rb2, &actions2)
+ if len(actions2) != 0 {
+ t.Fatalf("expected no actions when llm is nil")
+ }
}
func mustJSON(v any) json.RawMessage { b, _ := json.Marshal(v); return b }
func TestHandle_UnknownMethod_ReturnsError(t *testing.T) {
- var out bytes.Buffer
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out, handlers: map[string]func(Request){}}
- req := Request{JSONRPC: "2.0", ID: json.RawMessage("9"), Method: "no/such"}
- out.Reset()
- s.handle(req)
- resp := captureResponse(t, &out)
- if resp.Error == nil || resp.Error.Code != -32601 { t.Fatalf("expected method not found error, got %+v", resp.Error) }
+ var out bytes.Buffer
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out, handlers: map[string]func(Request){}}
+ req := Request{JSONRPC: "2.0", ID: json.RawMessage("9"), Method: "no/such"}
+ out.Reset()
+ s.handle(req)
+ resp := captureResponse(t, &out)
+ if resp.Error == nil || resp.Error.Code != -32601 {
+ t.Fatalf("expected method not found error, got %+v", resp.Error)
+ }
}
func TestHandle_Dispatch_Initialize(t *testing.T) {
- var out bytes.Buffer
- // Build a server via constructor to ensure handlers map is populated
- s := NewServer(bytes.NewReader(nil), &out, log.New(io.Discard, "", 0), ServerOptions{})
- req := Request{JSONRPC: "2.0", ID: json.RawMessage("13"), Method: "initialize"}
- out.Reset()
- s.handle(req)
- resp := captureResponse(t, &out)
- var init InitializeResult
- b, _ := json.Marshal(resp.Result)
- _ = json.Unmarshal(b, &init)
- if init.Capabilities.CodeActionProvider == nil || init.Capabilities.CompletionProvider == nil { t.Fatalf("missing capabilities") }
+ var out bytes.Buffer
+ // Build a server via constructor to ensure handlers map is populated
+ s := NewServer(bytes.NewReader(nil), &out, log.New(io.Discard, "", 0), ServerOptions{})
+ req := Request{JSONRPC: "2.0", ID: json.RawMessage("13"), Method: "initialize"}
+ out.Reset()
+ s.handle(req)
+ resp := captureResponse(t, &out)
+ var init InitializeResult
+ b, _ := json.Marshal(resp.Result)
+ _ = json.Unmarshal(b, &init)
+ if init.Capabilities.CodeActionProvider == nil || init.Capabilities.CompletionProvider == nil {
+ t.Fatalf("missing capabilities")
+ }
}
-
func TestDetectAndHandleChat_InsertsReply(t *testing.T) {
- var out bytes.Buffer
- s := NewServer(bytes.NewReader(nil), &out, log.New(io.Discard, "", 0), ServerOptions{})
- s.llmClient = fakeLLM{resp: tut.MultilineChatReply()}
- uri := "file:///chat.go"
- // Place a prompt line with a supported trigger at EOL, then a blank line
- s.setDocument(uri, "What time?>\n\n")
- out.Reset()
- s.detectAndHandleChat(uri)
- // Allow async goroutine to write the request
- for i := 0; i < 20 && out.Len() == 0; i++ { time.Sleep(10 * time.Millisecond) }
- if out.Len() == 0 { t.Fatalf("no output written by detectAndHandleChat") }
- // Expect a workspace/applyEdit request to be written
- req := captureRequest(t, &out)
- if req.Method != "workspace/applyEdit" { t.Fatalf("expected workspace/applyEdit, got %s", req.Method) }
- var params ApplyWorkspaceEditParams
- if err := json.Unmarshal(req.Params, &params); err != nil { t.Fatalf("decode params: %v", err) }
- we := params.Edit
- if len(we.Changes) == 0 { t.Fatalf("expected changes in edit") }
- edits := we.Changes[uri]
- if len(edits) != 2 { t.Fatalf("expected 2 edits (delete+insert), got %d", len(edits)) }
- if !strings.Contains(edits[1].NewText, "> Hello") || !strings.Contains(edits[1].NewText, "multi-line reply") {
- t.Fatalf("expected multi-line reply insertion, got %q", edits[1].NewText)
- }
+ var out bytes.Buffer
+ s := NewServer(bytes.NewReader(nil), &out, log.New(io.Discard, "", 0), ServerOptions{})
+ s.llmClient = fakeLLM{resp: tut.MultilineChatReply()}
+ uri := "file:///chat.go"
+ // Place a prompt line with a supported trigger at EOL, then a blank line
+ s.setDocument(uri, "What time?>\n\n")
+ out.Reset()
+ s.detectAndHandleChat(uri)
+ // Allow async goroutine to write the request
+ for i := 0; i < 20 && out.Len() == 0; i++ {
+ time.Sleep(10 * time.Millisecond)
+ }
+ if out.Len() == 0 {
+ t.Fatalf("no output written by detectAndHandleChat")
+ }
+ // Expect a workspace/applyEdit request to be written
+ req := captureRequest(t, &out)
+ if req.Method != "workspace/applyEdit" {
+ t.Fatalf("expected workspace/applyEdit, got %s", req.Method)
+ }
+ var params ApplyWorkspaceEditParams
+ if err := json.Unmarshal(req.Params, &params); err != nil {
+ t.Fatalf("decode params: %v", err)
+ }
+ we := params.Edit
+ if len(we.Changes) == 0 {
+ t.Fatalf("expected changes in edit")
+ }
+ edits := we.Changes[uri]
+ if len(edits) != 2 {
+ t.Fatalf("expected 2 edits (delete+insert), got %d", len(edits))
+ }
+ if !strings.Contains(edits[1].NewText, "> Hello") || !strings.Contains(edits[1].NewText, "multi-line reply") {
+ t.Fatalf("expected multi-line reply insertion, got %q", edits[1].NewText)
+ }
}
func TestHandleCodeActionResolve_Diagnostics(t *testing.T) {
- var out bytes.Buffer
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
- s.llmClient = fakeLLM{resp: "fixed"}
- uri := "file:///x.go"
- s.setDocument(uri, "package p\nvar x = 1\n")
- payload := struct {
- Type string `json:"type"`
- URI string `json:"uri"`
- Range Range `json:"range"`
- Selection string `json:"selection"`
- Diagnostics []Diagnostic `json:"diagnostics"`
- }{Type: "diagnostics", URI: uri, Range: Range{Start: Position{Line:1}, End: Position{Line:1, Character: 10}}, Selection: "var x = 1", Diagnostics: []Diagnostic{{Range: Range{Start: Position{Line:1}, End: Position{Line:1, Character:5}}, Message: "bad"}}}
- raw, _ := json.Marshal(payload)
- ca := CodeAction{Title: "Hexai: resolve diagnostics", Data: raw}
- b, _ := json.Marshal(ca)
- req := Request{JSONRPC: "2.0", ID: json.RawMessage("3"), Method: "codeAction/resolve", Params: b}
- out.Reset()
- s.handleCodeActionResolve(req)
- resp := captureResponse(t, &out)
- var resolved CodeAction
- rb, _ := json.Marshal(resp.Result)
- if err := json.Unmarshal(rb, &resolved); err != nil { t.Fatalf("decode resolved: %v", err) }
- if resolved.Edit == nil { t.Fatalf("expected resolved edit for diagnostics") }
+ var out bytes.Buffer
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
+ s.llmClient = fakeLLM{resp: "fixed"}
+ uri := "file:///x.go"
+ s.setDocument(uri, "package p\nvar x = 1\n")
+ payload := struct {
+ Type string `json:"type"`
+ URI string `json:"uri"`
+ Range Range `json:"range"`
+ Selection string `json:"selection"`
+ Diagnostics []Diagnostic `json:"diagnostics"`
+ }{Type: "diagnostics", URI: uri, Range: Range{Start: Position{Line: 1}, End: Position{Line: 1, Character: 10}}, Selection: "var x = 1", Diagnostics: []Diagnostic{{Range: Range{Start: Position{Line: 1}, End: Position{Line: 1, Character: 5}}, Message: "bad"}}}
+ raw, _ := json.Marshal(payload)
+ ca := CodeAction{Title: "Hexai: resolve diagnostics", Data: raw}
+ b, _ := json.Marshal(ca)
+ req := Request{JSONRPC: "2.0", ID: json.RawMessage("3"), Method: "codeAction/resolve", Params: b}
+ out.Reset()
+ s.handleCodeActionResolve(req)
+ resp := captureResponse(t, &out)
+ var resolved CodeAction
+ rb, _ := json.Marshal(resp.Result)
+ if err := json.Unmarshal(rb, &resolved); err != nil {
+ t.Fatalf("decode resolved: %v", err)
+ }
+ if resolved.Edit == nil {
+ t.Fatalf("expected resolved edit for diagnostics")
+ }
}
diff --git a/internal/lsp/handlers_execute.go b/internal/lsp/handlers_execute.go
index 2e3ec52..d0bc8fc 100644
--- a/internal/lsp/handlers_execute.go
+++ b/internal/lsp/handlers_execute.go
@@ -2,34 +2,33 @@
package lsp
import (
- "encoding/json"
+ "encoding/json"
)
func (s *Server) handleExecuteCommand(req Request) {
- var p ExecuteCommandParams
- if err := json.Unmarshal(req.Params, &p); err != nil {
- s.reply(req.ID, nil, nil)
- return
- }
- switch p.Command {
- case "hexai.showDocument":
- if len(p.Arguments) >= 2 {
- uri, _ := p.Arguments[0].(string)
- var r Range
- // Convert second arg to Range via re-marshal to be robust across clients
- if b, err := json.Marshal(p.Arguments[1]); err == nil {
- _ = json.Unmarshal(b, &r)
- }
- if uri != "" {
- s.clientShowDocument(uri, &r)
- }
- }
- s.reply(req.ID, nil, nil)
- return
- default:
- // Unknown command; no-op
- s.reply(req.ID, nil, nil)
- return
- }
+ var p ExecuteCommandParams
+ if err := json.Unmarshal(req.Params, &p); err != nil {
+ s.reply(req.ID, nil, nil)
+ return
+ }
+ switch p.Command {
+ case "hexai.showDocument":
+ if len(p.Arguments) >= 2 {
+ uri, _ := p.Arguments[0].(string)
+ var r Range
+ // Convert second arg to Range via re-marshal to be robust across clients
+ if b, err := json.Marshal(p.Arguments[1]); err == nil {
+ _ = json.Unmarshal(b, &r)
+ }
+ if uri != "" {
+ s.clientShowDocument(uri, &r)
+ }
+ }
+ s.reply(req.ID, nil, nil)
+ return
+ default:
+ // Unknown command; no-op
+ s.reply(req.ID, nil, nil)
+ return
+ }
}
-
diff --git a/internal/lsp/handlers_helpers_test.go b/internal/lsp/handlers_helpers_test.go
index 24a9690..0120cc3 100644
--- a/internal/lsp/handlers_helpers_test.go
+++ b/internal/lsp/handlers_helpers_test.go
@@ -6,32 +6,32 @@ import (
)
func TestHasDoubleSemicolonTrigger(t *testing.T) {
- cases := []struct {
- line string
- want bool
- }{
- {">>todo> remove this", true},
- {"prefix >>x> suffix", true},
- {">> spaced >", false},
- {"no markers", false},
- {">>x > space before close", false},
- }
- for _, tc := range cases {
- got := hasDoubleOpenTrigger(tc.line)
- if got != tc.want {
- t.Fatalf("hasDoubleOpenTrigger(%q)=%v want %v", tc.line, got, tc.want)
- }
- }
+ cases := []struct {
+ line string
+ want bool
+ }{
+ {">>todo> remove this", true},
+ {"prefix >>x> suffix", true},
+ {">> spaced >", false},
+ {"no markers", false},
+ {">>x > space before close", false},
+ }
+ for _, tc := range cases {
+ got := hasDoubleOpenTrigger(tc.line)
+ if got != tc.want {
+ t.Fatalf("hasDoubleOpenTrigger(%q)=%v want %v", tc.line, got, tc.want)
+ }
+ }
}
func TestCollectSemicolonMarkers(t *testing.T) {
- line := "keep >ok> this and >another> that"
- edits := collectSemicolonMarkers(line, 7)
- if len(edits) != 2 {
- t.Fatalf("expected 2 edits, got %d", len(edits))
- }
- // Validate the first edit aligns with ;ok;
- start := strings.Index(line, ">ok>")
+ line := "keep >ok> this and >another> that"
+ edits := collectSemicolonMarkers(line, 7)
+ if len(edits) != 2 {
+ t.Fatalf("expected 2 edits, got %d", len(edits))
+ }
+ // Validate the first edit aligns with ;ok;
+ start := strings.Index(line, ">ok>")
if start < 0 {
t.Fatalf("test setup: missing ;ok;")
}
@@ -41,11 +41,11 @@ func TestCollectSemicolonMarkers(t *testing.T) {
}
func TestPromptRemovalEditsForLine_WholeLine(t *testing.T) {
- line := ">>todo> remove this whole line"
- edits := promptRemovalEditsForLine(line, 3)
- if len(edits) != 1 {
- t.Fatalf("expected 1 whole-line edit, got %d", len(edits))
- }
+ line := ">>todo> remove this whole line"
+ edits := promptRemovalEditsForLine(line, 3)
+ if len(edits) != 1 {
+ t.Fatalf("expected 1 whole-line edit, got %d", len(edits))
+ }
e := edits[0]
if e.Range.Start.Line != 3 || e.Range.End.Line != 3 || e.Range.Start.Character != 0 || e.Range.End.Character != len(line) {
t.Fatalf("unexpected range for whole-line removal: %+v", e.Range)
diff --git a/internal/lsp/handlers_init.go b/internal/lsp/handlers_init.go
index 99ab026..ac1d566 100644
--- a/internal/lsp/handlers_init.go
+++ b/internal/lsp/handlers_init.go
@@ -2,9 +2,10 @@
package lsp
import (
+ "os"
+
"codeberg.org/snonux/hexai/internal"
"codeberg.org/snonux/hexai/internal/logging"
- "os"
)
func (s *Server) handleInitialize(req Request) {
diff --git a/internal/lsp/handlers_test.go b/internal/lsp/handlers_test.go
index 8fdd34f..a171143 100644
--- a/internal/lsp/handlers_test.go
+++ b/internal/lsp/handlers_test.go
@@ -15,7 +15,7 @@ func TestFindFirstInstructionInLine_NoMarker(t *testing.T) {
}
func TestFindFirstInstructionInLine_StrictInline_Basic(t *testing.T) {
- line := "prefix >rename var> suffix"
+ line := "prefix >rename var> suffix"
instr, cleaned, ok := findFirstInstructionInLine(line)
if !ok {
t.Fatalf("expected ok=true")
@@ -30,7 +30,7 @@ func TestFindFirstInstructionInLine_StrictInline_Basic(t *testing.T) {
}
func TestFindFirstInstructionInLine_StrictInline_TrailingSpacesTrimmed(t *testing.T) {
- line := "code>fix> \t\t"
+ line := "code>fix> \t\t"
instr, cleaned, ok := findFirstInstructionInLine(line)
if !ok {
t.Fatalf("expected ok=true")
@@ -44,16 +44,16 @@ func TestFindFirstInstructionInLine_StrictInline_TrailingSpacesTrimmed(t *testin
}
func TestFindFirstInstructionInLine_Inline_InvalidPatterns(t *testing.T) {
- cases := []string{
- "prefix > bad> suffix", // space after first '>' ⇒ invalid
- "prefix >bad > suffix", // space before closing '>' ⇒ invalid
- "prefix > > suffix", // empty inner ⇒ invalid
- }
- for _, line := range cases {
- if instr, _, ok := findFirstInstructionInLine(line); ok && instr != "" {
- t.Fatalf("%q: expected no inline instruction; got instr=%q", line, instr)
- }
- }
+ cases := []string{
+ "prefix > bad> suffix", // space after first '>' ⇒ invalid
+ "prefix >bad > suffix", // space before closing '>' ⇒ invalid
+ "prefix > > suffix", // empty inner ⇒ invalid
+ }
+ for _, line := range cases {
+ if instr, _, ok := findFirstInstructionInLine(line); ok && instr != "" {
+ t.Fatalf("%q: expected no inline instruction; got instr=%q", line, instr)
+ }
+ }
}
func TestFindFirstInstructionInLine_CBlockComment(t *testing.T) {
@@ -127,21 +127,21 @@ func TestFindFirstInstructionInLine_DoubleDash(t *testing.T) {
}
func TestFindFirstInstructionInLine_EarliestWins_CommentOverInline(t *testing.T) {
- line := "aa // comment >not this> trailing"
+ line := "aa // comment >not this> trailing"
instr, cleaned, ok := findFirstInstructionInLine(line)
if !ok {
t.Fatalf("expected ok=true")
}
- if instr != "comment >not this> trailing" {
- t.Fatalf("instr got %q want %q", instr, "comment >not this> trailing")
- }
+ if instr != "comment >not this> trailing" {
+ t.Fatalf("instr got %q want %q", instr, "comment >not this> trailing")
+ }
if cleaned != "aa" {
t.Fatalf("cleaned got %q want %q", cleaned, "aa")
}
}
func TestFindFirstInstructionInLine_EarliestWins_InlineOverComment(t *testing.T) {
- line := "aa >short> // comment"
+ line := "aa >short> // comment"
instr, cleaned, ok := findFirstInstructionInLine(line)
if !ok {
t.Fatalf("expected ok=true")
@@ -156,20 +156,20 @@ func TestFindFirstInstructionInLine_EarliestWins_InlineOverComment(t *testing.T)
}
func TestFindStrictInlineTag_Various(t *testing.T) {
- // basic
- if text, l, r, ok := findStrictInlineTag("pre>do it>post"); !ok || text != "do it" || l != 3 || r != 10 {
- t.Fatalf("unexpected: ok=%v text=%q l=%d r=%d", ok, text, l, r)
- }
- // at start
- if text, l, r, ok := findStrictInlineTag(">x>"); !ok || text != "x" || l != 0 || r != 3 {
- t.Fatalf("unexpected at start: ok=%v text=%q l=%d r=%d", ok, text, l, r)
- }
- // double opening '>>' should still allow a tag starting at the second '>'
- if text, _, _, ok := findStrictInlineTag("prefix >>bad> suffix"); !ok || text != "bad" {
- t.Fatalf("unexpected double-open handling: ok=%v text=%q", ok, text)
- }
- // inner spaces directly after first '>' or before last '>' invalidate the tag
- if _, _, _, ok := findStrictInlineTag("a> inner >b"); ok {
- t.Fatalf("expected invalid strict tag due to spaces at boundaries")
- }
+ // basic
+ if text, l, r, ok := findStrictInlineTag("pre>do it>post"); !ok || text != "do it" || l != 3 || r != 10 {
+ t.Fatalf("unexpected: ok=%v text=%q l=%d r=%d", ok, text, l, r)
+ }
+ // at start
+ if text, l, r, ok := findStrictInlineTag(">x>"); !ok || text != "x" || l != 0 || r != 3 {
+ t.Fatalf("unexpected at start: ok=%v text=%q l=%d r=%d", ok, text, l, r)
+ }
+ // double opening '>>' should still allow a tag starting at the second '>'
+ if text, _, _, ok := findStrictInlineTag("prefix >>bad> suffix"); !ok || text != "bad" {
+ t.Fatalf("unexpected double-open handling: ok=%v text=%q", ok, text)
+ }
+ // inner spaces directly after first '>' or before last '>' invalidate the tag
+ if _, _, _, ok := findStrictInlineTag("a> inner >b"); ok {
+ t.Fatalf("expected invalid strict tag due to spaces at boundaries")
+ }
}
diff --git a/internal/lsp/handlers_utils.go b/internal/lsp/handlers_utils.go
index e2c35e3..30a21a5 100644
--- a/internal/lsp/handlers_utils.go
+++ b/internal/lsp/handlers_utils.go
@@ -2,17 +2,20 @@
package lsp
import (
- "fmt"
- "codeberg.org/snonux/hexai/internal/llm"
- "codeberg.org/snonux/hexai/internal/logging"
- "strings"
- "time"
+ "fmt"
+ "strings"
+ "time"
+
+ "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/logging"
)
// Configurable inline trigger characters (default to '>') used by free helpers below.
// NewServer assigns these based on ServerOptions.
-var inlineOpenChar byte = '>'
-var inlineCloseChar byte = '>'
+var (
+ inlineOpenChar byte = '>'
+ inlineCloseChar byte = '>'
+)
// llmRequestOpts builds request options from server settings.
func (s *Server) llmRequestOpts() []llm.RequestOption {
@@ -129,10 +132,10 @@ func isIdentChar(ch byte) bool {
// Inline prompt utilities
func lineHasInlinePrompt(line string) bool {
- if _, _, _, ok := findStrictInlineTag(line); ok {
- return true
- }
- return hasDoubleOpenTrigger(line)
+ if _, _, _, ok := findStrictInlineTag(line); ok {
+ return true
+ }
+ return hasDoubleOpenTrigger(line)
}
func leadingIndent(line string) string {
@@ -173,60 +176,60 @@ func applyIndent(indent, suggestion string) string {
// opening marker and no space immediately before the closing marker. Returns the
// text between markers, the start index, the end index just after closing, and ok.
func findStrictInlineTag(line string) (string, int, int, bool) {
- pos := 0
- for pos < len(line) {
- // find opening marker
- j := strings.IndexByte(line[pos:], inlineOpenChar)
- if j < 0 {
- return "", 0, 0, false
- }
- j += pos
- // ensure single open (not double) and non-space after
- if j+1 >= len(line) || line[j+1] == inlineOpenChar || line[j+1] == ' ' {
- pos = j + 1
- continue
- }
- // find closing marker
- k := strings.IndexByte(line[j+1:], inlineCloseChar)
- if k < 0 {
- return "", 0, 0, false
- }
- closeIdx := j + 1 + k
- if closeIdx-1 < 0 || line[closeIdx-1] == ' ' {
- pos = closeIdx + 1
- continue
- }
- inner := strings.TrimSpace(line[j+1 : closeIdx])
- if inner == "" {
- pos = closeIdx + 1
- continue
- }
- end := closeIdx + 1
- return inner, j, end, true
- }
- return "", 0, 0, false
+ pos := 0
+ for pos < len(line) {
+ // find opening marker
+ j := strings.IndexByte(line[pos:], inlineOpenChar)
+ if j < 0 {
+ return "", 0, 0, false
+ }
+ j += pos
+ // ensure single open (not double) and non-space after
+ if j+1 >= len(line) || line[j+1] == inlineOpenChar || line[j+1] == ' ' {
+ pos = j + 1
+ continue
+ }
+ // find closing marker
+ k := strings.IndexByte(line[j+1:], inlineCloseChar)
+ if k < 0 {
+ return "", 0, 0, false
+ }
+ closeIdx := j + 1 + k
+ if closeIdx-1 < 0 || line[closeIdx-1] == ' ' {
+ pos = closeIdx + 1
+ continue
+ }
+ inner := strings.TrimSpace(line[j+1 : closeIdx])
+ if inner == "" {
+ pos = closeIdx + 1
+ continue
+ }
+ end := closeIdx + 1
+ return inner, j, end, true
+ }
+ return "", 0, 0, false
}
// isBareDoubleSemicolon reports whether the line contains a standalone
// double-semicolon marker with no inline content (";;" possibly with only
// whitespace after it). It explicitly excludes the valid form ";;text;".
func isBareDoubleOpen(line string) bool {
- t := strings.TrimSpace(line)
- // check for double-open pattern
- dbl := string([]byte{inlineOpenChar, inlineOpenChar})
- if !strings.Contains(t, dbl) {
- return false
- }
- if hasDoubleOpenTrigger(t) {
- return false
- }
- if strings.HasPrefix(t, dbl) {
- rest := strings.TrimSpace(t[len(dbl):])
- if rest == "" || rest == ";" {
- return true
- }
- }
- return false
+ t := strings.TrimSpace(line)
+ // check for double-open pattern
+ dbl := string([]byte{inlineOpenChar, inlineOpenChar})
+ if !strings.Contains(t, dbl) {
+ return false
+ }
+ if hasDoubleOpenTrigger(t) {
+ return false
+ }
+ if strings.HasPrefix(t, dbl) {
+ rest := strings.TrimSpace(t[len(dbl):])
+ if rest == "" || rest == ";" {
+ return true
+ }
+ }
+ return false
}
// stripDuplicateAssignmentPrefix removes a duplicated assignment prefix from the suggestion.
@@ -409,82 +412,82 @@ func (s *Server) collectPromptRemovalEdits(uri string) []TextEdit {
}
func promptRemovalEditsForLine(line string, lineNum int) []TextEdit {
- if hasDoubleOpenTrigger(line) {
- return []TextEdit{{Range: Range{Start: Position{Line: lineNum, Character: 0}, End: Position{Line: lineNum, Character: len(line)}}, NewText: ""}}
- }
- return collectSemicolonMarkers(line, lineNum)
+ if hasDoubleOpenTrigger(line) {
+ return []TextEdit{{Range: Range{Start: Position{Line: lineNum, Character: 0}, End: Position{Line: lineNum, Character: len(line)}}, NewText: ""}}
+ }
+ return collectSemicolonMarkers(line, lineNum)
}
func hasDoubleOpenTrigger(line string) bool {
- pos := 0
- for pos < len(line) {
- // look for double-open sequence
- dbl := string([]byte{inlineOpenChar, inlineOpenChar})
- j := strings.Index(line[pos:], dbl)
- if j < 0 {
- return false
- }
- j += pos
- contentStart := j + len(dbl)
- if contentStart >= len(line) {
- return false
- }
- first := line[contentStart]
- if first == ' ' || first == inlineOpenChar {
- pos = contentStart + 1
- continue
- }
- // find closing
- k := strings.IndexByte(line[contentStart+1:], inlineCloseChar)
- if k < 0 {
- return false
- }
- closeIdx := contentStart + 1 + k
- if closeIdx-1 >= 0 && line[closeIdx-1] == ' ' {
- pos = closeIdx + 1
- continue
- }
- return true
- }
- return false
+ pos := 0
+ for pos < len(line) {
+ // look for double-open sequence
+ dbl := string([]byte{inlineOpenChar, inlineOpenChar})
+ j := strings.Index(line[pos:], dbl)
+ if j < 0 {
+ return false
+ }
+ j += pos
+ contentStart := j + len(dbl)
+ if contentStart >= len(line) {
+ return false
+ }
+ first := line[contentStart]
+ if first == ' ' || first == inlineOpenChar {
+ pos = contentStart + 1
+ continue
+ }
+ // find closing
+ k := strings.IndexByte(line[contentStart+1:], inlineCloseChar)
+ if k < 0 {
+ return false
+ }
+ closeIdx := contentStart + 1 + k
+ if closeIdx-1 >= 0 && line[closeIdx-1] == ' ' {
+ pos = closeIdx + 1
+ continue
+ }
+ return true
+ }
+ return false
}
func collectSemicolonMarkers(line string, lineNum int) []TextEdit {
- var edits []TextEdit
- startSemi := 0
- for startSemi < len(line) {
- j := strings.IndexByte(line[startSemi:], inlineOpenChar)
- if j < 0 {
- break
- }
- j += startSemi
- k := strings.IndexByte(line[j+1:], inlineCloseChar)
- if k < 0 {
- break
- }
- if j+1 >= len(line) || line[j+1] == ' ' {
- startSemi = j + 1
- continue
- }
- if line[j+1] == inlineOpenChar { // skip double-open start
- startSemi = j + 2
- continue
- }
- closeIdx := j + 1 + k
- if closeIdx-1 < 0 || line[closeIdx-1] == ' ' {
- startSemi = closeIdx + 1
- continue
- }
- if closeIdx-(j+1) < 1 {
- startSemi = closeIdx + 1
- continue
- }
- endChar := closeIdx + 1
- if endChar < len(line) && line[endChar] == ' ' {
- endChar++
- }
- edits = append(edits, TextEdit{Range: Range{Start: Position{Line: lineNum, Character: j}, End: Position{Line: lineNum, Character: endChar}}, NewText: ""})
- startSemi = endChar
- }
- return edits
+ var edits []TextEdit
+ startSemi := 0
+ for startSemi < len(line) {
+ j := strings.IndexByte(line[startSemi:], inlineOpenChar)
+ if j < 0 {
+ break
+ }
+ j += startSemi
+ k := strings.IndexByte(line[j+1:], inlineCloseChar)
+ if k < 0 {
+ break
+ }
+ if j+1 >= len(line) || line[j+1] == ' ' {
+ startSemi = j + 1
+ continue
+ }
+ if line[j+1] == inlineOpenChar { // skip double-open start
+ startSemi = j + 2
+ continue
+ }
+ closeIdx := j + 1 + k
+ if closeIdx-1 < 0 || line[closeIdx-1] == ' ' {
+ startSemi = closeIdx + 1
+ continue
+ }
+ if closeIdx-(j+1) < 1 {
+ startSemi = closeIdx + 1
+ continue
+ }
+ endChar := closeIdx + 1
+ if endChar < len(line) && line[endChar] == ' ' {
+ endChar++
+ }
+ edits = append(edits, TextEdit{Range: Range{Start: Position{Line: lineNum, Character: j}, End: Position{Line: lineNum, Character: endChar}}, NewText: ""})
+ startSemi = endChar
+ }
+ return edits
}
diff --git a/internal/lsp/helpers_inline_prompt_test.go b/internal/lsp/helpers_inline_prompt_test.go
index 81312b4..4aaf892 100644
--- a/internal/lsp/helpers_inline_prompt_test.go
+++ b/internal/lsp/helpers_inline_prompt_test.go
@@ -1,58 +1,62 @@
package lsp
import (
- "encoding/json"
- "testing"
+ "encoding/json"
+ "testing"
)
func TestLineHasInlinePrompt_BasicAndDoubleOpen(t *testing.T) {
- // Basic inline
- if !lineHasInlinePrompt("do >task> now") {
- t.Fatalf("expected inline prompt detection for >text>")
- }
- // Double-open variant should be recognized as inline prompt too
- if !lineHasInlinePrompt(">>replace>") {
- t.Fatalf("expected inline prompt detection for >>text>")
- }
+ // Basic inline
+ if !lineHasInlinePrompt("do >task> now") {
+ t.Fatalf("expected inline prompt detection for >text>")
+ }
+ // Double-open variant should be recognized as inline prompt too
+ if !lineHasInlinePrompt(">>replace>") {
+ t.Fatalf("expected inline prompt detection for >>text>")
+ }
}
func TestIsTriggerEvent_TriggerCharNotAllowed(t *testing.T) {
- s := newTestServer()
- s.triggerChars = []string{"."} // only dot allowed
- p := CompletionParams{Position: Position{Line:0, Character:3}}
- if s.isTriggerEvent(p, "ab:") { // ':' not in triggerChars
- t.Fatalf("expected false when TriggerCharacter not configured")
- }
+ s := newTestServer()
+ s.triggerChars = []string{"."} // only dot allowed
+ p := CompletionParams{Position: Position{Line: 0, Character: 3}}
+ if s.isTriggerEvent(p, "ab:") { // ':' not in triggerChars
+ t.Fatalf("expected false when TriggerCharacter not configured")
+ }
}
func TestShouldSuppressForChatTriggerEOL_EmptySuffix_NoSuppression(t *testing.T) {
- s := newTestServer()
- s.chatSuffix = "" // disabled
- p := CompletionParams{Position: Position{Line:0, Character:5}}
- if s.shouldSuppressForChatTriggerEOL("What?>", p) {
- t.Fatalf("expected no suppression when chat suffix is empty")
- }
+ s := newTestServer()
+ s.chatSuffix = "" // disabled
+ p := CompletionParams{Position: Position{Line: 0, Character: 5}}
+ if s.shouldSuppressForChatTriggerEOL("What?>", p) {
+ t.Fatalf("expected no suppression when chat suffix is empty")
+ }
}
func TestIsTriggerEvent_TriggerCharacterMissing_ReturnsFalse(t *testing.T) {
- s := newTestServer()
- // Context says TriggerCharacter, but none provided
- ctx := struct{ TriggerKind int `json:"triggerKind"` }{TriggerKind: 2}
- raw, _ := json.Marshal(ctx)
- p := CompletionParams{Position: Position{Line:0, Character:1}, Context: json.RawMessage(raw)}
- if s.isTriggerEvent(p, "a") {
- t.Fatalf("expected false when TriggerCharacter kind with empty char")
- }
+ s := newTestServer()
+ // Context says TriggerCharacter, but none provided
+ ctx := struct {
+ TriggerKind int `json:"triggerKind"`
+ }{TriggerKind: 2}
+ raw, _ := json.Marshal(ctx)
+ p := CompletionParams{Position: Position{Line: 0, Character: 1}, Context: json.RawMessage(raw)}
+ if s.isTriggerEvent(p, "a") {
+ t.Fatalf("expected false when TriggerCharacter kind with empty char")
+ }
}
func TestIsTriggerEvent_TriggerForIncomplete_FallsBackToChar(t *testing.T) {
- s := newTestServer()
- s.triggerChars = []string{"."}
- // TriggerKind=3 should consult fallback char check
- ctx := struct{ TriggerKind int `json:"triggerKind"` }{TriggerKind: 3}
- raw, _ := json.Marshal(ctx)
- p := CompletionParams{Position: Position{Line:0, Character:2}, Context: json.RawMessage(raw)}
- if !s.isTriggerEvent(p, "x.") {
- t.Fatalf("expected true via fallback char for TriggerForIncomplete")
- }
+ s := newTestServer()
+ s.triggerChars = []string{"."}
+ // TriggerKind=3 should consult fallback char check
+ ctx := struct {
+ TriggerKind int `json:"triggerKind"`
+ }{TriggerKind: 3}
+ raw, _ := json.Marshal(ctx)
+ p := CompletionParams{Position: Position{Line: 0, Character: 2}, Context: json.RawMessage(raw)}
+ if !s.isTriggerEvent(p, "x.") {
+ t.Fatalf("expected true via fallback char for TriggerForIncomplete")
+ }
}
diff --git a/internal/lsp/helpers_more_test.go b/internal/lsp/helpers_more_test.go
index 28d78a4..a0b0c26 100644
--- a/internal/lsp/helpers_more_test.go
+++ b/internal/lsp/helpers_more_test.go
@@ -1,111 +1,163 @@
package lsp
-import ("testing")
+import (
+ "testing"
+)
func TestComputeWordStart(t *testing.T) {
- s := "fooBar 123"
- if i := computeWordStart(s, 5); i != 0 { t.Fatalf("start=%d", i) }
- if i := computeWordStart(s, len(s)); i != 7 { t.Fatalf("end start=%d", i) }
+ s := "fooBar 123"
+ if i := computeWordStart(s, 5); i != 0 {
+ t.Fatalf("start=%d", i)
+ }
+ if i := computeWordStart(s, len(s)); i != 7 {
+ t.Fatalf("end start=%d", i)
+ }
}
func TestLeadingAndApplyIndent(t *testing.T) {
- if got := leadingIndent("\t abc"); got == "" { t.Fatalf("expected indent") }
- out := applyIndent(" ", "x\n y\n\n z")
- if out == "" || out[:2] != " " { t.Fatalf("applyIndent failed: %q", out) }
+ if got := leadingIndent("\t abc"); got == "" {
+ t.Fatalf("expected indent")
+ }
+ out := applyIndent(" ", "x\n y\n\n z")
+ if out == "" || out[:2] != " " {
+ t.Fatalf("applyIndent failed: %q", out)
+ }
}
func TestFindStrictInlineTag(t *testing.T) {
- if _, _, _, ok := findStrictInlineTag(">do this> next"); !ok { t.Fatalf("expected strict tag") }
- if _, _, _, ok := findStrictInlineTag("> spaced >"); ok { t.Fatalf("should ignore spaced tag") }
+ if _, _, _, ok := findStrictInlineTag(">do this> next"); !ok {
+ t.Fatalf("expected strict tag")
+ }
+ if _, _, _, ok := findStrictInlineTag("> spaced >"); ok {
+ t.Fatalf("should ignore spaced tag")
+ }
}
// hasDoubleSemicolonTrigger tested elsewhere
func TestStripDuplicatePrefixes(t *testing.T) {
- if got := stripDuplicateAssignmentPrefix("name := ", "name := 123"); got == "name := 123" { t.Fatalf("expected trim") }
- if got := stripDuplicateGeneralPrefix("fmt.", "fmt.Println"); got == "fmt.Println" { t.Fatalf("expected trim general") }
+ if got := stripDuplicateAssignmentPrefix("name := ", "name := 123"); got == "name := 123" {
+ t.Fatalf("expected trim")
+ }
+ if got := stripDuplicateGeneralPrefix("fmt.", "fmt.Println"); got == "fmt.Println" {
+ t.Fatalf("expected trim general")
+ }
}
func TestExtractRangeText(t *testing.T) {
- d := &document{text: "a\nbc\nxyz", lines: []string{"a","bc","xyz"}}
- // single line
- got := extractRangeText(d, Range{Start: Position{Line:1, Character:0}, End: Position{Line:1, Character:2}})
- if got != "bc" { t.Fatalf("got %q", got) }
- // multi-line
- got = extractRangeText(d, Range{Start: Position{Line:0, Character:0}, End: Position{Line:2, Character:2}})
- if got != "a\nbc\nxy" { t.Fatalf("got %q", got) }
- // invalid range (start after end) returns empty string
- if got := extractRangeText(d, Range{Start: Position{Line:1, Character:5}, End: Position{Line:1, Character:2}}); got != "" {
- t.Fatalf("expected empty for invalid range, got %q", got)
- }
+ d := &document{text: "a\nbc\nxyz", lines: []string{"a", "bc", "xyz"}}
+ // single line
+ got := extractRangeText(d, Range{Start: Position{Line: 1, Character: 0}, End: Position{Line: 1, Character: 2}})
+ if got != "bc" {
+ t.Fatalf("got %q", got)
+ }
+ // multi-line
+ got = extractRangeText(d, Range{Start: Position{Line: 0, Character: 0}, End: Position{Line: 2, Character: 2}})
+ if got != "a\nbc\nxy" {
+ t.Fatalf("got %q", got)
+ }
+ // invalid range (start after end) returns empty string
+ if got := extractRangeText(d, Range{Start: Position{Line: 1, Character: 5}, End: Position{Line: 1, Character: 2}}); got != "" {
+ t.Fatalf("expected empty for invalid range, got %q", got)
+ }
}
func TestRangesOverlapAndOrder(t *testing.T) {
- a := Range{Start: Position{Line:1, Character:2}, End: Position{Line:1, Character:5}}
- b := Range{Start: Position{Line:1, Character:4}, End: Position{Line:1, Character:8}}
- if !rangesOverlap(a, b) { t.Fatalf("expected overlap") }
- c := Range{Start: Position{Line:2, Character:0}, End: Position{Line:2, Character:1}}
- if rangesOverlap(a, c) { t.Fatalf("no overlap expected") }
- if !lessPos(Position{Line:0, Character:1}, Position{Line:1, Character:0}) { t.Fatalf("lessPos failed") }
- if !greaterPos(Position{Line:2, Character:0}, Position{Line:1, Character:9}) { t.Fatalf("greaterPos failed") }
+ a := Range{Start: Position{Line: 1, Character: 2}, End: Position{Line: 1, Character: 5}}
+ b := Range{Start: Position{Line: 1, Character: 4}, End: Position{Line: 1, Character: 8}}
+ if !rangesOverlap(a, b) {
+ t.Fatalf("expected overlap")
+ }
+ c := Range{Start: Position{Line: 2, Character: 0}, End: Position{Line: 2, Character: 1}}
+ if rangesOverlap(a, c) {
+ t.Fatalf("no overlap expected")
+ }
+ if !lessPos(Position{Line: 0, Character: 1}, Position{Line: 1, Character: 0}) {
+ t.Fatalf("lessPos failed")
+ }
+ if !greaterPos(Position{Line: 2, Character: 0}, Position{Line: 1, Character: 9}) {
+ t.Fatalf("greaterPos failed")
+ }
}
func TestPromptRemovalEditsForLine(t *testing.T) {
- edits := promptRemovalEditsForLine(">>do thing>", 3)
- if len(edits) != 1 || edits[0].Range.Start.Line != 3 {
- t.Fatalf("expected full-line removal for double-semicolon")
- }
- edits2 := promptRemovalEditsForLine(">act> and >b>", 1)
- if len(edits2) == 0 { t.Fatalf("expected edits to remove strict markers") }
+ edits := promptRemovalEditsForLine(">>do thing>", 3)
+ if len(edits) != 1 || edits[0].Range.Start.Line != 3 {
+ t.Fatalf("expected full-line removal for double-semicolon")
+ }
+ edits2 := promptRemovalEditsForLine(">act> and >b>", 1)
+ if len(edits2) == 0 {
+ t.Fatalf("expected edits to remove strict markers")
+ }
}
func TestCollectPromptRemovalEdits_MultiLine(t *testing.T) {
- s := newTestServer()
- uri := "file:///t.go"
- s.setDocument(uri, "a\n>do> x\n>>wipe>\nend")
- edits := s.collectPromptRemovalEdits(uri)
- if len(edits) < 2 { t.Fatalf("expected >=2 edits, got %d", len(edits)) }
+ s := newTestServer()
+ uri := "file:///t.go"
+ s.setDocument(uri, "a\n>do> x\n>>wipe>\nend")
+ edits := s.collectPromptRemovalEdits(uri)
+ if len(edits) < 2 {
+ t.Fatalf("expected >=2 edits, got %d", len(edits))
+ }
}
func TestInParamListAndBuildPrompts(t *testing.T) {
- cur := "func add(a int, b string) int"
- if !inParamList(cur, 12) { t.Fatalf("expected in param list") }
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: 5}}
- sys, user := buildPrompts(false, p, "above", "current", "below", "func add")
- if sys == "" || user == "" { t.Fatalf("prompts empty") }
+ cur := "func add(a int, b string) int"
+ if !inParamList(cur, 12) {
+ t.Fatalf("expected in param list")
+ }
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: 5}}
+ sys, user := buildPrompts(false, p, "above", "current", "below", "func add")
+ if sys == "" || user == "" {
+ t.Fatalf("prompts empty")
+ }
}
func TestLabelForCompletion(t *testing.T) {
- if got := labelForCompletion("line one\nline two", "lin"); got != "line one" { t.Fatalf("expected label, got %q", got) }
- if got := labelForCompletion("result", "zzz"); got != "zzz" { t.Fatalf("expected filter preferred when not prefix, got %q", got) }
- if got := labelForCompletion("result", "re"); got != "result" { t.Fatalf("expected label when filter prefixes label, got %q", got) }
+ if got := labelForCompletion("line one\nline two", "lin"); got != "line one" {
+ t.Fatalf("expected label, got %q", got)
+ }
+ if got := labelForCompletion("result", "zzz"); got != "zzz" {
+ t.Fatalf("expected filter preferred when not prefix, got %q", got)
+ }
+ if got := labelForCompletion("result", "re"); got != "result" {
+ t.Fatalf("expected label when filter prefixes label, got %q", got)
+ }
}
func TestComputeTextEditAndFilter(t *testing.T) {
- // non-params edit
- p := CompletionParams{Position: Position{Line: 1, Character: 4}}
- te, filter := computeTextEditAndFilter("X", false, "ab cd", p)
- if te == nil || filter == "" { t.Fatalf("expected edit and filter") }
- // inside params
- line := "func add(a int, b int)"
- p2 := CompletionParams{Position: Position{Line: 0, Character: 12}}
- te2, _ := computeTextEditAndFilter("string", true, line, p2)
- if te2 == nil || te2.Range.Start.Character == 0 { t.Fatalf("expected param-range edit") }
+ // non-params edit
+ p := CompletionParams{Position: Position{Line: 1, Character: 4}}
+ te, filter := computeTextEditAndFilter("X", false, "ab cd", p)
+ if te == nil || filter == "" {
+ t.Fatalf("expected edit and filter")
+ }
+ // inside params
+ line := "func add(a int, b int)"
+ p2 := CompletionParams{Position: Position{Line: 0, Character: 12}}
+ te2, _ := computeTextEditAndFilter("string", true, line, p2)
+ if te2 == nil || te2.Range.Start.Character == 0 {
+ t.Fatalf("expected param-range edit")
+ }
}
func TestIsBareDoubleOpen(t *testing.T) {
- if !isBareDoubleOpen(">> ") { t.Fatalf("expected true") }
- if isBareDoubleOpen(">>x>") { t.Fatalf("expected false for content form") }
+ if !isBareDoubleOpen(">> ") {
+ t.Fatalf("expected true")
+ }
+ if isBareDoubleOpen(">>x>") {
+ t.Fatalf("expected false for content form")
+ }
}
func TestIsDefiningNewFunction(t *testing.T) {
- s := newTestServer()
- uri := "file:///z.go"
- s.setDocument(uri, "package p\n\nfunc add(a int) int\n{")
- if !s.isDefiningNewFunction(uri, Position{Line:2, Character:10}) {
- t.Fatalf("expected true before opening brace")
- }
- if s.isDefiningNewFunction(uri, Position{Line:3, Character:1}) {
- t.Fatalf("expected false inside body")
- }
+ s := newTestServer()
+ uri := "file:///z.go"
+ s.setDocument(uri, "package p\n\nfunc add(a int) int\n{")
+ if !s.isDefiningNewFunction(uri, Position{Line: 2, Character: 10}) {
+ t.Fatalf("expected true before opening brace")
+ }
+ if s.isDefiningNewFunction(uri, Position{Line: 3, Character: 1}) {
+ t.Fatalf("expected false inside body")
+ }
}
diff --git a/internal/lsp/init_and_trigger_test.go b/internal/lsp/init_and_trigger_test.go
index 64253a9..10c04fd 100644
--- a/internal/lsp/init_and_trigger_test.go
+++ b/internal/lsp/init_and_trigger_test.go
@@ -1,51 +1,73 @@
package lsp
import (
- "bytes"
- "encoding/json"
- "io"
- "log"
- "testing"
+ "bytes"
+ "encoding/json"
+ "io"
+ "log"
+ "testing"
)
func TestHandleInitialize_Capabilities(t *testing.T) {
- var out bytes.Buffer
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
- s.triggerChars = []string{".", ":"}
- req := Request{JSONRPC: "2.0", ID: json.RawMessage("7"), Method: "initialize"}
- out.Reset()
- s.handleInitialize(req)
- resp := captureResponse(t, &out)
- var init InitializeResult
- b, _ := json.Marshal(resp.Result)
- if err := json.Unmarshal(b, &init); err != nil { t.Fatalf("decode init: %v", err) }
- if init.Capabilities.CodeActionProvider == nil { t.Fatalf("expected codeActionProvider") }
- // CodeActionProvider is any; re-marshal to struct
- var cap struct{ ResolveProvider bool `json:"resolveProvider"` }
- cb, _ := json.Marshal(init.Capabilities.CodeActionProvider)
- _ = json.Unmarshal(cb, &cap)
- if !cap.ResolveProvider { t.Fatalf("expected resolveProvider=true") }
- if init.Capabilities.CompletionProvider == nil || len(init.Capabilities.CompletionProvider.TriggerCharacters) == 0 {
- t.Fatalf("expected trigger characters") }
+ var out bytes.Buffer
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
+ s.triggerChars = []string{".", ":"}
+ req := Request{JSONRPC: "2.0", ID: json.RawMessage("7"), Method: "initialize"}
+ out.Reset()
+ s.handleInitialize(req)
+ resp := captureResponse(t, &out)
+ var init InitializeResult
+ b, _ := json.Marshal(resp.Result)
+ if err := json.Unmarshal(b, &init); err != nil {
+ t.Fatalf("decode init: %v", err)
+ }
+ if init.Capabilities.CodeActionProvider == nil {
+ t.Fatalf("expected codeActionProvider")
+ }
+ // CodeActionProvider is any; re-marshal to struct
+ var cap struct {
+ ResolveProvider bool `json:"resolveProvider"`
+ }
+ cb, _ := json.Marshal(init.Capabilities.CodeActionProvider)
+ _ = json.Unmarshal(cb, &cap)
+ if !cap.ResolveProvider {
+ t.Fatalf("expected resolveProvider=true")
+ }
+ if init.Capabilities.CompletionProvider == nil || len(init.Capabilities.CompletionProvider.TriggerCharacters) == 0 {
+ t.Fatalf("expected trigger characters")
+ }
}
func TestIsTriggerEvent_Variants(t *testing.T) {
- s := newTestServer()
- s.triggerChars = []string{".", ":"}
- // 1) Manual invoke via context
- ctx := struct{ TriggerKind int `json:"triggerKind"` }{TriggerKind:1}
- raw, _ := json.Marshal(ctx)
- p := CompletionParams{Position: Position{Line:0, Character:1}, Context: json.RawMessage(raw)}
- if !s.isTriggerEvent(p, "a") { t.Fatalf("manual invoke should trigger") }
- // 2) TriggerCharacter present and allowed
- ctx2 := struct{ TriggerKind int `json:"triggerKind"`; TriggerCharacter string `json:"triggerCharacter"` }{TriggerKind:2, TriggerCharacter: "."}
- raw2, _ := json.Marshal(ctx2)
- p2 := CompletionParams{Position: Position{Line:0, Character:1}, Context: json.RawMessage(raw2)}
- if !s.isTriggerEvent(p2, "a.") { t.Fatalf("trigger char should trigger") }
- // 3) Fallback char left of cursor
- p3 := CompletionParams{Position: Position{Line:0, Character:3}}
- if !s.isTriggerEvent(p3, "ab:") { t.Fatalf("fallback char should trigger") }
- // 4) Bare double-open disables trigger
- p4 := CompletionParams{Position: Position{Line:0, Character:2}}
- if s.isTriggerEvent(p4, ">>") { t.Fatalf("bare double-open should not trigger") }
+ s := newTestServer()
+ s.triggerChars = []string{".", ":"}
+ // 1) Manual invoke via context
+ ctx := struct {
+ TriggerKind int `json:"triggerKind"`
+ }{TriggerKind: 1}
+ raw, _ := json.Marshal(ctx)
+ p := CompletionParams{Position: Position{Line: 0, Character: 1}, Context: json.RawMessage(raw)}
+ if !s.isTriggerEvent(p, "a") {
+ t.Fatalf("manual invoke should trigger")
+ }
+ // 2) TriggerCharacter present and allowed
+ ctx2 := struct {
+ TriggerKind int `json:"triggerKind"`
+ TriggerCharacter string `json:"triggerCharacter"`
+ }{TriggerKind: 2, TriggerCharacter: "."}
+ raw2, _ := json.Marshal(ctx2)
+ p2 := CompletionParams{Position: Position{Line: 0, Character: 1}, Context: json.RawMessage(raw2)}
+ if !s.isTriggerEvent(p2, "a.") {
+ t.Fatalf("trigger char should trigger")
+ }
+ // 3) Fallback char left of cursor
+ p3 := CompletionParams{Position: Position{Line: 0, Character: 3}}
+ if !s.isTriggerEvent(p3, "ab:") {
+ t.Fatalf("fallback char should trigger")
+ }
+ // 4) Bare double-open disables trigger
+ p4 := CompletionParams{Position: Position{Line: 0, Character: 2}}
+ if s.isTriggerEvent(p4, ">>") {
+ t.Fatalf("bare double-open should not trigger")
+ }
}
diff --git a/internal/lsp/init_shutdown_test.go b/internal/lsp/init_shutdown_test.go
index 7b08f2c..19b9b33 100644
--- a/internal/lsp/init_shutdown_test.go
+++ b/internal/lsp/init_shutdown_test.go
@@ -1,20 +1,21 @@
package lsp
import (
- "bytes"
- "encoding/json"
- "io"
- "log"
- "testing"
+ "bytes"
+ "encoding/json"
+ "io"
+ "log"
+ "testing"
)
func TestHandleShutdown_Replies(t *testing.T) {
- var out bytes.Buffer
- s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
- req := Request{JSONRPC: "2.0", ID: json.RawMessage("12"), Method: "shutdown"}
- out.Reset()
- s.handleShutdown(req)
- resp := captureResponse(t, &out)
- if string(resp.ID) != "12" || resp.Error != nil { t.Fatalf("unexpected shutdown response: %+v", resp) }
+ var out bytes.Buffer
+ s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out}
+ req := Request{JSONRPC: "2.0", ID: json.RawMessage("12"), Method: "shutdown"}
+ out.Reset()
+ s.handleShutdown(req)
+ resp := captureResponse(t, &out)
+ if string(resp.ID) != "12" || resp.Error != nil {
+ t.Fatalf("unexpected shutdown response: %+v", resp)
+ }
}
-
diff --git a/internal/lsp/instruction_table_test.go b/internal/lsp/instruction_table_test.go
index 06364db..ff750ca 100644
--- a/internal/lsp/instruction_table_test.go
+++ b/internal/lsp/instruction_table_test.go
@@ -3,22 +3,22 @@ package lsp
import "testing"
func TestFindFirstInstructionInLine_Table(t *testing.T) {
- cases := []struct{
- name string
- line string
- instr string
- }{
- {"strict_inline_marker", ">do> trailing", "do"},
- {"c_block", "x /* add docs */ y", "add docs"},
- {"html_comment", "<!-- fix --> code", "fix"},
- {"slash_slash", "code // please refactor", "please refactor"},
- {"hash", "# summarize", "summarize"},
- {"double_dash", "-- rewrite quickly", "rewrite quickly"},
- }
- for _, c := range cases {
- instr, _, ok := findFirstInstructionInLine(c.line)
- if !ok || instr != c.instr {
- t.Fatalf("%s: got %q ok=%v", c.name, instr, ok)
- }
- }
+ cases := []struct {
+ name string
+ line string
+ instr string
+ }{
+ {"strict_inline_marker", ">do> trailing", "do"},
+ {"c_block", "x /* add docs */ y", "add docs"},
+ {"html_comment", "<!-- fix --> code", "fix"},
+ {"slash_slash", "code // please refactor", "please refactor"},
+ {"hash", "# summarize", "summarize"},
+ {"double_dash", "-- rewrite quickly", "rewrite quickly"},
+ }
+ for _, c := range cases {
+ instr, _, ok := findFirstInstructionInLine(c.line)
+ if !ok || instr != c.instr {
+ t.Fatalf("%s: got %q ok=%v", c.name, instr, ok)
+ }
+ }
}
diff --git a/internal/lsp/label_filter_table_test.go b/internal/lsp/label_filter_table_test.go
index c42b0b1..b6b69c1 100644
--- a/internal/lsp/label_filter_table_test.go
+++ b/internal/lsp/label_filter_table_test.go
@@ -3,15 +3,14 @@ package lsp
import "testing"
func TestLabelForCompletion_Table(t *testing.T) {
- cases := []struct{ cleaned, filter, want string }{
- {"line one\nline two", "zzz", "zzz"},
- {"result", "re", "result"},
- {"hello world", "he", "hello world"},
- }
- for _, c := range cases {
- if got := labelForCompletion(c.cleaned, c.filter); got != c.want {
- t.Fatalf("cleaned=%q filter=%q got %q want %q", c.cleaned, c.filter, got, c.want)
- }
- }
+ cases := []struct{ cleaned, filter, want string }{
+ {"line one\nline two", "zzz", "zzz"},
+ {"result", "re", "result"},
+ {"hello world", "he", "hello world"},
+ }
+ for _, c := range cases {
+ if got := labelForCompletion(c.cleaned, c.filter); got != c.want {
+ t.Fatalf("cleaned=%q filter=%q got %q want %q", c.cleaned, c.filter, got, c.want)
+ }
+ }
}
-
diff --git a/internal/lsp/llm_stats_test.go b/internal/lsp/llm_stats_test.go
index 9e27823..43582a2 100644
--- a/internal/lsp/llm_stats_test.go
+++ b/internal/lsp/llm_stats_test.go
@@ -3,9 +3,8 @@ package lsp
import "testing"
func TestLogLLMStats_CoversCounters(t *testing.T) {
- s := newTestServer()
- s.incSentCounters(10)
- s.incRecvCounters(20)
- s.logLLMStats() // just ensure it does not panic and executes
+ s := newTestServer()
+ s.incSentCounters(10)
+ s.incRecvCounters(20)
+ s.logLLMStats() // just ensure it does not panic and executes
}
-
diff --git a/internal/lsp/log_context_test.go b/internal/lsp/log_context_test.go
index 0bc4ed3..02b4efd 100644
--- a/internal/lsp/log_context_test.go
+++ b/internal/lsp/log_context_test.go
@@ -1,15 +1,14 @@
package lsp
import (
- "io"
- "log"
- "testing"
+ "io"
+ "log"
+ "testing"
)
func TestLogCompletionContext(t *testing.T) {
- s := newTestServer()
- s.logger = log.New(io.Discard, "", 0)
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line:1, Character:2}}
- s.logCompletionContext(p, "a", "b", "c", "f")
+ s := newTestServer()
+ s.logger = log.New(io.Discard, "", 0)
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line: 1, Character: 2}}
+ s.logCompletionContext(p, "a", "b", "c", "f")
}
-
diff --git a/internal/lsp/postprocess_indent_test.go b/internal/lsp/postprocess_indent_test.go
index b546068..28f73a5 100644
--- a/internal/lsp/postprocess_indent_test.go
+++ b/internal/lsp/postprocess_indent_test.go
@@ -3,11 +3,11 @@ package lsp
import "testing"
func TestPostProcessCompletion_IndentWithDoubleOpen(t *testing.T) {
- s := newTestServer()
- cleaned := s.postProcessCompletion("a\nb", "", " >>gen>")
- // Expect each non-empty line to be indented by two spaces
- want := " a\n b"
- if cleaned != want {
- t.Fatalf("got %q want %q", cleaned, want)
- }
+ s := newTestServer()
+ cleaned := s.postProcessCompletion("a\nb", "", " >>gen>")
+ // Expect each non-empty line to be indented by two spaces
+ want := " a\n b"
+ if cleaned != want {
+ t.Fatalf("got %q want %q", cleaned, want)
+ }
}
diff --git a/internal/lsp/prefix_table_test.go b/internal/lsp/prefix_table_test.go
index 0ca23d2..d362927 100644
--- a/internal/lsp/prefix_table_test.go
+++ b/internal/lsp/prefix_table_test.go
@@ -3,22 +3,21 @@ package lsp
import "testing"
func TestPrefixStripping_Table(t *testing.T) {
- cases := []struct{ name, prefix, sugg, want string }{
- {"assign_walrus", "name := ", "name := compute()", "compute()"},
- {"assign_equals", "x = ", "x = y+1", "y+1"},
- {"general_db", "db.", "db.Query()", "Query()"},
- {"general_func", "func New ", "func New() *T", "() *T"},
- }
- for _, c := range cases {
- var got string
- if c.name == "assign_walrus" || c.name == "assign_equals" {
- got = stripDuplicateAssignmentPrefix(c.prefix, c.sugg)
- } else {
- got = stripDuplicateGeneralPrefix(c.prefix, c.sugg)
- }
- if got != c.want {
- t.Fatalf("%s: got %q want %q", c.name, got, c.want)
- }
- }
+ cases := []struct{ name, prefix, sugg, want string }{
+ {"assign_walrus", "name := ", "name := compute()", "compute()"},
+ {"assign_equals", "x = ", "x = y+1", "y+1"},
+ {"general_db", "db.", "db.Query()", "Query()"},
+ {"general_func", "func New ", "func New() *T", "() *T"},
+ }
+ for _, c := range cases {
+ var got string
+ if c.name == "assign_walrus" || c.name == "assign_equals" {
+ got = stripDuplicateAssignmentPrefix(c.prefix, c.sugg)
+ } else {
+ got = stripDuplicateGeneralPrefix(c.prefix, c.sugg)
+ }
+ if got != c.want {
+ t.Fatalf("%s: got %q want %q", c.name, got, c.want)
+ }
+ }
}
-
diff --git a/internal/lsp/provider_native_success_test.go b/internal/lsp/provider_native_success_test.go
index fd7afad..dd1abcd 100644
--- a/internal/lsp/provider_native_success_test.go
+++ b/internal/lsp/provider_native_success_test.go
@@ -1,54 +1,62 @@
package lsp
import (
- "context"
- "testing"
+ "context"
+ "testing"
- "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/llm"
)
type fakeCompleterOk struct{}
-func (fakeCompleterOk) Chat(context.Context, []llm.Message, ...llm.RequestOption) (string, error) { return "", nil }
+func (fakeCompleterOk) Chat(context.Context, []llm.Message, ...llm.RequestOption) (string, error) {
+ return "", nil
+}
func (fakeCompleterOk) Name() string { return "prov" }
func (fakeCompleterOk) DefaultModel() string { return "m" }
func (fakeCompleterOk) CodeCompletion(context.Context, string, string, int, string, float64) ([]string, error) {
- return []string{"SUGG"}, nil
+ return []string{"SUGG"}, nil
}
func TestProviderNativeCompletion_Success(t *testing.T) {
- s := newTestServer()
- s.llmClient = fakeCompleterOk{}
- // current line with dot trigger; position after dot
- current := "fmt."
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}}
- items, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false)
- if !ok || len(items) == 0 {
- t.Fatalf("expected provider-native items")
- }
- if items[0].Label == "" || items[0].TextEdit == nil {
- t.Fatalf("unexpected completion item: %+v", items[0])
- }
+ s := newTestServer()
+ s.llmClient = fakeCompleterOk{}
+ // current line with dot trigger; position after dot
+ current := "fmt."
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}}
+ items, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false)
+ if !ok || len(items) == 0 {
+ t.Fatalf("expected provider-native items")
+ }
+ if items[0].Label == "" || items[0].TextEdit == nil {
+ t.Fatalf("unexpected completion item: %+v", items[0])
+ }
}
type fakeCompleterIndent struct{}
-func (fakeCompleterIndent) Chat(context.Context, []llm.Message, ...llm.RequestOption) (string, error) { return "", nil }
+func (fakeCompleterIndent) Chat(context.Context, []llm.Message, ...llm.RequestOption) (string, error) {
+ return "", nil
+}
func (fakeCompleterIndent) Name() string { return "prov" }
func (fakeCompleterIndent) DefaultModel() string { return "m" }
func (fakeCompleterIndent) CodeCompletion(context.Context, string, string, int, string, float64) ([]string, error) {
- return []string{"a\nb"}, nil
+ return []string{"a\nb"}, nil
}
func TestProviderNativeCompletion_IndentWithDoubleOpen(t *testing.T) {
- s := newTestServer()
- s.llmClient = fakeCompleterIndent{}
- current := " >>do>" // leading indent + double-open marker
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}}
- items, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false)
- if !ok || len(items) == 0 { t.Fatalf("expected provider-native items") }
- if items[0].TextEdit == nil { t.Fatalf("expected text edit") }
- if got := items[0].TextEdit.NewText; len(got) < 2 || got[:2] != " " {
- t.Fatalf("expected indentation applied, got %q", got)
- }
+ s := newTestServer()
+ s.llmClient = fakeCompleterIndent{}
+ current := " >>do>" // leading indent + double-open marker
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}}
+ items, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false)
+ if !ok || len(items) == 0 {
+ t.Fatalf("expected provider-native items")
+ }
+ if items[0].TextEdit == nil {
+ t.Fatalf("expected text edit")
+ }
+ if got := items[0].TextEdit.NewText; len(got) < 2 || got[:2] != " " {
+ t.Fatalf("expected indentation applied, got %q", got)
+ }
}
diff --git a/internal/lsp/rewrite_diagnostics_realism_test.go b/internal/lsp/rewrite_diagnostics_realism_test.go
index 87ff571..eb7ff5a 100644
--- a/internal/lsp/rewrite_diagnostics_realism_test.go
+++ b/internal/lsp/rewrite_diagnostics_realism_test.go
@@ -1,62 +1,77 @@
package lsp
import (
- "encoding/json"
- "testing"
+ "encoding/json"
+ "testing"
)
func TestResolveRewrite_MultiLine_PreservesRange(t *testing.T) {
- s := newTestServer()
- s.llmClient = fakeLLM{resp: "line1\nline2"}
- uri := "file:///x.go"
- s.setDocument(uri, "package p\nvar a=1\n")
- r := Range{Start: Position{Line:1, Character:0}, End: Position{Line:1, Character:5}}
- payload := struct {
- Type string `json:"type"`
- URI string `json:"uri"`
- Range Range `json:"range"`
- Instruction string `json:"instruction"`
- Selection string `json:"selection"`
- }{Type: "rewrite", URI: uri, Range: r, Instruction: "expand", Selection: "var a"}
- raw, _ := json.Marshal(payload)
- ca := CodeAction{Title: "Hexai: rewrite selection", Data: raw}
- resolved, ok := s.resolveCodeAction(ca)
- if !ok || resolved.Edit == nil { t.Fatalf("expected resolved rewrite edit") }
- edits := resolved.Edit.Changes[uri]
- if len(edits) != 1 { t.Fatalf("expected 1 edit") }
- if edits[0].Range != r { t.Fatalf("range mismatch: got %+v want %+v", edits[0].Range, r) }
- if edits[0].NewText == "" || !containsNewline(edits[0].NewText) {
- t.Fatalf("expected multi-line replacement text, got %q", edits[0].NewText)
- }
+ s := newTestServer()
+ s.llmClient = fakeLLM{resp: "line1\nline2"}
+ uri := "file:///x.go"
+ s.setDocument(uri, "package p\nvar a=1\n")
+ r := Range{Start: Position{Line: 1, Character: 0}, End: Position{Line: 1, Character: 5}}
+ payload := struct {
+ Type string `json:"type"`
+ URI string `json:"uri"`
+ Range Range `json:"range"`
+ Instruction string `json:"instruction"`
+ Selection string `json:"selection"`
+ }{Type: "rewrite", URI: uri, Range: r, Instruction: "expand", Selection: "var a"}
+ raw, _ := json.Marshal(payload)
+ ca := CodeAction{Title: "Hexai: rewrite selection", Data: raw}
+ resolved, ok := s.resolveCodeAction(ca)
+ if !ok || resolved.Edit == nil {
+ t.Fatalf("expected resolved rewrite edit")
+ }
+ edits := resolved.Edit.Changes[uri]
+ if len(edits) != 1 {
+ t.Fatalf("expected 1 edit")
+ }
+ if edits[0].Range != r {
+ t.Fatalf("range mismatch: got %+v want %+v", edits[0].Range, r)
+ }
+ if edits[0].NewText == "" || !containsNewline(edits[0].NewText) {
+ t.Fatalf("expected multi-line replacement text, got %q", edits[0].NewText)
+ }
}
func TestResolveDiagnostics_MultiLine_PreservesRange(t *testing.T) {
- s := newTestServer()
- s.llmClient = fakeLLM{resp: "fixed\nvalue"}
- uri := "file:///x.go"
- s.setDocument(uri, "package p\nvar x = 1\n")
- r := Range{Start: Position{Line:1, Character:0}, End: Position{Line:1, Character:10}}
- payload := struct {
- Type string `json:"type"`
- URI string `json:"uri"`
- Range Range `json:"range"`
- Selection string `json:"selection"`
- Diagnostics []Diagnostic `json:"diagnostics"`
- }{Type: "diagnostics", URI: uri, Range: r, Selection: "var x = 1", Diagnostics: []Diagnostic{{Range: Range{Start: Position{Line:1}, End: Position{Line:1, Character:5}}, Message: "msg"}}}
- raw, _ := json.Marshal(payload)
- ca := CodeAction{Title: "Hexai: resolve diagnostics", Data: raw}
- resolved, ok := s.resolveCodeAction(ca)
- if !ok || resolved.Edit == nil { t.Fatalf("expected resolved diagnostics edit") }
- edits := resolved.Edit.Changes[uri]
- if len(edits) != 1 { t.Fatalf("expected 1 edit") }
- if edits[0].Range != r { t.Fatalf("range mismatch: got %+v want %+v", edits[0].Range, r) }
- if edits[0].NewText == "" || !containsNewline(edits[0].NewText) {
- t.Fatalf("expected multi-line replacement text, got %q", edits[0].NewText)
- }
+ s := newTestServer()
+ s.llmClient = fakeLLM{resp: "fixed\nvalue"}
+ uri := "file:///x.go"
+ s.setDocument(uri, "package p\nvar x = 1\n")
+ r := Range{Start: Position{Line: 1, Character: 0}, End: Position{Line: 1, Character: 10}}
+ payload := struct {
+ Type string `json:"type"`
+ URI string `json:"uri"`
+ Range Range `json:"range"`
+ Selection string `json:"selection"`
+ Diagnostics []Diagnostic `json:"diagnostics"`
+ }{Type: "diagnostics", URI: uri, Range: r, Selection: "var x = 1", Diagnostics: []Diagnostic{{Range: Range{Start: Position{Line: 1}, End: Position{Line: 1, Character: 5}}, Message: "msg"}}}
+ raw, _ := json.Marshal(payload)
+ ca := CodeAction{Title: "Hexai: resolve diagnostics", Data: raw}
+ resolved, ok := s.resolveCodeAction(ca)
+ if !ok || resolved.Edit == nil {
+ t.Fatalf("expected resolved diagnostics edit")
+ }
+ edits := resolved.Edit.Changes[uri]
+ if len(edits) != 1 {
+ t.Fatalf("expected 1 edit")
+ }
+ if edits[0].Range != r {
+ t.Fatalf("range mismatch: got %+v want %+v", edits[0].Range, r)
+ }
+ if edits[0].NewText == "" || !containsNewline(edits[0].NewText) {
+ t.Fatalf("expected multi-line replacement text, got %q", edits[0].NewText)
+ }
}
func containsNewline(s string) bool {
- for i := 0; i < len(s); i++ { if s[i] == '\n' { return true } }
- return false
+ for i := 0; i < len(s); i++ {
+ if s[i] == '\n' {
+ return true
+ }
+ }
+ return false
}
-
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index e040d08..fa4467b 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -2,15 +2,16 @@
package lsp
import (
- "bufio"
- "encoding/json"
- "codeberg.org/snonux/hexai/internal/llm"
- "codeberg.org/snonux/hexai/internal/logging"
- "io"
- "log"
- "strings"
- "sync"
- "time"
+ "bufio"
+ "encoding/json"
+ "io"
+ "log"
+ "strings"
+ "sync"
+ "time"
+
+ "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/logging"
)
// Server implements a minimal LSP over stdio.
@@ -27,8 +28,8 @@ type Server struct {
maxTokens int
contextMode string
windowLines int
- maxContextTokens int
- triggerChars []string
+ maxContextTokens int
+ triggerChars []string
// If set, used as the LSP coding temperature for all LLM calls
codingTemperature *float64
// LLM request stats
@@ -40,46 +41,46 @@ type Server struct {
// Small LRU cache for recent code completion outputs (keyed by context)
compCache map[string]string
compCacheOrder []string // most-recent at end; cap ~10
- // Outgoing JSON-RPC id counter for server-initiated requests
- nextID int64
+ // Outgoing JSON-RPC id counter for server-initiated requests
+ nextID int64
// Minimum identifier chars required for manual invoke to bypass prefix checks
manualInvokeMinPrefix int
- // Debounce and throttle settings
- completionDebounce time.Duration
- throttleInterval time.Duration
- lastLLMCall time.Time
+ // Debounce and throttle settings
+ completionDebounce time.Duration
+ throttleInterval time.Duration
+ lastLLMCall time.Time
- // Dispatch table for JSON-RPC methods → handler functions
- handlers map[string]func(Request)
+ // Dispatch table for JSON-RPC methods → handler functions
+ handlers map[string]func(Request)
- // Configurable trigger characters
- inlineOpen string
- inlineClose string
- chatSuffix string
- chatPrefixes []string
+ // Configurable trigger characters
+ inlineOpen string
+ inlineClose string
+ chatSuffix string
+ chatPrefixes []string
}
// ServerOptions collects configuration for NewServer to avoid long parameter lists.
type ServerOptions struct {
- LogContext bool
- MaxTokens int
- ContextMode string
- WindowLines int
- MaxContextTokens int
+ LogContext bool
+ MaxTokens int
+ ContextMode string
+ WindowLines int
+ MaxContextTokens int
- Client llm.Client
- TriggerCharacters []string
- CodingTemperature *float64
- ManualInvokeMinPrefix int
- CompletionDebounceMs int
- CompletionThrottleMs int
+ Client llm.Client
+ TriggerCharacters []string
+ CodingTemperature *float64
+ ManualInvokeMinPrefix int
+ CompletionDebounceMs int
+ CompletionThrottleMs int
- // Inline/chat triggers
- InlineOpen string
- InlineClose string
- ChatSuffix string
- ChatPrefixes []string
+ // Inline/chat triggers
+ InlineOpen string
+ InlineClose string
+ ChatSuffix string
+ ChatPrefixes []string
}
func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions) *Server {
@@ -113,38 +114,62 @@ func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions)
} else {
s.triggerChars = append([]string{}, opts.TriggerCharacters...)
}
- s.codingTemperature = opts.CodingTemperature
- s.compCache = make(map[string]string)
- s.manualInvokeMinPrefix = opts.ManualInvokeMinPrefix
- if opts.CompletionDebounceMs > 0 {
- s.completionDebounce = time.Duration(opts.CompletionDebounceMs) * time.Millisecond
- }
- if opts.CompletionThrottleMs > 0 {
- s.throttleInterval = time.Duration(opts.CompletionThrottleMs) * time.Millisecond
- }
- // Trigger character config (with sane defaults if missing)
- if strings.TrimSpace(opts.InlineOpen) == "" { s.inlineOpen = ">" } else { s.inlineOpen = opts.InlineOpen }
- if strings.TrimSpace(opts.InlineClose) == "" { s.inlineClose = ">" } else { s.inlineClose = opts.InlineClose }
- if strings.TrimSpace(opts.ChatSuffix) == "" { s.chatSuffix = ">" } else { s.chatSuffix = opts.ChatSuffix }
- if len(opts.ChatPrefixes) == 0 { s.chatPrefixes = []string{"?","!",":",";"} } else { s.chatPrefixes = append([]string{}, opts.ChatPrefixes...) }
+ s.codingTemperature = opts.CodingTemperature
+ s.compCache = make(map[string]string)
+ s.manualInvokeMinPrefix = opts.ManualInvokeMinPrefix
+ if opts.CompletionDebounceMs > 0 {
+ s.completionDebounce = time.Duration(opts.CompletionDebounceMs) * time.Millisecond
+ }
+ if opts.CompletionThrottleMs > 0 {
+ s.throttleInterval = time.Duration(opts.CompletionThrottleMs) * time.Millisecond
+ }
+ // Trigger character config (with sane defaults if missing)
+ if strings.TrimSpace(opts.InlineOpen) == "" {
+ s.inlineOpen = ">"
+ } else {
+ s.inlineOpen = opts.InlineOpen
+ }
+ if strings.TrimSpace(opts.InlineClose) == "" {
+ s.inlineClose = ">"
+ } else {
+ s.inlineClose = opts.InlineClose
+ }
+ if strings.TrimSpace(opts.ChatSuffix) == "" {
+ s.chatSuffix = ">"
+ } else {
+ s.chatSuffix = opts.ChatSuffix
+ }
+ if len(opts.ChatPrefixes) == 0 {
+ s.chatPrefixes = []string{"?", "!", ":", ";"}
+ } else {
+ s.chatPrefixes = append([]string{}, opts.ChatPrefixes...)
+ }
- // Assign package-level inline trigger chars for free helper functions
- if s.inlineOpen != "" { inlineOpenChar = s.inlineOpen[0] }
- if s.inlineClose != "" { inlineCloseChar = s.inlineClose[0] }
- if s.chatSuffix != "" { chatSuffixChar = s.chatSuffix[0] }
- if len(s.chatPrefixes) > 0 { chatPrefixSingles = append([]string{}, s.chatPrefixes...) }
+ // Assign package-level inline trigger chars for free helper functions
+ if s.inlineOpen != "" {
+ inlineOpenChar = s.inlineOpen[0]
+ }
+ if s.inlineClose != "" {
+ inlineCloseChar = s.inlineClose[0]
+ }
+ if s.chatSuffix != "" {
+ chatSuffixChar = s.chatSuffix[0]
+ }
+ if len(s.chatPrefixes) > 0 {
+ chatPrefixSingles = append([]string{}, s.chatPrefixes...)
+ }
// Initialize dispatch table
s.handlers = map[string]func(Request){
- "initialize": s.handleInitialize,
- "initialized": func(_ Request) { s.handleInitialized() },
- "shutdown": s.handleShutdown,
- "exit": func(_ Request) { s.handleExit() },
- "textDocument/didOpen": s.handleDidOpen,
- "textDocument/didChange": s.handleDidChange,
- "textDocument/didClose": s.handleDidClose,
- "textDocument/completion": s.handleCompletion,
- "textDocument/codeAction": s.handleCodeAction,
- "codeAction/resolve": s.handleCodeActionResolve,
+ "initialize": s.handleInitialize,
+ "initialized": func(_ Request) { s.handleInitialized() },
+ "shutdown": s.handleShutdown,
+ "exit": func(_ Request) { s.handleExit() },
+ "textDocument/didOpen": s.handleDidOpen,
+ "textDocument/didChange": s.handleDidChange,
+ "textDocument/didClose": s.handleDidClose,
+ "textDocument/completion": s.handleCompletion,
+ "textDocument/codeAction": s.handleCodeAction,
+ "codeAction/resolve": s.handleCodeActionResolve,
"workspace/executeCommand": s.handleExecuteCommand,
}
return s
diff --git a/internal/lsp/testfakes_test.go b/internal/lsp/testfakes_test.go
index 41fa705..3d42587 100644
--- a/internal/lsp/testfakes_test.go
+++ b/internal/lsp/testfakes_test.go
@@ -1,8 +1,9 @@
package lsp
import (
- "context"
- "codeberg.org/snonux/hexai/internal/llm"
+ "context"
+
+ "codeberg.org/snonux/hexai/internal/llm"
)
// countingLLM counts Chat calls; minimal implementation for tests that need
diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go
index c30fbd1..bdd01a1 100644
--- a/internal/lsp/transport.go
+++ b/internal/lsp/transport.go
@@ -4,11 +4,12 @@ package lsp
import (
"encoding/json"
"fmt"
- "codeberg.org/snonux/hexai/internal/logging"
"io"
"net/textproto"
"strconv"
"strings"
+
+ "codeberg.org/snonux/hexai/internal/logging"
)
func (s *Server) readMessage() ([]byte, error) {
diff --git a/internal/lsp/transport_test.go b/internal/lsp/transport_test.go
index c00b405..7ea47c4 100644
--- a/internal/lsp/transport_test.go
+++ b/internal/lsp/transport_test.go
@@ -1,40 +1,57 @@
package lsp
import (
- "bufio"
- "bytes"
- "testing"
+ "bufio"
+ "bytes"
+ "testing"
)
func TestReadMessage_ParsesContentLength(t *testing.T) {
- body := []byte(`{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
- frame := []byte("Content-Length: ")
- frame = append(frame, []byte(stringInt(len(body)))...)
- frame = append(frame, []byte("\r\n\r\n")...)
- frame = append(frame, body...)
- s := &Server{in: bufio.NewReader(bytes.NewReader(frame))}
- got, err := s.readMessage()
- if err != nil || string(got) != string(body) { t.Fatalf("readMessage failed: %v %q", err, string(got)) }
+ body := []byte(`{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
+ frame := []byte("Content-Length: ")
+ frame = append(frame, []byte(stringInt(len(body)))...)
+ frame = append(frame, []byte("\r\n\r\n")...)
+ frame = append(frame, body...)
+ s := &Server{in: bufio.NewReader(bytes.NewReader(frame))}
+ got, err := s.readMessage()
+ if err != nil || string(got) != string(body) {
+ t.Fatalf("readMessage failed: %v %q", err, string(got))
+ }
}
func TestWriteMessage_FramesJSON(t *testing.T) {
- var out bytes.Buffer
- s := &Server{out: &out}
- payload := struct{ JSONRPC string `json:"jsonrpc"`; Ping string `json:"ping"` }{JSONRPC: "2.0", Ping: "pong"}
- s.writeMessage(payload)
- got := out.String()
- if !bytes.HasPrefix([]byte(got), []byte("Content-Length: ")) { t.Fatalf("missing Content-Length header: %q", got) }
- // Header/body delimiter must be present
- idx := bytes.Index([]byte(got), []byte("\r\n\r\n"))
- if idx < 0 { t.Fatalf("missing CRLFCRLF delimiter: %q", got) }
- body := got[idx+4:]
- if body == "" || body[0] != '{' || body[len(body)-1] != '}' { t.Fatalf("body not JSON: %q", body) }
+ var out bytes.Buffer
+ s := &Server{out: &out}
+ payload := struct {
+ JSONRPC string `json:"jsonrpc"`
+ Ping string `json:"ping"`
+ }{JSONRPC: "2.0", Ping: "pong"}
+ s.writeMessage(payload)
+ got := out.String()
+ if !bytes.HasPrefix([]byte(got), []byte("Content-Length: ")) {
+ t.Fatalf("missing Content-Length header: %q", got)
+ }
+ // Header/body delimiter must be present
+ idx := bytes.Index([]byte(got), []byte("\r\n\r\n"))
+ if idx < 0 {
+ t.Fatalf("missing CRLFCRLF delimiter: %q", got)
+ }
+ body := got[idx+4:]
+ if body == "" || body[0] != '{' || body[len(body)-1] != '}' {
+ t.Fatalf("body not JSON: %q", body)
+ }
}
func stringInt(n int) string {
- if n == 0 { return "0" }
- var b [20]byte
- i := len(b)
- for n > 0 { i--; b[i] = byte('0' + n%10); n /= 10 }
- return string(b[i:])
+ if n == 0 {
+ return "0"
+ }
+ var b [20]byte
+ i := len(b)
+ for n > 0 {
+ i--
+ b[i] = byte('0' + n%10)
+ n /= 10
+ }
+ return string(b[i:])
}
diff --git a/internal/lsp/triggers_config_test.go b/internal/lsp/triggers_config_test.go
index 7fd6ecd..93d312a 100644
--- a/internal/lsp/triggers_config_test.go
+++ b/internal/lsp/triggers_config_test.go
@@ -1,74 +1,78 @@
package lsp
import (
- "bytes"
- "encoding/json"
- "io"
- "log"
- "testing"
- "time"
+ "bytes"
+ "encoding/json"
+ "io"
+ "log"
+ "testing"
+ "time"
)
func TestShouldSuppressForChatTriggerEOL_CustomConfig(t *testing.T) {
- s := newTestServer()
- // Customize: only ")#" at EOL suppresses
- s.chatSuffix = "#"
- s.chatPrefixes = []string{")"}
+ s := newTestServer()
+ // Customize: only ")#" at EOL suppresses
+ s.chatSuffix = "#"
+ s.chatPrefixes = []string{")"}
- p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line:0, Character:6}}
- if !s.shouldSuppressForChatTriggerEOL("ok)#", p) {
- t.Fatalf("expected suppression for custom prefix+suffix at EOL")
- }
- if s.shouldSuppressForChatTriggerEOL("ok]#", p) {
- t.Fatalf("did not expect suppression for non-matching prefix")
- }
+ p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x"}, Position: Position{Line: 0, Character: 6}}
+ if !s.shouldSuppressForChatTriggerEOL("ok)#", p) {
+ t.Fatalf("expected suppression for custom prefix+suffix at EOL")
+ }
+ if s.shouldSuppressForChatTriggerEOL("ok]#", p) {
+ t.Fatalf("did not expect suppression for non-matching prefix")
+ }
}
func TestNewServer_AssignsTriggerGlobals_AndParsingUsesThem(t *testing.T) {
- var out bytes.Buffer
- s := NewServer(bytes.NewReader(nil), &out, log.New(io.Discard, "", 0), ServerOptions{
- InlineOpen: "<", InlineClose: ">", ChatSuffix: ")", ChatPrefixes: []string{":"},
- })
- _ = s // ensure server constructed applies globals
- if inlineOpenChar != '<' || inlineCloseChar != '>' {
- t.Fatalf("inline markers not applied: %q %q", string(inlineOpenChar), string(inlineCloseChar))
- }
- if chatSuffixChar != ')' || len(chatPrefixSingles) == 0 || chatPrefixSingles[0] != ":" {
- t.Fatalf("chat markers not applied: suffix=%q prefixes=%v", string(chatSuffixChar), chatPrefixSingles)
- }
- if txt, l, r, ok := findStrictInlineTag("x<do>y"); !ok || txt != "do" || l != 1 || r != 5 {
- t.Fatalf("findStrictInlineTag failed: ok=%v txt=%q l=%d r=%d", ok, txt, l, r)
- }
- if got := stripTrailingTrigger("note:)"); got != "note:" {
- t.Fatalf("stripTrailingTrigger failed: %q", got)
- }
+ var out bytes.Buffer
+ s := NewServer(bytes.NewReader(nil), &out, log.New(io.Discard, "", 0), ServerOptions{
+ InlineOpen: "<", InlineClose: ">", ChatSuffix: ")", ChatPrefixes: []string{":"},
+ })
+ _ = s // ensure server constructed applies globals
+ if inlineOpenChar != '<' || inlineCloseChar != '>' {
+ t.Fatalf("inline markers not applied: %q %q", string(inlineOpenChar), string(inlineCloseChar))
+ }
+ if chatSuffixChar != ')' || len(chatPrefixSingles) == 0 || chatPrefixSingles[0] != ":" {
+ t.Fatalf("chat markers not applied: suffix=%q prefixes=%v", string(chatSuffixChar), chatPrefixSingles)
+ }
+ if txt, l, r, ok := findStrictInlineTag("x<do>y"); !ok || txt != "do" || l != 1 || r != 5 {
+ t.Fatalf("findStrictInlineTag failed: ok=%v txt=%q l=%d r=%d", ok, txt, l, r)
+ }
+ if got := stripTrailingTrigger("note:)"); got != "note:" {
+ t.Fatalf("stripTrailingTrigger failed: %q", got)
+ }
}
func TestIsTriggerEvent_BareDoubleOpenBlocksEvenWithContextTriggerChar(t *testing.T) {
- s := newTestServer()
- s.inlineOpen = ">" // ensure bare ">>" check is active
- s.triggerChars = []string{"."}
- // LSP context indicates TriggerCharacter '.' but current line is bare ">>"
- ctx := struct {
- TriggerKind int `json:"triggerKind"`
- TriggerCharacter string `json:"triggerCharacter"`
- }{TriggerKind: 2, TriggerCharacter: "."}
- raw, _ := json.Marshal(ctx)
- p := CompletionParams{Position: Position{Line: 0, Character: 2}, Context: json.RawMessage(raw)}
- if s.isTriggerEvent(p, ">>") {
- t.Fatalf("bare double-open should block trigger event even with context trigger char")
- }
+ s := newTestServer()
+ s.inlineOpen = ">" // ensure bare ">>" check is active
+ s.triggerChars = []string{"."}
+ // LSP context indicates TriggerCharacter '.' but current line is bare ">>"
+ ctx := struct {
+ TriggerKind int `json:"triggerKind"`
+ TriggerCharacter string `json:"triggerCharacter"`
+ }{TriggerKind: 2, TriggerCharacter: "."}
+ raw, _ := json.Marshal(ctx)
+ p := CompletionParams{Position: Position{Line: 0, Character: 2}, Context: json.RawMessage(raw)}
+ if s.isTriggerEvent(p, ">>") {
+ t.Fatalf("bare double-open should block trigger event even with context trigger char")
+ }
}
func TestDetectAndHandleChat_CustomConfig_InsertsReply(t *testing.T) {
- var out bytes.Buffer
- s := NewServer(bytes.NewReader(nil), &out, log.New(io.Discard, "", 0), ServerOptions{ChatSuffix: "#", ChatPrefixes: []string{")"}})
- s.llmClient = fakeLLM{resp: "Hello\nmulti-line reply"}
- uri := "file:///chat2.go"
- s.setDocument(uri, "ok)#\n\n")
- out.Reset()
- s.detectAndHandleChat(uri)
- // Give time for applyEdit request
- for i := 0; i < 20 && out.Len() == 0; i++ { time.Sleep(10 * time.Millisecond) }
- if out.Len() == 0 { t.Fatalf("no output written for custom chat config") }
+ var out bytes.Buffer
+ s := NewServer(bytes.NewReader(nil), &out, log.New(io.Discard, "", 0), ServerOptions{ChatSuffix: "#", ChatPrefixes: []string{")"}})
+ s.llmClient = fakeLLM{resp: "Hello\nmulti-line reply"}
+ uri := "file:///chat2.go"
+ s.setDocument(uri, "ok)#\n\n")
+ out.Reset()
+ s.detectAndHandleChat(uri)
+ // Give time for applyEdit request
+ for i := 0; i < 20 && out.Len() == 0; i++ {
+ time.Sleep(10 * time.Millisecond)
+ }
+ if out.Len() == 0 {
+ t.Fatalf("no output written for custom chat config")
+ }
}
diff --git a/internal/lsp/types.go b/internal/lsp/types.go
index 1598b96..fa9e71f 100644
--- a/internal/lsp/types.go
+++ b/internal/lsp/types.go
@@ -124,8 +124,8 @@ type CodeActionParams struct {
}
type WorkspaceEdit struct {
- Changes map[string][]TextEdit `json:"changes,omitempty"`
- DocumentChanges []any `json:"documentChanges,omitempty"`
+ Changes map[string][]TextEdit `json:"changes,omitempty"`
+ DocumentChanges []any `json:"documentChanges,omitempty"`
}
// ApplyWorkspaceEditParams is the client request payload for workspace/applyEdit.
@@ -135,34 +135,34 @@ type ApplyWorkspaceEditParams struct {
}
type CodeAction struct {
- Title string `json:"title"`
- Kind string `json:"kind,omitempty"`
- Edit *WorkspaceEdit `json:"edit,omitempty"`
- Data json.RawMessage `json:"data,omitempty"`
- Command *Command `json:"command,omitempty"`
+ Title string `json:"title"`
+ Kind string `json:"kind,omitempty"`
+ Edit *WorkspaceEdit `json:"edit,omitempty"`
+ Data json.RawMessage `json:"data,omitempty"`
+ Command *Command `json:"command,omitempty"`
}
// Extended workspace edit types (minimal subset)
type TextDocumentEdit struct {
- TextDocument VersionedTextDocumentIdentifier `json:"textDocument"`
- Edits []TextEdit `json:"edits"`
+ TextDocument VersionedTextDocumentIdentifier `json:"textDocument"`
+ Edits []TextEdit `json:"edits"`
}
type CreateFile struct {
- Kind string `json:"kind"`
- URI string `json:"uri"`
+ Kind string `json:"kind"`
+ URI string `json:"uri"`
}
// Commands
type Command struct {
- Title string `json:"title"`
- Command string `json:"command"`
- Arguments []any `json:"arguments,omitempty"`
+ Title string `json:"title"`
+ Command string `json:"command"`
+ Arguments []any `json:"arguments,omitempty"`
}
type ExecuteCommandParams struct {
- Command string `json:"command"`
- Arguments []any `json:"arguments,omitempty"`
+ Command string `json:"command"`
+ Arguments []any `json:"arguments,omitempty"`
}
// Diagnostics (subset needed for code action context)
diff --git a/internal/testutil/fixtures.go b/internal/testutil/fixtures.go
index 41993d3..be18242 100644
--- a/internal/testutil/fixtures.go
+++ b/internal/testutil/fixtures.go
@@ -2,26 +2,25 @@ package testutil
// MultilineDocBlock returns a realistic multi-line documentation block.
func MultilineDocBlock() string {
- return "// add adds two numbers\n// returns their sum"
+ return "// add adds two numbers\n// returns their sum"
}
// MultilineChatReply returns a multi-line assistant reply for chat tests.
func MultilineChatReply() string {
- return "Hello, world!\nThis is a multi-line reply."
+ return "Hello, world!\nThis is a multi-line reply."
}
// MultilineFunctionSuggestion returns a more realistic multi-line function body suggestion.
func MultilineFunctionSuggestion() string {
- return "(ctx context.Context, input string) (*CustData, error) {\n // TODO: implement\n return &CustData{}, nil\n}"
+ return "(ctx context.Context, input string) (*CustData, error) {\n // TODO: implement\n return &CustData{}, nil\n}"
}
// MarkdownCodeFence returns a fenced markdown snippet used in post-processing tests.
func MarkdownCodeFence() string {
- return "```go\nname := value\n```"
+ return "```go\nname := value\n```"
}
// MalformedJSON returns a deliberately malformed JSON string.
func MalformedJSON() string {
- return "{\"choices\":[{\"delta\":{\"content\":\"oops\"}}]"
+ return "{\"choices\":[{\"delta\":{\"content\":\"oops\"}}]"
}
-