diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-23 08:08:57 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-23 08:08:57 +0200 |
| commit | 4958ea5100ebf8d4ff9fd818b7bc59d01989feb4 (patch) | |
| tree | 44bc03cdfa7d58ea948023c87bfbe8a37fe315c9 | |
| parent | ba929c035c7c74113d061c57cc5b500af0b20b74 (diff) | |
fix: address all HIGH-severity code quality audit findings
- lsp/server.go: track request goroutines in inflight WaitGroup to
prevent use-after-close writes on shutdown
- lsp/llm_client_registry.go: acquire write lock before calling build()
to eliminate TOCTOU race on cache population
- lsp/handlers_codeaction.go: resolveSimplifyCodeAction now uses
PromptCodeActionSimplify{System,User} (was wrongly using rewrite prompts)
- askcli/taskexport.go: remove exported MustParseTaskExport to prevent
panic on malformed external input; move to unexported test helper
- cmd/ask/main.go: print error to stderr before os.Exit
- llm/{openai,ollama,openrouter}.go: add interface satisfaction assertions
- integrationtests/ask_test.go: replace type assertions with errors.As
for robust exec.ExitError unwrapping
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
| -rw-r--r-- | cmd/ask/main.go | 3 | ||||
| -rw-r--r-- | integrationtests/ask_test.go | 17 | ||||
| -rw-r--r-- | internal/askcli/taskexport.go | 8 | ||||
| -rw-r--r-- | internal/askcli/taskexport_test.go | 20 | ||||
| -rw-r--r-- | internal/llm/ollama.go | 6 | ||||
| -rw-r--r-- | internal/llm/openai.go | 6 | ||||
| -rw-r--r-- | internal/llm/openrouter.go | 6 | ||||
| -rw-r--r-- | internal/lsp/handlers_codeaction.go | 8 | ||||
| -rw-r--r-- | internal/lsp/llm_client_registry.go | 26 | ||||
| -rw-r--r-- | internal/lsp/server.go | 8 |
10 files changed, 73 insertions, 35 deletions
diff --git a/cmd/ask/main.go b/cmd/ask/main.go index 72b67e3..afab992 100644 --- a/cmd/ask/main.go +++ b/cmd/ask/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "os" "codeberg.org/snonux/hexai/internal/askcli" @@ -11,6 +12,8 @@ func main() { d := askcli.NewDispatcher(nil) code, err := d.Dispatch(context.Background(), os.Args[1:], os.Stdin, os.Stdout, os.Stderr) if err != nil { + // Print the internal error so callers get a useful diagnostic message. + fmt.Fprintln(os.Stderr, err) os.Exit(code) } if code != 0 { diff --git a/integrationtests/ask_test.go b/integrationtests/ask_test.go index 7dfbcb4..a762b9d 100644 --- a/integrationtests/ask_test.go +++ b/integrationtests/ask_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "os" "os/exec" @@ -56,8 +57,8 @@ func runAsk(ctx context.Context, args []string) (stdout, stderr bytes.Buffer, ex if err == nil { return } - ee, ok := err.(*exec.ExitError) - if !ok { + var ee *exec.ExitError + if !errors.As(err, &ee) { return bytes.Buffer{}, stderr, -1 } return stdout, stderr, ee.ExitCode() @@ -75,8 +76,8 @@ func runAskWithStdin(ctx context.Context, args []string, stdin string) (stdout, if err == nil { return } - ee, ok := err.(*exec.ExitError) - if !ok { + var ee *exec.ExitError + if !errors.As(err, &ee) { return bytes.Buffer{}, stderr, -1 } return stdout, stderr, ee.ExitCode() @@ -91,8 +92,8 @@ func runTask(ctx context.Context, args []string) (stdout, stderr bytes.Buffer, e if err == nil { return } - ee, ok := err.(*exec.ExitError) - if !ok { + var ee *exec.ExitError + if !errors.As(err, &ee) { return bytes.Buffer{}, stderr, -1 } return stdout, stderr, ee.ExitCode() @@ -108,8 +109,8 @@ func runTaskWithStdin(ctx context.Context, args []string, stdin string) (stdout, if err == nil { return } - ee, ok := err.(*exec.ExitError) - if !ok { + var ee *exec.ExitError + if !errors.As(err, &ee) { return bytes.Buffer{}, stderr, -1 } return stdout, stderr, ee.ExitCode() diff --git a/internal/askcli/taskexport.go b/internal/askcli/taskexport.go index 9841821..ca67ef5 100644 --- a/internal/askcli/taskexport.go +++ b/internal/askcli/taskexport.go @@ -32,11 +32,3 @@ func ParseTaskExport(r io.Reader) ([]TaskExport, error) { } return tasks, nil } - -func MustParseTaskExport(data []byte) []TaskExport { - var tasks []TaskExport - if err := json.Unmarshal(data, &tasks); err != nil { - panic(fmt.Sprintf("failed to parse task export JSON: %v", err)) - } - return tasks -} diff --git a/internal/askcli/taskexport_test.go b/internal/askcli/taskexport_test.go index e7779aa..799415a 100644 --- a/internal/askcli/taskexport_test.go +++ b/internal/askcli/taskexport_test.go @@ -2,11 +2,21 @@ package askcli import ( "encoding/json" + "fmt" "io" "strings" "testing" ) +// mustParseTaskExport is a test-only helper that panics on parse failure. +func mustParseTaskExport(data []byte) []TaskExport { + var tasks []TaskExport + if err := json.Unmarshal(data, &tasks); err != nil { + panic(fmt.Sprintf("failed to parse task export JSON: %v", err)) + } + return tasks +} + func TestParseTaskExport_ValidJSON(t *testing.T) { data := `[{"uuid":"abc123","description":"Test task","status":"pending","priority":"M","tags":["cli"],"urgency":10.5,"depends":[]}]` tasks, err := ParseTaskExport(strings.NewReader(data)) @@ -31,18 +41,18 @@ func TestParseTaskExport_InvalidJSON(t *testing.T) { } } -func TestMustParseTaskExport_Panics(t *testing.T) { +func TestMustParseTaskExportHelper_Panics(t *testing.T) { defer func() { if r := recover(); r == nil { - t.Fatal("MustParseTaskExport should panic on invalid JSON") + t.Fatal("mustParseTaskExport should panic on invalid JSON") } }() - MustParseTaskExport([]byte("not json")) + mustParseTaskExport([]byte("not json")) } -func TestMustParseTaskExport_ValidJSON(t *testing.T) { +func TestMustParseTaskExportHelper_ValidJSON(t *testing.T) { data := []byte(`[{"uuid":"xyz789","description":"Another task","status":"completed","priority":"H","tags":["agent"],"urgency":15.0,"depends":["dep1"]}]`) - tasks := MustParseTaskExport(data) + tasks := mustParseTaskExport(data) if len(tasks) != 1 { t.Fatalf("len(tasks) = %d, want 1", len(tasks)) } diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go index ade62a9..be93ab0 100644 --- a/internal/llm/ollama.go +++ b/internal/llm/ollama.go @@ -23,6 +23,12 @@ type ollamaClient struct { defaultTemperature *float64 } +// Ensure ollamaClient implements Client and Streamer. +var ( + _ Client = ollamaClient{} + _ Streamer = ollamaClient{} +) + type ollamaChatRequest struct { Model string `json:"model"` Messages []oaMessage `json:"messages"` diff --git a/internal/llm/openai.go b/internal/llm/openai.go index eccd558..cf18d9b 100644 --- a/internal/llm/openai.go +++ b/internal/llm/openai.go @@ -24,6 +24,12 @@ type openAIClient struct { defaultTemperature *float64 } +// Ensure openAIClient implements Client and Streamer. +var ( + _ Client = openAIClient{} + _ Streamer = openAIClient{} +) + type oaChatRequest struct { Model string `json:"model"` Messages []oaMessage `json:"messages"` diff --git a/internal/llm/openrouter.go b/internal/llm/openrouter.go index 60a594a..451e9ad 100644 --- a/internal/llm/openrouter.go +++ b/internal/llm/openrouter.go @@ -21,6 +21,12 @@ type openRouterClient struct { defaultTemperature *float64 } +// Ensure openRouterClient implements Client and Streamer. +var ( + _ Client = openRouterClient{} + _ Streamer = openRouterClient{} +) + func init() { RegisterProvider("openrouter", openRouterProviderFactory) } diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go index 1d8a36f..aba2113 100644 --- a/internal/lsp/handlers_codeaction.go +++ b/internal/lsp/handlers_codeaction.go @@ -266,10 +266,10 @@ func resolveGoTestCodeAction(s *Server, action CodeAction, payload codeActionPay func resolveSimplifyCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { cfg := s.currentConfig() - sys := cfg.PromptCodeActionRewriteSystem - user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{ - "instruction": "Simplify and improve the code while preserving behavior. Return only the improved code.", - "selection": payload.Selection, + // Use the simplify-specific prompts, not the rewrite prompts. + sys := cfg.PromptCodeActionSimplifySystem + user := renderTemplate(cfg.PromptCodeActionSimplifyUser, map[string]string{ + "selection": payload.Selection, }) return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 20*time.Second) } diff --git a/internal/lsp/llm_client_registry.go b/internal/lsp/llm_client_registry.go index 6b9c722..a018158 100644 --- a/internal/lsp/llm_client_registry.go +++ b/internal/lsp/llm_client_registry.go @@ -71,17 +71,28 @@ func (r *llmClientRegistry) clientFor(spec requestSpec, cfg appconfig.App, build if modelOverride == "" { modelOverride = strings.TrimSpace(spec.fallbackModel) } + + // Acquire write lock before calling build to prevent concurrent goroutines + // from all passing the read-lock cache-miss check and issuing duplicate + // build calls for the same provider (TOCTOU race). + r.clientsMu.Lock() + defer r.clientsMu.Unlock() + + // Re-check under the write lock; another goroutine may have populated the + // cache between our RUnlock and this Lock. + if provider == r.llmProvider && r.llmClient != nil { + return r.llmClient + } + if existing, ok := r.altClients[provider]; ok { + return existing + } + client, err := build(cfg, provider, modelOverride) if err != nil { logging.Logf("lsp ", "failed to build client for provider=%s: %v", provider, err) - if baseClient != nil { - return baseClient - } - return nil + return baseClient // may be nil; callers must handle nil } - r.clientsMu.Lock() - defer r.clientsMu.Unlock() if provider == r.llmProvider { if r.llmClient == nil { r.llmClient = client @@ -89,9 +100,6 @@ func (r *llmClientRegistry) clientFor(spec requestSpec, cfg appconfig.App, build } return r.llmClient } - if existing, ok := r.altClients[provider]; ok { - return existing - } if r.altClients == nil { r.altClients = make(map[string]llm.Client) } diff --git a/internal/lsp/server.go b/internal/lsp/server.go index c266e91..25c5e5c 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -378,7 +378,13 @@ func (s *Server) Run() error { // A response from client; ignore continue } - go s.handle(req) + // Track every request goroutine so Run's deferred inflight.Wait() + // catches them all and prevents use-after-close writes to s.out. + s.inflight.Add(1) + go func(r Request) { + defer s.inflight.Done() + s.handle(r) + }(req) if s.exited.Load() { return nil } |
