diff options
Diffstat (limited to 'internal/llm/copilot.go')
| -rw-r--r-- | internal/llm/copilot.go | 297 |
1 files changed, 170 insertions, 127 deletions
diff --git a/internal/llm/copilot.go b/internal/llm/copilot.go index 16eeda6..d3b1a9d 100644 --- a/internal/llm/copilot.go +++ b/internal/llm/copilot.go @@ -4,6 +4,7 @@ package llm import ( "bytes" "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -13,7 +14,6 @@ import ( "strings" "time" - "encoding/base64" appver "codeberg.org/snonux/hexai/internal" "codeberg.org/snonux/hexai/internal/logging" ) @@ -162,10 +162,14 @@ func buildCopilotChatRequest(o Options, messages []Message, defaultTemp *float64 } func (c copilotClient) postJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) - if err != nil { return nil, err } - for k, v := range headers { req.Header.Set(k, v) } - return c.httpClient.Do(req) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + return c.httpClient.Do(req) } func handleCopilotNon2xx(resp *http.Response, start time.Time) error { @@ -194,55 +198,73 @@ func decodeCopilotChat(resp *http.Response, start time.Time) (copilotChatRespons // --- Copilot session token management --- type ghCopilotTokenResp struct { - Token string `json:"token"` + Token string `json:"token"` } func (c *copilotClient) ensureSession(ctx context.Context) error { - // If token valid for >60s, reuse - if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) { - return nil - } - if strings.TrimSpace(c.apiKey) == "" { - return errors.New("missing Copilot API key") - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil) - if err != nil { return err } - req.Header.Set("Authorization", "Bearer "+c.apiKey) - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "hexai/"+appver.Version) - resp, err := c.httpClient.Do(req) - if err != nil { return err } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("copilot token http error: %d", resp.StatusCode) - } - var out ghCopilotTokenResp - if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { return err } - if strings.TrimSpace(out.Token) == "" { return errors.New("empty copilot session token") } - // Parse JWT exp - exp := parseJWTExp(out.Token) - if exp.IsZero() { exp = time.Now().Add(10 * time.Minute) } - c.sessionToken = out.Token - c.tokenExpiry = exp - return nil + // If token valid for >60s, reuse + if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) { + return nil + } + if strings.TrimSpace(c.apiKey) == "" { + return errors.New("missing Copilot API key") + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "hexai/"+appver.Version) + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("copilot token http error: %d", resp.StatusCode) + } + var out ghCopilotTokenResp + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return err + } + if strings.TrimSpace(out.Token) == "" { + return errors.New("empty copilot session token") + } + // Parse JWT exp + exp := parseJWTExp(out.Token) + if exp.IsZero() { + exp = time.Now().Add(10 * time.Minute) + } + c.sessionToken = out.Token + c.tokenExpiry = exp + return nil } var jwtExpRe = regexp.MustCompile(`"exp"\s*:\s*([0-9]+)`) // fallback if we can't base64 decode func parseJWTExp(token string) time.Time { - parts := strings.Split(token, ".") - if len(parts) < 2 { return time.Time{} } - b, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - if m := jwtExpRe.FindStringSubmatch(token); len(m) == 2 { - if n, err2 := parseInt64(m[1]); err2 == nil { return time.Unix(n, 0) } - } - return time.Time{} - } - var payload struct{ Exp int64 `json:"exp"` } - _ = json.Unmarshal(b, &payload) - if payload.Exp == 0 { return time.Time{} } - return time.Unix(payload.Exp, 0) + parts := strings.Split(token, ".") + if len(parts) < 2 { + return time.Time{} + } + b, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + if m := jwtExpRe.FindStringSubmatch(token); len(m) == 2 { + if n, err2 := parseInt64(m[1]); err2 == nil { + return time.Unix(n, 0) + } + } + return time.Time{} + } + var payload struct { + Exp int64 `json:"exp"` + } + _ = json.Unmarshal(b, &payload) + if payload.Exp == 0 { + return time.Time{} + } + return time.Unix(payload.Exp, 0) } func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, &n); return n, err } @@ -250,99 +272,120 @@ func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, & // --- Copilot headers --- func (c *copilotClient) headersChat() map[string]string { - _ = c.ensureSession(context.Background()) - h := map[string]string{ - "Content-Type": "application/json; charset=utf-8", - "Accept": "application/json", - "Authorization": "Bearer " + c.sessionToken, - "User-Agent": "GitHubCopilotChat/0.8.0", - "Editor-Plugin-Version": "copilot-chat/0.8.0", - "Editor-Version": "vscode/1.85.1", - "Openai-Intent": "conversation-panel", - "Openai-Organization": "github-copilot", - "VScode-MachineId": randHex(64), - "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), - "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), - } - return h + _ = c.ensureSession(context.Background()) + h := map[string]string{ + "Content-Type": "application/json; charset=utf-8", + "Accept": "application/json", + "Authorization": "Bearer " + c.sessionToken, + "User-Agent": "GitHubCopilotChat/0.8.0", + "Editor-Plugin-Version": "copilot-chat/0.8.0", + "Editor-Version": "vscode/1.85.1", + "Openai-Intent": "conversation-panel", + "Openai-Organization": "github-copilot", + "VScode-MachineId": randHex(64), + "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + } + return h } func (c *copilotClient) headersGhost() map[string]string { - _ = c.ensureSession(context.Background()) - h := map[string]string{ - "Content-Type": "application/json; charset=utf-8", - "Accept": "*/*", - "Authorization": "Bearer " + c.sessionToken, - "User-Agent": "GithubCopilot/1.155.0", - "Editor-Plugin-Version": "copilot/1.155.0", - "Editor-Version": "vscode/1.85.1", - "Openai-Intent": "copilot-ghost", - "Openai-Organization": "github-copilot", - "VScode-MachineId": randHex(64), - "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), - "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), - } - return h + _ = c.ensureSession(context.Background()) + h := map[string]string{ + "Content-Type": "application/json; charset=utf-8", + "Accept": "*/*", + "Authorization": "Bearer " + c.sessionToken, + "User-Agent": "GithubCopilot/1.155.0", + "Editor-Plugin-Version": "copilot/1.155.0", + "Editor-Version": "vscode/1.85.1", + "Openai-Intent": "copilot-ghost", + "Openai-Organization": "github-copilot", + "VScode-MachineId": randHex(64), + "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + } + return h } func randHex(n int) string { - const hex = "0123456789abcdef" - b := make([]byte, n) - for i := range b { - b[i] = hex[int(time.Now().UnixNano()+int64(i))%len(hex)] - } - return string(b) + const hex = "0123456789abcdef" + b := make([]byte, n) + for i := range b { + b[i] = hex[int(time.Now().UnixNano()+int64(i))%len(hex)] + } + return string(b) } // --- Codex-style code completion --- // CodeCompletion implements CodeCompleter; returns up to n suggestions. func (c copilotClient) CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error) { - if strings.TrimSpace(c.apiKey) == "" { return nil, errors.New("missing Copilot API key") } - if err := c.ensureSession(ctx); err != nil { return nil, err } - if n <= 0 { n = 1 } - maxTokens := 500 - body := map[string]any{ - "extra": map[string]any{ - "language": language, - "next_indent": 0, - "prompt_tokens": 500, - "suffix_tokens": 400, - "trim_by_indentation": true, - }, - "max_tokens": maxTokens, - "n": n, - "nwo": "hexai", - "prompt": prompt, - "stop": []string{"\n\n"}, - "stream": true, - "suffix": suffix, - "temperature": temperature, - "top_p": 1, - } - buf, _ := json.Marshal(body) - url := "https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions" - resp, err := c.postJSON(ctx, url, buf, c.headersGhost()) - if err != nil { return nil, err } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("copilot codex http error: %d", resp.StatusCode) - } - // Read all and parse lines that start with "data: " accumulating by index - raw, _ := io.ReadAll(resp.Body) - byIndex := make(map[int]string) - lines := strings.Split(string(raw), "\n") - for _, ln := range lines { - if !strings.HasPrefix(ln, "data: ") { continue } - var evt struct{ Choices []struct{ Index int `json:"index"`; Text string `json:"text"` } `json:"choices"` } - if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil { continue } - for _, ch := range evt.Choices { byIndex[ch.Index] += ch.Text } - } - out := make([]string, 0, len(byIndex)) - for i := 0; i < n; i++ { - if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" { out = append(out, s) } - } - return out, nil + if strings.TrimSpace(c.apiKey) == "" { + return nil, errors.New("missing Copilot API key") + } + if err := c.ensureSession(ctx); err != nil { + return nil, err + } + if n <= 0 { + n = 1 + } + maxTokens := 500 + body := map[string]any{ + "extra": map[string]any{ + "language": language, + "next_indent": 0, + "prompt_tokens": 500, + "suffix_tokens": 400, + "trim_by_indentation": true, + }, + "max_tokens": maxTokens, + "n": n, + "nwo": "hexai", + "prompt": prompt, + "stop": []string{"\n\n"}, + "stream": true, + "suffix": suffix, + "temperature": temperature, + "top_p": 1, + } + buf, _ := json.Marshal(body) + url := "https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions" + resp, err := c.postJSON(ctx, url, buf, c.headersGhost()) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("copilot codex http error: %d", resp.StatusCode) + } + // Read all and parse lines that start with "data: " accumulating by index + raw, _ := io.ReadAll(resp.Body) + byIndex := make(map[int]string) + lines := strings.Split(string(raw), "\n") + for _, ln := range lines { + if !strings.HasPrefix(ln, "data: ") { + continue + } + var evt struct { + Choices []struct { + Index int `json:"index"` + Text string `json:"text"` + } `json:"choices"` + } + if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil { + continue + } + for _, ch := range evt.Choices { + byIndex[ch.Index] += ch.Text + } + } + out := make([]string, 0, len(byIndex)) + for i := 0; i < n; i++ { + if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" { + out = append(out, s) + } + } + return out, nil } // newLineDataReader wraps a streaming body and exposes a JSON decoder that |
