diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-19 09:17:38 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-19 09:17:38 +0200 |
| commit | 5642eaf74a4a70e5c82646bef3e0dd42846baea8 (patch) | |
| tree | b767627a1cc4d74d2f02598ef0e45bff139d4358 /cmd | |
| parent | b90e06b41ee0cd6a6bf420462c21b38ae2a788c1 (diff) | |
Add hexai task proxy dispatch tests
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/hexai/app_runner.go | 202 | ||||
| -rw-r--r-- | cmd/hexai/app_runner_test.go | 71 | ||||
| -rw-r--r-- | cmd/hexai/main.go | 117 | ||||
| -rw-r--r-- | cmd/hexai/task_command_test.go | 57 |
4 files changed, 333 insertions, 114 deletions
diff --git a/cmd/hexai/app_runner.go b/cmd/hexai/app_runner.go new file mode 100644 index 0000000..cf1ed3a --- /dev/null +++ b/cmd/hexai/app_runner.go @@ -0,0 +1,202 @@ +package main + +import ( + "context" + "flag" + "fmt" + "io" + "log" + "strconv" + "strings" + + "codeberg.org/snonux/hexai/internal" + "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/hexaicli" +) + +type configLoader func(string) appconfig.App + +type cliRunner func(context.Context, []string, io.Reader, io.Writer, io.Writer) error + +type taskSubcommandRunner func([]string, io.Reader, io.Writer, io.Writer) (bool, int, error) + +type appRunner struct { + loadConfig configLoader + runCLI cliRunner + runTaskSubcommand taskSubcommandRunner +} + +type parsedAppArgs struct { + args []string + finalPath string + tpsSimulation string + selection []int + showVersion bool +} + +func newAppRunner() appRunner { + return appRunner{ + loadConfig: loadAppConfig, + runCLI: hexaicli.Run, + runTaskSubcommand: runTaskSubcommandIfRequested, + } +} + +func (r appRunner) run(args []string, stdin io.Reader, stdout, stderr io.Writer) int { + runner := normalizeAppRunner(r) + configPath, remaining := splitConfigPath(args) + cfg := runner.loadConfig(configPath) + parsed, err := parseAppArgs(cfg, configPath, remaining, stderr) + if err != nil { + return 2 + } + if parsed.showVersion { + fmt.Fprintln(stdout, internal.Version) + return 0 + } + if handled, exitCode, err := runner.runTaskSubcommand(parsed.args, stdin, stdout, stderr); handled { + if err != nil { + fmt.Fprintln(stderr, err) + } + return exitCode + } + ctx := buildCLIContext(parsed) + if err := runner.runCLI(ctx, parsed.args, stdin, stdout, stderr); err != nil { + return 1 + } + return 0 +} + +func normalizeAppRunner(r appRunner) appRunner { + if r.loadConfig == nil { + r.loadConfig = loadAppConfig + } + if r.runCLI == nil { + r.runCLI = hexaicli.Run + } + if r.runTaskSubcommand == nil { + r.runTaskSubcommand = runTaskSubcommandIfRequested + } + return r +} + +func loadAppConfig(configPath string) appconfig.App { + logger := log.New(io.Discard, "", 0) + return appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPath}) +} + +func parseAppArgs(cfg appconfig.App, configPath string, args []string, stderr io.Writer) (parsedAppArgs, error) { + cliEntries := cliEntriesFromConfig(cfg) + fs := flag.NewFlagSet("hexai", flag.ContinueOnError) + fs.SetOutput(stderr) + defaultPath := appconfig.DefaultConfigPath() + configFlag := fs.String("config", configPath, fmt.Sprintf("path to config file (default: %s)", defaultPath)) + tpsSimulation := fs.String("tps-simulation", "", "simulate stdout at a token-per-second rate; accepts '12' or '10-20'") + showVersion := fs.Bool("version", false, "print version and exit") + selectedFlags := bindProviderFlags(fs, cfg, cliEntries) + if err := fs.Parse(args); err != nil { + return parsedAppArgs{}, err + } + return parsedAppArgs{ + args: fs.Args(), + finalPath: normalizeConfigPath(*configFlag, configPath), + tpsSimulation: strings.TrimSpace(*tpsSimulation), + selection: selectionFromFlags(selectedFlags), + showVersion: *showVersion, + }, nil +} + +func cliEntriesFromConfig(cfg appconfig.App) []appconfig.SurfaceConfig { + if len(cfg.CLIConfigs) > 0 { + return cfg.CLIConfigs + } + return []appconfig.SurfaceConfig{{Provider: cfg.Provider}} +} + +func bindProviderFlags(fs *flag.FlagSet, cfg appconfig.App, cliEntries []appconfig.SurfaceConfig) []bool { + selectedFlags := make([]bool, len(cliEntries)) + for i, entry := range cliEntries { + name := strconv.Itoa(i) + provider := strings.TrimSpace(entry.Provider) + if provider == "" { + provider = cfg.Provider + } + model := strings.TrimSpace(entry.Model) + if model == "" { + model = pickDefaultModel(cfg, provider) + } + desc := fmt.Sprintf("use only provider #%d (%s:%s)", i, provider, model) + fs.BoolVar(&selectedFlags[i], name, false, desc) + } + return selectedFlags +} + +func normalizeConfigPath(flagValue, fallback string) string { + finalPath := strings.TrimSpace(flagValue) + if finalPath != "" { + return finalPath + } + return fallback +} + +func selectionFromFlags(selectedFlags []bool) []int { + var selection []int + for i, sel := range selectedFlags { + if sel { + selection = append(selection, i) + } + } + return selection +} + +func buildCLIContext(parsed parsedAppArgs) context.Context { + ctx := context.Background() + if parsed.finalPath != "" { + ctx = hexaicli.WithCLIConfigPath(ctx, parsed.finalPath) + } + if parsed.tpsSimulation != "" { + ctx = hexaicli.WithCLITPSSimulation(ctx, parsed.tpsSimulation) + } + if len(parsed.selection) > 0 { + ctx = hexaicli.WithCLISelection(ctx, parsed.selection) + } + return ctx +} + +func splitConfigPath(args []string) (string, []string) { + var path string + rest := make([]string, 0, len(args)) + skip := false + for i := 0; i < len(args); i++ { + if skip { + skip = false + continue + } + arg := args[i] + switch { + case arg == "--config" || arg == "-config": + if i+1 < len(args) { + path = args[i+1] + skip = true + } + case strings.HasPrefix(arg, "--config="): + path = arg[len("--config="):] + case strings.HasPrefix(arg, "-config="): + path = arg[len("-config="):] + default: + rest = append(rest, arg) + } + } + return strings.TrimSpace(path), rest +} + +func pickDefaultModel(cfg appconfig.App, provider string) string { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "ollama": + return strings.TrimSpace(cfg.OllamaModel) + case "anthropic": + return strings.TrimSpace(cfg.AnthropicModel) + default: + return strings.TrimSpace(cfg.OpenAIModel) + } +} diff --git a/cmd/hexai/app_runner_test.go b/cmd/hexai/app_runner_test.go new file mode 100644 index 0000000..2f03210 --- /dev/null +++ b/cmd/hexai/app_runner_test.go @@ -0,0 +1,71 @@ +package main + +import ( + "bytes" + "context" + "io" + "reflect" + "strings" + "testing" + + "codeberg.org/snonux/hexai/internal/appconfig" +) + +func TestAppRunnerRun_TaskDispatchAfterConfigFlag(t *testing.T) { + var gotConfigPath string + var gotArgs []string + runner := appRunner{ + loadConfig: func(path string) appconfig.App { + gotConfigPath = path + return appconfig.App{} + }, + runCLI: func(context.Context, []string, io.Reader, io.Writer, io.Writer) error { + t.Fatal("runCLI should not be called when task subcommand is handled") + return nil + }, + runTaskSubcommand: func(args []string, stdin io.Reader, stdout, stderr io.Writer) (bool, int, error) { + gotArgs = append([]string(nil), args...) + return true, 0, nil + }, + } + + exitCode := runner.run([]string{"--config", "/tmp/hexai.toml", "task", "list"}, strings.NewReader(""), &bytes.Buffer{}, &bytes.Buffer{}) + if exitCode != 0 { + t.Fatalf("exitCode = %d, want 0", exitCode) + } + if gotConfigPath != "/tmp/hexai.toml" { + t.Fatalf("configPath = %q, want /tmp/hexai.toml", gotConfigPath) + } + wantArgs := []string{"task", "list"} + if !reflect.DeepEqual(gotArgs, wantArgs) { + t.Fatalf("task args = %v, want %v", gotArgs, wantArgs) + } +} + +func TestAppRunnerRun_SingleArgumentTaskListFallsThroughToCLI(t *testing.T) { + var taskArgs []string + var cliArgs []string + runner := appRunner{ + loadConfig: func(string) appconfig.App { return appconfig.App{} }, + runCLI: func(_ context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { + cliArgs = append([]string(nil), args...) + return nil + }, + runTaskSubcommand: func(args []string, stdin io.Reader, stdout, stderr io.Writer) (bool, int, error) { + taskArgs = append([]string(nil), args...) + return false, 0, nil + }, + } + + exitCode := runner.run([]string{"task list"}, strings.NewReader(""), &bytes.Buffer{}, &bytes.Buffer{}) + if exitCode != 0 { + t.Fatalf("exitCode = %d, want 0", exitCode) + } + wantArgs := []string{"task list"} + if !reflect.DeepEqual(taskArgs, wantArgs) { + t.Fatalf("task dispatch args = %v, want %v", taskArgs, wantArgs) + } + if !reflect.DeepEqual(cliArgs, wantArgs) { + t.Fatalf("cli args = %v, want %v", cliArgs, wantArgs) + } +} diff --git a/cmd/hexai/main.go b/cmd/hexai/main.go index b14ee37..c6414e1 100644 --- a/cmd/hexai/main.go +++ b/cmd/hexai/main.go @@ -1,121 +1,10 @@ // Package main is the Hexai CLI entrypoint; parses flags and delegates to internal/hexaicli. package main -import ( - "context" - "flag" - "fmt" - "io" - "log" - "os" - "strconv" - "strings" - - "codeberg.org/snonux/hexai/internal" - "codeberg.org/snonux/hexai/internal/appconfig" - "codeberg.org/snonux/hexai/internal/hexaicli" -) +import "os" func main() { - configPath, remaining := splitConfigPath(os.Args[1:]) - logger := log.New(io.Discard, "", 0) - cfg := appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPath}) - cliEntries := cfg.CLIConfigs - if len(cliEntries) == 0 { - cliEntries = []appconfig.SurfaceConfig{{Provider: cfg.Provider}} - } - fs := flag.NewFlagSet(os.Args[0], flag.ExitOnError) - defaultPath := appconfig.DefaultConfigPath() - configFlag := fs.String("config", configPath, fmt.Sprintf("path to config file (default: %s)", defaultPath)) - tpsSimulation := fs.String("tps-simulation", "", "simulate stdout at a token-per-second rate; accepts '12' or '10-20'") - showVersion := fs.Bool("version", false, "print version and exit") - selectedFlags := make([]bool, len(cliEntries)) - for i, entry := range cliEntries { - name := strconv.Itoa(i) - provider := strings.TrimSpace(entry.Provider) - if provider == "" { - provider = cfg.Provider - } - model := strings.TrimSpace(entry.Model) - if model == "" { - model = pickDefaultModel(cfg, provider) - } - desc := fmt.Sprintf("use only provider #%d (%s:%s)", i, provider, model) - fs.BoolVar(&selectedFlags[i], name, false, desc) - } - _ = fs.Parse(remaining) - if *showVersion { - fmt.Fprintln(os.Stdout, internal.Version) - return - } - var selection []int - for i, sel := range selectedFlags { - if sel { - selection = append(selection, i) - } - } - finalPath := strings.TrimSpace(*configFlag) - if finalPath == "" { - finalPath = configPath - } - if handled, exitCode, err := runTaskSubcommandIfRequested(fs.Args(), os.Stdin, os.Stdout, os.Stderr); handled { - if err != nil { - fmt.Fprintln(os.Stderr, err) - } - if exitCode != 0 { - os.Exit(exitCode) - } - return - } - ctx := context.Background() - if finalPath != "" { - ctx = hexaicli.WithCLIConfigPath(ctx, finalPath) - } - if strings.TrimSpace(*tpsSimulation) != "" { - ctx = hexaicli.WithCLITPSSimulation(ctx, *tpsSimulation) - } - if len(selection) > 0 { - ctx = hexaicli.WithCLISelection(ctx, selection) - } - if err := hexaicli.Run(ctx, fs.Args(), os.Stdin, os.Stdout, os.Stderr); err != nil { - os.Exit(1) - } -} - -func splitConfigPath(args []string) (string, []string) { - var path string - rest := make([]string, 0, len(args)) - skip := false - for i := 0; i < len(args); i++ { - if skip { - skip = false - continue - } - arg := args[i] - switch { - case arg == "--config" || arg == "-config": - if i+1 < len(args) { - path = args[i+1] - skip = true - } - case strings.HasPrefix(arg, "--config="): - path = arg[len("--config="):] - case strings.HasPrefix(arg, "-config="): - path = arg[len("-config="):] - default: - rest = append(rest, arg) - } - } - return strings.TrimSpace(path), rest -} - -func pickDefaultModel(cfg appconfig.App, provider string) string { - switch strings.ToLower(strings.TrimSpace(provider)) { - case "ollama": - return strings.TrimSpace(cfg.OllamaModel) - case "anthropic": - return strings.TrimSpace(cfg.AnthropicModel) - default: - return strings.TrimSpace(cfg.OpenAIModel) + if exitCode := newAppRunner().run(os.Args[1:], os.Stdin, os.Stdout, os.Stderr); exitCode != 0 { + os.Exit(exitCode) } } diff --git a/cmd/hexai/task_command_test.go b/cmd/hexai/task_command_test.go index 2900910..498ac65 100644 --- a/cmd/hexai/task_command_test.go +++ b/cmd/hexai/task_command_test.go @@ -83,3 +83,60 @@ func TestTaskRunnerRun_PreservesTaskwarriorExitCode(t *testing.T) { t.Fatalf("exitCode = %d, want 7", exitCode) } } + +func TestTaskRunnerRun_PreservesStdoutAndStderr(t *testing.T) { + var stdout bytes.Buffer + var stderr bytes.Buffer + runner := taskRunner{ + findTaskBinary: func() (string, error) { return "/usr/bin/task", nil }, + detectRepoRoot: func(context.Context) (string, error) { return "/tmp/work/hexai", nil }, + runCommand: func(_ context.Context, name string, args []string, stdin io.Reader, out, errOut io.Writer) error { + _, _ = io.WriteString(out, "task stdout") + _, _ = io.WriteString(errOut, "task stderr") + return nil + }, + } + + exitCode, err := runner.run(context.Background(), []string{"list"}, strings.NewReader(""), &stdout, &stderr) + if err != nil { + t.Fatalf("run returned error: %v", err) + } + if exitCode != 0 { + t.Fatalf("exitCode = %d, want 0", exitCode) + } + if stdout.String() != "task stdout" { + t.Fatalf("stdout = %q, want %q", stdout.String(), "task stdout") + } + if stderr.String() != "task stderr" { + t.Fatalf("stderr = %q, want %q", stderr.String(), "task stderr") + } +} + +func TestTaskRunnerRun_TaskLookupFailure_IsActionable(t *testing.T) { + runner := taskRunner{ + findTaskBinary: func() (string, error) { return "", errors.New("not found") }, + } + + exitCode, err := runner.run(context.Background(), []string{"list"}, strings.NewReader(""), &bytes.Buffer{}, &bytes.Buffer{}) + if exitCode != 1 { + t.Fatalf("exitCode = %d, want 1", exitCode) + } + if err == nil || !strings.Contains(err.Error(), "Taskwarrior binary lookup failed") { + t.Fatalf("expected actionable task lookup error, got %v", err) + } +} + +func TestTaskRunnerRun_EmptyRepoName_IsActionable(t *testing.T) { + runner := taskRunner{ + findTaskBinary: func() (string, error) { return "/usr/bin/task", nil }, + detectRepoRoot: func(context.Context) (string, error) { return "/", nil }, + } + + exitCode, err := runner.run(context.Background(), []string{"list"}, strings.NewReader(""), &bytes.Buffer{}, &bytes.Buffer{}) + if exitCode != 1 { + t.Fatalf("exitCode = %d, want 1", exitCode) + } + if err == nil || !strings.Contains(err.Error(), "could not derive project name") { + t.Fatalf("expected actionable project-name error, got %v", err) + } +} |
