summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-23 08:08:57 +0200
committerPaul Buetow <paul@buetow.org>2026-03-23 08:08:57 +0200
commit4958ea5100ebf8d4ff9fd818b7bc59d01989feb4 (patch)
tree44bc03cdfa7d58ea948023c87bfbe8a37fe315c9
parentba929c035c7c74113d061c57cc5b500af0b20b74 (diff)
fix: address all HIGH-severity code quality audit findings
- lsp/server.go: track request goroutines in inflight WaitGroup to prevent use-after-close writes on shutdown - lsp/llm_client_registry.go: acquire write lock before calling build() to eliminate TOCTOU race on cache population - lsp/handlers_codeaction.go: resolveSimplifyCodeAction now uses PromptCodeActionSimplify{System,User} (was wrongly using rewrite prompts) - askcli/taskexport.go: remove exported MustParseTaskExport to prevent panic on malformed external input; move to unexported test helper - cmd/ask/main.go: print error to stderr before os.Exit - llm/{openai,ollama,openrouter}.go: add interface satisfaction assertions - integrationtests/ask_test.go: replace type assertions with errors.As for robust exec.ExitError unwrapping Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
-rw-r--r--cmd/ask/main.go3
-rw-r--r--integrationtests/ask_test.go17
-rw-r--r--internal/askcli/taskexport.go8
-rw-r--r--internal/askcli/taskexport_test.go20
-rw-r--r--internal/llm/ollama.go6
-rw-r--r--internal/llm/openai.go6
-rw-r--r--internal/llm/openrouter.go6
-rw-r--r--internal/lsp/handlers_codeaction.go8
-rw-r--r--internal/lsp/llm_client_registry.go26
-rw-r--r--internal/lsp/server.go8
10 files changed, 73 insertions, 35 deletions
diff --git a/cmd/ask/main.go b/cmd/ask/main.go
index 72b67e3..afab992 100644
--- a/cmd/ask/main.go
+++ b/cmd/ask/main.go
@@ -2,6 +2,7 @@ package main
import (
"context"
+ "fmt"
"os"
"codeberg.org/snonux/hexai/internal/askcli"
@@ -11,6 +12,8 @@ func main() {
d := askcli.NewDispatcher(nil)
code, err := d.Dispatch(context.Background(), os.Args[1:], os.Stdin, os.Stdout, os.Stderr)
if err != nil {
+ // Print the internal error so callers get a useful diagnostic message.
+ fmt.Fprintln(os.Stderr, err)
os.Exit(code)
}
if code != 0 {
diff --git a/integrationtests/ask_test.go b/integrationtests/ask_test.go
index 7dfbcb4..a762b9d 100644
--- a/integrationtests/ask_test.go
+++ b/integrationtests/ask_test.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"fmt"
"os"
"os/exec"
@@ -56,8 +57,8 @@ func runAsk(ctx context.Context, args []string) (stdout, stderr bytes.Buffer, ex
if err == nil {
return
}
- ee, ok := err.(*exec.ExitError)
- if !ok {
+ var ee *exec.ExitError
+ if !errors.As(err, &ee) {
return bytes.Buffer{}, stderr, -1
}
return stdout, stderr, ee.ExitCode()
@@ -75,8 +76,8 @@ func runAskWithStdin(ctx context.Context, args []string, stdin string) (stdout,
if err == nil {
return
}
- ee, ok := err.(*exec.ExitError)
- if !ok {
+ var ee *exec.ExitError
+ if !errors.As(err, &ee) {
return bytes.Buffer{}, stderr, -1
}
return stdout, stderr, ee.ExitCode()
@@ -91,8 +92,8 @@ func runTask(ctx context.Context, args []string) (stdout, stderr bytes.Buffer, e
if err == nil {
return
}
- ee, ok := err.(*exec.ExitError)
- if !ok {
+ var ee *exec.ExitError
+ if !errors.As(err, &ee) {
return bytes.Buffer{}, stderr, -1
}
return stdout, stderr, ee.ExitCode()
@@ -108,8 +109,8 @@ func runTaskWithStdin(ctx context.Context, args []string, stdin string) (stdout,
if err == nil {
return
}
- ee, ok := err.(*exec.ExitError)
- if !ok {
+ var ee *exec.ExitError
+ if !errors.As(err, &ee) {
return bytes.Buffer{}, stderr, -1
}
return stdout, stderr, ee.ExitCode()
diff --git a/internal/askcli/taskexport.go b/internal/askcli/taskexport.go
index 9841821..ca67ef5 100644
--- a/internal/askcli/taskexport.go
+++ b/internal/askcli/taskexport.go
@@ -32,11 +32,3 @@ func ParseTaskExport(r io.Reader) ([]TaskExport, error) {
}
return tasks, nil
}
-
-func MustParseTaskExport(data []byte) []TaskExport {
- var tasks []TaskExport
- if err := json.Unmarshal(data, &tasks); err != nil {
- panic(fmt.Sprintf("failed to parse task export JSON: %v", err))
- }
- return tasks
-}
diff --git a/internal/askcli/taskexport_test.go b/internal/askcli/taskexport_test.go
index e7779aa..799415a 100644
--- a/internal/askcli/taskexport_test.go
+++ b/internal/askcli/taskexport_test.go
@@ -2,11 +2,21 @@ package askcli
import (
"encoding/json"
+ "fmt"
"io"
"strings"
"testing"
)
+// mustParseTaskExport is a test-only helper that panics on parse failure.
+func mustParseTaskExport(data []byte) []TaskExport {
+ var tasks []TaskExport
+ if err := json.Unmarshal(data, &tasks); err != nil {
+ panic(fmt.Sprintf("failed to parse task export JSON: %v", err))
+ }
+ return tasks
+}
+
func TestParseTaskExport_ValidJSON(t *testing.T) {
data := `[{"uuid":"abc123","description":"Test task","status":"pending","priority":"M","tags":["cli"],"urgency":10.5,"depends":[]}]`
tasks, err := ParseTaskExport(strings.NewReader(data))
@@ -31,18 +41,18 @@ func TestParseTaskExport_InvalidJSON(t *testing.T) {
}
}
-func TestMustParseTaskExport_Panics(t *testing.T) {
+func TestMustParseTaskExportHelper_Panics(t *testing.T) {
defer func() {
if r := recover(); r == nil {
- t.Fatal("MustParseTaskExport should panic on invalid JSON")
+ t.Fatal("mustParseTaskExport should panic on invalid JSON")
}
}()
- MustParseTaskExport([]byte("not json"))
+ mustParseTaskExport([]byte("not json"))
}
-func TestMustParseTaskExport_ValidJSON(t *testing.T) {
+func TestMustParseTaskExportHelper_ValidJSON(t *testing.T) {
data := []byte(`[{"uuid":"xyz789","description":"Another task","status":"completed","priority":"H","tags":["agent"],"urgency":15.0,"depends":["dep1"]}]`)
- tasks := MustParseTaskExport(data)
+ tasks := mustParseTaskExport(data)
if len(tasks) != 1 {
t.Fatalf("len(tasks) = %d, want 1", len(tasks))
}
diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go
index ade62a9..be93ab0 100644
--- a/internal/llm/ollama.go
+++ b/internal/llm/ollama.go
@@ -23,6 +23,12 @@ type ollamaClient struct {
defaultTemperature *float64
}
+// Ensure ollamaClient implements Client and Streamer.
+var (
+ _ Client = ollamaClient{}
+ _ Streamer = ollamaClient{}
+)
+
type ollamaChatRequest struct {
Model string `json:"model"`
Messages []oaMessage `json:"messages"`
diff --git a/internal/llm/openai.go b/internal/llm/openai.go
index eccd558..cf18d9b 100644
--- a/internal/llm/openai.go
+++ b/internal/llm/openai.go
@@ -24,6 +24,12 @@ type openAIClient struct {
defaultTemperature *float64
}
+// Ensure openAIClient implements Client and Streamer.
+var (
+ _ Client = openAIClient{}
+ _ Streamer = openAIClient{}
+)
+
type oaChatRequest struct {
Model string `json:"model"`
Messages []oaMessage `json:"messages"`
diff --git a/internal/llm/openrouter.go b/internal/llm/openrouter.go
index 60a594a..451e9ad 100644
--- a/internal/llm/openrouter.go
+++ b/internal/llm/openrouter.go
@@ -21,6 +21,12 @@ type openRouterClient struct {
defaultTemperature *float64
}
+// Ensure openRouterClient implements Client and Streamer.
+var (
+ _ Client = openRouterClient{}
+ _ Streamer = openRouterClient{}
+)
+
func init() {
RegisterProvider("openrouter", openRouterProviderFactory)
}
diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go
index 1d8a36f..aba2113 100644
--- a/internal/lsp/handlers_codeaction.go
+++ b/internal/lsp/handlers_codeaction.go
@@ -266,10 +266,10 @@ func resolveGoTestCodeAction(s *Server, action CodeAction, payload codeActionPay
func resolveSimplifyCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) {
cfg := s.currentConfig()
- sys := cfg.PromptCodeActionRewriteSystem
- user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{
- "instruction": "Simplify and improve the code while preserving behavior. Return only the improved code.",
- "selection": payload.Selection,
+ // Use the simplify-specific prompts, not the rewrite prompts.
+ sys := cfg.PromptCodeActionSimplifySystem
+ user := renderTemplate(cfg.PromptCodeActionSimplifyUser, map[string]string{
+ "selection": payload.Selection,
})
return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 20*time.Second)
}
diff --git a/internal/lsp/llm_client_registry.go b/internal/lsp/llm_client_registry.go
index 6b9c722..a018158 100644
--- a/internal/lsp/llm_client_registry.go
+++ b/internal/lsp/llm_client_registry.go
@@ -71,17 +71,28 @@ func (r *llmClientRegistry) clientFor(spec requestSpec, cfg appconfig.App, build
if modelOverride == "" {
modelOverride = strings.TrimSpace(spec.fallbackModel)
}
+
+ // Acquire write lock before calling build to prevent concurrent goroutines
+ // from all passing the read-lock cache-miss check and issuing duplicate
+ // build calls for the same provider (TOCTOU race).
+ r.clientsMu.Lock()
+ defer r.clientsMu.Unlock()
+
+ // Re-check under the write lock; another goroutine may have populated the
+ // cache between our RUnlock and this Lock.
+ if provider == r.llmProvider && r.llmClient != nil {
+ return r.llmClient
+ }
+ if existing, ok := r.altClients[provider]; ok {
+ return existing
+ }
+
client, err := build(cfg, provider, modelOverride)
if err != nil {
logging.Logf("lsp ", "failed to build client for provider=%s: %v", provider, err)
- if baseClient != nil {
- return baseClient
- }
- return nil
+ return baseClient // may be nil; callers must handle nil
}
- r.clientsMu.Lock()
- defer r.clientsMu.Unlock()
if provider == r.llmProvider {
if r.llmClient == nil {
r.llmClient = client
@@ -89,9 +100,6 @@ func (r *llmClientRegistry) clientFor(spec requestSpec, cfg appconfig.App, build
}
return r.llmClient
}
- if existing, ok := r.altClients[provider]; ok {
- return existing
- }
if r.altClients == nil {
r.altClients = make(map[string]llm.Client)
}
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index c266e91..25c5e5c 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -378,7 +378,13 @@ func (s *Server) Run() error {
// A response from client; ignore
continue
}
- go s.handle(req)
+ // Track every request goroutine so Run's deferred inflight.Wait()
+ // catches them all and prevents use-after-close writes to s.out.
+ s.inflight.Add(1)
+ go func(r Request) {
+ defer s.inflight.Done()
+ s.handle(r)
+ }(req)
if s.exited.Load() {
return nil
}