diff options
Diffstat (limited to 'internal/hexaicli/run.go')
| -rw-r--r-- | internal/hexaicli/run.go | 64 |
1 files changed, 47 insertions, 17 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index bc0341d..b48bee0 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -205,8 +205,9 @@ type chatRunSummary struct { } func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error { - results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr) - if printer == nil { + streamSingle := len(jobs) == 1 + results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr, streamSingle) + if printer == nil && !streamSingle { if err := writeCLIJobOutputs(stdout, results); err != nil { return err } @@ -214,17 +215,17 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st return writeCLIJobSummaries(stderr, results) } -func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer) ([]*cliJobResult, *termprint.ColumnPrinter) { +func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer, streamSingle bool) ([]*cliJobResult, *termprint.ColumnPrinter) { results := make([]*cliJobResult, len(jobs)) printer := setupCLIPrinter(stdout, jobs) + printCLIHeader(stderr, jobs, printer) var wg sync.WaitGroup for _, job := range jobs { job := job wg.Add(1) - printProviderInfo(stderr, job.client, job.req.model) go func() { defer wg.Done() - results[job.index] = runSingleCLIJob(ctx, job, msgs, input, printer) + results[job.index] = runSingleCLIJob(ctx, job, msgs, input, stdout, printer, streamSingle) }() } wg.Wait() @@ -232,21 +233,21 @@ func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, inpu } func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter { - if len(jobs) == 0 { + if len(jobs) < 2 { return nil } - printer := newColumnPrinter(stdout, jobs) - printer.PrintHeader() - return printer + return newColumnPrinter(stdout, jobs) } -func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, printer *termprint.ColumnPrinter) *cliJobResult { +func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult { var errBuf bytes.Buffer var outBuf bytes.Buffer jobMsgs := append([]llm.Message(nil), msgs...) writer := io.Writer(&outBuf) if printer != nil { writer = printer.Writer(job.index) + } else if streamOutput { + writer = io.MultiWriter(stdout, &outBuf) } err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf) if printer != nil { @@ -263,6 +264,7 @@ func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error { printed := false + showHeading := cliJobResultCount(results) > 1 for _, res := range results { if res == nil { continue @@ -272,7 +274,7 @@ func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error { return err } } - if err := writeCLIJobOutput(stdout, res); err != nil { + if err := writeCLIJobOutput(stdout, res, showHeading); err != nil { return err } printed = true @@ -280,10 +282,22 @@ func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error { return nil } -func writeCLIJobOutput(stdout io.Writer, res *cliJobResult) error { - heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) - if _, err := io.WriteString(stdout, heading); err != nil { - return err +func cliJobResultCount(results []*cliJobResult) int { + count := 0 + for _, res := range results { + if res != nil { + count++ + } + } + return count +} + +func writeCLIJobOutput(stdout io.Writer, res *cliJobResult, showHeading bool) error { + if showHeading { + heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) + if _, err := io.WriteString(stdout, heading); err != nil { + return err + } } if res.output == "" { return nil @@ -338,6 +352,18 @@ func newColumnPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter return termprint.NewColumnPrinter(stdout, providers, models) } +func printCLIHeader(stderr io.Writer, jobs []cliJob, printer *termprint.ColumnPrinter) { + if len(jobs) == 0 { + return + } + if printer != nil { + printer.PrintHeaderTo(stderr) + return + } + job := jobs[0] + printProviderInfo(stderr, job.client, job.req.model) +} + // WithCLISelection injects provider indices into the context so Run only executes those jobs. func WithCLISelection(ctx context.Context, indices []int) context.Context { if ctx == nil { @@ -549,12 +575,16 @@ func summarizeChatRun(ctx context.Context, client llm.Client, model string, msgs return summary } -// printProviderInfo writes the provider/model line to stderr. +// printProviderInfo writes the provider:model header and divider to stderr. func printProviderInfo(errw io.Writer, client llm.Client, model string) { if strings.TrimSpace(model) == "" { model = client.DefaultModel() } - _, _ = fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), model) + printer := termprint.NewColumnPrinter(errw, []string{client.Name()}, []string{model}) + if printer == nil { + return + } + printer.PrintHeader() } // newClientFromConfig is kept for tests; delegates to llmutils. |
