summaryrefslogtreecommitdiff
path: root/internal/hexaicli/run.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/hexaicli/run.go')
-rw-r--r--internal/hexaicli/run.go64
1 files changed, 47 insertions, 17 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index bc0341d..b48bee0 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -205,8 +205,9 @@ type chatRunSummary struct {
}
func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error {
- results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr)
- if printer == nil {
+ streamSingle := len(jobs) == 1
+ results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr, streamSingle)
+ if printer == nil && !streamSingle {
if err := writeCLIJobOutputs(stdout, results); err != nil {
return err
}
@@ -214,17 +215,17 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st
return writeCLIJobSummaries(stderr, results)
}
-func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer) ([]*cliJobResult, *termprint.ColumnPrinter) {
+func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer, streamSingle bool) ([]*cliJobResult, *termprint.ColumnPrinter) {
results := make([]*cliJobResult, len(jobs))
printer := setupCLIPrinter(stdout, jobs)
+ printCLIHeader(stderr, jobs, printer)
var wg sync.WaitGroup
for _, job := range jobs {
job := job
wg.Add(1)
- printProviderInfo(stderr, job.client, job.req.model)
go func() {
defer wg.Done()
- results[job.index] = runSingleCLIJob(ctx, job, msgs, input, printer)
+ results[job.index] = runSingleCLIJob(ctx, job, msgs, input, stdout, printer, streamSingle)
}()
}
wg.Wait()
@@ -232,21 +233,21 @@ func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, inpu
}
func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter {
- if len(jobs) == 0 {
+ if len(jobs) < 2 {
return nil
}
- printer := newColumnPrinter(stdout, jobs)
- printer.PrintHeader()
- return printer
+ return newColumnPrinter(stdout, jobs)
}
-func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, printer *termprint.ColumnPrinter) *cliJobResult {
+func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult {
var errBuf bytes.Buffer
var outBuf bytes.Buffer
jobMsgs := append([]llm.Message(nil), msgs...)
writer := io.Writer(&outBuf)
if printer != nil {
writer = printer.Writer(job.index)
+ } else if streamOutput {
+ writer = io.MultiWriter(stdout, &outBuf)
}
err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf)
if printer != nil {
@@ -263,6 +264,7 @@ func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input
func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error {
printed := false
+ showHeading := cliJobResultCount(results) > 1
for _, res := range results {
if res == nil {
continue
@@ -272,7 +274,7 @@ func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error {
return err
}
}
- if err := writeCLIJobOutput(stdout, res); err != nil {
+ if err := writeCLIJobOutput(stdout, res, showHeading); err != nil {
return err
}
printed = true
@@ -280,10 +282,22 @@ func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error {
return nil
}
-func writeCLIJobOutput(stdout io.Writer, res *cliJobResult) error {
- heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model)
- if _, err := io.WriteString(stdout, heading); err != nil {
- return err
+func cliJobResultCount(results []*cliJobResult) int {
+ count := 0
+ for _, res := range results {
+ if res != nil {
+ count++
+ }
+ }
+ return count
+}
+
+func writeCLIJobOutput(stdout io.Writer, res *cliJobResult, showHeading bool) error {
+ if showHeading {
+ heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model)
+ if _, err := io.WriteString(stdout, heading); err != nil {
+ return err
+ }
}
if res.output == "" {
return nil
@@ -338,6 +352,18 @@ func newColumnPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter
return termprint.NewColumnPrinter(stdout, providers, models)
}
+func printCLIHeader(stderr io.Writer, jobs []cliJob, printer *termprint.ColumnPrinter) {
+ if len(jobs) == 0 {
+ return
+ }
+ if printer != nil {
+ printer.PrintHeaderTo(stderr)
+ return
+ }
+ job := jobs[0]
+ printProviderInfo(stderr, job.client, job.req.model)
+}
+
// WithCLISelection injects provider indices into the context so Run only executes those jobs.
func WithCLISelection(ctx context.Context, indices []int) context.Context {
if ctx == nil {
@@ -549,12 +575,16 @@ func summarizeChatRun(ctx context.Context, client llm.Client, model string, msgs
return summary
}
-// printProviderInfo writes the provider/model line to stderr.
+// printProviderInfo writes the provider:model header and divider to stderr.
func printProviderInfo(errw io.Writer, client llm.Client, model string) {
if strings.TrimSpace(model) == "" {
model = client.DefaultModel()
}
- _, _ = fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), model)
+ printer := termprint.NewColumnPrinter(errw, []string{client.Name()}, []string{model})
+ if printer == nil {
+ return
+ }
+ printer.PrintHeader()
}
// newClientFromConfig is kept for tests; delegates to llmutils.