summaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-19 09:17:38 +0200
committerPaul Buetow <paul@buetow.org>2026-03-19 09:17:38 +0200
commit5642eaf74a4a70e5c82646bef3e0dd42846baea8 (patch)
treeb767627a1cc4d74d2f02598ef0e45bff139d4358 /cmd
parentb90e06b41ee0cd6a6bf420462c21b38ae2a788c1 (diff)
Add hexai task proxy dispatch tests
Diffstat (limited to 'cmd')
-rw-r--r--cmd/hexai/app_runner.go202
-rw-r--r--cmd/hexai/app_runner_test.go71
-rw-r--r--cmd/hexai/main.go117
-rw-r--r--cmd/hexai/task_command_test.go57
4 files changed, 333 insertions, 114 deletions
diff --git a/cmd/hexai/app_runner.go b/cmd/hexai/app_runner.go
new file mode 100644
index 0000000..cf1ed3a
--- /dev/null
+++ b/cmd/hexai/app_runner.go
@@ -0,0 +1,202 @@
+package main
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "strconv"
+ "strings"
+
+ "codeberg.org/snonux/hexai/internal"
+ "codeberg.org/snonux/hexai/internal/appconfig"
+ "codeberg.org/snonux/hexai/internal/hexaicli"
+)
+
+type configLoader func(string) appconfig.App
+
+type cliRunner func(context.Context, []string, io.Reader, io.Writer, io.Writer) error
+
+type taskSubcommandRunner func([]string, io.Reader, io.Writer, io.Writer) (bool, int, error)
+
+type appRunner struct {
+ loadConfig configLoader
+ runCLI cliRunner
+ runTaskSubcommand taskSubcommandRunner
+}
+
+type parsedAppArgs struct {
+ args []string
+ finalPath string
+ tpsSimulation string
+ selection []int
+ showVersion bool
+}
+
+func newAppRunner() appRunner {
+ return appRunner{
+ loadConfig: loadAppConfig,
+ runCLI: hexaicli.Run,
+ runTaskSubcommand: runTaskSubcommandIfRequested,
+ }
+}
+
+func (r appRunner) run(args []string, stdin io.Reader, stdout, stderr io.Writer) int {
+ runner := normalizeAppRunner(r)
+ configPath, remaining := splitConfigPath(args)
+ cfg := runner.loadConfig(configPath)
+ parsed, err := parseAppArgs(cfg, configPath, remaining, stderr)
+ if err != nil {
+ return 2
+ }
+ if parsed.showVersion {
+ fmt.Fprintln(stdout, internal.Version)
+ return 0
+ }
+ if handled, exitCode, err := runner.runTaskSubcommand(parsed.args, stdin, stdout, stderr); handled {
+ if err != nil {
+ fmt.Fprintln(stderr, err)
+ }
+ return exitCode
+ }
+ ctx := buildCLIContext(parsed)
+ if err := runner.runCLI(ctx, parsed.args, stdin, stdout, stderr); err != nil {
+ return 1
+ }
+ return 0
+}
+
+func normalizeAppRunner(r appRunner) appRunner {
+ if r.loadConfig == nil {
+ r.loadConfig = loadAppConfig
+ }
+ if r.runCLI == nil {
+ r.runCLI = hexaicli.Run
+ }
+ if r.runTaskSubcommand == nil {
+ r.runTaskSubcommand = runTaskSubcommandIfRequested
+ }
+ return r
+}
+
+func loadAppConfig(configPath string) appconfig.App {
+ logger := log.New(io.Discard, "", 0)
+ return appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPath})
+}
+
+func parseAppArgs(cfg appconfig.App, configPath string, args []string, stderr io.Writer) (parsedAppArgs, error) {
+ cliEntries := cliEntriesFromConfig(cfg)
+ fs := flag.NewFlagSet("hexai", flag.ContinueOnError)
+ fs.SetOutput(stderr)
+ defaultPath := appconfig.DefaultConfigPath()
+ configFlag := fs.String("config", configPath, fmt.Sprintf("path to config file (default: %s)", defaultPath))
+ tpsSimulation := fs.String("tps-simulation", "", "simulate stdout at a token-per-second rate; accepts '12' or '10-20'")
+ showVersion := fs.Bool("version", false, "print version and exit")
+ selectedFlags := bindProviderFlags(fs, cfg, cliEntries)
+ if err := fs.Parse(args); err != nil {
+ return parsedAppArgs{}, err
+ }
+ return parsedAppArgs{
+ args: fs.Args(),
+ finalPath: normalizeConfigPath(*configFlag, configPath),
+ tpsSimulation: strings.TrimSpace(*tpsSimulation),
+ selection: selectionFromFlags(selectedFlags),
+ showVersion: *showVersion,
+ }, nil
+}
+
+func cliEntriesFromConfig(cfg appconfig.App) []appconfig.SurfaceConfig {
+ if len(cfg.CLIConfigs) > 0 {
+ return cfg.CLIConfigs
+ }
+ return []appconfig.SurfaceConfig{{Provider: cfg.Provider}}
+}
+
+func bindProviderFlags(fs *flag.FlagSet, cfg appconfig.App, cliEntries []appconfig.SurfaceConfig) []bool {
+ selectedFlags := make([]bool, len(cliEntries))
+ for i, entry := range cliEntries {
+ name := strconv.Itoa(i)
+ provider := strings.TrimSpace(entry.Provider)
+ if provider == "" {
+ provider = cfg.Provider
+ }
+ model := strings.TrimSpace(entry.Model)
+ if model == "" {
+ model = pickDefaultModel(cfg, provider)
+ }
+ desc := fmt.Sprintf("use only provider #%d (%s:%s)", i, provider, model)
+ fs.BoolVar(&selectedFlags[i], name, false, desc)
+ }
+ return selectedFlags
+}
+
+func normalizeConfigPath(flagValue, fallback string) string {
+ finalPath := strings.TrimSpace(flagValue)
+ if finalPath != "" {
+ return finalPath
+ }
+ return fallback
+}
+
+func selectionFromFlags(selectedFlags []bool) []int {
+ var selection []int
+ for i, sel := range selectedFlags {
+ if sel {
+ selection = append(selection, i)
+ }
+ }
+ return selection
+}
+
+func buildCLIContext(parsed parsedAppArgs) context.Context {
+ ctx := context.Background()
+ if parsed.finalPath != "" {
+ ctx = hexaicli.WithCLIConfigPath(ctx, parsed.finalPath)
+ }
+ if parsed.tpsSimulation != "" {
+ ctx = hexaicli.WithCLITPSSimulation(ctx, parsed.tpsSimulation)
+ }
+ if len(parsed.selection) > 0 {
+ ctx = hexaicli.WithCLISelection(ctx, parsed.selection)
+ }
+ return ctx
+}
+
+func splitConfigPath(args []string) (string, []string) {
+ var path string
+ rest := make([]string, 0, len(args))
+ skip := false
+ for i := 0; i < len(args); i++ {
+ if skip {
+ skip = false
+ continue
+ }
+ arg := args[i]
+ switch {
+ case arg == "--config" || arg == "-config":
+ if i+1 < len(args) {
+ path = args[i+1]
+ skip = true
+ }
+ case strings.HasPrefix(arg, "--config="):
+ path = arg[len("--config="):]
+ case strings.HasPrefix(arg, "-config="):
+ path = arg[len("-config="):]
+ default:
+ rest = append(rest, arg)
+ }
+ }
+ return strings.TrimSpace(path), rest
+}
+
+func pickDefaultModel(cfg appconfig.App, provider string) string {
+ switch strings.ToLower(strings.TrimSpace(provider)) {
+ case "ollama":
+ return strings.TrimSpace(cfg.OllamaModel)
+ case "anthropic":
+ return strings.TrimSpace(cfg.AnthropicModel)
+ default:
+ return strings.TrimSpace(cfg.OpenAIModel)
+ }
+}
diff --git a/cmd/hexai/app_runner_test.go b/cmd/hexai/app_runner_test.go
new file mode 100644
index 0000000..2f03210
--- /dev/null
+++ b/cmd/hexai/app_runner_test.go
@@ -0,0 +1,71 @@
+package main
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "reflect"
+ "strings"
+ "testing"
+
+ "codeberg.org/snonux/hexai/internal/appconfig"
+)
+
+func TestAppRunnerRun_TaskDispatchAfterConfigFlag(t *testing.T) {
+ var gotConfigPath string
+ var gotArgs []string
+ runner := appRunner{
+ loadConfig: func(path string) appconfig.App {
+ gotConfigPath = path
+ return appconfig.App{}
+ },
+ runCLI: func(context.Context, []string, io.Reader, io.Writer, io.Writer) error {
+ t.Fatal("runCLI should not be called when task subcommand is handled")
+ return nil
+ },
+ runTaskSubcommand: func(args []string, stdin io.Reader, stdout, stderr io.Writer) (bool, int, error) {
+ gotArgs = append([]string(nil), args...)
+ return true, 0, nil
+ },
+ }
+
+ exitCode := runner.run([]string{"--config", "/tmp/hexai.toml", "task", "list"}, strings.NewReader(""), &bytes.Buffer{}, &bytes.Buffer{})
+ if exitCode != 0 {
+ t.Fatalf("exitCode = %d, want 0", exitCode)
+ }
+ if gotConfigPath != "/tmp/hexai.toml" {
+ t.Fatalf("configPath = %q, want /tmp/hexai.toml", gotConfigPath)
+ }
+ wantArgs := []string{"task", "list"}
+ if !reflect.DeepEqual(gotArgs, wantArgs) {
+ t.Fatalf("task args = %v, want %v", gotArgs, wantArgs)
+ }
+}
+
+func TestAppRunnerRun_SingleArgumentTaskListFallsThroughToCLI(t *testing.T) {
+ var taskArgs []string
+ var cliArgs []string
+ runner := appRunner{
+ loadConfig: func(string) appconfig.App { return appconfig.App{} },
+ runCLI: func(_ context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
+ cliArgs = append([]string(nil), args...)
+ return nil
+ },
+ runTaskSubcommand: func(args []string, stdin io.Reader, stdout, stderr io.Writer) (bool, int, error) {
+ taskArgs = append([]string(nil), args...)
+ return false, 0, nil
+ },
+ }
+
+ exitCode := runner.run([]string{"task list"}, strings.NewReader(""), &bytes.Buffer{}, &bytes.Buffer{})
+ if exitCode != 0 {
+ t.Fatalf("exitCode = %d, want 0", exitCode)
+ }
+ wantArgs := []string{"task list"}
+ if !reflect.DeepEqual(taskArgs, wantArgs) {
+ t.Fatalf("task dispatch args = %v, want %v", taskArgs, wantArgs)
+ }
+ if !reflect.DeepEqual(cliArgs, wantArgs) {
+ t.Fatalf("cli args = %v, want %v", cliArgs, wantArgs)
+ }
+}
diff --git a/cmd/hexai/main.go b/cmd/hexai/main.go
index b14ee37..c6414e1 100644
--- a/cmd/hexai/main.go
+++ b/cmd/hexai/main.go
@@ -1,121 +1,10 @@
// Package main is the Hexai CLI entrypoint; parses flags and delegates to internal/hexaicli.
package main
-import (
- "context"
- "flag"
- "fmt"
- "io"
- "log"
- "os"
- "strconv"
- "strings"
-
- "codeberg.org/snonux/hexai/internal"
- "codeberg.org/snonux/hexai/internal/appconfig"
- "codeberg.org/snonux/hexai/internal/hexaicli"
-)
+import "os"
func main() {
- configPath, remaining := splitConfigPath(os.Args[1:])
- logger := log.New(io.Discard, "", 0)
- cfg := appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPath})
- cliEntries := cfg.CLIConfigs
- if len(cliEntries) == 0 {
- cliEntries = []appconfig.SurfaceConfig{{Provider: cfg.Provider}}
- }
- fs := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
- defaultPath := appconfig.DefaultConfigPath()
- configFlag := fs.String("config", configPath, fmt.Sprintf("path to config file (default: %s)", defaultPath))
- tpsSimulation := fs.String("tps-simulation", "", "simulate stdout at a token-per-second rate; accepts '12' or '10-20'")
- showVersion := fs.Bool("version", false, "print version and exit")
- selectedFlags := make([]bool, len(cliEntries))
- for i, entry := range cliEntries {
- name := strconv.Itoa(i)
- provider := strings.TrimSpace(entry.Provider)
- if provider == "" {
- provider = cfg.Provider
- }
- model := strings.TrimSpace(entry.Model)
- if model == "" {
- model = pickDefaultModel(cfg, provider)
- }
- desc := fmt.Sprintf("use only provider #%d (%s:%s)", i, provider, model)
- fs.BoolVar(&selectedFlags[i], name, false, desc)
- }
- _ = fs.Parse(remaining)
- if *showVersion {
- fmt.Fprintln(os.Stdout, internal.Version)
- return
- }
- var selection []int
- for i, sel := range selectedFlags {
- if sel {
- selection = append(selection, i)
- }
- }
- finalPath := strings.TrimSpace(*configFlag)
- if finalPath == "" {
- finalPath = configPath
- }
- if handled, exitCode, err := runTaskSubcommandIfRequested(fs.Args(), os.Stdin, os.Stdout, os.Stderr); handled {
- if err != nil {
- fmt.Fprintln(os.Stderr, err)
- }
- if exitCode != 0 {
- os.Exit(exitCode)
- }
- return
- }
- ctx := context.Background()
- if finalPath != "" {
- ctx = hexaicli.WithCLIConfigPath(ctx, finalPath)
- }
- if strings.TrimSpace(*tpsSimulation) != "" {
- ctx = hexaicli.WithCLITPSSimulation(ctx, *tpsSimulation)
- }
- if len(selection) > 0 {
- ctx = hexaicli.WithCLISelection(ctx, selection)
- }
- if err := hexaicli.Run(ctx, fs.Args(), os.Stdin, os.Stdout, os.Stderr); err != nil {
- os.Exit(1)
- }
-}
-
-func splitConfigPath(args []string) (string, []string) {
- var path string
- rest := make([]string, 0, len(args))
- skip := false
- for i := 0; i < len(args); i++ {
- if skip {
- skip = false
- continue
- }
- arg := args[i]
- switch {
- case arg == "--config" || arg == "-config":
- if i+1 < len(args) {
- path = args[i+1]
- skip = true
- }
- case strings.HasPrefix(arg, "--config="):
- path = arg[len("--config="):]
- case strings.HasPrefix(arg, "-config="):
- path = arg[len("-config="):]
- default:
- rest = append(rest, arg)
- }
- }
- return strings.TrimSpace(path), rest
-}
-
-func pickDefaultModel(cfg appconfig.App, provider string) string {
- switch strings.ToLower(strings.TrimSpace(provider)) {
- case "ollama":
- return strings.TrimSpace(cfg.OllamaModel)
- case "anthropic":
- return strings.TrimSpace(cfg.AnthropicModel)
- default:
- return strings.TrimSpace(cfg.OpenAIModel)
+ if exitCode := newAppRunner().run(os.Args[1:], os.Stdin, os.Stdout, os.Stderr); exitCode != 0 {
+ os.Exit(exitCode)
}
}
diff --git a/cmd/hexai/task_command_test.go b/cmd/hexai/task_command_test.go
index 2900910..498ac65 100644
--- a/cmd/hexai/task_command_test.go
+++ b/cmd/hexai/task_command_test.go
@@ -83,3 +83,60 @@ func TestTaskRunnerRun_PreservesTaskwarriorExitCode(t *testing.T) {
t.Fatalf("exitCode = %d, want 7", exitCode)
}
}
+
+func TestTaskRunnerRun_PreservesStdoutAndStderr(t *testing.T) {
+ var stdout bytes.Buffer
+ var stderr bytes.Buffer
+ runner := taskRunner{
+ findTaskBinary: func() (string, error) { return "/usr/bin/task", nil },
+ detectRepoRoot: func(context.Context) (string, error) { return "/tmp/work/hexai", nil },
+ runCommand: func(_ context.Context, name string, args []string, stdin io.Reader, out, errOut io.Writer) error {
+ _, _ = io.WriteString(out, "task stdout")
+ _, _ = io.WriteString(errOut, "task stderr")
+ return nil
+ },
+ }
+
+ exitCode, err := runner.run(context.Background(), []string{"list"}, strings.NewReader(""), &stdout, &stderr)
+ if err != nil {
+ t.Fatalf("run returned error: %v", err)
+ }
+ if exitCode != 0 {
+ t.Fatalf("exitCode = %d, want 0", exitCode)
+ }
+ if stdout.String() != "task stdout" {
+ t.Fatalf("stdout = %q, want %q", stdout.String(), "task stdout")
+ }
+ if stderr.String() != "task stderr" {
+ t.Fatalf("stderr = %q, want %q", stderr.String(), "task stderr")
+ }
+}
+
+func TestTaskRunnerRun_TaskLookupFailure_IsActionable(t *testing.T) {
+ runner := taskRunner{
+ findTaskBinary: func() (string, error) { return "", errors.New("not found") },
+ }
+
+ exitCode, err := runner.run(context.Background(), []string{"list"}, strings.NewReader(""), &bytes.Buffer{}, &bytes.Buffer{})
+ if exitCode != 1 {
+ t.Fatalf("exitCode = %d, want 1", exitCode)
+ }
+ if err == nil || !strings.Contains(err.Error(), "Taskwarrior binary lookup failed") {
+ t.Fatalf("expected actionable task lookup error, got %v", err)
+ }
+}
+
+func TestTaskRunnerRun_EmptyRepoName_IsActionable(t *testing.T) {
+ runner := taskRunner{
+ findTaskBinary: func() (string, error) { return "/usr/bin/task", nil },
+ detectRepoRoot: func(context.Context) (string, error) { return "/", nil },
+ }
+
+ exitCode, err := runner.run(context.Background(), []string{"list"}, strings.NewReader(""), &bytes.Buffer{}, &bytes.Buffer{})
+ if exitCode != 1 {
+ t.Fatalf("exitCode = %d, want 1", exitCode)
+ }
+ if err == nil || !strings.Contains(err.Error(), "could not derive project name") {
+ t.Fatalf("expected actionable project-name error, got %v", err)
+ }
+}