diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-02 14:11:47 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-02 14:11:47 +0200 |
| commit | f6f829ba620509dbc501ae282eeaae3ba123e231 (patch) | |
| tree | 9abd95f893c7a05deab5d02233e08d270359d420 | |
| parent | fae0964bed7e77e11df2fa98783c63c806670049 (diff) | |
hexaicli: decompose runCLIJobs and runChat helpers (task 417)
| -rw-r--r-- | internal/hexaicli/run.go | 282 |
1 files changed, 175 insertions, 107 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 85a3fdd..1505f31 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -200,87 +200,116 @@ type cliJobResult struct { err error } +type chatRunSummary struct { + snapshot stats.Snapshot + sent int + recv int + scopeReq int64 + scopeRPM float64 +} + func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error { + results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr) + if printer == nil { + if err := writeCLIJobOutputs(stdout, results); err != nil { + return err + } + } + return writeCLIJobSummaries(stderr, results) +} + +func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer) ([]*cliJobResult, *columnPrinter) { results := make([]*cliJobResult, len(jobs)) + printer := setupCLIPrinter(stdout, jobs) var wg sync.WaitGroup - var printer *columnPrinter - if len(jobs) > 0 { - printer = newColumnPrinter(stdout, jobs) - printer.PrintHeader() - } for _, job := range jobs { job := job wg.Add(1) printProviderInfo(stderr, job.client, job.req.model) go func() { defer wg.Done() - var errBuf bytes.Buffer - var outBuf bytes.Buffer - jobMsgs := make([]llm.Message, len(msgs)) - copy(jobMsgs, msgs) - writer := io.Writer(&outBuf) - if printer != nil { - writer = printer.Writer(job.index) - } - err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf) - if printer != nil { - printer.Flush(job.index) - } - results[job.index] = &cliJobResult{ - provider: job.client.Name(), - model: job.req.model, - output: outBuf.String(), - summary: errBuf.String(), - err: err, - } + results[job.index] = runSingleCLIJob(ctx, job, msgs, input, printer) }() } wg.Wait() - var firstErr error - if printer == nil { - printed := false - for _, res := range results { - if res == nil { - continue - } - if printed { - if _, err := io.WriteString(stdout, "\n"); err != nil { - return err - } - } - heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) - if _, err := io.WriteString(stdout, heading); err != nil { + return results, printer +} + +func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *columnPrinter { + if len(jobs) == 0 { + return nil + } + printer := newColumnPrinter(stdout, jobs) + printer.PrintHeader() + return printer +} + +func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, printer *columnPrinter) *cliJobResult { + var errBuf bytes.Buffer + var outBuf bytes.Buffer + jobMsgs := append([]llm.Message(nil), msgs...) + writer := io.Writer(&outBuf) + if printer != nil { + writer = printer.Writer(job.index) + } + err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf) + if printer != nil { + printer.Flush(job.index) + } + return &cliJobResult{ + provider: job.client.Name(), + model: job.req.model, + output: outBuf.String(), + summary: errBuf.String(), + err: err, + } +} + +func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error { + printed := false + for _, res := range results { + if res == nil { + continue + } + if printed { + if _, err := io.WriteString(stdout, "\n"); err != nil { return err } - if res.output != "" { - if _, err := io.WriteString(stdout, res.output); err != nil { - return err - } - if !strings.HasSuffix(res.output, "\n") { - if _, err := io.WriteString(stdout, "\n"); err != nil { - return err - } - } - } - printed = true } + if err := writeCLIJobOutput(stdout, res); err != nil { + return err + } + printed = true } + return nil +} + +func writeCLIJobOutput(stdout io.Writer, res *cliJobResult) error { + heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) + if _, err := io.WriteString(stdout, heading); err != nil { + return err + } + if res.output == "" { + return nil + } + if _, err := io.WriteString(stdout, res.output); err != nil { + return err + } + if strings.HasSuffix(res.output, "\n") { + return nil + } + _, err := io.WriteString(stdout, "\n") + return err +} + +func writeCLIJobSummaries(stderr io.Writer, results []*cliJobResult) error { + var firstErr error for _, res := range results { if res == nil { continue } - if res.summary != "" { - summary := strings.TrimLeft(res.summary, "\n") - if summary != "" { - if _, err := io.WriteString(stderr, summary); err != nil { - return err - } - } - } - if res.err != nil { - if _, err := fmt.Fprintf(stderr, logging.AnsiBase+"hexai: provider=%s model=%s error: %v"+logging.AnsiReset+"\n", res.provider, res.model, res.err); err != nil { - return err - } + if err := writeCLIJobSummary(stderr, res); err != nil { + return err } if firstErr == nil && res.err != nil { firstErr = res.err @@ -289,6 +318,20 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st return firstErr } +func writeCLIJobSummary(stderr io.Writer, res *cliJobResult) error { + summary := strings.TrimLeft(res.summary, "\n") + if summary != "" { + if _, err := io.WriteString(stderr, summary); err != nil { + return err + } + } + if res.err == nil { + return nil + } + _, err := fmt.Fprintf(stderr, logging.AnsiBase+"hexai: provider=%s model=%s error: %v"+logging.AnsiReset+"\n", res.provider, res.model, res.err) + return err +} + func newColumnPrinter(stdout io.Writer, jobs []cliJob) *columnPrinter { cols := len(jobs) width := detectTerminalWidth(stdout) @@ -581,67 +624,92 @@ func buildMessagesFromConfig(cfg appconfig.App, input string) []llm.Message { // runChat executes the chat request, handling streaming and summary output. func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { start := time.Now() - // Best-effort tmux status update (colored start heartbeat) + model := effectiveModel(req, client) + _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model)) + + output, err := runChatRequest(ctx, client, req, msgs, out) + if err != nil { + return err + } + + dur := time.Since(start) + summary := summarizeChatRun(ctx, client, model, msgs, output) + if _, err := fmt.Fprintf(errw, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d | global Σ reqs=%d rpm=%.2f"+logging.AnsiReset+"\n", + client.Name(), model, dur.Round(time.Millisecond), summary.sent, summary.recv, summary.snapshot.Global.Reqs, summary.snapshot.RPM); err != nil { + return err + } + _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(summary.snapshot.Global.Reqs, summary.snapshot.RPM, summary.snapshot.Global.Sent, summary.snapshot.Global.Recv, client.Name(), model, summary.scopeRPM, summary.scopeReq, summary.snapshot.Window)) + return nil +} + +func effectiveModel(req requestArgs, client llm.Client) string { model := strings.TrimSpace(req.model) if model == "" { model = client.DefaultModel() } - _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model)) - var output string - if s, ok := client.(llm.Streamer); ok { - var b strings.Builder - var streamErr error - if err := s.ChatStream(ctx, msgs, func(chunk string) { - if streamErr != nil { - return - } - b.WriteString(chunk) - if _, err := fmt.Fprint(out, chunk); err != nil { - streamErr = err - } - }, req.options...); err != nil { - return err - } - if streamErr != nil { - return streamErr - } - output = b.String() - } else { - txt, err := client.Chat(ctx, msgs, req.options...) - if err != nil { - return err + return model +} + +func runChatRequest(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, out io.Writer) (string, error) { + if streamer, ok := client.(llm.Streamer); ok { + return runStreamingChat(ctx, streamer, msgs, req.options, out) + } + return runSimpleChat(ctx, client, msgs, req.options, out) +} + +func runStreamingChat(ctx context.Context, client llm.Streamer, msgs []llm.Message, options []llm.RequestOption, out io.Writer) (string, error) { + var output strings.Builder + var writeErr error + if err := client.ChatStream(ctx, msgs, func(chunk string) { + if writeErr != nil { + return } - output = txt - if _, err := fmt.Fprint(out, output); err != nil { - return err + output.WriteString(chunk) + if _, err := fmt.Fprint(out, chunk); err != nil { + writeErr = err } + }, options...); err != nil { + return "", err } - dur := time.Since(start) - // Contribute to global stats and update tmux status - sent := 0 + if writeErr != nil { + return "", writeErr + } + return output.String(), nil +} + +func runSimpleChat(ctx context.Context, client llm.Client, msgs []llm.Message, options []llm.RequestOption, out io.Writer) (string, error) { + output, err := client.Chat(ctx, msgs, options...) + if err != nil { + return "", err + } + if _, err := fmt.Fprint(out, output); err != nil { + return "", err + } + return output, nil +} + +func summarizeChatRun(ctx context.Context, client llm.Client, model string, msgs []llm.Message, output string) chatRunSummary { + summary := chatRunSummary{snapshot: stats.Snapshot{Window: time.Hour}} for _, m := range msgs { - sent += len(m.Content) + summary.sent += len(m.Content) + } + summary.recv = len(output) + _ = stats.Update(ctx, client.Name(), model, summary.sent, summary.recv) + snap, err := stats.TakeSnapshot() + if err == nil { + summary.snapshot = snap } - recv := len(output) - _ = stats.Update(ctx, client.Name(), model, sent, recv) - snap, _ := stats.TakeSnapshot() - minsWin := snap.Window.Minutes() + minsWin := summary.snapshot.Window.Minutes() if minsWin <= 0 { minsWin = 0.001 } - scopeReqs := int64(0) - if pe, ok := snap.Providers[client.Name()]; ok { + if pe, ok := summary.snapshot.Providers[client.Name()]; ok { if mc, ok2 := pe.Models[model]; ok2 { - scopeReqs = mc.Reqs + summary.scopeReq = mc.Reqs } } - scopeRPM := float64(scopeReqs) / minsWin - if _, err := fmt.Fprintf(errw, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d | global Σ reqs=%d rpm=%.2f"+logging.AnsiReset+"\n", - client.Name(), model, dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM); err != nil { - return err - } - _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), model, scopeRPM, scopeReqs, snap.Window)) - return nil + summary.scopeRPM = float64(summary.scopeReq) / minsWin + return summary } // printProviderInfo writes the provider/model line to stderr. |
