summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-02 14:11:47 +0200
committerPaul Buetow <paul@buetow.org>2026-03-02 14:11:47 +0200
commitf6f829ba620509dbc501ae282eeaae3ba123e231 (patch)
tree9abd95f893c7a05deab5d02233e08d270359d420
parentfae0964bed7e77e11df2fa98783c63c806670049 (diff)
hexaicli: decompose runCLIJobs and runChat helpers (task 417)
-rw-r--r--internal/hexaicli/run.go282
1 files changed, 175 insertions, 107 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index 85a3fdd..1505f31 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -200,87 +200,116 @@ type cliJobResult struct {
err error
}
+type chatRunSummary struct {
+ snapshot stats.Snapshot
+ sent int
+ recv int
+ scopeReq int64
+ scopeRPM float64
+}
+
func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error {
+ results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr)
+ if printer == nil {
+ if err := writeCLIJobOutputs(stdout, results); err != nil {
+ return err
+ }
+ }
+ return writeCLIJobSummaries(stderr, results)
+}
+
+func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer) ([]*cliJobResult, *columnPrinter) {
results := make([]*cliJobResult, len(jobs))
+ printer := setupCLIPrinter(stdout, jobs)
var wg sync.WaitGroup
- var printer *columnPrinter
- if len(jobs) > 0 {
- printer = newColumnPrinter(stdout, jobs)
- printer.PrintHeader()
- }
for _, job := range jobs {
job := job
wg.Add(1)
printProviderInfo(stderr, job.client, job.req.model)
go func() {
defer wg.Done()
- var errBuf bytes.Buffer
- var outBuf bytes.Buffer
- jobMsgs := make([]llm.Message, len(msgs))
- copy(jobMsgs, msgs)
- writer := io.Writer(&outBuf)
- if printer != nil {
- writer = printer.Writer(job.index)
- }
- err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf)
- if printer != nil {
- printer.Flush(job.index)
- }
- results[job.index] = &cliJobResult{
- provider: job.client.Name(),
- model: job.req.model,
- output: outBuf.String(),
- summary: errBuf.String(),
- err: err,
- }
+ results[job.index] = runSingleCLIJob(ctx, job, msgs, input, printer)
}()
}
wg.Wait()
- var firstErr error
- if printer == nil {
- printed := false
- for _, res := range results {
- if res == nil {
- continue
- }
- if printed {
- if _, err := io.WriteString(stdout, "\n"); err != nil {
- return err
- }
- }
- heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model)
- if _, err := io.WriteString(stdout, heading); err != nil {
+ return results, printer
+}
+
+func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *columnPrinter {
+ if len(jobs) == 0 {
+ return nil
+ }
+ printer := newColumnPrinter(stdout, jobs)
+ printer.PrintHeader()
+ return printer
+}
+
+func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, printer *columnPrinter) *cliJobResult {
+ var errBuf bytes.Buffer
+ var outBuf bytes.Buffer
+ jobMsgs := append([]llm.Message(nil), msgs...)
+ writer := io.Writer(&outBuf)
+ if printer != nil {
+ writer = printer.Writer(job.index)
+ }
+ err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf)
+ if printer != nil {
+ printer.Flush(job.index)
+ }
+ return &cliJobResult{
+ provider: job.client.Name(),
+ model: job.req.model,
+ output: outBuf.String(),
+ summary: errBuf.String(),
+ err: err,
+ }
+}
+
+func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error {
+ printed := false
+ for _, res := range results {
+ if res == nil {
+ continue
+ }
+ if printed {
+ if _, err := io.WriteString(stdout, "\n"); err != nil {
return err
}
- if res.output != "" {
- if _, err := io.WriteString(stdout, res.output); err != nil {
- return err
- }
- if !strings.HasSuffix(res.output, "\n") {
- if _, err := io.WriteString(stdout, "\n"); err != nil {
- return err
- }
- }
- }
- printed = true
}
+ if err := writeCLIJobOutput(stdout, res); err != nil {
+ return err
+ }
+ printed = true
}
+ return nil
+}
+
+func writeCLIJobOutput(stdout io.Writer, res *cliJobResult) error {
+ heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model)
+ if _, err := io.WriteString(stdout, heading); err != nil {
+ return err
+ }
+ if res.output == "" {
+ return nil
+ }
+ if _, err := io.WriteString(stdout, res.output); err != nil {
+ return err
+ }
+ if strings.HasSuffix(res.output, "\n") {
+ return nil
+ }
+ _, err := io.WriteString(stdout, "\n")
+ return err
+}
+
+func writeCLIJobSummaries(stderr io.Writer, results []*cliJobResult) error {
+ var firstErr error
for _, res := range results {
if res == nil {
continue
}
- if res.summary != "" {
- summary := strings.TrimLeft(res.summary, "\n")
- if summary != "" {
- if _, err := io.WriteString(stderr, summary); err != nil {
- return err
- }
- }
- }
- if res.err != nil {
- if _, err := fmt.Fprintf(stderr, logging.AnsiBase+"hexai: provider=%s model=%s error: %v"+logging.AnsiReset+"\n", res.provider, res.model, res.err); err != nil {
- return err
- }
+ if err := writeCLIJobSummary(stderr, res); err != nil {
+ return err
}
if firstErr == nil && res.err != nil {
firstErr = res.err
@@ -289,6 +318,20 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st
return firstErr
}
+func writeCLIJobSummary(stderr io.Writer, res *cliJobResult) error {
+ summary := strings.TrimLeft(res.summary, "\n")
+ if summary != "" {
+ if _, err := io.WriteString(stderr, summary); err != nil {
+ return err
+ }
+ }
+ if res.err == nil {
+ return nil
+ }
+ _, err := fmt.Fprintf(stderr, logging.AnsiBase+"hexai: provider=%s model=%s error: %v"+logging.AnsiReset+"\n", res.provider, res.model, res.err)
+ return err
+}
+
func newColumnPrinter(stdout io.Writer, jobs []cliJob) *columnPrinter {
cols := len(jobs)
width := detectTerminalWidth(stdout)
@@ -581,67 +624,92 @@ func buildMessagesFromConfig(cfg appconfig.App, input string) []llm.Message {
// runChat executes the chat request, handling streaming and summary output.
func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error {
start := time.Now()
- // Best-effort tmux status update (colored start heartbeat)
+ model := effectiveModel(req, client)
+ _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model))
+
+ output, err := runChatRequest(ctx, client, req, msgs, out)
+ if err != nil {
+ return err
+ }
+
+ dur := time.Since(start)
+ summary := summarizeChatRun(ctx, client, model, msgs, output)
+ if _, err := fmt.Fprintf(errw, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d | global Σ reqs=%d rpm=%.2f"+logging.AnsiReset+"\n",
+ client.Name(), model, dur.Round(time.Millisecond), summary.sent, summary.recv, summary.snapshot.Global.Reqs, summary.snapshot.RPM); err != nil {
+ return err
+ }
+ _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(summary.snapshot.Global.Reqs, summary.snapshot.RPM, summary.snapshot.Global.Sent, summary.snapshot.Global.Recv, client.Name(), model, summary.scopeRPM, summary.scopeReq, summary.snapshot.Window))
+ return nil
+}
+
+func effectiveModel(req requestArgs, client llm.Client) string {
model := strings.TrimSpace(req.model)
if model == "" {
model = client.DefaultModel()
}
- _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model))
- var output string
- if s, ok := client.(llm.Streamer); ok {
- var b strings.Builder
- var streamErr error
- if err := s.ChatStream(ctx, msgs, func(chunk string) {
- if streamErr != nil {
- return
- }
- b.WriteString(chunk)
- if _, err := fmt.Fprint(out, chunk); err != nil {
- streamErr = err
- }
- }, req.options...); err != nil {
- return err
- }
- if streamErr != nil {
- return streamErr
- }
- output = b.String()
- } else {
- txt, err := client.Chat(ctx, msgs, req.options...)
- if err != nil {
- return err
+ return model
+}
+
+func runChatRequest(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, out io.Writer) (string, error) {
+ if streamer, ok := client.(llm.Streamer); ok {
+ return runStreamingChat(ctx, streamer, msgs, req.options, out)
+ }
+ return runSimpleChat(ctx, client, msgs, req.options, out)
+}
+
+func runStreamingChat(ctx context.Context, client llm.Streamer, msgs []llm.Message, options []llm.RequestOption, out io.Writer) (string, error) {
+ var output strings.Builder
+ var writeErr error
+ if err := client.ChatStream(ctx, msgs, func(chunk string) {
+ if writeErr != nil {
+ return
}
- output = txt
- if _, err := fmt.Fprint(out, output); err != nil {
- return err
+ output.WriteString(chunk)
+ if _, err := fmt.Fprint(out, chunk); err != nil {
+ writeErr = err
}
+ }, options...); err != nil {
+ return "", err
}
- dur := time.Since(start)
- // Contribute to global stats and update tmux status
- sent := 0
+ if writeErr != nil {
+ return "", writeErr
+ }
+ return output.String(), nil
+}
+
+func runSimpleChat(ctx context.Context, client llm.Client, msgs []llm.Message, options []llm.RequestOption, out io.Writer) (string, error) {
+ output, err := client.Chat(ctx, msgs, options...)
+ if err != nil {
+ return "", err
+ }
+ if _, err := fmt.Fprint(out, output); err != nil {
+ return "", err
+ }
+ return output, nil
+}
+
+func summarizeChatRun(ctx context.Context, client llm.Client, model string, msgs []llm.Message, output string) chatRunSummary {
+ summary := chatRunSummary{snapshot: stats.Snapshot{Window: time.Hour}}
for _, m := range msgs {
- sent += len(m.Content)
+ summary.sent += len(m.Content)
+ }
+ summary.recv = len(output)
+ _ = stats.Update(ctx, client.Name(), model, summary.sent, summary.recv)
+ snap, err := stats.TakeSnapshot()
+ if err == nil {
+ summary.snapshot = snap
}
- recv := len(output)
- _ = stats.Update(ctx, client.Name(), model, sent, recv)
- snap, _ := stats.TakeSnapshot()
- minsWin := snap.Window.Minutes()
+ minsWin := summary.snapshot.Window.Minutes()
if minsWin <= 0 {
minsWin = 0.001
}
- scopeReqs := int64(0)
- if pe, ok := snap.Providers[client.Name()]; ok {
+ if pe, ok := summary.snapshot.Providers[client.Name()]; ok {
if mc, ok2 := pe.Models[model]; ok2 {
- scopeReqs = mc.Reqs
+ summary.scopeReq = mc.Reqs
}
}
- scopeRPM := float64(scopeReqs) / minsWin
- if _, err := fmt.Fprintf(errw, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d | global Σ reqs=%d rpm=%.2f"+logging.AnsiReset+"\n",
- client.Name(), model, dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM); err != nil {
- return err
- }
- _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), model, scopeRPM, scopeReqs, snap.Window))
- return nil
+ summary.scopeRPM = float64(summary.scopeReq) / minsWin
+ return summary
}
// printProviderInfo writes the provider/model line to stderr.