diff options
| -rw-r--r-- | docs/usage.md | 9 | ||||
| -rw-r--r-- | internal/hexaicli/cache.go | 121 | ||||
| -rw-r--r-- | internal/hexaicli/cache_test.go | 207 | ||||
| -rw-r--r-- | internal/hexaicli/run.go | 114 | ||||
| -rw-r--r-- | internal/hexaicli/run_test.go | 38 | ||||
| -rw-r--r-- | internal/hexaicli/run_timeout_test.go | 1 | ||||
| -rw-r--r-- | internal/hexaicli/testhelpers_test.go | 17 | ||||
| -rw-r--r-- | internal/llmutils/client.go | 20 | ||||
| -rw-r--r-- | internal/llmutils/client_test.go | 16 | ||||
| -rw-r--r-- | internal/version.go | 2 |
10 files changed, 493 insertions, 52 deletions
diff --git a/docs/usage.md b/docs/usage.md index 1ecf30f..8404969 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -87,6 +87,10 @@ Defaults: concise answers. If the prompt asks for commands, Hexai outputs only c Provider/model headers and run summaries are written to `stderr`, so `stdout` stays usable in pipes. +Successful CLI responses are cached for 24 hours under Hexai's cache directory. The cache key includes the input text, provider, resolved model, and effective CLI prompt text, so prompt/config changes automatically invalidate old entries. + +To rerun a multi-provider prompt and print only one response cleanly, use the existing numbered provider flags such as `-0`, `-1`, etc. That reuses the cached response for just that provider when available, which avoids the side-by-side layout on `stdout`. + `--tps-simulation` accepts either a fixed rate such as `20` or a range such as `12-18`. It streams positional arguments, piped stdin, or built-in placeholder text when no input is provided, so you can preview perceived model latency without needing a real provider or local hardware. ### Examples @@ -107,6 +111,11 @@ hexai 'install ripgrep on macOS' # Verbose explanation hexai 'install ripgrep on macOS and explain' +# Warm the cache with two configured CLI providers, then print only the +# second provider's cached response on a rerun +hexai 'summarize this file' +hexai -1 'summarize this file' + # Simulate 12-18 tokens per second with placeholder text hexai --tps-simulation 12-18 diff --git a/internal/hexaicli/cache.go b/internal/hexaicli/cache.go new file mode 100644 index 0000000..544eab0 --- /dev/null +++ b/internal/hexaicli/cache.go @@ -0,0 +1,121 @@ +package hexaicli + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "os" + "path/filepath" + "time" + + "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/stats" +) + +const cliResponseCacheTTL = 24 * time.Hour + +var nowCLIResponseCache = time.Now + +type cliResponseCacheKey struct { + Provider string `json:"provider"` + Model string `json:"model"` + Messages []llm.Message `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` +} + +type cliResponseCacheEntry struct { + CreatedAt time.Time `json:"created_at"` + Output string `json:"output"` +} + +func newCLIResponseCacheKey(provider, model string, req requestArgs, msgs []llm.Message) cliResponseCacheKey { + return cliResponseCacheKey{ + Provider: provider, + Model: model, + Messages: cloneCLIMessages(msgs), + MaxTokens: req.maxTokens, + Temperature: cloneCLITemperature(req.temperature), + } +} + +func lookupCLIResponseCache(key cliResponseCacheKey) (string, time.Duration, bool) { + path, ok := cliResponseCachePath(key) + if !ok { + return "", 0, false + } + entry, ok := loadCLIResponseCacheEntry(path) + if !ok { + return "", 0, false + } + age := nowCLIResponseCache().Sub(entry.CreatedAt) + if age > cliResponseCacheTTL { + _ = os.Remove(path) + return "", 0, false + } + return entry.Output, age, true +} + +func storeCLIResponseCache(key cliResponseCacheKey, output string) { + path, ok := cliResponseCachePath(key) + if !ok { + return + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return + } + entry := cliResponseCacheEntry{CreatedAt: nowCLIResponseCache().UTC(), Output: output} + data, err := json.Marshal(entry) + if err != nil { + return + } + _ = os.WriteFile(path, data, 0o600) +} + +func cliResponseCachePath(key cliResponseCacheKey) (string, bool) { + dir, err := stats.CacheDir() + if err != nil { + return "", false + } + fingerprint, ok := cliResponseCacheFingerprint(key) + if !ok { + return "", false + } + return filepath.Join(dir, "cli-responses", fingerprint+".json"), true +} + +func cliResponseCacheFingerprint(key cliResponseCacheKey) (string, bool) { + data, err := json.Marshal(key) + if err != nil { + return "", false + } + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]), true +} + +func loadCLIResponseCacheEntry(path string) (cliResponseCacheEntry, bool) { + data, err := os.ReadFile(path) + if err != nil { + return cliResponseCacheEntry{}, false + } + var entry cliResponseCacheEntry + if err := json.Unmarshal(data, &entry); err != nil { + _ = os.Remove(path) + return cliResponseCacheEntry{}, false + } + return entry, true +} + +func cloneCLIMessages(msgs []llm.Message) []llm.Message { + out := make([]llm.Message, len(msgs)) + copy(out, msgs) + return out +} + +func cloneCLITemperature(temp *float64) *float64 { + if temp == nil { + return nil + } + value := *temp + return &value +} diff --git a/internal/hexaicli/cache_test.go b/internal/hexaicli/cache_test.go new file mode 100644 index 0000000..5f00e7b --- /dev/null +++ b/internal/hexaicli/cache_test.go @@ -0,0 +1,207 @@ +package hexaicli + +import ( + "bytes" + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/llm" +) + +func TestCLIResponseCacheFingerprintChanges(t *testing.T) { + base := newCLIResponseCacheKey("openai", "gpt-4.1", requestArgs{maxTokens: 42}, []llm.Message{ + {Role: "system", Content: "sys"}, + {Role: "user", Content: "hello"}, + }) + baseFingerprint, ok := cliResponseCacheFingerprint(base) + if !ok { + t.Fatal("expected fingerprint for base key") + } + + tests := []struct { + name string + key cliResponseCacheKey + }{ + {name: "provider", key: newCLIResponseCacheKey("anthropic", "gpt-4.1", requestArgs{maxTokens: 42}, base.Messages)}, + {name: "model", key: newCLIResponseCacheKey("openai", "gpt-5", requestArgs{maxTokens: 42}, base.Messages)}, + {name: "prompt", key: newCLIResponseCacheKey("openai", "gpt-4.1", requestArgs{maxTokens: 42}, []llm.Message{{Role: "system", Content: "different"}, {Role: "user", Content: "hello"}})}, + {name: "temperature", key: newCLIResponseCacheKey("openai", "gpt-4.1", requestArgs{maxTokens: 42, temperature: floatPtr(0.7)}, base.Messages)}, + } + + for _, tc := range tests { + fingerprint, ok := cliResponseCacheFingerprint(tc.key) + if !ok { + t.Fatalf("%s: expected fingerprint", tc.name) + } + if fingerprint == baseFingerprint { + t.Fatalf("%s: expected fingerprint change", tc.name) + } + } +} + +func TestLookupCLIResponseCacheExpiresEntries(t *testing.T) { + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + + oldNow := nowCLIResponseCache + nowCLIResponseCache = func() time.Time { return time.Date(2026, 3, 15, 10, 0, 0, 0, time.UTC) } + defer func() { nowCLIResponseCache = oldNow }() + + key := newCLIResponseCacheKey("openai", "gpt-4.1", requestArgs{maxTokens: 10}, []llm.Message{{Role: "user", Content: "hello"}}) + storeCLIResponseCache(key, "cached") + + path, ok := cliResponseCachePath(key) + if !ok { + t.Fatal("expected cache path") + } + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected cache file: %v", err) + } + + nowCLIResponseCache = func() time.Time { return time.Date(2026, 3, 16, 11, 0, 0, 0, time.UTC) } + if _, _, hit := lookupCLIResponseCache(key); hit { + t.Fatal("expected expired cache miss") + } + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Fatalf("expected expired cache file removal, got %v", err) + } +} + +func TestRun_UsesCachedResponseWithoutClientCall(t *testing.T) { + t.Chdir(t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + + oldNew := newClientFromApp + defer func() { newClientFromApp = oldNew }() + + calls := 0 + newClientFromApp = func(cfg appconfig.App) (llm.Client, error) { + calls++ + return &fakeClient{name: cfg.Provider, model: "gpt-4.1", resp: "cached output"}, nil + } + + var firstOut, firstErr bytes.Buffer + if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &firstOut, &firstErr); err != nil { + t.Fatalf("first Run: %v", err) + } + if calls != 1 { + t.Fatalf("expected one live client call, got %d", calls) + } + + newClientFromApp = func(appconfig.App) (llm.Client, error) { + t.Fatal("client should not be constructed on cache hit") + return nil, nil + } + + var secondOut, secondErr bytes.Buffer + if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &secondOut, &secondErr); err != nil { + t.Fatalf("second Run: %v", err) + } + if got := secondOut.String(); got != "cached output" { + t.Fatalf("stdout = %q, want cached output", got) + } + if !strings.Contains(secondErr.String(), "cache hit provider=openai model=gpt-4.1") { + t.Fatalf("expected cache hit note, got %q", secondErr.String()) + } +} + +func TestRun_WithSelectionUsesChosenCachedResponse(t *testing.T) { + workDir := t.TempDir() + configHome := t.TempDir() + t.Chdir(workDir) + t.Setenv("XDG_CONFIG_HOME", configHome) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + + configPath := filepath.Join(configHome, "hexai", "config.toml") + writeConfigString(t, configPath, ` +[provider] +name = "openai" + +[[models.cli]] +provider = "openai" +model = "gpt-4.1" + +[[models.cli]] +provider = "anthropic" +model = "claude-3-5-sonnet-20240620" +`) + + oldNew := newClientFromApp + defer func() { newClientFromApp = oldNew }() + newClientFromApp = func(cfg appconfig.App) (llm.Client, error) { + switch cfg.Provider { + case "anthropic": + return &fakeClient{name: "anthropic", model: "claude-3-5-sonnet-20240620", resp: "RIGHT"}, nil + default: + return &fakeClient{name: "openai", model: "gpt-4.1", resp: "LEFT"}, nil + } + } + + if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &bytes.Buffer{}, &bytes.Buffer{}); err != nil { + t.Fatalf("warm cache Run: %v", err) + } + + newClientFromApp = func(appconfig.App) (llm.Client, error) { + t.Fatal("client should not be constructed for selected cache hit") + return nil, nil + } + + ctx := WithCLISelection(context.Background(), []int{1}) + var out, errb bytes.Buffer + if err := Run(ctx, []string{"hello"}, strings.NewReader(""), &out, &errb); err != nil { + t.Fatalf("selected Run: %v", err) + } + if got := out.String(); got != "RIGHT" { + t.Fatalf("stdout = %q, want RIGHT", got) + } + if strings.Contains(out.String(), "LEFT") { + t.Fatalf("unexpected other provider output: %q", out.String()) + } + if !strings.Contains(errb.String(), "anthropic:claude-3-5-sonnet-20240620") { + t.Fatalf("expected selected provider header, got %q", errb.String()) + } +} + +func TestRun_ExpiredCacheFallsBackToProvider(t *testing.T) { + t.Chdir(t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + + oldNow := nowCLIResponseCache + nowCLIResponseCache = func() time.Time { return time.Date(2026, 3, 15, 10, 0, 0, 0, time.UTC) } + defer func() { nowCLIResponseCache = oldNow }() + + oldNew := newClientFromApp + defer func() { newClientFromApp = oldNew }() + + calls := 0 + newClientFromApp = func(cfg appconfig.App) (llm.Client, error) { + calls++ + resp := "first" + if calls > 1 { + resp = "second" + } + return &fakeClient{name: cfg.Provider, model: "gpt-4.1", resp: resp}, nil + } + + if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &bytes.Buffer{}, &bytes.Buffer{}); err != nil { + t.Fatalf("first Run: %v", err) + } + + nowCLIResponseCache = func() time.Time { return time.Date(2026, 3, 16, 11, 0, 0, 0, time.UTC) } + var out, errb bytes.Buffer + if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb); err != nil { + t.Fatalf("second Run: %v", err) + } + if calls != 2 { + t.Fatalf("expected second live provider call after expiry, got %d", calls) + } + if got := out.String(); got != "second" { + t.Fatalf("stdout = %q, want second", got) + } +} diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index b48bee0..4cd94b4 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -24,15 +24,17 @@ import ( ) type requestArgs struct { - model string - options []llm.RequestOption + model string + maxTokens int + temperature *float64 + options []llm.RequestOption } type cliJob struct { index int provider string entry appconfig.SurfaceConfig - client llm.Client + cfg appconfig.App req requestArgs } @@ -55,40 +57,33 @@ func buildCLIJobs(cfg appconfig.App) ([]cliJob, error) { } provider = canonicalProvider(provider) derived := llmutils.ConfigForProvider(cfg, provider, entry.Model) - client, err := newClientFromApp(derived) - if err != nil { - return nil, err - } - req := buildCLIRequest(entry, provider, cfg, client) - if strings.TrimSpace(req.model) == "" { - req.model = strings.TrimSpace(client.DefaultModel()) - } - jobs = append(jobs, cliJob{index: i, provider: provider, entry: entry, client: client, req: req}) + req := buildCLIRequest(entry, provider, derived) + jobs = append(jobs, cliJob{index: i, provider: provider, entry: entry, cfg: derived, req: req}) } return jobs, nil } -func buildCLIRequest(entry appconfig.SurfaceConfig, provider string, cfg appconfig.App, client llm.Client) requestArgs { +func buildCLIRequest(entry appconfig.SurfaceConfig, provider string, cfg appconfig.App) requestArgs { opts := make([]llm.RequestOption, 0, 2) + req := requestArgs{maxTokens: cfg.MaxTokens} if cfg.MaxTokens > 0 { opts = append(opts, llm.WithMaxTokens(cfg.MaxTokens)) } model := strings.TrimSpace(entry.Model) if model == "" { - if client != nil { - model = strings.TrimSpace(client.DefaultModel()) - } - if model == "" { - model = strings.TrimSpace(llmutils.DefaultModelForProvider(cfg, provider)) - } + model = strings.TrimSpace(llmutils.DefaultModelForProvider(cfg, provider)) } if entry.Model != "" { opts = append(opts, llm.WithModel(entry.Model)) } if temp, ok := cliTemperatureFromEntry(cfg, provider, entry, model); ok { + tempValue := temp + req.temperature = &tempValue opts = append(opts, llm.WithTemperature(temp)) } - return requestArgs{model: model, options: opts} + req.model = model + req.options = opts + return req } func cliTemperatureFromEntry(cfg appconfig.App, provider string, entry appconfig.SurfaceConfig, model string) (float64, bool) { @@ -240,28 +235,72 @@ func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter { } func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult { + if res := cachedCLIJobResult(job, msgs, stdout, printer, streamOutput); res != nil { + return res + } + + client, err := newClientFromApp(job.cfg) + if err != nil { + return &cliJobResult{provider: job.provider, model: job.req.model, err: err} + } + model := effectiveModel(job.req, client) + var errBuf bytes.Buffer var outBuf bytes.Buffer jobMsgs := append([]llm.Message(nil), msgs...) writer := io.Writer(&outBuf) if printer != nil { - writer = printer.Writer(job.index) + writer = io.MultiWriter(printer.Writer(job.index), &outBuf) } else if streamOutput { writer = io.MultiWriter(stdout, &outBuf) } - err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf) + err = runChat(ctx, client, job.req, jobMsgs, input, writer, &errBuf) if printer != nil { printer.Flush(job.index) } + if err == nil { + storeCLIResponseCache(newCLIResponseCacheKey(job.provider, model, job.req, jobMsgs), outBuf.String()) + } return &cliJobResult{ - provider: job.client.Name(), - model: job.req.model, + provider: job.provider, + model: model, output: outBuf.String(), summary: errBuf.String(), err: err, } } +func cachedCLIJobResult(job cliJob, msgs []llm.Message, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult { + output, age, ok := lookupCLIResponseCache(newCLIResponseCacheKey(job.provider, job.req.model, job.req, msgs)) + if !ok { + return nil + } + if err := writeCachedCLIJobOutput(output, stdout, printer, job.index, streamOutput); err != nil { + return &cliJobResult{provider: job.provider, model: job.req.model, err: err} + } + return &cliJobResult{ + provider: job.provider, + model: job.req.model, + output: output, + summary: cacheHitSummary(job.provider, job.req.model, age), + } +} + +func writeCachedCLIJobOutput(output string, stdout io.Writer, printer *termprint.ColumnPrinter, idx int, streamOutput bool) error { + if printer != nil { + if _, err := io.WriteString(printer.Writer(idx), output); err != nil { + return err + } + printer.Flush(idx) + return nil + } + if !streamOutput { + return nil + } + _, err := io.WriteString(stdout, output) + return err +} + func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error { printed := false showHeading := cliJobResultCount(results) > 1 @@ -346,7 +385,7 @@ func newColumnPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter providers := make([]string, len(jobs)) models := make([]string, len(jobs)) for _, job := range jobs { - providers[job.index] = job.client.Name() + providers[job.index] = job.provider models[job.index] = job.req.model } return termprint.NewColumnPrinter(stdout, providers, models) @@ -361,7 +400,7 @@ func printCLIHeader(stderr io.Writer, jobs []cliJob, printer *termprint.ColumnPr return } job := jobs[0] - printProviderInfo(stderr, job.client, job.req.model) + printProviderLabel(stderr, job.provider, job.req.model) } // WithCLISelection injects provider indices into the context so Run only executes those jobs. @@ -577,16 +616,35 @@ func summarizeChatRun(ctx context.Context, client llm.Client, model string, msgs // printProviderInfo writes the provider:model header and divider to stderr. func printProviderInfo(errw io.Writer, client llm.Client, model string) { + printProviderLabel(errw, client.Name(), chooseCLIModel(model, client.DefaultModel())) +} + +func printProviderLabel(errw io.Writer, provider, model string) { if strings.TrimSpace(model) == "" { - model = client.DefaultModel() + return } - printer := termprint.NewColumnPrinter(errw, []string{client.Name()}, []string{model}) + printer := termprint.NewColumnPrinter(errw, []string{provider}, []string{model}) if printer == nil { return } printer.PrintHeader() } +func chooseCLIModel(model, fallback string) string { + model = strings.TrimSpace(model) + if model != "" { + return model + } + return strings.TrimSpace(fallback) +} + +func cacheHitSummary(provider, model string, age time.Duration) string { + if age < 0 { + age = 0 + } + return fmt.Sprintf(logging.AnsiBase+"cache hit provider=%s model=%s age=%s"+logging.AnsiReset+"\n", provider, model, age.Round(time.Second)) +} + // newClientFromConfig is kept for tests; delegates to llmutils. var newClientFromApp = llmutils.NewClientFromApp diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go index 315d016..8059c25 100644 --- a/internal/hexaicli/run_test.go +++ b/internal/hexaicli/run_test.go @@ -188,17 +188,28 @@ func TestRun_SingleProviderHeaderUsesStderr(t *testing.T) { } func TestExecuteCLIJobs_MultiProviderHeaderUsesStderr(t *testing.T) { + oldNew := newClientFromApp + defer func() { newClientFromApp = oldNew }() + newClientFromApp = func(cfg appconfig.App) (llm.Client, error) { + switch cfg.Provider { + case "anthropic": + return &fakeClient{name: "anthropic", model: "claude", resp: "RIGHT"}, nil + default: + return &fakeClient{name: "openai", model: "gpt-4.1", resp: "LEFT"}, nil + } + } + jobs := []cliJob{ { index: 0, provider: "openai", - client: &fakeClient{name: "openai", model: "gpt-4.1", resp: "LEFT"}, + cfg: appconfig.App{Provider: "openai", OpenAIModel: "gpt-4.1"}, req: requestArgs{model: "gpt-4.1"}, }, { index: 1, provider: "anthropic", - client: &fakeClient{name: "anthropic", model: "claude", resp: "RIGHT"}, + cfg: appconfig.App{Provider: "anthropic", AnthropicModel: "claude"}, req: requestArgs{model: "claude"}, }, } @@ -225,7 +236,7 @@ func TestBuildCLIRequest_Override(t *testing.T) { AnthropicModel: "claude-3-5-sonnet", } entry := appconfig.SurfaceConfig{Provider: "anthropic", Model: "override", Temperature: floatPtr(0.7)} - req := buildCLIRequest(entry, "anthropic", cfg, &fakeClient{name: "anthropic", model: "default"}) + req := buildCLIRequest(entry, "anthropic", cfg) if req.model != "override" { t.Fatalf("expected model override, got %q", req.model) } @@ -241,7 +252,8 @@ func TestBuildCLIRequest_Override(t *testing.T) { func TestBuildCLIRequest_Gpt5Temp(t *testing.T) { cfg := appconfig.App{Provider: "openai", CodingTemperature: floatPtr(0.2)} entry := appconfig.SurfaceConfig{} - req := buildCLIRequest(entry, "openai", cfg, &fakeClient{name: "openai", model: "gpt-5.1"}) + cfg.OpenAIModel = "gpt-5.1" + req := buildCLIRequest(entry, "openai", cfg) if req.model != "gpt-5.1" { t.Fatalf("expected fallback model, got %q", req.model) } @@ -255,21 +267,6 @@ func TestBuildCLIRequest_Gpt5Temp(t *testing.T) { } func TestBuildCLIJobs_MultiEntries(t *testing.T) { - old := newClientFromApp - defer func() { newClientFromApp = old }() - newClientFromApp = func(cfg appconfig.App) (llm.Client, error) { - model := cfg.OpenAIModel - if cfg.Provider == "anthropic" { - model = cfg.AnthropicModel - } - if cfg.Provider == "ollama" { - model = cfg.OllamaModel - } - if strings.TrimSpace(model) == "" { - model = "default" - } - return &fakeClient{name: cfg.Provider, model: model}, nil - } cfg := appconfig.App{ Provider: "ollama", OllamaModel: "llama3", @@ -291,6 +288,9 @@ func TestBuildCLIJobs_MultiEntries(t *testing.T) { if jobs[1].provider != "anthropic" || jobs[1].req.model != "claude" { t.Fatalf("unexpected second job: %+v", jobs[1]) } + if jobs[0].cfg.Provider != "openai" || jobs[1].cfg.Provider != "anthropic" { + t.Fatalf("unexpected derived configs: %+v", jobs) + } } func TestFilterJobsBySelection(t *testing.T) { diff --git a/internal/hexaicli/run_timeout_test.go b/internal/hexaicli/run_timeout_test.go index 642f172..9561617 100644 --- a/internal/hexaicli/run_timeout_test.go +++ b/internal/hexaicli/run_timeout_test.go @@ -13,6 +13,7 @@ import ( func TestRun_DefaultRequestTimeoutIsTenMinutes(t *testing.T) { t.Chdir(t.TempDir()) t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) t.Setenv("HEXAI_REQUEST_TIMEOUT", "") oldNew := newClientFromApp diff --git a/internal/hexaicli/testhelpers_test.go b/internal/hexaicli/testhelpers_test.go index 3197880..a4c9e6b 100644 --- a/internal/hexaicli/testhelpers_test.go +++ b/internal/hexaicli/testhelpers_test.go @@ -10,6 +10,23 @@ import ( "codeberg.org/snonux/hexai/internal/llm" ) +func TestMain(m *testing.M) { + cacheDir, err := os.MkdirTemp("", "hexai-cli-cache-*") + if err != nil { + panic(err) + } + oldCacheHome := os.Getenv("XDG_CACHE_HOME") + _ = os.Setenv("XDG_CACHE_HOME", cacheDir) + code := m.Run() + if oldCacheHome == "" { + _ = os.Unsetenv("XDG_CACHE_HOME") + } else { + _ = os.Setenv("XDG_CACHE_HOME", oldCacheHome) + } + _ = os.RemoveAll(cacheDir) + os.Exit(code) +} + // setStdin sets os.Stdin from a string and returns a restore func and reader. func setStdin(t *testing.T, content string) (func(), *os.File) { t.Helper() diff --git a/internal/llmutils/client.go b/internal/llmutils/client.go index 16a6338..3641556 100644 --- a/internal/llmutils/client.go +++ b/internal/llmutils/client.go @@ -21,13 +21,25 @@ func CanonicalProvider(name string) string { func DefaultModelForProvider(cfg appconfig.App, provider string) string { switch CanonicalProvider(provider) { case "openrouter": - return strings.TrimSpace(cfg.OpenRouterModel) + if model := strings.TrimSpace(cfg.OpenRouterModel); model != "" { + return model + } + return "openrouter/auto" case "ollama": - return strings.TrimSpace(cfg.OllamaModel) + if model := strings.TrimSpace(cfg.OllamaModel); model != "" { + return model + } + return "qwen3-coder:30b-a3b-q4_K_M" case "anthropic": - return strings.TrimSpace(cfg.AnthropicModel) + if model := strings.TrimSpace(cfg.AnthropicModel); model != "" { + return model + } + return "claude-3-5-sonnet-20240620" default: - return strings.TrimSpace(cfg.OpenAIModel) + if model := strings.TrimSpace(cfg.OpenAIModel); model != "" { + return model + } + return "gpt-4.1" } } diff --git a/internal/llmutils/client_test.go b/internal/llmutils/client_test.go index 837d408..0e38476 100644 --- a/internal/llmutils/client_test.go +++ b/internal/llmutils/client_test.go @@ -56,6 +56,22 @@ func TestDefaultModelForProvider(t *testing.T) { } } +func TestDefaultModelForProvider_Fallbacks(t *testing.T) { + cfg := appconfig.App{} + if got := DefaultModelForProvider(cfg, "openai"); got != "gpt-4.1" { + t.Fatalf("openai fallback = %q", got) + } + if got := DefaultModelForProvider(cfg, "openrouter"); got != "openrouter/auto" { + t.Fatalf("openrouter fallback = %q", got) + } + if got := DefaultModelForProvider(cfg, "ollama"); got != "qwen3-coder:30b-a3b-q4_K_M" { + t.Fatalf("ollama fallback = %q", got) + } + if got := DefaultModelForProvider(cfg, "anthropic"); got != "claude-3-5-sonnet-20240620" { + t.Fatalf("anthropic fallback = %q", got) + } +} + func TestConfigForProvider(t *testing.T) { base := appconfig.App{ Provider: "openai", diff --git a/internal/version.go b/internal/version.go index 923040e..74f331e 100644 --- a/internal/version.go +++ b/internal/version.go @@ -1,4 +1,4 @@ // Summary: Hexai semantic version identifier used by CLI and LSP binaries. package internal -const Version = "0.22.3" +const Version = "0.23.0" |
