summaryrefslogtreecommitdiff
path: root/internal/llm/openrouter_test.go
blob: 07d6e0fa2c254c95e7dafcb2498baaa050ec1b99 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package llm

import (
	"context"
	"encoding/json"
	"io"
	"log"
	"net/http"
	"net/http/httptest"
	"os"
	"strings"
	"testing"

	"codeberg.org/snonux/hexai/internal/logging"
)

func TestOpenRouter_Chat_SendsHeadersAndBody(t *testing.T) {
	if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
		t.Skip("skip network-bound tests in restricted environments")
	}
	var capturedHeaders http.Header
	var capturedBody []byte
	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		capturedHeaders = r.Header.Clone()
		body, err := io.ReadAll(r.Body)
		if err != nil {
			t.Fatalf("read body: %v", err)
		}
		capturedBody = append([]byte(nil), body...)
		_ = json.NewEncoder(w).Encode(map[string]any{
			"choices": []map[string]any{
				{"index": 0, "message": map[string]string{"role": "assistant", "content": "ack"}},
			},
		})
	}))
	defer srv.Close()

	c := newOpenRouter(srv.URL, "anthropic/claude-test", "KEY", f64p(0.2)).(openRouterClient)
	c.httpClient = srv.Client()
	out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "ping"}})
	if err != nil {
		t.Fatalf("chat returned error: %v", err)
	}
	if out != "ack" {
		t.Fatalf("unexpected response: %q", out)
	}
	if capturedHeaders.Get("Authorization") != "Bearer KEY" {
		t.Fatalf("missing auth header: %#v", capturedHeaders)
	}
	if capturedHeaders.Get("HTTP-Referer") != "https://github.com/snonux/hexai" {
		t.Fatalf("missing referer header: %#v", capturedHeaders)
	}
	if capturedHeaders.Get("X-Title") != "Hexai" {
		t.Fatalf("missing title header: %#v", capturedHeaders)
	}

	var req oaChatRequest
	if err := json.Unmarshal(capturedBody, &req); err != nil {
		t.Fatalf("unmarshal request: %v", err)
	}
	if req.Model != "anthropic/claude-test" {
		t.Fatalf("unexpected model: %q", req.Model)
	}
	if len(req.Messages) != 1 || req.Messages[0].Role != "user" || req.Messages[0].Content != "ping" {
		t.Fatalf("unexpected messages: %#v", req.Messages)
	}
}

func TestOpenRouter_ChatStream_SendsHeaders(t *testing.T) {
	if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
		t.Skip("skip network-bound tests in restricted environments")
	}
	var acceptHeader string
	var referer string
	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		acceptHeader = r.Header.Get("Accept")
		referer = r.Header.Get("HTTP-Referer")
		w.Header().Set("Content-Type", "text/event-stream")
		_, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n")
		_, _ = io.WriteString(w, "data: [DONE]\n")
	}))
	defer srv.Close()

	c := newOpenRouter(srv.URL, "anthropic/claude-test", "KEY", f64p(0.2)).(openRouterClient)
	c.httpClient = srv.Client()
	var got string
	err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "ping"}}, func(s string) { got += s })
	if err != nil {
		t.Fatalf("chat stream error: %v", err)
	}
	if got != "hi" {
		t.Fatalf("expected stream output 'hi', got %q", got)
	}
	if acceptHeader != "text/event-stream" {
		t.Fatalf("unexpected Accept header: %q", acceptHeader)
	}
	if referer != "https://github.com/snonux/hexai" {
		t.Fatalf("missing referer header in stream: %q", referer)
	}
}

func TestOpenRouter_Chat_MissingKey(t *testing.T) {
	c := newOpenRouter("http://example", "anthropic/claude-test", "", f64p(0.2)).(openRouterClient)
	if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "ping"}}); err == nil {
		t.Fatalf("expected error for missing api key")
	} else if !strings.Contains(err.Error(), "OPENROUTER_API_KEY") || !strings.Contains(err.Error(), "HEXAI_OPENROUTER_API_KEY") {
		t.Fatalf("expected actionable API key hint, got %q", err.Error())
	}
}

func TestOpenRouter_DefaultsAndMetadata(t *testing.T) {
	logger := log.New(io.Discard, "", 0)
	logging.Bind(logger)
	c := newOpenRouter("", "", "KEY", nil).(openRouterClient)
	if c.baseURL != "https://openrouter.ai/api/v1" {
		t.Fatalf("default baseURL mismatch: %s", c.baseURL)
	}
	if c.defaultModel != "openrouter/auto" {
		t.Fatalf("default model mismatch: %s", c.defaultModel)
	}
	if name := c.Name(); name != "openrouter" {
		t.Fatalf("Name() = %s", name)
	}
	if model := c.DefaultModel(); model != "openrouter/auto" {
		t.Fatalf("DefaultModel() = %s", model)
	}
	c.logf("smoke")
}