diff options
| author | Paul Buetow <paul@buetow.org> | 2025-10-02 08:41:45 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-10-02 08:41:45 +0300 |
| commit | e36a5446bc62842ae3b3e165f66fecb7285a8c6a (patch) | |
| tree | d3f9f7a66d8b4e5fdb13903722580a8f90eae5d1 /internal | |
| parent | f14eb9199f4e1aee49594e590c08996244bb77b3 (diff) | |
feat: add OpenRouter providerv0.15.0
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/appconfig/config.go | 47 | ||||
| -rw-r--r-- | internal/hexailsp/run.go | 30 | ||||
| -rw-r--r-- | internal/llm/openai.go | 38 | ||||
| -rw-r--r-- | internal/llm/openai_http_test.go | 2 | ||||
| -rw-r--r-- | internal/llm/openai_request_test.go | 6 | ||||
| -rw-r--r-- | internal/llm/openai_temp_test.go | 6 | ||||
| -rw-r--r-- | internal/llm/openrouter.go | 168 | ||||
| -rw-r--r-- | internal/llm/openrouter_test.go | 125 | ||||
| -rw-r--r-- | internal/llm/provider.go | 15 | ||||
| -rw-r--r-- | internal/llm/provider_more2_test.go | 2 | ||||
| -rw-r--r-- | internal/llm/provider_more_test.go | 4 | ||||
| -rw-r--r-- | internal/llm/provider_test.go | 6 | ||||
| -rw-r--r-- | internal/llmutils/client.go | 29 | ||||
| -rw-r--r-- | internal/lsp/server.go | 35 | ||||
| -rw-r--r-- | internal/lsp/server_test.go | 17 | ||||
| -rw-r--r-- | internal/version.go | 2 |
16 files changed, 463 insertions, 69 deletions
diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go index 96ac300..e5a8d5f 100644 --- a/internal/appconfig/config.go +++ b/internal/appconfig/config.go @@ -56,8 +56,12 @@ type App struct { OpenAIModel string `json:"openai_model" toml:"openai_model"` // Default temperature for OpenAI requests (nil means use provider default) OpenAITemperature *float64 `json:"openai_temperature" toml:"openai_temperature"` - OllamaBaseURL string `json:"ollama_base_url" toml:"ollama_base_url"` - OllamaModel string `json:"ollama_model" toml:"ollama_model"` + OpenRouterBaseURL string `json:"openrouter_base_url" toml:"openrouter_base_url"` + OpenRouterModel string `json:"openrouter_model" toml:"openrouter_model"` + // Default temperature for OpenRouter requests (nil means use provider default) + OpenRouterTemperature *float64 `json:"openrouter_temperature" toml:"openrouter_temperature"` + OllamaBaseURL string `json:"ollama_base_url" toml:"ollama_base_url"` + OllamaModel string `json:"ollama_model" toml:"ollama_model"` // Default temperature for Ollama requests (nil means use provider default) OllamaTemperature *float64 `json:"ollama_temperature" toml:"ollama_temperature"` CopilotBaseURL string `json:"copilot_base_url" toml:"copilot_base_url"` @@ -228,6 +232,7 @@ type fileConfig struct { Chat sectionChat `toml:"chat"` Provider sectionProvider `toml:"provider"` OpenAI sectionOpenAI `toml:"openai"` + OpenRouter sectionOpenRouter `toml:"openrouter"` Copilot sectionCopilot `toml:"copilot"` Ollama sectionOllama `toml:"ollama"` Prompts sectionPrompts `toml:"prompts"` @@ -308,6 +313,12 @@ func (s sectionOpenAI) resolvedModel() string { return model } +type sectionOpenRouter struct { + Model string `toml:"model"` + BaseURL string `toml:"base_url"` + Temperature *float64 `toml:"temperature"` +} + type sectionCopilot struct { Model string `toml:"model"` BaseURL string `toml:"base_url"` @@ -445,6 +456,16 @@ func (fc *fileConfig) toApp() App { out.mergeProviderFields(&tmp) } + // openrouter + if (fc.OpenRouter != sectionOpenRouter{}) || fc.OpenRouter.Temperature != nil { + tmp := App{ + OpenRouterBaseURL: fc.OpenRouter.BaseURL, + OpenRouterModel: fc.OpenRouter.Model, + OpenRouterTemperature: fc.OpenRouter.Temperature, + } + out.mergeProviderFields(&tmp) + } + // copilot if (fc.Copilot != sectionCopilot{}) || fc.Copilot.Temperature != nil { tmp := App{ @@ -1025,6 +1046,15 @@ func (a *App) mergeProviderFields(other *App) { if other.OpenAITemperature != nil { // allow explicit 0.0 a.OpenAITemperature = other.OpenAITemperature } + if s := strings.TrimSpace(other.OpenRouterBaseURL); s != "" { + a.OpenRouterBaseURL = s + } + if s := strings.TrimSpace(other.OpenRouterModel); s != "" { + a.OpenRouterModel = s + } + if other.OpenRouterTemperature != nil { // allow explicit 0.0 + a.OpenRouterTemperature = other.OpenRouterTemperature + } if s := strings.TrimSpace(other.OllamaBaseURL); s != "" { a.OllamaBaseURL = s } @@ -1223,6 +1253,19 @@ func loadFromEnv(logger *log.Logger) *App { any = true } + if s := getenv("HEXAI_OPENROUTER_BASE_URL"); s != "" { + out.OpenRouterBaseURL = s + any = true + } + if model, ok := pickModel("openrouter", getenv("HEXAI_OPENROUTER_MODEL")); ok { + out.OpenRouterModel = model + any = true + } + if f, ok := parseFloatPtr("HEXAI_OPENROUTER_TEMPERATURE"); ok { + out.OpenRouterTemperature = f + any = true + } + if s := getenv("HEXAI_OLLAMA_BASE_URL"); s != "" { out.OllamaBaseURL = s any = true diff --git a/internal/hexailsp/run.go b/internal/hexailsp/run.go index 0e383ac..f0ab404 100644 --- a/internal/hexailsp/run.go +++ b/internal/hexailsp/run.go @@ -104,28 +104,36 @@ func buildClientIfNil(cfg appconfig.App, client llm.Client) llm.Client { return client } llmCfg := llm.Config{ - Provider: cfg.Provider, - OpenAIBaseURL: cfg.OpenAIBaseURL, - OpenAIModel: cfg.OpenAIModel, - OpenAITemperature: cfg.OpenAITemperature, - OllamaBaseURL: cfg.OllamaBaseURL, - OllamaModel: cfg.OllamaModel, - OllamaTemperature: cfg.OllamaTemperature, - CopilotBaseURL: cfg.CopilotBaseURL, - CopilotModel: cfg.CopilotModel, - CopilotTemperature: cfg.CopilotTemperature, + Provider: cfg.Provider, + OpenAIBaseURL: cfg.OpenAIBaseURL, + OpenAIModel: cfg.OpenAIModel, + OpenAITemperature: cfg.OpenAITemperature, + OpenRouterBaseURL: cfg.OpenRouterBaseURL, + OpenRouterModel: cfg.OpenRouterModel, + OpenRouterTemperature: cfg.OpenRouterTemperature, + OllamaBaseURL: cfg.OllamaBaseURL, + OllamaModel: cfg.OllamaModel, + OllamaTemperature: cfg.OllamaTemperature, + CopilotBaseURL: cfg.CopilotBaseURL, + CopilotModel: cfg.CopilotModel, + CopilotTemperature: cfg.CopilotTemperature, } // Prefer HEXAI_OPENAI_API_KEY; fall back to OPENAI_API_KEY oaKey := os.Getenv("HEXAI_OPENAI_API_KEY") if strings.TrimSpace(oaKey) == "" { oaKey = os.Getenv("OPENAI_API_KEY") } + // Prefer HEXAI_OPENROUTER_API_KEY; fall back to OPENROUTER_API_KEY + orKey := os.Getenv("HEXAI_OPENROUTER_API_KEY") + if strings.TrimSpace(orKey) == "" { + orKey = os.Getenv("OPENROUTER_API_KEY") + } // Prefer HEXAI_COPILOT_API_KEY; fall back to COPILOT_API_KEY cpKey := os.Getenv("HEXAI_COPILOT_API_KEY") if strings.TrimSpace(cpKey) == "" { cpKey = os.Getenv("COPILOT_API_KEY") } - if c, err := llm.NewFromConfig(llmCfg, oaKey, cpKey); err != nil { + if c, err := llm.NewFromConfig(llmCfg, oaKey, orKey, cpKey); err != nil { logging.Logf("lsp ", "llm disabled: %v", err) return nil } else { diff --git a/internal/llm/openai.go b/internal/llm/openai.go index 8a0d6d7..c284bb3 100644 --- a/internal/llm/openai.go +++ b/internal/llm/openai.go @@ -106,7 +106,7 @@ func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...Requ } start := time.Now() c.logStart(false, o, messages) - req := buildOAChatRequest(o, messages, c.defaultTemperature, false) + req := buildOAChatRequest(o, messages, c.defaultTemperature, false, "llm/openai ") body, err := json.Marshal(req) if err != nil { c.logf("marshal error: %v", err) @@ -122,10 +122,10 @@ func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...Requ return "", err } defer resp.Body.Close() - if err := handleOpenAINon2xx(resp, start); err != nil { + if err := handleOpenAINon2xx(resp, start, "llm/openai ", "openai"); err != nil { return "", err } - out, err := decodeOpenAIChat(resp, start) + out, err := decodeOpenAIChat(resp, start, "llm/openai ") if err != nil { return "", err } @@ -157,7 +157,7 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt } start := time.Now() c.logStart(true, o, messages) - req := buildOAChatRequest(o, messages, c.defaultTemperature, true) + req := buildOAChatRequest(o, messages, c.defaultTemperature, true, "llm/openai ") body, err := json.Marshal(req) if err != nil { c.logf("marshal error: %v", err) @@ -173,11 +173,11 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt return err } defer resp.Body.Close() - if err := handleOpenAINon2xx(resp, start); err != nil { + if err := handleOpenAINon2xx(resp, start, "llm/openai ", "openai"); err != nil { return err } - if err := parseOpenAIStream(resp, start, onDelta); err != nil { + if err := parseOpenAIStream(resp, start, onDelta, "llm/openai ", "openai"); err != nil { return err } logging.Logf("llm/openai ", "stream end duration=%s", time.Since(start)) @@ -196,7 +196,7 @@ func (c openAIClient) logStart(stream bool, o Options, messages []Message) { c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) } -func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, stream bool) oaChatRequest { +func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, stream bool, logPrefix string) oaChatRequest { req := oaChatRequest{Model: o.Model, Stream: stream} req.Messages = make([]oaMessage, len(messages)) for i, m := range messages { @@ -223,7 +223,7 @@ func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, str if req.Temperature == nil || *req.Temperature != 1.0 { t := 1.0 req.Temperature = &t - logging.Logf("llm/openai ", "forcing temperature=1.0 for model=%s (gpt-5 constraint)", o.Model) + logging.Logf(logPrefix, "forcing temperature=1.0 for model=%s (gpt-5 constraint)", o.Model) } } return req @@ -262,30 +262,30 @@ func (c openAIClient) doJSONWithAccept(ctx context.Context, url string, body []b return c.httpClient.Do(req) } -func handleOpenAINon2xx(resp *http.Response, start time.Time) error { +func handleOpenAINon2xx(resp *http.Response, start time.Time, logPrefix, provider string) error { if resp.StatusCode >= 200 && resp.StatusCode < 300 { return nil } var apiErr oaChatResponse _ = json.NewDecoder(resp.Body).Decode(&apiErr) if apiErr.Error != nil && apiErr.Error.Message != "" { - logging.Logf("llm/openai ", "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase) - return fmt.Errorf("openai error: %s (status %d)", apiErr.Error.Message, resp.StatusCode) + logging.Logf(logPrefix, "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase) + return fmt.Errorf("%s error: %s (status %d)", provider, apiErr.Error.Message, resp.StatusCode) } - logging.Logf("llm/openai ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) - return fmt.Errorf("openai http error: status %d", resp.StatusCode) + logging.Logf(logPrefix, "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) + return fmt.Errorf("%s http error: status %d", provider, resp.StatusCode) } -func decodeOpenAIChat(resp *http.Response, start time.Time) (oaChatResponse, error) { +func decodeOpenAIChat(resp *http.Response, start time.Time, logPrefix string) (oaChatResponse, error) { var out oaChatResponse if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { - logging.Logf("llm/openai ", "%sdecode error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + logging.Logf(logPrefix, "%sdecode error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return oaChatResponse{}, err } return out, nil } -func parseOpenAIStream(resp *http.Response, start time.Time, onDelta func(string)) error { +func parseOpenAIStream(resp *http.Response, start time.Time, onDelta func(string), logPrefix, provider string) error { // Parse SSE: lines starting with "data: " containing JSON or [DONE] scanner := bufio.NewScanner(resp.Body) const maxBuf = 1024 * 1024 @@ -305,8 +305,8 @@ func parseOpenAIStream(resp *http.Response, start time.Time, onDelta func(string continue } if chunk.Error != nil && chunk.Error.Message != "" { - logging.Logf("llm/openai ", "%sstream error: %s%s", logging.AnsiRed, chunk.Error.Message, logging.AnsiBase) - return fmt.Errorf("openai stream error: %s", chunk.Error.Message) + logging.Logf(logPrefix, "%sstream error: %s%s", logging.AnsiRed, chunk.Error.Message, logging.AnsiBase) + return fmt.Errorf("%s stream error: %s", provider, chunk.Error.Message) } for _, ch := range chunk.Choices { if ch.Delta.Content != "" { @@ -315,7 +315,7 @@ func parseOpenAIStream(resp *http.Response, start time.Time, onDelta func(string } } if err := scanner.Err(); err != nil { - logging.Logf("llm/openai ", "%sstream read error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + logging.Logf(logPrefix, "%sstream read error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return err } return nil diff --git a/internal/llm/openai_http_test.go b/internal/llm/openai_http_test.go index cb4bfcb..affcae9 100644 --- a/internal/llm/openai_http_test.go +++ b/internal/llm/openai_http_test.go @@ -60,7 +60,7 @@ func TestOpenAI_ChatStream_SSE(t *testing.T) { func TestHandleOpenAINon2xx_NoErrorBody(t *testing.T) { resp := &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("{}"))} - if err := handleOpenAINon2xx(resp, time.Now()); err == nil { + if err := handleOpenAINon2xx(resp, time.Now(), "llm/openai ", "openai"); err == nil { t.Fatalf("expected http error") } } diff --git a/internal/llm/openai_request_test.go b/internal/llm/openai_request_test.go index 001e3b7..d053031 100644 --- a/internal/llm/openai_request_test.go +++ b/internal/llm/openai_request_test.go @@ -9,13 +9,13 @@ func TestBuildOAChatRequest_MaxTokensKeyByModel(t *testing.T) { msgs := []Message{{Role: "user", Content: "hi"}} mt := 123 // Legacy model: use max_tokens - r1 := buildOAChatRequest(Options{Model: "gpt-4.1", MaxTokens: mt}, msgs, nil, false) + r1 := buildOAChatRequest(Options{Model: "gpt-4.1", MaxTokens: mt}, msgs, nil, false, "llm/test ") b1, _ := json.Marshal(r1) if !contains(string(b1), "max_tokens") || contains(string(b1), "max_completion_tokens") { t.Fatalf("expected max_tokens only, got %s", string(b1)) } // gpt-5 family: use max_completion_tokens - r2 := buildOAChatRequest(Options{Model: "gpt-5.0-preview", MaxTokens: mt}, msgs, nil, false) + r2 := buildOAChatRequest(Options{Model: "gpt-5.0-preview", MaxTokens: mt}, msgs, nil, false, "llm/test ") b2, _ := json.Marshal(r2) if !contains(string(b2), "max_completion_tokens") || contains(string(b2), "max_tokens\":") { t.Fatalf("expected max_completion_tokens only, got %s", string(b2)) @@ -25,7 +25,7 @@ func TestBuildOAChatRequest_MaxTokensKeyByModel(t *testing.T) { func TestBuildOAChatRequest_TemperatureForcedForGpt5(t *testing.T) { msgs := []Message{{Role: "user", Content: "hi"}} // Explicit temp 0.2 → should be forced to 1.0 for gpt-5 - r := buildOAChatRequest(Options{Model: "gpt-5.0", Temperature: 0.2, MaxTokens: 50}, msgs, nil, false) + r := buildOAChatRequest(Options{Model: "gpt-5.0", Temperature: 0.2, MaxTokens: 50}, msgs, nil, false, "llm/test ") b, _ := json.Marshal(r) if !contains(string(b), "\"temperature\":1") { t.Fatalf("expected forced temperature 1.0 for gpt-5, got %s", string(b)) diff --git a/internal/llm/openai_temp_test.go b/internal/llm/openai_temp_test.go index 7615117..3d71b94 100644 --- a/internal/llm/openai_temp_test.go +++ b/internal/llm/openai_temp_test.go @@ -5,7 +5,7 @@ import "testing" func TestNewFromConfig_DefaultTemp_ByModel(t *testing.T) { // OpenAI, gpt-5.* → default temp 1.0 when not provided cfg := Config{Provider: "openai", OpenAIModel: "gpt-5.0-preview"} - c, err := NewFromConfig(cfg, "key", "") + c, err := NewFromConfig(cfg, "key", "", "") if err != nil { t.Fatalf("new: %v", err) } @@ -18,7 +18,7 @@ func TestNewFromConfig_DefaultTemp_ByModel(t *testing.T) { } // OpenAI, gpt-4.* → default temp 0.2 when not provided cfg2 := Config{Provider: "openai", OpenAIModel: "gpt-4.1"} - c2, err := NewFromConfig(cfg2, "key", "") + c2, err := NewFromConfig(cfg2, "key", "", "") if err != nil { t.Fatalf("new2: %v", err) } @@ -32,7 +32,7 @@ func TestNewFromConfig_DefaultTemp_UpgradeWhenGpt5AndDefault02(t *testing.T) { // Simulate app-default of 0.2 while selecting a gpt-5 model: should upgrade to 1.0 v := 0.2 cfg := Config{Provider: "openai", OpenAIModel: "gpt-5.0", OpenAITemperature: &v} - c, err := NewFromConfig(cfg, "key", "") + c, err := NewFromConfig(cfg, "key", "", "") if err != nil { t.Fatalf("new: %v", err) } diff --git a/internal/llm/openrouter.go b/internal/llm/openrouter.go new file mode 100644 index 0000000..f03844a --- /dev/null +++ b/internal/llm/openrouter.go @@ -0,0 +1,168 @@ +// Summary: OpenRouter client implementation leveraging OpenAI-compatible helpers with provider-specific headers. +package llm + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "strings" + "time" + + "codeberg.org/snonux/hexai/internal/logging" +) + +type openRouterClient struct { + httpClient *http.Client + apiKey string + baseURL string + defaultModel string + chatLogger logging.ChatLogger + defaultTemperature *float64 +} + +func newOpenRouter(baseURL, model, apiKey string, defaultTemp *float64) Client { + if strings.TrimSpace(baseURL) == "" { + baseURL = "https://openrouter.ai/api/v1" + } + if strings.TrimSpace(model) == "" { + model = "openrouter/auto" + } + return openRouterClient{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + apiKey: apiKey, + baseURL: baseURL, + defaultModel: model, + chatLogger: logging.NewChatLogger("openrouter"), + defaultTemperature: defaultTemp, + } +} + +func (c openRouterClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) { + if strings.TrimSpace(c.apiKey) == "" { + return nilStringErr("missing OpenRouter API key") + } + o := Options{Model: c.defaultModel} + for _, opt := range opts { + opt(&o) + } + if strings.TrimSpace(o.Model) == "" { + o.Model = c.defaultModel + } + start := time.Now() + c.logStart(false, o, messages) + req := buildOAChatRequest(o, messages, c.defaultTemperature, false, "llm/openrouter ") + body, err := json.Marshal(req) + if err != nil { + c.logf("marshal error: %v", err) + return "", err + } + endpoint := strings.TrimRight(c.baseURL, "/") + "/chat/completions" + logging.Logf("llm/openrouter ", "POST %s", endpoint) + resp, err := c.doJSON(ctx, endpoint, body) + if err != nil { + logging.Logf("llm/openrouter ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + return "", err + } + defer resp.Body.Close() + if err := handleOpenAINon2xx(resp, start, "llm/openrouter ", "openrouter"); err != nil { + return "", err + } + out, err := decodeOpenAIChat(resp, start, "llm/openrouter ") + if err != nil { + return "", err + } + if len(out.Choices) == 0 { + logging.Logf("llm/openrouter ", "%sno choices returned duration=%s%s", logging.AnsiRed, time.Since(start), logging.AnsiBase) + return "", errors.New("openrouter: no choices returned") + } + content := out.Choices[0].Message.Content + logging.Logf("llm/openrouter ", "success choice=0 finish=%s size=%d preview=%s%s%s duration=%s", out.Choices[0].FinishReason, len(content), logging.AnsiGreen, logging.PreviewForLog(content), logging.AnsiBase, time.Since(start)) + return content, nil +} + +func (c openRouterClient) Name() string { return "openrouter" } +func (c openRouterClient) DefaultModel() string { return c.defaultModel } + +func (c openRouterClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error { + if strings.TrimSpace(c.apiKey) == "" { + return errors.New("missing OpenRouter API key") + } + o := Options{Model: c.defaultModel} + for _, opt := range opts { + opt(&o) + } + if strings.TrimSpace(o.Model) == "" { + o.Model = c.defaultModel + } + start := time.Now() + c.logStart(true, o, messages) + req := buildOAChatRequest(o, messages, c.defaultTemperature, true, "llm/openrouter ") + body, err := json.Marshal(req) + if err != nil { + c.logf("marshal error: %v", err) + return err + } + endpoint := strings.TrimRight(c.baseURL, "/") + "/chat/completions" + logging.Logf("llm/openrouter ", "POST %s (stream)", endpoint) + resp, err := c.doJSONWithAccept(ctx, endpoint, body, "text/event-stream") + if err != nil { + logging.Logf("llm/openrouter ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + return err + } + defer resp.Body.Close() + if err := handleOpenAINon2xx(resp, start, "llm/openrouter ", "openrouter"); err != nil { + return err + } + if err := parseOpenAIStream(resp, start, onDelta, "llm/openrouter ", "openrouter"); err != nil { + return err + } + logging.Logf("llm/openrouter ", "stream end duration=%s", time.Since(start)) + return nil +} + +func (c openRouterClient) logf(format string, args ...any) { + logging.Logf("llm/openrouter ", format, args...) +} + +func (c openRouterClient) logStart(stream bool, o Options, messages []Message) { + logMessages := make([]struct{ Role, Content string }, len(messages)) + for i, m := range messages { + logMessages[i] = struct{ Role, Content string }{m.Role, m.Content} + } + c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) +} + +func (c openRouterClient) doJSON(ctx context.Context, url string, body []byte) (*http.Response, error) { + headers := map[string]string{ + "Authorization": "Bearer " + c.apiKey, + "HTTP-Referer": "https://github.com/snonux/hexai", + "X-Title": "Hexai", + } + return c.doJSONWithHeaders(ctx, url, body, headers, "") +} + +func (c openRouterClient) doJSONWithAccept(ctx context.Context, url string, body []byte, accept string) (*http.Response, error) { + headers := map[string]string{ + "Authorization": "Bearer " + c.apiKey, + "HTTP-Referer": "https://github.com/snonux/hexai", + "X-Title": "Hexai", + } + return c.doJSONWithHeaders(ctx, url, body, headers, accept) +} + +func (c openRouterClient) doJSONWithHeaders(ctx context.Context, url string, body []byte, headers map[string]string, accept string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(accept) != "" { + req.Header.Set("Accept", accept) + } + for k, v := range headers { + req.Header.Set(k, v) + } + return c.httpClient.Do(req) +} diff --git a/internal/llm/openrouter_test.go b/internal/llm/openrouter_test.go new file mode 100644 index 0000000..2a07be0 --- /dev/null +++ b/internal/llm/openrouter_test.go @@ -0,0 +1,125 @@ +package llm + +import ( + "context" + "encoding/json" + "io" + "log" + "net/http" + "net/http/httptest" + "os" + "testing" + + "codeberg.org/snonux/hexai/internal/logging" +) + +func TestOpenRouter_Chat_SendsHeadersAndBody(t *testing.T) { + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + var capturedHeaders http.Header + var capturedBody []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + capturedBody = append([]byte(nil), body...) + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"index": 0, "message": map[string]string{"role": "assistant", "content": "ack"}}, + }, + }) + })) + defer srv.Close() + + c := newOpenRouter(srv.URL, "anthropic/claude-test", "KEY", f64p(0.2)).(openRouterClient) + c.httpClient = srv.Client() + out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "ping"}}) + if err != nil { + t.Fatalf("chat returned error: %v", err) + } + if out != "ack" { + t.Fatalf("unexpected response: %q", out) + } + if capturedHeaders.Get("Authorization") != "Bearer KEY" { + t.Fatalf("missing auth header: %#v", capturedHeaders) + } + if capturedHeaders.Get("HTTP-Referer") != "https://github.com/snonux/hexai" { + t.Fatalf("missing referer header: %#v", capturedHeaders) + } + if capturedHeaders.Get("X-Title") != "Hexai" { + t.Fatalf("missing title header: %#v", capturedHeaders) + } + + var req oaChatRequest + if err := json.Unmarshal(capturedBody, &req); err != nil { + t.Fatalf("unmarshal request: %v", err) + } + if req.Model != "anthropic/claude-test" { + t.Fatalf("unexpected model: %q", req.Model) + } + if len(req.Messages) != 1 || req.Messages[0].Role != "user" || req.Messages[0].Content != "ping" { + t.Fatalf("unexpected messages: %#v", req.Messages) + } +} + +func TestOpenRouter_ChatStream_SendsHeaders(t *testing.T) { + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + var acceptHeader string + var referer string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + acceptHeader = r.Header.Get("Accept") + referer = r.Header.Get("HTTP-Referer") + w.Header().Set("Content-Type", "text/event-stream") + io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n") + io.WriteString(w, "data: [DONE]\n") + })) + defer srv.Close() + + c := newOpenRouter(srv.URL, "anthropic/claude-test", "KEY", f64p(0.2)).(openRouterClient) + c.httpClient = srv.Client() + var got string + err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "ping"}}, func(s string) { got += s }) + if err != nil { + t.Fatalf("chat stream error: %v", err) + } + if got != "hi" { + t.Fatalf("expected stream output 'hi', got %q", got) + } + if acceptHeader != "text/event-stream" { + t.Fatalf("unexpected Accept header: %q", acceptHeader) + } + if referer != "https://github.com/snonux/hexai" { + t.Fatalf("missing referer header in stream: %q", referer) + } +} + +func TestOpenRouter_Chat_MissingKey(t *testing.T) { + c := newOpenRouter("http://example", "anthropic/claude-test", "", f64p(0.2)).(openRouterClient) + if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "ping"}}); err == nil { + t.Fatalf("expected error for missing api key") + } +} + +func TestOpenRouter_DefaultsAndMetadata(t *testing.T) { + logger := log.New(io.Discard, "", 0) + logging.Bind(logger) + c := newOpenRouter("", "", "KEY", nil).(openRouterClient) + if c.baseURL != "https://openrouter.ai/api/v1" { + t.Fatalf("default baseURL mismatch: %s", c.baseURL) + } + if c.defaultModel != "openrouter/auto" { + t.Fatalf("default model mismatch: %s", c.defaultModel) + } + if name := c.Name(); name != "openrouter" { + t.Fatalf("Name() = %s", name) + } + if model := c.DefaultModel(); model != "openrouter/auto" { + t.Fatalf("DefaultModel() = %s", model) + } + c.logf("smoke") +} diff --git a/internal/llm/provider.go b/internal/llm/provider.go index 84efaf9..b2c47e4 100644 --- a/internal/llm/provider.go +++ b/internal/llm/provider.go @@ -69,6 +69,10 @@ type Config struct { OpenAIBaseURL string OpenAIModel string OpenAITemperature *float64 + // OpenRouter options + OpenRouterBaseURL string + OpenRouterModel string + OpenRouterTemperature *float64 // Ollama options OllamaBaseURL string OllamaModel string @@ -82,7 +86,7 @@ type Config struct { // NewFromConfig creates an LLM client using only the supplied configuration. // The OpenAI API key is supplied separately and may be read from the environment // by the caller; other environment-based configuration is not used. -func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, error) { +func NewFromConfig(cfg Config, openAIAPIKey, openRouterAPIKey, copilotAPIKey string) (Client, error) { p := strings.ToLower(strings.TrimSpace(cfg.Provider)) if p == "" { p = "openai" @@ -112,6 +116,15 @@ func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, erro cfg.OpenAITemperature = &v } return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil + case "openrouter": + if strings.TrimSpace(openRouterAPIKey) == "" { + return nil, errors.New("missing OPENROUTER_API_KEY for provider openrouter") + } + if cfg.OpenRouterTemperature == nil { + t := 0.2 + cfg.OpenRouterTemperature = &t + } + return newOpenRouter(cfg.OpenRouterBaseURL, cfg.OpenRouterModel, openRouterAPIKey, cfg.OpenRouterTemperature), nil case "ollama": if cfg.OllamaTemperature == nil { t := 0.2 diff --git a/internal/llm/provider_more2_test.go b/internal/llm/provider_more2_test.go index 465be82..e001e5c 100644 --- a/internal/llm/provider_more2_test.go +++ b/internal/llm/provider_more2_test.go @@ -5,7 +5,7 @@ import "testing" func TestNewFromConfig_Copilot(t *testing.T) { t.Setenv("COPILOT_API_KEY", "x") cfg := Config{Provider: "copilot", CopilotModel: "small"} - c, err := NewFromConfig(cfg, "", "x") + c, err := NewFromConfig(cfg, "", "", "x") if err != nil || c == nil { t.Fatalf("copilot provider failed: %v %v", c, err) } diff --git a/internal/llm/provider_more_test.go b/internal/llm/provider_more_test.go index d7469af..eff99e6 100644 --- a/internal/llm/provider_more_test.go +++ b/internal/llm/provider_more_test.go @@ -16,13 +16,13 @@ func TestWithOptions_Apply(t *testing.T) { func TestNewFromConfig_Success_OpenAI_And_Copilot(t *testing.T) { // OpenAI success oc := Config{Provider: "openai", OpenAIBaseURL: "http://x", OpenAIModel: "gpt"} - c, err := NewFromConfig(oc, "KEY", "") + c, err := NewFromConfig(oc, "KEY", "", "") if err != nil || c == nil || c.Name() != "openai" || c.DefaultModel() == "" { t.Fatalf("openai new: %v %v", c, err) } // Copilot success cc := Config{Provider: "copilot", CopilotBaseURL: "http://x", CopilotModel: "gpt-4o-mini"} - c2, err := NewFromConfig(cc, "", "KEY") + c2, err := NewFromConfig(cc, "", "", "KEY") if err != nil || c2 == nil || c2.Name() != "copilot" || c2.DefaultModel() == "" { t.Fatalf("copilot new: %v %v", c2, err) } diff --git a/internal/llm/provider_test.go b/internal/llm/provider_test.go index 29e2514..2c0d69c 100644 --- a/internal/llm/provider_test.go +++ b/internal/llm/provider_test.go @@ -7,15 +7,15 @@ import ( func TestNewFromConfig_DefaultsAndErrors(t *testing.T) { // Unknown provider - if _, err := NewFromConfig(Config{Provider: "bogus"}, "", ""); err == nil { + if _, err := NewFromConfig(Config{Provider: "bogus"}, "", "", ""); err == nil { t.Fatalf("expected error for unknown provider") } // OpenAI missing key - if _, err := NewFromConfig(Config{Provider: "openai", OpenAIModel: "g"}, "", ""); err == nil { + if _, err := NewFromConfig(Config{Provider: "openai", OpenAIModel: "g"}, "", "", ""); err == nil { t.Fatalf("expected key error") } // Copilot missing key - if _, err := NewFromConfig(Config{Provider: "copilot", CopilotModel: "m"}, "", ""); err == nil { + if _, err := NewFromConfig(Config{Provider: "copilot", CopilotModel: "m"}, "", "", ""); err == nil { t.Fatalf("expected key error") } } diff --git a/internal/llmutils/client.go b/internal/llmutils/client.go index 9bd39ee..2f3da55 100644 --- a/internal/llmutils/client.go +++ b/internal/llmutils/client.go @@ -11,24 +11,31 @@ import ( // NewClientFromApp builds an llm.Client using app config and environment keys. func NewClientFromApp(cfg appconfig.App) (llm.Client, error) { llmCfg := llm.Config{ - Provider: cfg.Provider, - OpenAIBaseURL: cfg.OpenAIBaseURL, - OpenAIModel: cfg.OpenAIModel, - OpenAITemperature: cfg.OpenAITemperature, - OllamaBaseURL: cfg.OllamaBaseURL, - OllamaModel: cfg.OllamaModel, - OllamaTemperature: cfg.OllamaTemperature, - CopilotBaseURL: cfg.CopilotBaseURL, - CopilotModel: cfg.CopilotModel, - CopilotTemperature: cfg.CopilotTemperature, + Provider: cfg.Provider, + OpenAIBaseURL: cfg.OpenAIBaseURL, + OpenAIModel: cfg.OpenAIModel, + OpenAITemperature: cfg.OpenAITemperature, + OpenRouterBaseURL: cfg.OpenRouterBaseURL, + OpenRouterModel: cfg.OpenRouterModel, + OpenRouterTemperature: cfg.OpenRouterTemperature, + OllamaBaseURL: cfg.OllamaBaseURL, + OllamaModel: cfg.OllamaModel, + OllamaTemperature: cfg.OllamaTemperature, + CopilotBaseURL: cfg.CopilotBaseURL, + CopilotModel: cfg.CopilotModel, + CopilotTemperature: cfg.CopilotTemperature, } oaKey := os.Getenv("HEXAI_OPENAI_API_KEY") if strings.TrimSpace(oaKey) == "" { oaKey = os.Getenv("OPENAI_API_KEY") } + orKey := os.Getenv("HEXAI_OPENROUTER_API_KEY") + if strings.TrimSpace(orKey) == "" { + orKey = os.Getenv("OPENROUTER_API_KEY") + } cpKey := os.Getenv("HEXAI_COPILOT_API_KEY") if strings.TrimSpace(cpKey) == "" { cpKey = os.Getenv("COPILOT_API_KEY") } - return llm.NewFromConfig(llmCfg, oaKey, cpKey) + return llm.NewFromConfig(llmCfg, oaKey, orKey, cpKey) } diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 8e210b4..d55a967 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -217,26 +217,33 @@ func (s *Server) currentLLMClient() llm.Client { func newClientForProvider(cfg appconfig.App, provider string) (llm.Client, error) { llmCfg := llm.Config{ - Provider: provider, - OpenAIBaseURL: cfg.OpenAIBaseURL, - OpenAIModel: cfg.OpenAIModel, - OpenAITemperature: cfg.OpenAITemperature, - OllamaBaseURL: cfg.OllamaBaseURL, - OllamaModel: cfg.OllamaModel, - OllamaTemperature: cfg.OllamaTemperature, - CopilotBaseURL: cfg.CopilotBaseURL, - CopilotModel: cfg.CopilotModel, - CopilotTemperature: cfg.CopilotTemperature, + Provider: provider, + OpenAIBaseURL: cfg.OpenAIBaseURL, + OpenAIModel: cfg.OpenAIModel, + OpenAITemperature: cfg.OpenAITemperature, + OpenRouterBaseURL: cfg.OpenRouterBaseURL, + OpenRouterModel: cfg.OpenRouterModel, + OpenRouterTemperature: cfg.OpenRouterTemperature, + OllamaBaseURL: cfg.OllamaBaseURL, + OllamaModel: cfg.OllamaModel, + OllamaTemperature: cfg.OllamaTemperature, + CopilotBaseURL: cfg.CopilotBaseURL, + CopilotModel: cfg.CopilotModel, + CopilotTemperature: cfg.CopilotTemperature, } oaKey := strings.TrimSpace(os.Getenv("HEXAI_OPENAI_API_KEY")) if oaKey == "" { oaKey = strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) } + orKey := strings.TrimSpace(os.Getenv("HEXAI_OPENROUTER_API_KEY")) + if orKey == "" { + orKey = strings.TrimSpace(os.Getenv("OPENROUTER_API_KEY")) + } cpKey := strings.TrimSpace(os.Getenv("HEXAI_COPILOT_API_KEY")) if cpKey == "" { cpKey = strings.TrimSpace(os.Getenv("COPILOT_API_KEY")) } - return llm.NewFromConfig(llmCfg, oaKey, cpKey) + return llm.NewFromConfig(llmCfg, oaKey, orKey, cpKey) } func (s *Server) clientFor(spec requestSpec) llm.Client { @@ -273,6 +280,12 @@ func (s *Server) clientFor(spec requestSpec) llm.Client { } else if spec.fallbackModel != "" { cfg.OpenAIModel = spec.fallbackModel } + case "openrouter": + if modelOverride != "" { + cfg.OpenRouterModel = modelOverride + } else if spec.fallbackModel != "" { + cfg.OpenRouterModel = spec.fallbackModel + } case "copilot": if modelOverride != "" { cfg.CopilotModel = modelOverride diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index 4f24b57..836e43f 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -85,3 +85,20 @@ func TestServerApplyOptions(t *testing.T) { t.Fatalf("expected config to update, got %d", got) } } + +func TestServerStoreAndTakePendingCompletion(t *testing.T) { + s := newTestServer() + items := []CompletionItem{{Label: "foo"}} + s.storePendingCompletion("key", items) + if len(s.pendingCompletions) != 1 { + t.Fatalf("expected pending map to be populated") + } + items[0].Label = "bar" // ensure copy stored + got := s.takePendingCompletion("key") + if len(got) != 1 || got[0].Label != "foo" { + t.Fatalf("expected preserved copy of completion, got %+v", got) + } + if len(s.pendingCompletions) != 0 { + t.Fatalf("expected pending map to be cleared after take") + } +} diff --git a/internal/version.go b/internal/version.go index f781c7a..a28ebba 100644 --- a/internal/version.go +++ b/internal/version.go @@ -1,4 +1,4 @@ // Summary: Hexai semantic version identifier used by CLI and LSP binaries. package internal -const Version = "0.14.0" +const Version = "0.15.0" |
