mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into t3
This commit is contained in:
+23
-5
@@ -48,6 +48,9 @@ type AgentInstance struct {
|
||||
// LightCandidates holds the resolved provider candidates for the light model.
|
||||
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
|
||||
LightCandidates []providers.FallbackCandidate
|
||||
// LightProvider is the concrete provider instance for the configured light model.
|
||||
// It is only used when routing selects the light tier for a turn.
|
||||
LightProvider providers.LLMProvider
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
@@ -171,14 +174,28 @@ func NewAgentInstance(
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
var lightCandidates []providers.FallbackCandidate
|
||||
var lightProvider providers.LLMProvider
|
||||
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
|
||||
resolved := resolveModelCandidates(cfg, defaults.Provider, rc.LightModel, nil)
|
||||
if len(resolved) > 0 {
|
||||
router = routing.New(routing.RouterConfig{
|
||||
LightModel: rc.LightModel,
|
||||
Threshold: rc.Threshold,
|
||||
})
|
||||
lightCandidates = resolved
|
||||
lightModelCfg, err := resolvedModelConfig(cfg, rc.LightModel, workspace)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Routing light model config invalid; routing disabled",
|
||||
map[string]any{"light_model": rc.LightModel, "agent_id": agentID, "error": err.Error()})
|
||||
} else {
|
||||
lp, _, err := providers.CreateProviderFromConfig(lightModelCfg)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Routing light model provider init failed; routing disabled",
|
||||
map[string]any{"light_model": rc.LightModel, "agent_id": agentID, "error": err.Error()})
|
||||
} else {
|
||||
router = routing.New(routing.RouterConfig{
|
||||
LightModel: rc.LightModel,
|
||||
Threshold: rc.Threshold,
|
||||
})
|
||||
lightCandidates = resolved
|
||||
lightProvider = lp
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.WarnCF("agent", "Routing light model not found; routing disabled",
|
||||
map[string]any{"light_model": rc.LightModel, "agent_id": agentID})
|
||||
@@ -207,6 +224,7 @@ func NewAgentInstance(
|
||||
Candidates: candidates,
|
||||
Router: router,
|
||||
LightCandidates: lightCandidates,
|
||||
LightProvider: lightProvider,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+19
-12
@@ -387,10 +387,17 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
for al.running.Load() {
|
||||
idleTicker := time.NewTicker(100 * time.Millisecond)
|
||||
defer idleTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-idleTicker.C:
|
||||
if !al.running.Load() {
|
||||
return nil
|
||||
}
|
||||
case msg, ok := <-al.bus.InboundChan():
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -517,12 +524,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.publishResponseIfNeeded(ctx, target.Channel, target.ChatID, finalResponse)
|
||||
}
|
||||
}()
|
||||
default:
|
||||
time.Sleep(time.Microsecond * 200)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// drainBusToSteering consumes inbound messages and redirects messages from the
|
||||
@@ -1680,7 +1683,11 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
ts.recordPersistedMessage(rootMsg)
|
||||
}
|
||||
|
||||
activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages)
|
||||
activeCandidates, activeModel, usedLight := al.selectCandidates(ts.agent, ts.userMessage, messages)
|
||||
activeProvider := ts.agent.Provider
|
||||
if usedLight && ts.agent.LightProvider != nil {
|
||||
activeProvider = ts.agent.LightProvider
|
||||
}
|
||||
pendingMessages := append([]providers.Message(nil), ts.opts.InitialSteeringMessages...)
|
||||
var finalContent string
|
||||
|
||||
@@ -1902,7 +1909,7 @@ turnLoop:
|
||||
providerCtx,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return ts.agent.Provider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
|
||||
return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
@@ -1918,7 +1925,7 @@ turnLoop:
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts)
|
||||
return activeProvider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts)
|
||||
}
|
||||
|
||||
var response *providers.LLMResponse
|
||||
@@ -2747,9 +2754,9 @@ func (al *AgentLoop) selectCandidates(
|
||||
agent *AgentInstance,
|
||||
userMsg string,
|
||||
history []providers.Message,
|
||||
) (candidates []providers.FallbackCandidate, model string) {
|
||||
) (candidates []providers.FallbackCandidate, model string, usedLight bool) {
|
||||
if agent.Router == nil || len(agent.LightCandidates) == 0 {
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model), false
|
||||
}
|
||||
|
||||
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
|
||||
@@ -2760,7 +2767,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model), false
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Model routing: light model selected",
|
||||
@@ -2770,7 +2777,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel())
|
||||
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()), true
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
|
||||
@@ -1296,6 +1296,46 @@ func newChatCompletionTestServer(
|
||||
}))
|
||||
}
|
||||
|
||||
func newStrictChatCompletionTestServer(
|
||||
t *testing.T,
|
||||
label string,
|
||||
expectedModel string,
|
||||
response string,
|
||||
calls *int,
|
||||
) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path)
|
||||
}
|
||||
*calls = *calls + 1
|
||||
defer r.Body.Close()
|
||||
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode %s request: %v", label, err)
|
||||
}
|
||||
if req.Model != expectedModel {
|
||||
t.Fatalf("%s server model = %q, want %q", label, req.Model, expectedModel)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": response},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("encode %s response: %v", label, err)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
|
||||
// Use a short timeout to avoid hanging
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
||||
@@ -1697,6 +1737,92 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
heavyCalls := 0
|
||||
heavyServer := newStrictChatCompletionTestServer(
|
||||
t,
|
||||
"heavy",
|
||||
"gemini-2.5-flash",
|
||||
"heavy reply",
|
||||
&heavyCalls,
|
||||
)
|
||||
defer heavyServer.Close()
|
||||
|
||||
lightCalls := 0
|
||||
lightServer := newStrictChatCompletionTestServer(
|
||||
t,
|
||||
"light",
|
||||
"qwen2.5:0.5b",
|
||||
"light reply",
|
||||
&lightCalls,
|
||||
)
|
||||
defer lightServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "gemini-main",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
Routing: &config.RoutingConfig{
|
||||
Enabled: true,
|
||||
LightModel: "qwen-light",
|
||||
Threshold: 0.99,
|
||||
},
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "gemini-main",
|
||||
Model: "gemini/gemini-2.5-flash",
|
||||
APIBase: heavyServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("heavy-key"),
|
||||
},
|
||||
{
|
||||
ModelName: "qwen-light",
|
||||
Model: "ollama/qwen2.5:0.5b",
|
||||
APIBase: lightServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("light-key"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if resp != "light reply" {
|
||||
t.Fatalf("response = %q, want %q", resp, "light reply")
|
||||
}
|
||||
if heavyCalls != 0 {
|
||||
t.Fatalf("heavy calls = %d, want 0", heavyCalls)
|
||||
}
|
||||
if lightCalls != 1 {
|
||||
t.Fatalf("light calls = %d, want 1", lightCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
|
||||
+5
-7
@@ -545,13 +545,11 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
} else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
|
||||
cred.AccountID = accountID
|
||||
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
|
||||
if id := extractAccountID(tokenResp.IDToken); id != "" {
|
||||
cred.AccountID = id
|
||||
} else if id := extractAccountID(tokenResp.AccessToken); id != "" {
|
||||
cred.AccountID = id
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
|
||||
@@ -6,6 +6,7 @@ package dingtalk
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
@@ -135,13 +136,17 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
ctx context.Context,
|
||||
data *chatbot.BotCallbackDataModel,
|
||||
) ([]byte, error) {
|
||||
if data == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Extract message content from Text field
|
||||
content := data.Text.Content
|
||||
content := strings.TrimSpace(data.Text.Content)
|
||||
if content == "" {
|
||||
// Try to extract from Content interface{} if Text is empty
|
||||
if contentMap, ok := data.Content.(map[string]any); ok {
|
||||
if textContent, ok := contentMap["content"].(string); ok {
|
||||
content = textContent
|
||||
content = strings.TrimSpace(textContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -150,12 +155,19 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
return nil, nil // Ignore empty messages
|
||||
}
|
||||
|
||||
senderID := data.SenderStaffId
|
||||
senderNick := data.SenderNick
|
||||
chatID := senderID
|
||||
if data.ConversationType != "1" {
|
||||
// For group chats
|
||||
chatID = data.ConversationId
|
||||
senderID := strings.TrimSpace(data.SenderStaffId)
|
||||
if senderID == "" {
|
||||
senderID = strings.TrimSpace(data.SenderId)
|
||||
}
|
||||
senderNick := strings.TrimSpace(data.SenderNick)
|
||||
|
||||
chatID := strings.TrimSpace(data.ConversationId)
|
||||
if chatID == "" && data.ConversationType == "1" {
|
||||
// Fallback for direct chats when conversation_id is absent.
|
||||
chatID = senderID
|
||||
}
|
||||
if chatID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Store the session webhook for this chat so we can reply later
|
||||
@@ -171,11 +183,19 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
|
||||
var peer bus.Peer
|
||||
if data.ConversationType == "1" {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
peerID := senderID
|
||||
if peerID == "" {
|
||||
peerID = chatID
|
||||
}
|
||||
peer = bus.Peer{Kind: "direct", ID: peerID}
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: data.ConversationId}
|
||||
isMentioned := data.IsInAtList
|
||||
if isMentioned {
|
||||
content = stripLeadingAtMentions(content)
|
||||
}
|
||||
// In group chats, apply unified group trigger filtering
|
||||
respond, cleaned := c.ShouldRespondInGroup(false, content)
|
||||
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
|
||||
if !respond {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -189,10 +209,18 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
})
|
||||
|
||||
// Build sender info
|
||||
platformID := senderID
|
||||
if platformID == "" {
|
||||
platformID = chatID
|
||||
}
|
||||
resolvedSenderID := senderID
|
||||
if resolvedSenderID == "" {
|
||||
resolvedSenderID = platformID
|
||||
}
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "dingtalk",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("dingtalk", senderID),
|
||||
PlatformID: platformID,
|
||||
CanonicalID: identity.BuildCanonicalID("dingtalk", platformID),
|
||||
DisplayName: senderNick,
|
||||
}
|
||||
|
||||
@@ -201,7 +229,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
}
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(ctx, peer, "", senderID, chatID, content, nil, metadata, sender)
|
||||
c.HandleMessage(ctx, peer, "", resolvedSenderID, chatID, content, nil, metadata, sender)
|
||||
|
||||
// Return nil to indicate we've handled the message asynchronously
|
||||
// The response will be sent through the message bus
|
||||
@@ -229,3 +257,19 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stripLeadingAtMentions(content string) string {
|
||||
fields := strings.Fields(content)
|
||||
if len(fields) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(fields) && strings.HasPrefix(fields[i], "@") {
|
||||
i++
|
||||
}
|
||||
if i == 0 {
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
return strings.Join(fields[i:], " ")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
package dingtalk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func newTestDingTalkChannel(t *testing.T, cfg config.DingTalkConfig) (*DingTalkChannel, *bus.MessageBus) {
|
||||
t.Helper()
|
||||
|
||||
if cfg.ClientID == "" {
|
||||
cfg.ClientID = "test-client-id"
|
||||
}
|
||||
if cfg.ClientSecret.String() == "" {
|
||||
cfg.ClientSecret.Set("test-client-secret")
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewDingTalkChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("new channel: %v", err)
|
||||
}
|
||||
return ch, msgBus
|
||||
}
|
||||
|
||||
func mustReceiveInbound(t *testing.T, msgBus *bus.MessageBus) bus.InboundMessage {
|
||||
t.Helper()
|
||||
select {
|
||||
case msg := <-msgBus.InboundChan():
|
||||
return msg
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected inbound message")
|
||||
return bus.InboundMessage{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention(t *testing.T) {
|
||||
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{
|
||||
GroupTrigger: config.GroupTriggerConfig{MentionOnly: true},
|
||||
})
|
||||
|
||||
_, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
|
||||
Text: chatbot.BotCallbackDataTextModel{Content: " @bot /help "},
|
||||
SenderStaffId: "staff-123",
|
||||
SenderNick: "Alice",
|
||||
ConversationType: "2",
|
||||
ConversationId: "group-abc",
|
||||
SessionWebhook: "https://example.com/webhook",
|
||||
IsInAtList: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("handler returned error: %v", err)
|
||||
}
|
||||
|
||||
inbound := mustReceiveInbound(t, msgBus)
|
||||
if inbound.Channel != "dingtalk" {
|
||||
t.Fatalf("channel=%q", inbound.Channel)
|
||||
}
|
||||
if inbound.ChatID != "group-abc" {
|
||||
t.Fatalf("chat_id=%q", inbound.ChatID)
|
||||
}
|
||||
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-abc" {
|
||||
t.Fatalf("peer=%+v", inbound.Peer)
|
||||
}
|
||||
if inbound.Content != "/help" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceived_DirectFallbackSenderIDUsesConversationID(t *testing.T) {
|
||||
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{})
|
||||
|
||||
_, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
|
||||
Text: chatbot.BotCallbackDataTextModel{Content: "ping"},
|
||||
SenderStaffId: "",
|
||||
SenderId: "openid-user-42",
|
||||
SenderNick: "Bob",
|
||||
ConversationType: "1",
|
||||
ConversationId: "conv-direct-42",
|
||||
SessionWebhook: "https://example.com/webhook-direct",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("handler returned error: %v", err)
|
||||
}
|
||||
|
||||
inbound := mustReceiveInbound(t, msgBus)
|
||||
if inbound.ChatID != "conv-direct-42" {
|
||||
t.Fatalf("chat_id=%q", inbound.ChatID)
|
||||
}
|
||||
if inbound.Peer.Kind != "direct" || inbound.Peer.ID != "openid-user-42" {
|
||||
t.Fatalf("peer=%+v", inbound.Peer)
|
||||
}
|
||||
if inbound.SenderID != "dingtalk:openid-user-42" {
|
||||
t.Fatalf("sender_id=%q", inbound.SenderID)
|
||||
}
|
||||
|
||||
if _, ok := ch.sessionWebhooks.Load("conv-direct-42"); !ok {
|
||||
t.Fatal("expected session webhook keyed by conversation_id")
|
||||
}
|
||||
if _, ok := ch.sessionWebhooks.Load(""); ok {
|
||||
t.Fatal("unexpected empty chat_id webhook key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripLeadingAtMentions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantOut string
|
||||
}{
|
||||
{name: "single mention and command", input: "@bot /help", wantOut: "/help"},
|
||||
{name: "multiple mentions", input: "@bot @alice /new", wantOut: "/new"},
|
||||
{name: "no mention", input: "/help", wantOut: "/help"},
|
||||
{name: "mention only", input: "@bot", wantOut: ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := stripLeadingAtMentions(tt.input)
|
||||
if got != tt.wantOut {
|
||||
t.Fatalf("stripLeadingAtMentions(%q)=%q want=%q", tt.input, got, tt.wantOut)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+42
-52
@@ -12,6 +12,14 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
weixinChannelVersion = "2.1.1"
|
||||
weixinIlinkAppID = "bot"
|
||||
// 2.1.1 encoded as 0x00MMNNPP => 0x00020101 => 131329
|
||||
weixinClientVersion = 131329
|
||||
)
|
||||
|
||||
type ApiClient struct {
|
||||
@@ -80,13 +88,9 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if endpoint == "ilink/bot/get_bot_qrcode" || endpoint == "ilink/bot/get_qrcode_status" {
|
||||
// QR routes have different headers sometimes, but let's stick to base ones
|
||||
if endpoint == "ilink/bot/get_qrcode_status" {
|
||||
// Use direct map assignment to send exact header name the Tencent API expects
|
||||
req.Header["iLink-App-ClientVersion"] = []string{"1"}
|
||||
}
|
||||
} else {
|
||||
req.Header["iLink-App-Id"] = []string{weixinIlinkAppID}
|
||||
req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)}
|
||||
if endpoint != "ilink/bot/get_bot_qrcode" && endpoint != "ilink/bot/get_qrcode_status" {
|
||||
req.Header["AuthorizationType"] = []string{"ilink_bot_token"}
|
||||
req.Header["X-WECHAT-UIN"] = []string{randomWechatUIN()}
|
||||
if c.Token != "" {
|
||||
@@ -119,7 +123,7 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpdatesResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
|
||||
var resp GetUpdatesResp
|
||||
err := c.post(ctx, "ilink/bot/getupdates", req, &resp)
|
||||
if err != nil {
|
||||
@@ -129,7 +133,7 @@ func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpda
|
||||
}
|
||||
|
||||
func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendMessageResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
|
||||
var resp SendMessageResp
|
||||
if err := c.post(ctx, "ilink/bot/sendmessage", req, &resp); err != nil {
|
||||
return nil, err
|
||||
@@ -138,7 +142,7 @@ func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendM
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*GetUploadUrlResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
|
||||
var resp GetUploadUrlResp
|
||||
err := c.post(ctx, "ilink/bot/getuploadurl", req, &resp)
|
||||
if err != nil {
|
||||
@@ -148,7 +152,7 @@ func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*Get
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfigResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
|
||||
var resp GetConfigResp
|
||||
if err := c.post(ctx, "ilink/bot/getconfig", req, &resp); err != nil {
|
||||
return nil, err
|
||||
@@ -157,7 +161,7 @@ func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfig
|
||||
}
|
||||
|
||||
func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTypingResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
|
||||
var resp SendTypingResp
|
||||
if err := c.post(ctx, "ilink/bot/sendtyping", req, &resp); err != nil {
|
||||
return nil, err
|
||||
@@ -165,38 +169,51 @@ func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTyp
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) {
|
||||
// get_bot_qrcode is GET, not POST
|
||||
func (c *ApiClient) getQR(ctx context.Context, endpoint string, query map[string]string, respObj any) error {
|
||||
u, err := url.Parse(c.BaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
u.Path = path.Join(u.Path, "ilink/bot/get_bot_qrcode")
|
||||
u.Path = path.Join(u.Path, endpoint)
|
||||
q := u.Query()
|
||||
q.Set("bot_type", botType)
|
||||
for key, value := range query {
|
||||
q.Set(key, value)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
req.Header["iLink-App-Id"] = []string{weixinIlinkAppID}
|
||||
req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)}
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("get_bot_qrcode failed: %d %s", resp.StatusCode, string(respBody))
|
||||
return fmt.Errorf("%s failed: %d %s", endpoint, resp.StatusCode, string(respBody))
|
||||
}
|
||||
if err := json.Unmarshal(respBody, respObj); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) {
|
||||
// get_bot_qrcode is GET, not POST
|
||||
var qrcodeResp QRCodeResponse
|
||||
if err := json.Unmarshal(respBody, &qrcodeResp); err != nil {
|
||||
if err := c.getQR(ctx, "ilink/bot/get_bot_qrcode", map[string]string{
|
||||
"bot_type": botType,
|
||||
}, &qrcodeResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &qrcodeResp, nil
|
||||
@@ -204,37 +221,10 @@ func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeRespo
|
||||
|
||||
func (c *ApiClient) GetQRCodeStatus(ctx context.Context, qrcode string) (*StatusResponse, error) {
|
||||
// get_qrcode_status is GET
|
||||
u, err := url.Parse(c.BaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = path.Join(u.Path, "ilink/bot/get_qrcode_status")
|
||||
q := u.Query()
|
||||
q.Set("qrcode", qrcode)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header["iLink-App-ClientVersion"] = []string{"1"}
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("get_qrcode_status failed: %d %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var statusResp StatusResponse
|
||||
if err := json.Unmarshal(respBody, &statusResp); err != nil {
|
||||
if err := c.getQR(ctx, "ilink/bot/get_qrcode_status", map[string]string{
|
||||
"qrcode": qrcode,
|
||||
}, &statusResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &statusResp, nil
|
||||
|
||||
@@ -40,6 +40,7 @@ func PerformLoginInteractive(
|
||||
if err != nil {
|
||||
return "", "", "", "", fmt.Errorf("failed to create api client: %w", err)
|
||||
}
|
||||
pollAPI := api
|
||||
|
||||
logger.InfoC("weixin", "Requesting Weixin QR code...")
|
||||
qrResp, err := api.GetQRCode(ctx, opts.BotType)
|
||||
@@ -76,7 +77,7 @@ func PerformLoginInteractive(
|
||||
case <-timeoutCtx.Done():
|
||||
return "", "", "", "", fmt.Errorf("login timeout")
|
||||
case <-pollTicker.C:
|
||||
statusResp, err := api.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode)
|
||||
statusResp, err := pollAPI.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode)
|
||||
if err != nil {
|
||||
// Long poll timeout or temporary error
|
||||
continue
|
||||
@@ -99,6 +100,27 @@ func PerformLoginInteractive(
|
||||
})
|
||||
|
||||
return statusResp.BotToken, statusResp.IlinkUserID, statusResp.IlinkBotID, statusResp.Baseurl, nil
|
||||
case "scaned_but_redirect":
|
||||
if statusResp.RedirectHost == "" {
|
||||
logger.WarnC(
|
||||
"weixin",
|
||||
"scaned_but_redirect received without redirect_host; continuing on current host",
|
||||
)
|
||||
continue
|
||||
}
|
||||
nextBaseURL := "https://" + statusResp.RedirectHost + "/"
|
||||
nextAPI, nextErr := NewApiClient(nextBaseURL, "", opts.Proxy)
|
||||
if nextErr != nil {
|
||||
logger.WarnCF("weixin", "Failed to switch QR polling host", map[string]any{
|
||||
"redirect_host": statusResp.RedirectHost,
|
||||
"error": nextErr.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
pollAPI = nextAPI
|
||||
logger.InfoCF("weixin", "Switched QR polling host", map[string]any{
|
||||
"redirect_host": statusResp.RedirectHost,
|
||||
})
|
||||
case "expired":
|
||||
return "", "", "", "", fmt.Errorf("qrcode expired, please try again")
|
||||
default:
|
||||
|
||||
+146
-27
@@ -34,6 +34,8 @@ const (
|
||||
weixinMediaMaxBytes = 100 << 20
|
||||
weixinTypingKeepAlive = 5 * time.Second
|
||||
weixinUploadRetryMax = 3
|
||||
weixinDownloadRetryMax = 2
|
||||
weixinDownloadRetryDelay = 300 * time.Millisecond
|
||||
weixinVoiceTranscodeTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
@@ -163,49 +165,108 @@ func buildCDNDownloadURL(base, encryptedQueryParam string) string {
|
||||
"/download?encrypted_query_param=" + url.QueryEscape(encryptedQueryParam)
|
||||
}
|
||||
|
||||
func shouldRetryCDNDownload(statusCode int) bool {
|
||||
// statusCode=0 represents transport/build errors from the HTTP client.
|
||||
return statusCode == 0 || statusCode >= 500 || statusCode == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
func buildCDNUploadURL(base, uploadParam, filekey string) string {
|
||||
return strings.TrimRight(base, "/") +
|
||||
"/upload?encrypted_query_param=" + url.QueryEscape(uploadParam) +
|
||||
"&filekey=" + url.QueryEscape(filekey)
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) downloadCDNBuffer(ctx context.Context, encryptedQueryParam string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam),
|
||||
nil,
|
||||
)
|
||||
func uniqCDNURLs(urls []string) []string {
|
||||
seen := make(map[string]struct{}, len(urls))
|
||||
out := make([]string, 0, len(urls))
|
||||
for _, raw := range urls {
|
||||
u := strings.TrimSpace(raw)
|
||||
if u == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[u]; ok {
|
||||
continue
|
||||
}
|
||||
seen[u] = struct{}{}
|
||||
out = append(out, u)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) downloadCDNBufferOnce(ctx context.Context, downloadURL string) ([]byte, int, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
resp, err := c.api.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body))
|
||||
return nil, resp.StatusCode, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, weixinMediaMaxBytes+1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
if len(data) > weixinMediaMaxBytes {
|
||||
return nil, fmt.Errorf("cdn media too large: %d bytes", len(data))
|
||||
return nil, resp.StatusCode, fmt.Errorf("cdn media too large: %d bytes", len(data))
|
||||
}
|
||||
return data, nil
|
||||
return data, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) downloadCDNBuffer(
|
||||
ctx context.Context,
|
||||
encryptedQueryParam,
|
||||
fullURL string,
|
||||
) ([]byte, error) {
|
||||
candidates := uniqCDNURLs([]string{
|
||||
strings.TrimSpace(fullURL),
|
||||
func() string {
|
||||
if strings.TrimSpace(encryptedQueryParam) == "" {
|
||||
return ""
|
||||
}
|
||||
return buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam)
|
||||
}(),
|
||||
})
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("missing CDN download URL")
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, downloadURL := range candidates {
|
||||
for attempt := 1; attempt <= weixinDownloadRetryMax; attempt++ {
|
||||
data, statusCode, err := c.downloadCDNBufferOnce(ctx, downloadURL)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
lastErr = fmt.Errorf("%w (attempt=%d url=%s)", err, attempt, downloadURL)
|
||||
if !shouldRetryCDNDownload(statusCode) {
|
||||
break
|
||||
}
|
||||
if attempt < weixinDownloadRetryMax {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(weixinDownloadRetryDelay):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) downloadAndDecryptCDNBuffer(
|
||||
ctx context.Context,
|
||||
encryptedQueryParam string,
|
||||
fullURL string,
|
||||
key []byte,
|
||||
) ([]byte, error) {
|
||||
data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam)
|
||||
data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam, fullURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -215,6 +276,33 @@ func (c *WeixinChannel) downloadAndDecryptCDNBuffer(
|
||||
return decryptAESECB(data, key)
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) downloadImageBuffer(
|
||||
ctx context.Context,
|
||||
img *ImageItem,
|
||||
key []byte,
|
||||
) ([]byte, error) {
|
||||
if img == nil {
|
||||
return nil, fmt.Errorf("image item is nil")
|
||||
}
|
||||
if img.Media != nil {
|
||||
data, err := c.downloadAndDecryptCDNBuffer(ctx, img.Media.EncryptQueryParam, img.Media.FullURL, key)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
if img.ThumbMedia == nil {
|
||||
return nil, fmt.Errorf("image download failed: %w", err)
|
||||
}
|
||||
}
|
||||
if img.ThumbMedia != nil {
|
||||
data, err := c.downloadAndDecryptCDNBuffer(ctx, img.ThumbMedia.EncryptQueryParam, img.ThumbMedia.FullURL, key)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
return nil, fmt.Errorf("image download failed: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("image media is nil")
|
||||
}
|
||||
|
||||
func detectMediaMetadata(data []byte, fallbackName, fallbackContentType string) (string, string) {
|
||||
contentType := strings.TrimSpace(fallbackContentType)
|
||||
ext := filepath.Ext(fallbackName)
|
||||
@@ -310,15 +398,18 @@ func isDownloadableMediaItem(item *MessageItem) bool {
|
||||
|
||||
switch item.Type {
|
||||
case MessageItemTypeImage:
|
||||
return item.ImageItem != nil && item.ImageItem.Media != nil && item.ImageItem.Media.EncryptQueryParam != ""
|
||||
return item.ImageItem != nil && item.ImageItem.Media != nil &&
|
||||
(item.ImageItem.Media.EncryptQueryParam != "" || item.ImageItem.Media.FullURL != "")
|
||||
case MessageItemTypeVideo:
|
||||
return item.VideoItem != nil && item.VideoItem.Media != nil && item.VideoItem.Media.EncryptQueryParam != ""
|
||||
return item.VideoItem != nil && item.VideoItem.Media != nil &&
|
||||
(item.VideoItem.Media.EncryptQueryParam != "" || item.VideoItem.Media.FullURL != "")
|
||||
case MessageItemTypeFile:
|
||||
return item.FileItem != nil && item.FileItem.Media != nil && item.FileItem.Media.EncryptQueryParam != ""
|
||||
return item.FileItem != nil && item.FileItem.Media != nil &&
|
||||
(item.FileItem.Media.EncryptQueryParam != "" || item.FileItem.Media.FullURL != "")
|
||||
case MessageItemTypeVoice:
|
||||
return item.VoiceItem != nil &&
|
||||
item.VoiceItem.Media != nil &&
|
||||
item.VoiceItem.Media.EncryptQueryParam != "" &&
|
||||
(item.VoiceItem.Media.EncryptQueryParam != "" || item.VoiceItem.Media.FullURL != "") &&
|
||||
strings.TrimSpace(item.VoiceItem.Text) == ""
|
||||
default:
|
||||
return false
|
||||
@@ -434,16 +525,20 @@ func (c *WeixinChannel) downloadMediaFromItem(
|
||||
|
||||
switch item.Type {
|
||||
case MessageItemTypeImage:
|
||||
if item.ImageItem == nil {
|
||||
return "", fmt.Errorf("image media is nil")
|
||||
}
|
||||
key, ok, err := imageAESKey(item.ImageItem)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.ImageItem.Media.EncryptQueryParam, func() []byte {
|
||||
decryptKey := func() []byte {
|
||||
if ok {
|
||||
return key
|
||||
}
|
||||
return nil
|
||||
}())
|
||||
}()
|
||||
data, err := c.downloadImageBuffer(ctx, item.ImageItem, decryptKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -454,7 +549,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
silk, err := c.downloadAndDecryptCDNBuffer(ctx, item.VoiceItem.Media.EncryptQueryParam, key)
|
||||
silk, err := c.downloadAndDecryptCDNBuffer(
|
||||
ctx,
|
||||
item.VoiceItem.Media.EncryptQueryParam,
|
||||
item.VoiceItem.Media.FullURL,
|
||||
key,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -468,7 +568,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.FileItem.Media.EncryptQueryParam, key)
|
||||
data, err := c.downloadAndDecryptCDNBuffer(
|
||||
ctx,
|
||||
item.FileItem.Media.EncryptQueryParam,
|
||||
item.FileItem.Media.FullURL,
|
||||
key,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -484,7 +589,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.VideoItem.Media.EncryptQueryParam, key)
|
||||
data, err := c.downloadAndDecryptCDNBuffer(
|
||||
ctx,
|
||||
item.VideoItem.Media.EncryptQueryParam,
|
||||
item.VideoItem.Media.FullURL,
|
||||
key,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -701,11 +811,13 @@ func (c *WeixinChannel) uploadLocalFile(
|
||||
}
|
||||
return nil, fmt.Errorf("getuploadurl failed: ret=%d errcode=%d errmsg=%s", resp.Ret, resp.Errcode, resp.Errmsg)
|
||||
}
|
||||
if strings.TrimSpace(resp.UploadParam) == "" {
|
||||
return nil, fmt.Errorf("getuploadurl returned empty upload_param")
|
||||
uploadParam := strings.TrimSpace(resp.UploadParam)
|
||||
uploadFullURL := strings.TrimSpace(resp.UploadFullURL)
|
||||
if uploadParam == "" && uploadFullURL == "" {
|
||||
return nil, fmt.Errorf("getuploadurl returned no upload URL")
|
||||
}
|
||||
|
||||
downloadParam, err := c.uploadBufferToCDN(ctx, data, resp.UploadParam, filekey, aesKey)
|
||||
downloadParam, err := c.uploadBufferToCDN(ctx, data, uploadParam, uploadFullURL, filekey, aesKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -723,6 +835,7 @@ func (c *WeixinChannel) uploadBufferToCDN(
|
||||
ctx context.Context,
|
||||
plaintext []byte,
|
||||
uploadParam,
|
||||
uploadFullURL,
|
||||
filekey string,
|
||||
aesKey []byte,
|
||||
) (string, error) {
|
||||
@@ -731,7 +844,13 @@ func (c *WeixinChannel) uploadBufferToCDN(
|
||||
return "", err
|
||||
}
|
||||
|
||||
uploadURL := buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey)
|
||||
uploadURL := strings.TrimSpace(uploadFullURL)
|
||||
if uploadURL == "" {
|
||||
if strings.TrimSpace(uploadParam) == "" {
|
||||
return "", fmt.Errorf("missing CDN upload URL")
|
||||
}
|
||||
uploadURL = buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey)
|
||||
}
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= weixinUploadRetryMax; attempt++ {
|
||||
|
||||
@@ -36,18 +36,29 @@ type syncCursorFile struct {
|
||||
GetUpdatesBuf string `json:"get_updates_buf"`
|
||||
}
|
||||
|
||||
type contextTokensFile struct {
|
||||
Tokens map[string]string `json:"tokens"`
|
||||
}
|
||||
|
||||
func picoclawHomeDir() string {
|
||||
return config.GetHome()
|
||||
}
|
||||
|
||||
func buildWeixinSyncBufPath(cfg config.WeixinConfig) string {
|
||||
key := "default"
|
||||
func genWeixinAccountKey(cfg config.WeixinConfig) string {
|
||||
token := strings.TrimSpace(cfg.Token.String())
|
||||
if token != "" {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(cfg.BaseURL) + "|" + token))
|
||||
key = hex.EncodeToString(sum[:8])
|
||||
if token == "" {
|
||||
return "default"
|
||||
}
|
||||
return filepath.Join(picoclawHomeDir(), "channels", "weixin", "sync", key+".json")
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(cfg.BaseURL) + "|" + token))
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
func buildWeixinSyncBufPath(cfg config.WeixinConfig) string {
|
||||
return filepath.Join(picoclawHomeDir(), "channels", "weixin", "sync", genWeixinAccountKey(cfg)+".json")
|
||||
}
|
||||
|
||||
func buildWeixinContextTokensPath(cfg config.WeixinConfig) string {
|
||||
return filepath.Join(picoclawHomeDir(), "channels", "weixin", "context-tokens", genWeixinAccountKey(cfg)+".json")
|
||||
}
|
||||
|
||||
func loadGetUpdatesBuf(path string) (string, error) {
|
||||
@@ -75,6 +86,29 @@ func saveGetUpdatesBuf(path, cursor string) error {
|
||||
return fileutil.WriteFileAtomic(path, data, 0o600)
|
||||
}
|
||||
|
||||
func loadContextTokens(path string) (map[string]string, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var decoded contextTokensFile
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decoded.Tokens, nil
|
||||
}
|
||||
|
||||
func saveContextTokens(path string, tokens map[string]string) error {
|
||||
data, err := json.Marshal(contextTokensFile{Tokens: tokens})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteFileAtomic(path, data, 0o600)
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) cdnBaseURL() string {
|
||||
if base := strings.TrimSpace(c.config.CDNBaseURL); base != "" {
|
||||
return strings.TrimRight(base, "/")
|
||||
|
||||
@@ -38,6 +38,7 @@ type GetUploadUrlResp struct {
|
||||
APIStatus
|
||||
UploadParam string `json:"upload_param,omitempty"`
|
||||
ThumbUploadParam string `json:"thumb_upload_param,omitempty"`
|
||||
UploadFullURL string `json:"upload_full_url,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -69,6 +70,7 @@ type CDNMedia struct {
|
||||
EncryptQueryParam string `json:"encrypt_query_param,omitempty"`
|
||||
AesKey string `json:"aes_key,omitempty"` // base64 encoded
|
||||
EncryptType int `json:"encrypt_type,omitempty"`
|
||||
FullURL string `json:"full_url,omitempty"`
|
||||
}
|
||||
|
||||
type ImageItem struct {
|
||||
@@ -202,9 +204,10 @@ type QRCodeResponse struct {
|
||||
}
|
||||
|
||||
type StatusResponse struct {
|
||||
Status string `json:"status"` // "wait", "scaned", "confirmed", "expired"
|
||||
BotToken string `json:"bot_token,omitempty"`
|
||||
IlinkBotID string `json:"ilink_bot_id,omitempty"`
|
||||
Baseurl string `json:"baseurl,omitempty"`
|
||||
IlinkUserID string `json:"ilink_user_id,omitempty"`
|
||||
Status string `json:"status"` // "wait", "scaned", "confirmed", "expired", "scaned_but_redirect"
|
||||
BotToken string `json:"bot_token,omitempty"`
|
||||
IlinkBotID string `json:"ilink_bot_id,omitempty"`
|
||||
Baseurl string `json:"baseurl,omitempty"`
|
||||
IlinkUserID string `json:"ilink_user_id,omitempty"`
|
||||
RedirectHost string `json:"redirect_host,omitempty"`
|
||||
}
|
||||
|
||||
@@ -26,12 +26,13 @@ type WeixinChannel struct {
|
||||
bus *bus.MessageBus
|
||||
// contextTokens stores the last context_token per user (from_user_id → context_token).
|
||||
// This is required by the iLink API to associate replies with the right chat session.
|
||||
contextTokens sync.Map
|
||||
typingMu sync.Mutex
|
||||
typingCache map[string]typingTicketCacheEntry
|
||||
pauseMu sync.Mutex
|
||||
pauseUntil time.Time
|
||||
syncBufPath string
|
||||
contextTokens sync.Map
|
||||
typingMu sync.Mutex
|
||||
typingCache map[string]typingTicketCacheEntry
|
||||
pauseMu sync.Mutex
|
||||
pauseUntil time.Time
|
||||
syncBufPath string
|
||||
contextTokensPath string
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -57,12 +58,13 @@ func NewWeixinChannel(cfg config.WeixinConfig, messageBus *bus.MessageBus) (*Wei
|
||||
)
|
||||
|
||||
return &WeixinChannel{
|
||||
BaseChannel: base,
|
||||
api: api,
|
||||
config: cfg,
|
||||
bus: messageBus,
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
syncBufPath: buildWeixinSyncBufPath(cfg),
|
||||
BaseChannel: base,
|
||||
api: api,
|
||||
config: cfg,
|
||||
bus: messageBus,
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
syncBufPath: buildWeixinSyncBufPath(cfg),
|
||||
contextTokensPath: buildWeixinContextTokensPath(cfg),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -70,11 +72,53 @@ func (c *WeixinChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("weixin", "Starting Weixin channel")
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
c.SetRunning(true)
|
||||
c.restoreContextTokens()
|
||||
go c.pollLoop(c.ctx)
|
||||
logger.InfoC("weixin", "Weixin channel started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreContextTokens loads persisted context tokens from disk into memory.
|
||||
func (c *WeixinChannel) restoreContextTokens() {
|
||||
tokens, err := loadContextTokens(c.contextTokensPath)
|
||||
if err != nil {
|
||||
logger.WarnCF("weixin", "Failed to load persisted context tokens", map[string]any{
|
||||
"path": c.contextTokensPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(tokens) == 0 {
|
||||
return
|
||||
}
|
||||
for userID, token := range tokens {
|
||||
c.contextTokens.Store(userID, token)
|
||||
}
|
||||
logger.InfoCF("weixin", "Restored context tokens from disk", map[string]any{
|
||||
"path": c.contextTokensPath,
|
||||
"count": len(tokens),
|
||||
})
|
||||
}
|
||||
|
||||
// persistContextTokens saves all in-memory context tokens to disk.
|
||||
func (c *WeixinChannel) persistContextTokens() {
|
||||
tokens := make(map[string]string)
|
||||
c.contextTokens.Range(func(k, v any) bool {
|
||||
if userID, ok := k.(string); ok {
|
||||
if token, ok := v.(string); ok {
|
||||
tokens[userID] = token
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err := saveContextTokens(c.contextTokensPath, tokens); err != nil {
|
||||
logger.WarnCF("weixin", "Failed to persist context tokens", map[string]any{
|
||||
"path": c.contextTokensPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("weixin", "Stopping Weixin channel")
|
||||
c.SetRunning(false)
|
||||
@@ -307,6 +351,7 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess
|
||||
// Store context_token for outbound reply association
|
||||
if msg.ContextToken != "" {
|
||||
c.contextTokens.Store(fromUserID, msg.ContextToken)
|
||||
c.persistContextTokens()
|
||||
}
|
||||
|
||||
c.HandleMessage(ctx, peer, messageID, fromUserID, fromUserID, content, mediaRefs, metadata, sender)
|
||||
|
||||
@@ -72,7 +72,7 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) {
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", key)
|
||||
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "", key)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
|
||||
}
|
||||
@@ -81,6 +81,116 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndDecryptCDNBufferUsesFullURLWhenProvided(t *testing.T) {
|
||||
key := []byte("1234567890abcdef")
|
||||
plaintext := []byte("hello weixin")
|
||||
ciphertext, err := encryptAESECB(plaintext, key)
|
||||
if err != nil {
|
||||
t.Fatalf("encryptAESECB() error = %v", err)
|
||||
}
|
||||
|
||||
fullURLAttempts := 0
|
||||
ch := &WeixinChannel{
|
||||
api: &ApiClient{
|
||||
HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.String() == "https://full.example.com/download" {
|
||||
fullURLAttempts++
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(ciphertext)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
t.Fatalf("unexpected fallback request: %s", r.URL.String())
|
||||
return nil, nil
|
||||
})},
|
||||
},
|
||||
config: config.WeixinConfig{
|
||||
CDNBaseURL: "https://cdn.example.com",
|
||||
},
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "https://full.example.com/download", key)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, plaintext) {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext)
|
||||
}
|
||||
if fullURLAttempts == 0 {
|
||||
t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndDecryptCDNBufferFallsBackToConstructedURLWhenFullURLFails(t *testing.T) {
|
||||
key := []byte("1234567890abcdef")
|
||||
plaintext := []byte("hello weixin")
|
||||
ciphertext, err := encryptAESECB(plaintext, key)
|
||||
if err != nil {
|
||||
t.Fatalf("encryptAESECB() error = %v", err)
|
||||
}
|
||||
|
||||
fullURLAttempts := 0
|
||||
constructedAttempts := 0
|
||||
ch := &WeixinChannel{
|
||||
api: &ApiClient{
|
||||
HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.String() == "https://full.example.com/download?encrypted_query_param=token&taskid=123" {
|
||||
fullURLAttempts++
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(bytes.NewReader(nil)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
if r.URL.String() != "https://cdn.example.com/download?encrypted_query_param=token" {
|
||||
t.Fatalf("unexpected fallback request: %s", r.URL.String())
|
||||
}
|
||||
constructedAttempts++
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(ciphertext)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})},
|
||||
},
|
||||
config: config.WeixinConfig{
|
||||
CDNBaseURL: "https://cdn.example.com",
|
||||
},
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
got, err := ch.downloadAndDecryptCDNBuffer(
|
||||
context.Background(),
|
||||
"token",
|
||||
"https://full.example.com/download?encrypted_query_param=token&taskid=123",
|
||||
key,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, plaintext) {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext)
|
||||
}
|
||||
if fullURLAttempts == 0 {
|
||||
t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts)
|
||||
}
|
||||
if constructedAttempts == 0 {
|
||||
t.Fatalf("constructedAttempts = %d, want > 0", constructedAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCDNDownloadURLEscapesOpaqueToken(t *testing.T) {
|
||||
token := "MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%3D"
|
||||
|
||||
got := buildCDNDownloadURL("https://cdn.example.com", token)
|
||||
|
||||
if got != "https://cdn.example.com/download?encrypted_query_param=MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%253D" {
|
||||
t.Fatalf("buildCDNDownloadURL() = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadBufferToCDN(t *testing.T) {
|
||||
key := []byte("1234567890abcdef")
|
||||
plaintext := []byte("upload me")
|
||||
@@ -120,7 +230,7 @@ func TestUploadBufferToCDN(t *testing.T) {
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "file-key", key)
|
||||
got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "", "file-key", key)
|
||||
if err != nil {
|
||||
t.Fatalf("uploadBufferToCDN() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
orc "github.com/sipeed/picoclaw/pkg/providers/openai_responses_common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -21,14 +25,12 @@ type (
|
||||
)
|
||||
|
||||
const (
|
||||
// azureAPIVersion is the Azure OpenAI API version used for all requests.
|
||||
azureAPIVersion = "2024-10-21"
|
||||
defaultRequestTimeout = common.DefaultRequestTimeout
|
||||
)
|
||||
|
||||
// Provider implements the LLM provider interface for Azure OpenAI endpoints.
|
||||
// It handles Azure-specific authentication (api-key header), URL construction
|
||||
// (deployment-based), and request body formatting (max_completion_tokens, no model field).
|
||||
// It handles Azure-specific authentication (Bearer token), URL construction
|
||||
// (Responses API), and request/response formatting.
|
||||
type Provider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
@@ -72,8 +74,8 @@ func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds
|
||||
)
|
||||
}
|
||||
|
||||
// Chat sends a chat completion request to the Azure OpenAI endpoint.
|
||||
// The model parameter is used as the Azure deployment name in the URL.
|
||||
// Chat sends a request to the Azure OpenAI Responses API endpoint.
|
||||
// The model parameter is passed in the request body.
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
@@ -85,34 +87,43 @@ func (p *Provider) Chat(
|
||||
return nil, fmt.Errorf("Azure API base not configured")
|
||||
}
|
||||
|
||||
// model is the deployment name for Azure OpenAI
|
||||
deployment := model
|
||||
|
||||
// Build Azure-specific URL safely using url.JoinPath and query encoding
|
||||
// to prevent path traversal or query injection via deployment names.
|
||||
base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions")
|
||||
requestURL, err := url.JoinPath(p.apiBase, "openai/v1/responses")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build Azure request URL: %w", err)
|
||||
}
|
||||
requestURL := base + "?api-version=" + azureAPIVersion
|
||||
|
||||
// Build request body — no "model" field (Azure infers from deployment URL)
|
||||
requestBody := map[string]any{
|
||||
"messages": common.SerializeMessages(messages),
|
||||
input, instructions := orc.TranslateMessages(messages)
|
||||
|
||||
requestBody := responses.ResponseNewParams{
|
||||
Model: model,
|
||||
Input: responses.ResponseNewParamsInputUnion{
|
||||
OfInputItemList: input,
|
||||
},
|
||||
Store: openai.Opt(false),
|
||||
}
|
||||
|
||||
if instructions != "" {
|
||||
requestBody.Instructions = openai.Opt(instructions)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
requestBody["tool_choice"] = "auto"
|
||||
enableWebSearch, _ := options["native_search"].(bool)
|
||||
requestBody.Tools = orc.TranslateTools(tools, enableWebSearch)
|
||||
requestBody.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{
|
||||
OfToolChoiceMode: openai.Opt(responses.ToolChoiceOptionsAuto),
|
||||
}
|
||||
}
|
||||
|
||||
// Azure OpenAI always uses max_completion_tokens
|
||||
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
|
||||
requestBody["max_completion_tokens"] = maxTokens
|
||||
requestBody.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||
}
|
||||
|
||||
if temperature, ok := common.AsFloat(options["temperature"]); ok {
|
||||
requestBody["temperature"] = temperature
|
||||
requestBody.Temperature = openai.Opt(temperature)
|
||||
}
|
||||
|
||||
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
|
||||
requestBody.PromptCacheKey = openai.Opt(cacheKey)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
@@ -125,10 +136,9 @@ func (p *Provider) Chat(
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Azure uses api-key header instead of Authorization: Bearer
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Api-Key", p.apiKey)
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
@@ -141,7 +151,7 @@ func (p *Provider) Chat(
|
||||
return nil, common.HandleErrorResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
return common.ReadAndParseResponse(resp, p.apiBase)
|
||||
return orc.ParseResponseBody(resp.Body)
|
||||
}
|
||||
|
||||
// GetDefaultModel returns an empty string as Azure deployments are user-configured.
|
||||
|
||||
@@ -6,17 +6,31 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
|
||||
// writeValidResponse writes a minimal valid Responses API response.
|
||||
func writeValidResponse(w http.ResponseWriter) {
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
"type": "message",
|
||||
"content": []map[string]any{
|
||||
{"type": "output_text", "text": "ok"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 2,
|
||||
"total_tokens": 7,
|
||||
"input_tokens_details": map[string]any{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]any{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
@@ -24,11 +38,9 @@ func writeValidResponse(w http.ResponseWriter) {
|
||||
|
||||
func TestProviderChat_AzureURLConstruction(t *testing.T) {
|
||||
var capturedPath string
|
||||
var capturedAPIVersion string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.Path
|
||||
capturedAPIVersion = r.URL.Query().Get("api-version")
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
@@ -39,22 +51,19 @@ func TestProviderChat_AzureURLConstruction(t *testing.T) {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
|
||||
wantPath := "/openai/v1/responses"
|
||||
if capturedPath != wantPath {
|
||||
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
|
||||
}
|
||||
if capturedAPIVersion != azureAPIVersion {
|
||||
t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureAuthHeader(t *testing.T) {
|
||||
var capturedAPIKey string
|
||||
var capturedAuth string
|
||||
var capturedAPIKey string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAPIKey = r.Header.Get("Api-Key")
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
capturedAPIKey = r.Header.Get("Api-Key")
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
@@ -65,15 +74,15 @@ func TestProviderChat_AzureAuthHeader(t *testing.T) {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if capturedAPIKey != "test-azure-key" {
|
||||
t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
|
||||
if capturedAuth != "Bearer test-azure-key" {
|
||||
t.Errorf("Authorization header = %q, want %q", capturedAuth, "Bearer test-azure-key")
|
||||
}
|
||||
if capturedAuth != "" {
|
||||
t.Errorf("Authorization header should be empty, got %q", capturedAuth)
|
||||
if capturedAPIKey != "" {
|
||||
t.Errorf("Api-Key header should be empty, got %q", capturedAPIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
|
||||
func TestProviderChat_AzureRequestBodyContainsModel(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -83,17 +92,17 @@ func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if _, exists := requestBody["model"]; exists {
|
||||
t.Error("request body should not contain 'model' field for Azure OpenAI")
|
||||
if requestBody["model"] != "my-deployment" {
|
||||
t.Errorf("model = %v, want %q", requestBody["model"], "my-deployment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
|
||||
func TestProviderChat_AzureUsesMaxOutputTokens(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -114,12 +123,35 @@ func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if _, exists := requestBody["max_completion_tokens"]; !exists {
|
||||
t.Error("request body should contain 'max_completion_tokens'")
|
||||
if requestBody["max_output_tokens"] == nil {
|
||||
t.Error("request body should contain 'max_output_tokens'")
|
||||
}
|
||||
if _, exists := requestBody["max_tokens"]; exists {
|
||||
t.Error("request body should not contain 'max_tokens'")
|
||||
}
|
||||
if _, exists := requestBody["max_completion_tokens"]; exists {
|
||||
t.Error("request body should not contain 'max_completion_tokens'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureStoreIsFalse(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&requestBody)
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if requestBody["store"] != false {
|
||||
t.Errorf("store = %v, want false", requestBody["store"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureHTTPError(t *testing.T) {
|
||||
@@ -135,27 +167,66 @@ func TestProviderChat_AzureHTTPError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureParseTextOutput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]any{
|
||||
"id": "resp_1",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]any{
|
||||
{
|
||||
"type": "message",
|
||||
"content": []map[string]any{
|
||||
{"type": "output_text", "text": "Hello there!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 10, "output_tokens": 5, "total_tokens": 15,
|
||||
"input_tokens_details": map[string]any{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]any{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if out.Content != "Hello there!" {
|
||||
t.Errorf("Content = %q, want %q", out.Content, "Hello there!")
|
||||
}
|
||||
if out.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
|
||||
}
|
||||
if out.Usage.TotalTokens != 15 {
|
||||
t.Errorf("TotalTokens = %d, want 15", out.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureParseToolCalls(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
"id": "resp_2",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{
|
||||
"content": "",
|
||||
"tool_calls": []map[string]any{
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": `{"city":"Seattle"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
"type": "function_call",
|
||||
"call_id": "call_1",
|
||||
"name": "get_weather",
|
||||
"arguments": `{"city":"Seattle"}`,
|
||||
},
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 10, "output_tokens": 8, "total_tokens": 18,
|
||||
"input_tokens_details": map[string]any{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]any{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
@@ -167,13 +238,15 @@ func TestProviderChat_AzureParseToolCalls(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].Name != "get_weather" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
if out.FinishReason != "tool_calls" {
|
||||
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "tool_calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
|
||||
@@ -205,28 +278,103 @@ func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
|
||||
var capturedPath string
|
||||
func TestProviderChat_AzureNativeWebSearchInjection(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
|
||||
if capturedPath == "" {
|
||||
capturedPath = r.URL.Path
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&requestBody)
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "web_search",
|
||||
Description: "local web search",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "read_file",
|
||||
Description: "read a file",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
|
||||
// Deployment name with characters that could cause path injection
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
|
||||
// With native_search=true: user-defined web_search should be replaced by built-in
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment",
|
||||
map[string]any{"native_search": true})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// The slash and special chars in the deployment name must be escaped, not treated as path separators
|
||||
if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
|
||||
t.Fatal("deployment name was interpolated without escaping — path injection possible")
|
||||
toolsAny, ok := requestBody["tools"].([]any)
|
||||
if !ok {
|
||||
t.Fatal("request body should contain 'tools' array")
|
||||
}
|
||||
if len(toolsAny) != 2 {
|
||||
t.Fatalf("len(tools) = %d, want 2 (read_file + web_search builtin)", len(toolsAny))
|
||||
}
|
||||
|
||||
// First tool should be read_file (user-defined web_search was skipped)
|
||||
firstTool, _ := toolsAny[0].(map[string]any)
|
||||
if firstTool["name"] != "read_file" {
|
||||
t.Errorf("first tool name = %v, want %q", firstTool["name"], "read_file")
|
||||
}
|
||||
|
||||
// Second tool should be built-in web_search
|
||||
secondTool, _ := toolsAny[1].(map[string]any)
|
||||
if secondTool["type"] != "web_search" {
|
||||
t.Errorf("second tool type = %v, want %q", secondTool["type"], "web_search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureNoNativeWebSearch(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&requestBody)
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "web_search",
|
||||
Description: "local web search",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
|
||||
// Without native_search: user-defined web_search should be kept as-is
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
toolsAny, ok := requestBody["tools"].([]any)
|
||||
if !ok {
|
||||
t.Fatal("request body should contain 'tools' array")
|
||||
}
|
||||
if len(toolsAny) != 1 {
|
||||
t.Fatalf("len(tools) = %d, want 1", len(toolsAny))
|
||||
}
|
||||
|
||||
// Should be the user-defined function tool, not built-in
|
||||
tool, _ := toolsAny[0].(map[string]any)
|
||||
if tool["type"] != "function" {
|
||||
t.Errorf("tool type = %v, want %q", tool["type"], "function")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -13,6 +12,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
orc "github.com/sipeed/picoclaw/pkg/providers/openai_responses_common"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -96,7 +96,7 @@ func (p *CodexProvider) Chat(
|
||||
}
|
||||
|
||||
// Respect tools.web.prefer_native: only inject native search when the agent
|
||||
// loop requested it (options["native_search"]), so prefer_native: false
|
||||
// loop passes options["native_search"]=true, so prefer_native=false means no injection.
|
||||
useNativeSearch := p.enableWebSearch && (options["native_search"] == true)
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options, useNativeSearch)
|
||||
|
||||
@@ -153,7 +153,7 @@ func (p *CodexProvider) Chat(
|
||||
return nil, fmt.Errorf("codex API call: stream ended without completed response")
|
||||
}
|
||||
|
||||
return parseCodexResponse(resp), nil
|
||||
return orc.ParseResponseFromStruct(resp), nil
|
||||
}
|
||||
|
||||
func (p *CodexProvider) GetDefaultModel() string {
|
||||
@@ -209,89 +209,14 @@ func resolveCodexModel(model string) (string, string) {
|
||||
func buildCodexParams(
|
||||
messages []Message, tools []ToolDefinition, model string, options map[string]any, enableWebSearch bool,
|
||||
) responses.ResponseNewParams {
|
||||
var inputItems responses.ResponseInputParam
|
||||
var instructions string
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
// Use the full concatenated system prompt (static + dynamic + summary)
|
||||
// as instructions. This keeps behavior consistent with Anthropic and
|
||||
// OpenAI-compat adapters where the complete system context lives in
|
||||
// one place. Prefix caching is handled by prompt_cache_key below,
|
||||
// not by splitting content across instructions vs input messages.
|
||||
instructions = msg.Content
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||
CallID: msg.ToolCallID,
|
||||
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{
|
||||
OfString: openai.Opt(msg.Content),
|
||||
},
|
||||
},
|
||||
})
|
||||
} else {
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfMessage: &responses.EasyInputMessageParam{
|
||||
Role: responses.EasyInputMessageRoleUser,
|
||||
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||
},
|
||||
})
|
||||
}
|
||||
case "assistant":
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
if msg.Content != "" {
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfMessage: &responses.EasyInputMessageParam{
|
||||
Role: responses.EasyInputMessageRoleAssistant,
|
||||
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||
},
|
||||
})
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
name, args, ok := resolveCodexToolCall(tc)
|
||||
if !ok {
|
||||
logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]any{
|
||||
"call_id": tc.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
|
||||
CallID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfMessage: &responses.EasyInputMessageParam{
|
||||
Role: responses.EasyInputMessageRoleAssistant,
|
||||
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||
},
|
||||
})
|
||||
}
|
||||
case "tool":
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||
CallID: msg.ToolCallID,
|
||||
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{
|
||||
OfString: openai.Opt(msg.Content),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
inputItems, instructions := orc.TranslateMessages(messages)
|
||||
|
||||
params := responses.ResponseNewParams{
|
||||
Model: model,
|
||||
Input: responses.ResponseNewParamsInputUnion{
|
||||
OfInputItemList: inputItems,
|
||||
},
|
||||
Instructions: openai.Opt(instructions),
|
||||
Store: openai.Opt(false),
|
||||
Store: openai.Opt(false),
|
||||
}
|
||||
|
||||
if instructions != "" {
|
||||
@@ -309,115 +234,12 @@ func buildCodexParams(
|
||||
}
|
||||
|
||||
if len(tools) > 0 || enableWebSearch {
|
||||
params.Tools = translateToolsForCodex(tools, enableWebSearch)
|
||||
params.Tools = orc.TranslateTools(tools, enableWebSearch)
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) {
|
||||
name = tc.Name
|
||||
if name == "" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
}
|
||||
if name == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
if len(tc.Arguments) > 0 {
|
||||
argsJSON, err := json.Marshal(tc.Arguments)
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
return name, string(argsJSON), true
|
||||
}
|
||||
|
||||
if tc.Function != nil && tc.Function.Arguments != "" {
|
||||
return name, tc.Function.Arguments, true
|
||||
}
|
||||
|
||||
return name, "{}", true
|
||||
}
|
||||
|
||||
func translateToolsForCodex(tools []ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam {
|
||||
capHint := len(tools)
|
||||
if enableWebSearch {
|
||||
capHint++
|
||||
}
|
||||
result := make([]responses.ToolUnionParam, 0, capHint)
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" {
|
||||
continue
|
||||
}
|
||||
if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") {
|
||||
continue
|
||||
}
|
||||
ft := responses.FunctionToolParam{
|
||||
Name: t.Function.Name,
|
||||
Parameters: t.Function.Parameters,
|
||||
Strict: openai.Opt(false),
|
||||
}
|
||||
if t.Function.Description != "" {
|
||||
ft.Description = openai.Opt(t.Function.Description)
|
||||
}
|
||||
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
|
||||
}
|
||||
if enableWebSearch {
|
||||
result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseCodexResponse(resp *responses.Response) *LLMResponse {
|
||||
var content strings.Builder
|
||||
var toolCalls []ToolCall
|
||||
|
||||
for _, item := range resp.Output {
|
||||
switch item.Type {
|
||||
case "message":
|
||||
for _, c := range item.Content {
|
||||
if c.Type == "output_text" {
|
||||
content.WriteString(c.Text)
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
|
||||
args = map[string]any{"raw": item.Arguments}
|
||||
}
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: item.CallID,
|
||||
Name: item.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
if resp.Status == "incomplete" {
|
||||
finishReason = "length"
|
||||
}
|
||||
|
||||
var usage *UsageInfo
|
||||
if resp.Usage.TotalTokens > 0 {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: int(resp.Usage.InputTokens),
|
||||
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||
TotalTokens: int(resp.Usage.TotalTokens),
|
||||
}
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content.String(),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
func createCodexTokenSource() func() (string, string, error) {
|
||||
return func() (string, string, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/openai/openai-go/v3"
|
||||
openaiopt "github.com/openai/openai-go/v3/option"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
|
||||
orc "github.com/sipeed/picoclaw/pkg/providers/openai_responses_common"
|
||||
)
|
||||
|
||||
func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
@@ -225,7 +227,7 @@ func TestParseCodexResponse_TextOutput(t *testing.T) {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
result := parseCodexResponse(&resp)
|
||||
result := orc.ParseResponseFromStruct(&resp)
|
||||
if result.Content != "Hello there!" {
|
||||
t.Errorf("Content = %q, want %q", result.Content, "Hello there!")
|
||||
}
|
||||
@@ -266,7 +268,7 @@ func TestParseCodexResponse_FunctionCall(t *testing.T) {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
result := parseCodexResponse(&resp)
|
||||
result := orc.ParseResponseFromStruct(&resp)
|
||||
if len(result.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
// Package openai_responses_common provides shared utilities for providers
|
||||
// that use the OpenAI Responses API (e.g., Azure, Codex).
|
||||
package openai_responses_common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// TranslateMessages converts internal Message entries to the OpenAI Responses API
|
||||
// input format. System messages are extracted as instructions (returned separately),
|
||||
// user/assistant/tool messages become ResponseInputItemUnionParam entries.
|
||||
// Supports multipart media (images, audio).
|
||||
func TranslateMessages(messages []protocoltypes.Message) (input responses.ResponseInputParam, instructions string) {
|
||||
input = make(responses.ResponseInputParam, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
instructions = msg.Content
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
input = append(input, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||
CallID: msg.ToolCallID,
|
||||
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{
|
||||
OfString: openai.Opt(msg.Content),
|
||||
},
|
||||
},
|
||||
})
|
||||
} else if len(msg.Media) > 0 {
|
||||
content := BuildMultipartContent(msg.Content, msg.Media)
|
||||
input = append(input, responses.ResponseInputItemUnionParam{
|
||||
OfInputMessage: &responses.ResponseInputItemMessageParam{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
input = append(input, responses.ResponseInputItemUnionParam{
|
||||
OfMessage: &responses.EasyInputMessageParam{
|
||||
Role: responses.EasyInputMessageRoleUser,
|
||||
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||
},
|
||||
})
|
||||
}
|
||||
case "assistant":
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
if msg.Content != "" {
|
||||
input = append(input, responses.ResponseInputItemUnionParam{
|
||||
OfMessage: &responses.EasyInputMessageParam{
|
||||
Role: responses.EasyInputMessageRoleAssistant,
|
||||
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||
},
|
||||
})
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
name, args, ok := ResolveToolCall(tc)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
input = append(input, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
|
||||
CallID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
input = append(input, responses.ResponseInputItemUnionParam{
|
||||
OfMessage: &responses.EasyInputMessageParam{
|
||||
Role: responses.EasyInputMessageRoleAssistant,
|
||||
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||
},
|
||||
})
|
||||
}
|
||||
case "tool":
|
||||
input = append(input, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||
CallID: msg.ToolCallID,
|
||||
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{
|
||||
OfString: openai.Opt(msg.Content),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return input, instructions
|
||||
}
|
||||
|
||||
// BuildMultipartContent constructs a ResponseInputMessageContentListParam from
|
||||
// text content and media URLs (data:image/... and data:audio/... URIs).
|
||||
func BuildMultipartContent(text string, media []string) responses.ResponseInputMessageContentListParam {
|
||||
parts := make(responses.ResponseInputMessageContentListParam, 0, 1+len(media))
|
||||
|
||||
if text != "" {
|
||||
parts = append(parts, responses.ResponseInputContentUnionParam{
|
||||
OfInputText: &responses.ResponseInputTextParam{
|
||||
Text: text,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
for _, mediaURL := range media {
|
||||
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||
parts = append(parts, responses.ResponseInputContentUnionParam{
|
||||
OfInputImage: &responses.ResponseInputImageParam{
|
||||
ImageURL: openai.Opt(mediaURL),
|
||||
Detail: responses.ResponseInputImageDetailAuto,
|
||||
},
|
||||
})
|
||||
} else if strings.HasPrefix(mediaURL, "data:audio/") {
|
||||
if format, data, ok := ParseDataAudioURL(mediaURL); ok {
|
||||
parts = append(parts, responses.ResponseInputContentUnionParam{
|
||||
OfInputFile: &responses.ResponseInputFileParam{
|
||||
FileData: openai.Opt(data),
|
||||
Filename: openai.Opt("audio." + format),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// ParseDataAudioURL extracts the format and base64 data from a data:audio/... URL.
|
||||
func ParseDataAudioURL(mediaURL string) (format, data string, ok bool) {
|
||||
if !strings.HasPrefix(mediaURL, "data:audio/") {
|
||||
return "", "", false
|
||||
}
|
||||
payload := strings.TrimPrefix(mediaURL, "data:audio/")
|
||||
meta, data, found := strings.Cut(payload, ",")
|
||||
if !found {
|
||||
return "", "", false
|
||||
}
|
||||
format, _, _ = strings.Cut(meta, ";")
|
||||
format = strings.TrimSpace(format)
|
||||
data = strings.TrimSpace(data)
|
||||
if format == "" || data == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return format, data, true
|
||||
}
|
||||
|
||||
// ResolveToolCall extracts the function name and JSON arguments string from a ToolCall.
|
||||
// Returns ok=false if the tool call has no name or if arguments fail to marshal.
|
||||
func ResolveToolCall(tc protocoltypes.ToolCall) (name string, arguments string, ok bool) {
|
||||
name = tc.Name
|
||||
if name == "" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
}
|
||||
if name == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
if len(tc.Arguments) > 0 {
|
||||
argsJSON, err := json.Marshal(tc.Arguments)
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
return name, string(argsJSON), true
|
||||
}
|
||||
|
||||
if tc.Function != nil && tc.Function.Arguments != "" {
|
||||
return name, tc.Function.Arguments, true
|
||||
}
|
||||
|
||||
return name, "{}", true
|
||||
}
|
||||
|
||||
// TranslateTools converts internal ToolDefinition entries to the OpenAI Responses API
|
||||
// tool format. If enableWebSearch is true, a web_search tool is appended and any
|
||||
// user-defined tool named "web_search" is skipped to avoid duplicates.
|
||||
func TranslateTools(tools []protocoltypes.ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam {
|
||||
capHint := len(tools)
|
||||
if enableWebSearch {
|
||||
capHint++
|
||||
}
|
||||
result := make([]responses.ToolUnionParam, 0, capHint)
|
||||
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" {
|
||||
continue
|
||||
}
|
||||
if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") {
|
||||
continue
|
||||
}
|
||||
ft := responses.FunctionToolParam{
|
||||
Name: t.Function.Name,
|
||||
Parameters: t.Function.Parameters,
|
||||
Strict: openai.Opt(false),
|
||||
}
|
||||
if t.Function.Description != "" {
|
||||
ft.Description = openai.Opt(t.Function.Description)
|
||||
}
|
||||
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
|
||||
}
|
||||
|
||||
if enableWebSearch {
|
||||
result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ParseResponseBody parses an OpenAI Responses API JSON body into an LLMResponse.
|
||||
// Handles output item types: "message" (output_text + refusal), "function_call", and "reasoning".
|
||||
func ParseResponseBody(body io.Reader) (*protocoltypes.LLMResponse, error) {
|
||||
var apiResp responses.Response
|
||||
if err := json.NewDecoder(body).Decode(&apiResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return parseResponse(&apiResp), nil
|
||||
}
|
||||
|
||||
// ParseResponseFromStruct converts a decoded responses.Response into an LLMResponse.
|
||||
// Used by providers that receive the Response struct directly (e.g., via streaming SDK).
|
||||
func ParseResponseFromStruct(resp *responses.Response) *protocoltypes.LLMResponse {
|
||||
return parseResponse(resp)
|
||||
}
|
||||
|
||||
// parseResponse is the shared implementation for extracting LLMResponse fields
|
||||
// from a decoded responses.Response.
|
||||
func parseResponse(apiResp *responses.Response) *protocoltypes.LLMResponse {
|
||||
var content strings.Builder
|
||||
var reasoningContent strings.Builder
|
||||
var toolCalls []protocoltypes.ToolCall
|
||||
|
||||
for _, item := range apiResp.Output {
|
||||
switch item.Type {
|
||||
case "message":
|
||||
for _, c := range item.Content {
|
||||
switch c.Type {
|
||||
case "output_text":
|
||||
content.WriteString(c.Text)
|
||||
case "refusal":
|
||||
content.WriteString(c.Refusal)
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
|
||||
args = map[string]any{"raw": item.Arguments}
|
||||
}
|
||||
toolCalls = append(toolCalls, protocoltypes.ToolCall{
|
||||
ID: item.CallID,
|
||||
Name: item.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
case "reasoning":
|
||||
for _, s := range item.Summary {
|
||||
reasoningContent.WriteString(s.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
if apiResp.Status == "incomplete" {
|
||||
finishReason = "length"
|
||||
}
|
||||
|
||||
var usage *protocoltypes.UsageInfo
|
||||
if apiResp.Usage.TotalTokens > 0 {
|
||||
usage = &protocoltypes.UsageInfo{
|
||||
PromptTokens: int(apiResp.Usage.InputTokens),
|
||||
CompletionTokens: int(apiResp.Usage.OutputTokens),
|
||||
TotalTokens: int(apiResp.Usage.TotalTokens),
|
||||
}
|
||||
}
|
||||
|
||||
return &protocoltypes.LLMResponse{
|
||||
Content: content.String(),
|
||||
ReasoningContent: reasoningContent.String(),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,593 @@
|
||||
package openai_responses_common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// --- TranslateMessages tests ---
|
||||
|
||||
func TestTranslateMessages_SystemExtractedAsInstructions(t *testing.T) {
|
||||
msgs := []protocoltypes.Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
input, instructions := TranslateMessages(msgs)
|
||||
if instructions != "You are helpful" {
|
||||
t.Errorf("instructions = %q, want %q", instructions, "You are helpful")
|
||||
}
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("len(input) = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].OfMessage == nil {
|
||||
t.Fatal("expected user message item")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateMessages_UserTextMessage(t *testing.T) {
|
||||
msgs := []protocoltypes.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
input, instructions := TranslateMessages(msgs)
|
||||
if instructions != "" {
|
||||
t.Errorf("instructions = %q, want empty", instructions)
|
||||
}
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("len(input) = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].OfMessage == nil {
|
||||
t.Fatal("expected EasyInputMessage")
|
||||
}
|
||||
if string(input[0].OfMessage.Role) != "user" {
|
||||
t.Errorf("role = %q, want %q", input[0].OfMessage.Role, "user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateMessages_UserWithToolCallID(t *testing.T) {
|
||||
msgs := []protocoltypes.Message{
|
||||
{Role: "user", Content: `{"temp":72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
input, _ := TranslateMessages(msgs)
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("len(input) = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].OfFunctionCallOutput == nil {
|
||||
t.Fatal("expected FunctionCallOutput for user with ToolCallID")
|
||||
}
|
||||
if input[0].OfFunctionCallOutput.CallID != "call_1" {
|
||||
t.Errorf("CallID = %q, want %q", input[0].OfFunctionCallOutput.CallID, "call_1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateMessages_UserWithMedia(t *testing.T) {
|
||||
msgs := []protocoltypes.Message{
|
||||
{Role: "user", Content: "Describe this", Media: []string{"data:image/png;base64,abc123"}},
|
||||
}
|
||||
input, _ := TranslateMessages(msgs)
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("len(input) = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].OfInputMessage == nil {
|
||||
t.Fatal("expected InputMessage for multipart content")
|
||||
}
|
||||
if input[0].OfInputMessage.Role != "user" {
|
||||
t.Errorf("role = %q, want %q", input[0].OfInputMessage.Role, "user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateMessages_AssistantWithToolCalls(t *testing.T) {
|
||||
msgs := []protocoltypes.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check",
|
||||
ToolCalls: []protocoltypes.ToolCall{
|
||||
{ID: "call_1", Name: "get_weather", Arguments: map[string]any{"city": "SF"}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temp":72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
input, _ := TranslateMessages(msgs)
|
||||
// user + assistant text + function_call + tool output = 4 items
|
||||
if len(input) != 4 {
|
||||
t.Fatalf("len(input) = %d, want 4", len(input))
|
||||
}
|
||||
// item[1] = assistant text
|
||||
if input[1].OfMessage == nil {
|
||||
t.Fatal("expected assistant text message")
|
||||
}
|
||||
// item[2] = function call
|
||||
if input[2].OfFunctionCall == nil {
|
||||
t.Fatal("expected function call")
|
||||
}
|
||||
if input[2].OfFunctionCall.Name != "get_weather" {
|
||||
t.Errorf("function name = %q, want %q", input[2].OfFunctionCall.Name, "get_weather")
|
||||
}
|
||||
// item[3] = tool output
|
||||
if input[3].OfFunctionCallOutput == nil {
|
||||
t.Fatal("expected function call output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateMessages_AssistantWithoutToolCalls(t *testing.T) {
|
||||
msgs := []protocoltypes.Message{
|
||||
{Role: "assistant", Content: "Sure thing"},
|
||||
}
|
||||
input, _ := TranslateMessages(msgs)
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("len(input) = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].OfMessage == nil {
|
||||
t.Fatal("expected EasyInputMessage for assistant without tool calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateMessages_ToolMessage(t *testing.T) {
|
||||
msgs := []protocoltypes.Message{
|
||||
{Role: "tool", Content: "result data", ToolCallID: "call_99"},
|
||||
}
|
||||
input, _ := TranslateMessages(msgs)
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("len(input) = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].OfFunctionCallOutput == nil {
|
||||
t.Fatal("expected FunctionCallOutput")
|
||||
}
|
||||
if input[0].OfFunctionCallOutput.CallID != "call_99" {
|
||||
t.Errorf("CallID = %q, want %q", input[0].OfFunctionCallOutput.CallID, "call_99")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ResolveToolCall tests ---
|
||||
|
||||
func TestResolveToolCall_FromNameAndArguments(t *testing.T) {
|
||||
tc := protocoltypes.ToolCall{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "SF"},
|
||||
}
|
||||
name, args, ok := ResolveToolCall(tc)
|
||||
if !ok {
|
||||
t.Fatal("expected ok=true")
|
||||
}
|
||||
if name != "get_weather" {
|
||||
t.Errorf("name = %q, want %q", name, "get_weather")
|
||||
}
|
||||
if !strings.Contains(args, "SF") {
|
||||
t.Errorf("args = %q, want to contain SF", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToolCall_FromFunctionField(t *testing.T) {
|
||||
tc := protocoltypes.ToolCall{
|
||||
ID: "call_1",
|
||||
Function: &protocoltypes.FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md"}`,
|
||||
},
|
||||
}
|
||||
name, args, ok := ResolveToolCall(tc)
|
||||
if !ok {
|
||||
t.Fatal("expected ok=true")
|
||||
}
|
||||
if name != "read_file" {
|
||||
t.Errorf("name = %q, want %q", name, "read_file")
|
||||
}
|
||||
if args != `{"path":"README.md"}` {
|
||||
t.Errorf("args = %q, want %q", args, `{"path":"README.md"}`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToolCall_EmptyName(t *testing.T) {
|
||||
tc := protocoltypes.ToolCall{}
|
||||
_, _, ok := ResolveToolCall(tc)
|
||||
if ok {
|
||||
t.Error("expected ok=false for empty tool call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToolCall_NoArgsFallsBackToEmptyObject(t *testing.T) {
|
||||
tc := protocoltypes.ToolCall{Name: "do_something"}
|
||||
name, args, ok := ResolveToolCall(tc)
|
||||
if !ok {
|
||||
t.Fatal("expected ok=true")
|
||||
}
|
||||
if name != "do_something" {
|
||||
t.Errorf("name = %q, want %q", name, "do_something")
|
||||
}
|
||||
if args != "{}" {
|
||||
t.Errorf("args = %q, want %q", args, "{}")
|
||||
}
|
||||
}
|
||||
|
||||
// --- TranslateTools tests ---
|
||||
|
||||
func TestTranslateTools_FunctionTools(t *testing.T) {
|
||||
tools := []protocoltypes.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := TranslateTools(tools, false)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
if result[0].OfFunction == nil {
|
||||
t.Fatal("expected function tool")
|
||||
}
|
||||
if result[0].OfFunction.Name != "get_weather" {
|
||||
t.Errorf("name = %q, want %q", result[0].OfFunction.Name, "get_weather")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateTools_SkipsNonFunction(t *testing.T) {
|
||||
tools := []protocoltypes.ToolDefinition{
|
||||
{Type: "not_function"},
|
||||
}
|
||||
result := TranslateTools(tools, false)
|
||||
if len(result) != 0 {
|
||||
t.Errorf("len(result) = %d, want 0", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateTools_WebSearchAppended(t *testing.T) {
|
||||
result := TranslateTools(nil, true)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
if result[0].OfWebSearch == nil {
|
||||
t.Fatal("expected web_search tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateTools_WebSearchReplacesUserDefined(t *testing.T) {
|
||||
tools := []protocoltypes.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "web_search",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "read_file",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := TranslateTools(tools, true)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("len(result) = %d, want 2", len(result))
|
||||
}
|
||||
if result[0].OfFunction == nil || result[0].OfFunction.Name != "read_file" {
|
||||
t.Errorf("first tool should be read_file, got %v", result[0])
|
||||
}
|
||||
if result[1].OfWebSearch == nil {
|
||||
t.Error("second tool should be web_search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateTools_DescriptionOmittedWhenEmpty(t *testing.T) {
|
||||
tools := []protocoltypes.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "no_desc",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := TranslateTools(tools, false)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
if result[0].OfFunction.Description.Valid() {
|
||||
t.Error("Description should not be set when empty")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseResponseBody tests ---
|
||||
|
||||
func TestParseResponseBody_TextOutput(t *testing.T) {
|
||||
body := strings.NewReader(`{
|
||||
"id": "resp_123",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"content": [{"type": "output_text", "text": "Hello!"}]
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens_details": {"reasoning_tokens": 0}
|
||||
}
|
||||
}`)
|
||||
|
||||
result, err := ParseResponseBody(body)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponseBody error: %v", err)
|
||||
}
|
||||
if result.Content != "Hello!" {
|
||||
t.Errorf("Content = %q, want %q", result.Content, "Hello!")
|
||||
}
|
||||
if result.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||
}
|
||||
if result.Usage.TotalTokens != 15 {
|
||||
t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponseBody_FunctionCall(t *testing.T) {
|
||||
body := strings.NewReader(`{
|
||||
"id": "resp_456",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": [
|
||||
{
|
||||
"type": "function_call",
|
||||
"call_id": "call_abc",
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"SF\"}"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 8,
|
||||
"total_tokens": 18,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens_details": {"reasoning_tokens": 0}
|
||||
}
|
||||
}`)
|
||||
|
||||
result, err := ParseResponseBody(body)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponseBody error: %v", err)
|
||||
}
|
||||
if len(result.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
|
||||
}
|
||||
if result.ToolCalls[0].Name != "get_weather" {
|
||||
t.Errorf("Name = %q, want %q", result.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
if result.ToolCalls[0].ID != "call_abc" {
|
||||
t.Errorf("ID = %q, want %q", result.ToolCalls[0].ID, "call_abc")
|
||||
}
|
||||
if result.FinishReason != "tool_calls" {
|
||||
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponseBody_Reasoning(t *testing.T) {
|
||||
body := strings.NewReader(`{
|
||||
"id": "resp_789",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": [
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "rs_1",
|
||||
"summary": [{"type": "summary_text", "text": "Thinking about it..."}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"content": [{"type": "output_text", "text": "The answer is 42."}]
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens_details": {"reasoning_tokens": 10}
|
||||
}
|
||||
}`)
|
||||
|
||||
result, err := ParseResponseBody(body)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponseBody error: %v", err)
|
||||
}
|
||||
if result.Content != "The answer is 42." {
|
||||
t.Errorf("Content = %q, want %q", result.Content, "The answer is 42.")
|
||||
}
|
||||
if result.ReasoningContent != "Thinking about it..." {
|
||||
t.Errorf("ReasoningContent = %q, want %q", result.ReasoningContent, "Thinking about it...")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponseBody_Refusal(t *testing.T) {
|
||||
body := strings.NewReader(`{
|
||||
"id": "resp_ref",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"content": [{"type": "refusal", "refusal": "I cannot help with that."}]
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 10,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens_details": {"reasoning_tokens": 0}
|
||||
}
|
||||
}`)
|
||||
|
||||
result, err := ParseResponseBody(body)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponseBody error: %v", err)
|
||||
}
|
||||
if result.Content != "I cannot help with that." {
|
||||
t.Errorf("Content = %q, want %q", result.Content, "I cannot help with that.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponseBody_IncompleteStatus(t *testing.T) {
|
||||
body := strings.NewReader(`{
|
||||
"id": "resp_inc",
|
||||
"object": "response",
|
||||
"status": "incomplete",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"content": [{"type": "output_text", "text": "partial"}]
|
||||
}
|
||||
],
|
||||
"usage": {"input_tokens": 5, "output_tokens": 2, "total_tokens": 7,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens_details": {"reasoning_tokens": 0}}
|
||||
}`)
|
||||
|
||||
result, err := ParseResponseBody(body)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if result.FinishReason != "length" {
|
||||
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "length")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponseBody_FailedStatus(t *testing.T) {
|
||||
body := strings.NewReader(`{
|
||||
"id": "resp_fail",
|
||||
"object": "response",
|
||||
"status": "failed",
|
||||
"output": [],
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens_details": {"reasoning_tokens": 0}}
|
||||
}`)
|
||||
|
||||
result, err := ParseResponseBody(body)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
// failed/canceled statuses are not specially mapped; they fall through to "stop"
|
||||
if result.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseDataAudioURL tests ---
|
||||
|
||||
func TestParseDataAudioURL_Valid(t *testing.T) {
|
||||
format, data, ok := ParseDataAudioURL("data:audio/mp3;base64,SGVsbG8=")
|
||||
if !ok {
|
||||
t.Fatal("expected ok=true")
|
||||
}
|
||||
if format != "mp3" {
|
||||
t.Errorf("format = %q, want %q", format, "mp3")
|
||||
}
|
||||
if data != "SGVsbG8=" {
|
||||
t.Errorf("data = %q, want %q", data, "SGVsbG8=")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDataAudioURL_NotAudio(t *testing.T) {
|
||||
_, _, ok := ParseDataAudioURL("data:image/png;base64,abc")
|
||||
if ok {
|
||||
t.Error("expected ok=false for non-audio URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDataAudioURL_MalformedNoComma(t *testing.T) {
|
||||
_, _, ok := ParseDataAudioURL("data:audio/mp3;base64")
|
||||
if ok {
|
||||
t.Error("expected ok=false for malformed URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDataAudioURL_EmptyData(t *testing.T) {
|
||||
_, _, ok := ParseDataAudioURL("data:audio/mp3;base64,")
|
||||
if ok {
|
||||
t.Error("expected ok=false for empty data")
|
||||
}
|
||||
}
|
||||
|
||||
// --- BuildMultipartContent tests ---
|
||||
|
||||
func TestBuildMultipartContent_TextOnly(t *testing.T) {
|
||||
parts := BuildMultipartContent("hello", nil)
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("len(parts) = %d, want 1", len(parts))
|
||||
}
|
||||
if parts[0].OfInputText == nil {
|
||||
t.Fatal("expected text part")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMultipartContent_TextAndImage(t *testing.T) {
|
||||
parts := BuildMultipartContent("describe", []string{"data:image/png;base64,abc"})
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("len(parts) = %d, want 2", len(parts))
|
||||
}
|
||||
if parts[0].OfInputText == nil {
|
||||
t.Error("first part should be text")
|
||||
}
|
||||
if parts[1].OfInputImage == nil {
|
||||
t.Error("second part should be image")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMultipartContent_AudioFile(t *testing.T) {
|
||||
parts := BuildMultipartContent("", []string{"data:audio/wav;base64,AAAA"})
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("len(parts) = %d, want 1", len(parts))
|
||||
}
|
||||
if parts[0].OfInputFile == nil {
|
||||
t.Fatal("expected file part for audio")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMultipartContent_EmptyTextSkipped(t *testing.T) {
|
||||
parts := BuildMultipartContent("", []string{"data:image/png;base64,abc"})
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("len(parts) = %d, want 1", len(parts))
|
||||
}
|
||||
if parts[0].OfInputImage == nil {
|
||||
t.Error("should only have image part")
|
||||
}
|
||||
}
|
||||
|
||||
// --- JSON serialization sanity checks ---
|
||||
|
||||
func TestTranslateTools_SerializesToJSON(t *testing.T) {
|
||||
tools := []protocoltypes.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: protocoltypes.ToolFunctionDefinition{
|
||||
Name: "test_tool",
|
||||
Description: "A test",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := TranslateTools(tools, true)
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal error: %v", err)
|
||||
}
|
||||
s := string(data)
|
||||
if !strings.Contains(s, "test_tool") {
|
||||
t.Errorf("JSON should contain test_tool, got: %s", s)
|
||||
}
|
||||
if !strings.Contains(s, "web_search") {
|
||||
t.Errorf("JSON should contain web_search, got: %s", s)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user