diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-19 08:58:38 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-19 08:58:38 +0200 |
| commit | 31394385e72dd3a317585838ed1696076043cc60 (patch) | |
| tree | 79c0c43b514df611fdcd15c899c8d1048cd01ab0 | |
| parent | 934266d5bbefbc33c95e933b6ef02875539bb72f (diff) | |
Inject runner dependencies across CLI, action, and LSP
| -rw-r--r-- | internal/hexaiaction/run.go | 34 | ||||
| -rw-r--r-- | internal/hexaiaction/run_seam_test.go | 35 | ||||
| -rw-r--r-- | internal/hexaicli/run.go | 97 | ||||
| -rw-r--r-- | internal/hexaicli/run_output_test.go | 2 | ||||
| -rw-r--r-- | internal/hexaicli/run_test.go | 2 | ||||
| -rw-r--r-- | internal/hexaicli/runner.go | 161 | ||||
| -rw-r--r-- | internal/hexaicli/runner_test.go | 61 | ||||
| -rw-r--r-- | internal/hexailsp/dependencies.go | 77 | ||||
| -rw-r--r-- | internal/hexailsp/run.go | 49 | ||||
| -rw-r--r-- | internal/hexailsp/run_more_test.go | 36 |
10 files changed, 443 insertions, 111 deletions
diff --git a/internal/hexaiaction/run.go b/internal/hexaiaction/run.go index 84cb9b1..f34a4cd 100644 --- a/internal/hexaiaction/run.go +++ b/internal/hexaiaction/run.go @@ -64,10 +64,18 @@ type actionClient interface { type actionClientFactory func(cfg appconfig.App) (actionClient, error) +type actionConfigLoader func(context.Context, *log.Logger) appconfig.App + +type actionStatusSink interface { + SetLLMStart(provider, model string) error +} + // Runner executes action requests with injectable dependencies for testability. type Runner struct { chooseAction actionChooser newClient actionClientFactory + loadConfig actionConfigLoader + statusSink actionStatusSink } // NewRunner builds a Runner with production dependencies. @@ -75,6 +83,8 @@ func NewRunner() *Runner { return &Runner{ chooseAction: chooseActionFromConfig, newClient: defaultActionClientFactory, + loadConfig: loadActionConfig, + statusSink: tmuxActionStatusSink{}, } } @@ -91,6 +101,16 @@ func defaultActionClientFactory(cfg appconfig.App) (actionClient, error) { return llmutils.NewClientFromApp(cfg) } +type tmuxActionStatusSink struct{} + +func (tmuxActionStatusSink) SetLLMStart(provider, model string) error { + return tmux.SetStatus(tmux.FormatLLMStartStatus(provider, model)) +} + +func loadActionConfig(ctx context.Context, logger *log.Logger) appconfig.App { + return appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPathFromContext(ctx)}) +} + type actionPlan struct { fallback string run func(context.Context) (string, error) @@ -127,6 +147,8 @@ func Run(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer) error { func (r *Runner) Run(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer) error { chooser := chooseActionFromConfig newClient := defaultActionClientFactory + loadConfig := loadActionConfig + statusSink := actionStatusSink(tmuxActionStatusSink{}) if r != nil { if r.chooseAction != nil { chooser = r.chooseAction @@ -134,10 +156,16 @@ func (r *Runner) Run(ctx context.Context, stdin io.Reader, stdout, stderr io.Wri if r.newClient != nil { newClient = r.newClient } + if r.loadConfig != nil { + loadConfig = r.loadConfig + } + if r.statusSink != nil { + statusSink = r.statusSink + } } logger := log.New(stderr, "hexai-tmux-action ", log.LstdFlags|log.Lmsgprefix) - cfg := appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPathFromContext(ctx)}) + cfg := loadConfig(ctx, logger) if cfg.StatsWindowMinutes > 0 { stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) } @@ -159,7 +187,9 @@ func (r *Runner) Run(ctx context.Context, stdin io.Reader, stdout, stderr io.Wri if primaryModel == "" { primaryModel = cli.DefaultModel() } - _ = tmux.SetStatus(tmux.FormatLLMStartStatus(cli.Name(), primaryModel)) + if statusSink != nil { + _ = statusSink.SetLLMStart(cli.Name(), primaryModel) + } var client chatDoer = cli parts, err := ParseInput(stdin) if err != nil { diff --git a/internal/hexaiaction/run_seam_test.go b/internal/hexaiaction/run_seam_test.go index affd68e..8fb8533 100644 --- a/internal/hexaiaction/run_seam_test.go +++ b/internal/hexaiaction/run_seam_test.go @@ -3,6 +3,7 @@ package hexaiaction import ( "bytes" "context" + "log" "testing" "codeberg.org/snonux/hexai/internal/appconfig" @@ -17,6 +18,17 @@ func (llmFake) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) func (llmFake) Name() string { return "fake" } func (llmFake) DefaultModel() string { return "model" } +type recordingActionStatusSink struct { + provider string + model string +} + +func (s *recordingActionStatusSink) SetLLMStart(provider, model string) error { + s.provider = provider + s.model = model + return nil +} + func TestRun_WithSeams_SkipAndRewrite(t *testing.T) { // Isolate from user config to avoid environment-dependent behavior/logging. t.Setenv("XDG_CONFIG_HOME", t.TempDir()) @@ -48,3 +60,26 @@ func TestRun_WithSeams_SkipAndRewrite(t *testing.T) { t.Fatalf("expected non-empty rewrite output") } } + +func TestRun_WithInjectedConfigAndStatusSink(t *testing.T) { + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + sink := &recordingActionStatusSink{} + runner := NewRunner() + runner.loadConfig = func(context.Context, *log.Logger) appconfig.App { return appconfig.Load(nil) } + runner.newClient = func(_ appconfig.App) (actionClient, error) { return llmFake{}, nil } + runner.chooseAction = func(_ appconfig.App) (actionChoice, error) { + return actionChoice{kind: ActionSkip}, nil + } + runner.statusSink = sink + + var out bytes.Buffer + if err := runner.Run(context.Background(), bytes.NewBufferString("selection"), &out, &bytes.Buffer{}); err != nil { + t.Fatalf("Run: %v", err) + } + if out.String() != "selection" { + t.Fatalf("unexpected output %q", out.String()) + } + if sink.provider != "fake" || sink.model == "" { + t.Fatalf("unexpected status sink values: provider=%q model=%q", sink.provider, sink.model) + } +} diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 9c7ba73..0da1a5f 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -7,20 +7,17 @@ import ( "context" "fmt" "io" - "log" "os" "strings" "sync" "time" "codeberg.org/snonux/hexai/internal/appconfig" - "codeberg.org/snonux/hexai/internal/editor" "codeberg.org/snonux/hexai/internal/llm" "codeberg.org/snonux/hexai/internal/llmutils" "codeberg.org/snonux/hexai/internal/logging" "codeberg.org/snonux/hexai/internal/stats" "codeberg.org/snonux/hexai/internal/termprint" - "codeberg.org/snonux/hexai/internal/tmux" ) type requestArgs struct { @@ -95,77 +92,13 @@ func cliTemperatureFromEntry(cfg appconfig.App, provider string, entry appconfig // Run executes the Hexai CLI behavior given arguments and I/O streams. // It assumes flags have already been parsed by the caller. func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { - if spec, ok, err := tpsSimulationFromContext(ctx); err != nil { - _, _ = fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) - return err - } else if ok { - input, inputErr := readSimulationInput(stdin, args) - if inputErr != nil { - _, _ = fmt.Fprintln(stderr, logging.AnsiBase+inputErr.Error()+logging.AnsiReset) - return inputErr - } - return runTPSSimulation(ctx, spec, input, stdout) - } - - // Load configuration silently; config-load messages are noise in the CLI. - logger := log.New(io.Discard, "", 0) - configPath := configPathFromContext(ctx) - cfg := appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPath}) - if cfg.StatsWindowMinutes > 0 { - stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) - } - jobs, err := buildCLIJobs(cfg) - if err != nil { - _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) - return err - } - if selected := selectionFromContext(ctx); len(selected) > 0 { - jobs, err = filterJobsBySelection(jobs, selected) - if err != nil { - _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: %v"+logging.AnsiReset+"\n", err) - return err - } - } - if len(jobs) == 0 { - return fmt.Errorf("hexai: no CLI providers configured") - } - // Prefer piped stdin when present; only open the editor when there are no args - // and no stdin content available. - input, rerr := readInput(stdin, args) - if rerr != nil && len(args) == 0 { - if prompt, eerr := editor.OpenTempAndEdit(nil); eerr == nil && strings.TrimSpace(prompt) != "" { - args = []string{prompt} - input, rerr = readInput(stdin, args) - } - } - if rerr != nil { - _, _ = fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset) - return rerr - } - msgs := buildMessagesFromConfig(cfg, input) - if err := runCLIJobs(ctx, jobs, msgs, input, stdout, stderr); err != nil { - _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) - return err - } - return nil + return NewRunner().Run(ctx, args, stdin, stdout, stderr) } // RunWithClient executes the CLI flow using an already-constructed client. // Useful for testing and embedding. func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer, client llm.Client) error { - input, err := readInput(stdin, args) - if err != nil { - _, _ = fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) - return err - } - req := requestArgs{model: strings.TrimSpace(client.DefaultModel())} - printProviderInfo(stderr, client, req.model) - msgs := buildMessages(input) - if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { - _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) - return err - } - return nil + return NewRunner().RunWithClient(ctx, args, stdin, stdout, stderr, client) } type cliJobResult struct { @@ -184,9 +117,9 @@ type chatRunSummary struct { scopeRPM float64 } -func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error { +func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer, clientFactory cliClientFactory, statusSink cliStatusSink) error { streamSingle := len(jobs) == 1 - results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr, streamSingle) + results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr, streamSingle, clientFactory, statusSink) if printer == nil && !streamSingle { if err := writeCLIJobOutputs(stdout, results); err != nil { return err @@ -195,7 +128,7 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st return writeCLIJobSummaries(stderr, results) } -func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer, streamSingle bool) ([]*cliJobResult, *termprint.ColumnPrinter) { +func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer, streamSingle bool, clientFactory cliClientFactory, statusSink cliStatusSink) ([]*cliJobResult, *termprint.ColumnPrinter) { results := make([]*cliJobResult, len(jobs)) printer := setupCLIPrinter(stdout, jobs) printCLIHeader(stderr, jobs, printer) @@ -205,7 +138,7 @@ func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, inpu wg.Add(1) go func() { defer wg.Done() - results[job.index] = runSingleCLIJob(ctx, job, msgs, input, stdout, printer, streamSingle) + results[job.index] = runSingleCLIJob(ctx, job, msgs, input, stdout, printer, streamSingle, clientFactory, statusSink) }() } wg.Wait() @@ -219,12 +152,12 @@ func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter { return newColumnPrinter(stdout, jobs) } -func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult { +func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool, clientFactory cliClientFactory, statusSink cliStatusSink) *cliJobResult { if res := cachedCLIJobResult(job, msgs, stdout, printer, streamOutput); res != nil { return res } - client, err := newClientFromApp(job.cfg) + client, err := clientFactory(job.cfg) if err != nil { return &cliJobResult{provider: job.provider, model: job.req.model, err: err} } @@ -239,7 +172,7 @@ func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input } else if streamOutput { writer = io.MultiWriter(stdout, &outBuf) } - err = runChat(ctx, client, job.req, jobMsgs, input, writer, &errBuf) + err = runChatWithStatus(statusSink, ctx, client, job.req, jobMsgs, input, writer, &errBuf) if printer != nil { printer.Flush(job.index) } @@ -532,9 +465,15 @@ func buildMessagesFromConfig(cfg appconfig.App, input string) []llm.Message { // runChat executes the chat request, handling streaming and summary output. func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { + return runChatWithStatus(tmuxCLIStatusSink{}, ctx, client, req, msgs, input, out, errw) +} + +func runChatWithStatus(statusSink cliStatusSink, ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { start := time.Now() model := effectiveModel(req, client) - _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model)) + if statusSink != nil { + _ = statusSink.SetLLMStart(client.Name(), model) + } output, err := runChatRequest(ctx, client, req, msgs, out) if err != nil { @@ -547,7 +486,9 @@ func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm client.Name(), model, dur.Round(time.Millisecond), summary.sent, summary.recv, summary.snapshot.Global.Reqs, summary.snapshot.RPM); err != nil { return err } - _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(summary.snapshot.Global.Reqs, summary.snapshot.RPM, summary.snapshot.Global.Sent, summary.snapshot.Global.Recv, client.Name(), model, summary.scopeRPM, summary.scopeReq, summary.snapshot.Window)) + if statusSink != nil { + _ = statusSink.SetGlobal(summary.snapshot, client.Name(), model, summary.scopeRPM, summary.scopeReq) + } return nil } diff --git a/internal/hexaicli/run_output_test.go b/internal/hexaicli/run_output_test.go index b4614da..e61e4b6 100644 --- a/internal/hexaicli/run_output_test.go +++ b/internal/hexaicli/run_output_test.go @@ -400,7 +400,7 @@ func TestRunCLIJobs_MultiJob_WritesOutputs(t *testing.T) { } stdout.Reset() stderr.Reset() - if err := runCLIJobs(context.Background(), singleJobs, msgs, "hello", &stdout, &stderr); err != nil { + if err := runCLIJobs(context.Background(), singleJobs, msgs, "hello", &stdout, &stderr, newClientFromApp, nil); err != nil { t.Fatalf("runCLIJobs single: %v", err) } if !strings.Contains(stdout.String(), "out-a") { diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go index 9711399..69e5d98 100644 --- a/internal/hexaicli/run_test.go +++ b/internal/hexaicli/run_test.go @@ -229,7 +229,7 @@ func TestExecuteCLIJobs_MultiProviderHeaderUsesStderr(t *testing.T) { } var stdout, stderr bytes.Buffer - results, printer := executeCLIJobs(context.Background(), jobs, buildMessages("hello"), "hello", &stdout, &stderr, false) + results, printer := executeCLIJobs(context.Background(), jobs, buildMessages("hello"), "hello", &stdout, &stderr, false, newClientFromApp, nil) if printer == nil { t.Fatalf("expected column printer for multi-provider run") } diff --git a/internal/hexaicli/runner.go b/internal/hexaicli/runner.go new file mode 100644 index 0000000..f372021 --- /dev/null +++ b/internal/hexaicli/runner.go @@ -0,0 +1,161 @@ +package hexaicli + +import ( + "context" + "fmt" + "io" + "log" + "strings" + "time" + + "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/editor" + "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/logging" + "codeberg.org/snonux/hexai/internal/stats" + "codeberg.org/snonux/hexai/internal/tmux" +) + +type cliConfigLoader func(context.Context, *log.Logger) appconfig.App + +type cliEditorOpener func([]byte) (string, error) + +type cliClientFactory func(appconfig.App) (llm.Client, error) + +type cliStatusSink interface { + SetLLMStart(provider, model string) error + SetGlobal(snapshot stats.Snapshot, provider, model string, scopeRPM float64, scopeReq int64) error +} + +// Runner executes the CLI with injectable configuration, editor, client, and status dependencies. +type Runner struct { + loadConfig cliConfigLoader + openEditor cliEditorOpener + newClient cliClientFactory + statusSink cliStatusSink +} + +type tmuxCLIStatusSink struct{} + +func (tmuxCLIStatusSink) SetLLMStart(provider, model string) error { + return tmux.SetStatus(tmux.FormatLLMStartStatus(provider, model)) +} + +func (tmuxCLIStatusSink) SetGlobal(snapshot stats.Snapshot, provider, model string, scopeRPM float64, scopeReq int64) error { + return tmux.SetStatus(tmux.FormatGlobalStatusColored( + snapshot.Global.Reqs, + snapshot.RPM, + snapshot.Global.Sent, + snapshot.Global.Recv, + provider, + model, + scopeRPM, + scopeReq, + snapshot.Window, + )) +} + +// NewRunner builds a CLI runner with production dependencies. +func NewRunner() *Runner { + return &Runner{ + loadConfig: loadConfigFromContext, + openEditor: editor.OpenTempAndEdit, + newClient: newClientFromApp, + statusSink: tmuxCLIStatusSink{}, + } +} + +func (r *Runner) Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { + runner := normalizeRunner(r) + if spec, ok, err := tpsSimulationFromContext(ctx); err != nil { + _, _ = fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) + return err + } else if ok { + input, inputErr := readSimulationInput(stdin, args) + if inputErr != nil { + _, _ = fmt.Fprintln(stderr, logging.AnsiBase+inputErr.Error()+logging.AnsiReset) + return inputErr + } + return runTPSSimulation(ctx, spec, input, stdout) + } + + logger := log.New(io.Discard, "", 0) + cfg := runner.loadConfig(ctx, logger) + if cfg.StatsWindowMinutes > 0 { + stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) + } + jobs, err := buildCLIJobs(cfg) + if err != nil { + _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) + return err + } + if selected := selectionFromContext(ctx); len(selected) > 0 { + jobs, err = filterJobsBySelection(jobs, selected) + if err != nil { + _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: %v"+logging.AnsiReset+"\n", err) + return err + } + } + if len(jobs) == 0 { + return fmt.Errorf("hexai: no CLI providers configured") + } + + input, rerr := readInput(stdin, args) + if rerr != nil && len(args) == 0 { + if prompt, eerr := runner.openEditor(nil); eerr == nil && strings.TrimSpace(prompt) != "" { + args = []string{prompt} + input, rerr = readInput(stdin, args) + } + } + if rerr != nil { + _, _ = fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset) + return rerr + } + msgs := buildMessagesFromConfig(cfg, input) + if err := runCLIJobs(ctx, jobs, msgs, input, stdout, stderr, runner.newClient, runner.statusSink); err != nil { + _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) + return err + } + return nil +} + +func (r *Runner) RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer, client llm.Client) error { + runner := normalizeRunner(r) + input, err := readInput(stdin, args) + if err != nil { + _, _ = fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) + return err + } + req := requestArgs{model: strings.TrimSpace(client.DefaultModel())} + printProviderInfo(stderr, client, req.model) + msgs := buildMessages(input) + if err := runChatWithStatus(runner.statusSink, ctx, client, req, msgs, input, stdout, stderr); err != nil { + _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) + return err + } + return nil +} + +func normalizeRunner(r *Runner) Runner { + if r == nil { + return *NewRunner() + } + runner := *r + if runner.loadConfig == nil { + runner.loadConfig = loadConfigFromContext + } + if runner.openEditor == nil { + runner.openEditor = editor.OpenTempAndEdit + } + if runner.newClient == nil { + runner.newClient = newClientFromApp + } + if runner.statusSink == nil { + runner.statusSink = tmuxCLIStatusSink{} + } + return runner +} + +func loadConfigFromContext(ctx context.Context, logger *log.Logger) appconfig.App { + return appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPathFromContext(ctx)}) +} diff --git a/internal/hexaicli/runner_test.go b/internal/hexaicli/runner_test.go new file mode 100644 index 0000000..1d438b0 --- /dev/null +++ b/internal/hexaicli/runner_test.go @@ -0,0 +1,61 @@ +package hexaicli + +import ( + "bytes" + "context" + "log" + "strings" + "testing" + + "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/stats" +) + +type recordingCLIStatusSink struct { + startProvider string + startModel string + globalCalls int +} + +func (s *recordingCLIStatusSink) SetLLMStart(provider, model string) error { + s.startProvider = provider + s.startModel = model + return nil +} + +func (s *recordingCLIStatusSink) SetGlobal(stats.Snapshot, string, string, float64, int64) error { + s.globalCalls++ + return nil +} + +func TestRunner_UsesInjectedDependencies(t *testing.T) { + sink := &recordingCLIStatusSink{} + runner := &Runner{ + loadConfig: func(context.Context, *log.Logger) appconfig.App { + return appconfig.App{ + CoreConfig: appconfig.CoreConfig{Provider: "openai"}, + PromptConfig: appconfig.PromptConfig{PromptCLIDefaultSystem: "SYS"}, + } + }, + openEditor: func([]byte) (string, error) { return "PROMPT", nil }, + newClient: func(appconfig.App) (client llm.Client, err error) { + return &fakeClient{name: "fake", model: "m", resp: "OUT"}, nil + }, + statusSink: sink, + } + + var stdout, stderr bytes.Buffer + if err := runner.Run(context.Background(), nil, strings.NewReader(""), &stdout, &stderr); err != nil { + t.Fatalf("Run: %v", err) + } + if stdout.String() != "OUT" { + t.Fatalf("stdout = %q, want OUT", stdout.String()) + } + if sink.startProvider != "fake" || sink.startModel == "" { + t.Fatalf("unexpected start status: provider=%q model=%q", sink.startProvider, sink.startModel) + } + if sink.globalCalls != 1 { + t.Fatalf("expected one global status update, got %d", sink.globalCalls) + } +} diff --git a/internal/hexailsp/dependencies.go b/internal/hexailsp/dependencies.go new file mode 100644 index 0000000..7e025d4 --- /dev/null +++ b/internal/hexailsp/dependencies.go @@ -0,0 +1,77 @@ +package hexailsp + +import ( + "log" + + "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/ignore" + "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/llmutils" + "codeberg.org/snonux/hexai/internal/logging" + "codeberg.org/snonux/hexai/internal/lsp" + "codeberg.org/snonux/hexai/internal/runtimeconfig" +) + +type configLoader func(*log.Logger, appconfig.LoadOptions) appconfig.App + +type clientBuilder func(appconfig.App, llm.Client) llm.Client + +type configStoreFactory func(appconfig.App) *runtimeconfig.Store + +type ignoreCheckerFactory func(appconfig.App) *ignore.Checker + +type runDependencies struct { + loadConfig configLoader + buildClient clientBuilder + newConfigStore configStoreFactory + newIgnoreChecker ignoreCheckerFactory + statusSink lsp.StatusSink +} + +func defaultRunDependencies() runDependencies { + return runDependencies{ + loadConfig: appconfig.LoadWithOptions, + buildClient: defaultClientBuilder, + newConfigStore: runtimeconfig.New, + newIgnoreChecker: defaultIgnoreCheckerFactory, + statusSink: tmuxStatusSink{}, + } +} + +func normalizeRunDependencies(deps runDependencies) runDependencies { + if deps.loadConfig == nil { + deps.loadConfig = appconfig.LoadWithOptions + } + if deps.buildClient == nil { + deps.buildClient = defaultClientBuilder + } + if deps.newConfigStore == nil { + deps.newConfigStore = runtimeconfig.New + } + if deps.newIgnoreChecker == nil { + deps.newIgnoreChecker = defaultIgnoreCheckerFactory + } + if deps.statusSink == nil { + deps.statusSink = tmuxStatusSink{} + } + return deps +} + +func defaultClientBuilder(cfg appconfig.App, client llm.Client) llm.Client { + if client != nil { + return client + } + c, err := llmutils.NewClientFromApp(cfg) + if err != nil { + logging.Logf("lsp ", "llm disabled: %v", err) + return nil + } + logging.Logf("lsp ", "llm enabled provider=%s model=%s", c.Name(), c.DefaultModel()) + return c +} + +func defaultIgnoreCheckerFactory(cfg appconfig.App) *ignore.Checker { + gitRoot := appconfig.FindGitRoot() + useGI := cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore + return ignore.New(gitRoot, useGI, cfg.IgnoreExtraPatterns) +} diff --git a/internal/hexailsp/run.go b/internal/hexailsp/run.go index ec5dbba..57e7476 100644 --- a/internal/hexailsp/run.go +++ b/internal/hexailsp/run.go @@ -13,10 +13,8 @@ import ( "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/ignore" "codeberg.org/snonux/hexai/internal/llm" - "codeberg.org/snonux/hexai/internal/llmutils" "codeberg.org/snonux/hexai/internal/logging" "codeberg.org/snonux/hexai/internal/lsp" - "codeberg.org/snonux/hexai/internal/runtimeconfig" "codeberg.org/snonux/hexai/internal/stats" tmx "codeberg.org/snonux/hexai/internal/tmux" ) @@ -56,6 +54,11 @@ func Run(logPath string, stdin io.Reader, stdout io.Writer, stderr io.Writer) er // RunWithConfig is like Run but accepts an explicit config file path. func RunWithConfig(logPath string, configPath string, stdin io.Reader, stdout io.Writer, stderr io.Writer) error { + return runWithConfigDependencies(logPath, configPath, stdin, stdout, stderr, defaultRunDependencies()) +} + +func runWithConfigDependencies(logPath string, configPath string, stdin io.Reader, stdout io.Writer, stderr io.Writer, deps runDependencies) error { + deps = normalizeRunDependencies(deps) logger := log.New(stderr, "hexai-lsp-server ", log.LstdFlags|log.Lmsgprefix) if strings.TrimSpace(logPath) != "" { f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) @@ -71,35 +74,36 @@ func RunWithConfig(logPath string, configPath string, stdin io.Reader, stdout io } logging.Bind(logger) loadOpts := appconfig.LoadOptions{ConfigPath: configPath} - cfg := appconfig.LoadWithOptions(logger, loadOpts) + cfg := deps.loadConfig(logger, loadOpts) if err := cfg.Validate(); err != nil { return fmt.Errorf("invalid config: %w", err) } if cfg.StatsWindowMinutes > 0 { stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) } - return RunWithFactory(logPath, configPath, stdin, stdout, logger, cfg, nil, nil) + return runWithDependencies(logPath, configPath, stdin, stdout, logger, cfg, nil, nil, deps) } // RunWithFactory is the testable entrypoint. When client is nil, it is built from cfg+env. // When factory is nil, lsp.NewServer is used. func RunWithFactory(logPath string, configPath string, stdin io.Reader, stdout io.Writer, logger *log.Logger, cfg appconfig.App, client llm.Client, factory ServerFactory) error { + return runWithDependencies(logPath, configPath, stdin, stdout, logger, cfg, client, factory, defaultRunDependencies()) +} + +func runWithDependencies(logPath string, configPath string, stdin io.Reader, stdout io.Writer, logger *log.Logger, cfg appconfig.App, client llm.Client, factory ServerFactory, deps runDependencies) error { + deps = normalizeRunDependencies(deps) normalizeLoggingConfig(&cfg) if err := cfg.Validate(); err != nil { return fmt.Errorf("invalid config: %w", err) } - client = buildClientIfNil(cfg, client) + client = deps.buildClient(cfg, client) factory = ensureFactory(factory) - // Create gitignore-aware file checker for LSP filtering - gitRoot := appconfig.FindGitRoot() - useGI := cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore - ignoreChecker := ignore.New(gitRoot, useGI, cfg.IgnoreExtraPatterns) - - store := runtimeconfig.New(cfg) + ignoreChecker := deps.newIgnoreChecker(cfg) + store := deps.newConfigStore(cfg) logContext := strings.TrimSpace(logPath) != "" loadOpts := appconfig.LoadOptions{ConfigPath: strings.TrimSpace(configPath)} - opts := makeServerOptions(cfg, logContext, client, loadOpts, ignoreChecker) + opts := makeServerOptions(cfg, logContext, client, loadOpts, ignoreChecker, deps.statusSink) opts.ConfigLoadOptions = loadOpts opts.ConfigStore = store server := factory(stdin, stdout, logger, opts) @@ -110,13 +114,13 @@ func RunWithFactory(logPath string, configPath string, stdin io.Reader, stdout i if updated.StatsWindowMinutes > 0 { stats.SetWindow(time.Duration(updated.StatsWindowMinutes) * time.Minute) } - if newClient := buildClientIfNil(updated, nil); newClient != nil { + if newClient := deps.buildClient(updated, nil); newClient != nil { client = newClient } // Update ignore checker patterns on config hot-reload useGI := updated.IgnoreGitignore == nil || *updated.IgnoreGitignore ignoreChecker.Update(useGI, updated.IgnoreExtraPatterns) - opts := makeServerOptions(updated, logContext, client, loadOpts, ignoreChecker) + opts := makeServerOptions(updated, logContext, client, loadOpts, ignoreChecker, deps.statusSink) opts.ConfigStore = store configurable.ApplyOptions(opts) }) @@ -136,19 +140,6 @@ func normalizeLoggingConfig(cfg *appconfig.App) { } } -func buildClientIfNil(cfg appconfig.App, client llm.Client) llm.Client { - if client != nil { - return client - } - c, err := llmutils.NewClientFromApp(cfg) - if err != nil { - logging.Logf("lsp ", "llm disabled: %v", err) - return nil - } - logging.Logf("lsp ", "llm enabled provider=%s model=%s", c.Name(), c.DefaultModel()) - return c -} - func ensureFactory(factory ServerFactory) ServerFactory { if factory != nil { return factory @@ -158,7 +149,7 @@ func ensureFactory(factory ServerFactory) ServerFactory { } } -func makeServerOptions(cfg appconfig.App, logContext bool, client llm.Client, loadOpts appconfig.LoadOptions, ignoreChecker *ignore.Checker) lsp.ServerOptions { +func makeServerOptions(cfg appconfig.App, logContext bool, client llm.Client, loadOpts appconfig.LoadOptions, ignoreChecker *ignore.Checker, statusSink lsp.StatusSink) lsp.ServerOptions { return lsp.ServerOptions{ ConfigLoadOptions: loadOpts, LogContext: logContext, @@ -166,6 +157,6 @@ func makeServerOptions(cfg appconfig.App, logContext bool, client llm.Client, lo Config: &cfg, Client: client, IgnoreChecker: ignoreChecker, - StatusSink: tmuxStatusSink{}, + StatusSink: statusSink, } } diff --git a/internal/hexailsp/run_more_test.go b/internal/hexailsp/run_more_test.go index 7017811..d0f17b5 100644 --- a/internal/hexailsp/run_more_test.go +++ b/internal/hexailsp/run_more_test.go @@ -30,6 +30,11 @@ func (stubClient) Chat(context.Context, []llm.Message, ...llm.RequestOption) (st func (stubClient) Name() string { return "stub" } func (stubClient) DefaultModel() string { return "stub-model" } +type recordingStatusSink struct{} + +func (recordingStatusSink) SetLLMStart(string, string) error { return nil } +func (recordingStatusSink) SetGlobal(lsp.GlobalStatus) error { return nil } + func TestRunWithFactory_BuildsOptionsAndClient(t *testing.T) { var captured lsp.ServerOptions factory := func(r io.Reader, w io.Writer, logger *log.Logger, opts lsp.ServerOptions) ServerRunner { @@ -101,3 +106,34 @@ func TestRunWithFactory_SubscriptionAppliesUpdates(t *testing.T) { t.Fatalf("expected normalized context mode, got %+v", latest) } } + +func TestRunWithDependencies_UsesInjectedClientBuilderAndStatusSink(t *testing.T) { + var captured lsp.ServerOptions + sink := &recordingStatusSink{} + buildCalls := 0 + factory := func(r io.Reader, w io.Writer, logger *log.Logger, opts lsp.ServerOptions) ServerRunner { + captured = opts + return &recRunner{} + } + cfg := appconfig.Load(nil) + if err := runWithDependencies("", "", bytes.NewBuffer(nil), bytes.NewBuffer(nil), log.New(io.Discard, "", 0), cfg, nil, factory, runDependencies{ + buildClient: func(appconfig.App, llm.Client) llm.Client { + buildCalls++ + return stubClient{} + }, + newConfigStore: runtimeconfig.New, + newIgnoreChecker: defaultIgnoreCheckerFactory, + statusSink: sink, + }); err != nil { + t.Fatalf("runWithDependencies error: %v", err) + } + if buildCalls != 1 { + t.Fatalf("expected one client build, got %d", buildCalls) + } + if captured.Client == nil { + t.Fatal("expected injected client to be passed through") + } + if captured.StatusSink != sink { + t.Fatal("expected injected status sink to be passed through") + } +} |
