mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #535 from xiaket/ci-enable-dupl-linter
ci: enable duplication linter in CI
This commit is contained in:
+58
-65
@@ -95,75 +95,68 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "step-3.5-flash",
|
||||
},
|
||||
tests := []struct {
|
||||
name string
|
||||
aliasName string
|
||||
modelName string
|
||||
apiBase string
|
||||
wantProvider string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "alias with provider prefix",
|
||||
aliasName: "step-3.5-flash",
|
||||
modelName: "openrouter/stepfun/step-3.5-flash:free",
|
||||
apiBase: "https://openrouter.ai/api/v1",
|
||||
wantProvider: "openrouter",
|
||||
wantModel: "stepfun/step-3.5-flash:free",
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "step-3.5-flash",
|
||||
Model: "openrouter/stepfun/step-3.5-flash:free",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
{
|
||||
name: "alias without provider prefix",
|
||||
aliasName: "glm-5",
|
||||
modelName: "glm-5",
|
||||
apiBase: "https://api.z.ai/api/coding/paas/v4",
|
||||
wantProvider: "openai",
|
||||
wantModel: "glm-5",
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != "openrouter" {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter")
|
||||
}
|
||||
if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "glm-5",
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "glm-5",
|
||||
Model: "glm-5",
|
||||
APIBase: "https://api.z.ai/api/coding/paas/v4",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != "openai" {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai")
|
||||
}
|
||||
if agent.Candidates[0].Model != "glm-5" {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5")
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: tt.aliasName,
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: tt.aliasName,
|
||||
Model: tt.modelName,
|
||||
APIBase: tt.apiBase,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != tt.wantProvider {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider)
|
||||
}
|
||||
if agent.Candidates[0].Model != tt.wantModel {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+25
-57
@@ -27,16 +27,15 @@ func (f *fakeChannel) IsAllowed(string) bool {
|
||||
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
|
||||
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
// Create temp workspace
|
||||
func newTestAgentLoop(
|
||||
t *testing.T,
|
||||
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
|
||||
t.Helper()
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
cfg = &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
@@ -46,74 +45,43 @@ func TestRecordLastChannel(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
msgBus = bus.NewMessageBus()
|
||||
provider = &mockProvider{}
|
||||
al = NewAgentLoop(cfg, msgBus, provider)
|
||||
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
// Test RecordLastChannel
|
||||
testChannel := "test-channel"
|
||||
err = al.RecordLastChannel(testChannel)
|
||||
if err != nil {
|
||||
if err := al.RecordLastChannel(testChannel); err != nil {
|
||||
t.Fatalf("RecordLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify channel was saved
|
||||
lastChannel := al.state.GetLastChannel()
|
||||
if lastChannel != testChannel {
|
||||
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
|
||||
if got := al.state.GetLastChannel(); got != testChannel {
|
||||
t.Errorf("Expected channel '%s', got '%s'", testChannel, got)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChannel() != testChannel {
|
||||
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
|
||||
if got := al2.state.GetLastChannel(); got != testChannel {
|
||||
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordLastChatID(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Test RecordLastChatID
|
||||
testChatID := "test-chat-id-123"
|
||||
err = al.RecordLastChatID(testChatID)
|
||||
if err != nil {
|
||||
if err := al.RecordLastChatID(testChatID); err != nil {
|
||||
t.Fatalf("RecordLastChatID failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify chat ID was saved
|
||||
lastChatID := al.state.GetLastChatID()
|
||||
if lastChatID != testChatID {
|
||||
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
|
||||
if got := al.state.GetLastChatID(); got != testChatID {
|
||||
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChatID() != testChatID {
|
||||
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
|
||||
if got := al2.state.GetLastChatID(); got != testChatID {
|
||||
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+52
-50
@@ -539,86 +539,88 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
logger.InfoC("channels", "Outbound dispatcher started")
|
||||
func dispatchLoop[M any](
|
||||
ctx context.Context,
|
||||
m *Manager,
|
||||
subscribe func(context.Context) (M, bool),
|
||||
getChannel func(M) string,
|
||||
enqueue func(context.Context, *channelWorker, M) bool,
|
||||
startMsg, stopMsg, unknownMsg, noWorkerMsg string,
|
||||
) {
|
||||
logger.InfoC("channels", startMsg)
|
||||
|
||||
for {
|
||||
msg, ok := m.bus.SubscribeOutbound(ctx)
|
||||
msg, ok := subscribe(ctx)
|
||||
if !ok {
|
||||
logger.InfoC("channels", "Outbound dispatcher stopped")
|
||||
logger.InfoC("channels", stopMsg)
|
||||
return
|
||||
}
|
||||
|
||||
channel := getChannel(msg)
|
||||
|
||||
// Silently skip internal channels
|
||||
if constants.IsInternalChannel(msg.Channel) {
|
||||
if constants.IsInternalChannel(channel) {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
_, exists := m.channels[channel]
|
||||
w, wExists := m.workers[channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
|
||||
continue
|
||||
}
|
||||
|
||||
if wExists && w != nil {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
case <-ctx.Done():
|
||||
if !enqueue(ctx, w, msg) {
|
||||
return
|
||||
}
|
||||
} else if exists {
|
||||
logger.WarnCF("channels", "Channel has no active worker, skipping message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.SubscribeOutbound,
|
||||
func(msg bus.OutboundMessage) string { return msg.Channel },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
},
|
||||
"Outbound dispatcher started",
|
||||
"Outbound dispatcher stopped",
|
||||
"Unknown channel for outbound message",
|
||||
"Channel has no active worker, skipping message",
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
|
||||
logger.InfoC("channels", "Outbound media dispatcher started")
|
||||
|
||||
for {
|
||||
msg, ok := m.bus.SubscribeOutboundMedia(ctx)
|
||||
if !ok {
|
||||
logger.InfoC("channels", "Outbound media dispatcher stopped")
|
||||
return
|
||||
}
|
||||
|
||||
// Silently skip internal channels
|
||||
if constants.IsInternalChannel(msg.Channel) {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
logger.WarnCF("channels", "Unknown channel for outbound media message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if wExists && w != nil {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.SubscribeOutboundMedia,
|
||||
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
|
||||
select {
|
||||
case w.mediaQueue <- msg:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return
|
||||
return false
|
||||
}
|
||||
} else if exists {
|
||||
logger.WarnCF("channels", "Channel has no active worker, skipping media message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
}
|
||||
}
|
||||
},
|
||||
"Outbound media dispatcher started",
|
||||
"Outbound media dispatcher stopped",
|
||||
"Unknown channel for outbound media message",
|
||||
"Channel has no active worker, skipping media message",
|
||||
)
|
||||
}
|
||||
|
||||
// runMediaWorker processes outbound media messages for a single channel.
|
||||
|
||||
+16
-60
@@ -342,18 +342,11 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
|
||||
return result.MediaID, nil
|
||||
}
|
||||
|
||||
// sendImageMessage sends an image message using a media_id.
|
||||
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
|
||||
// sendWeComMessage marshals payload and POSTs it to the WeCom message API.
|
||||
func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
|
||||
|
||||
msg := WeComImageMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "image",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Image.MediaID = mediaID
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
@@ -400,6 +393,17 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendImageMessage sends an image message using a media_id.
|
||||
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
|
||||
msg := WeComImageMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "image",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Image.MediaID = mediaID
|
||||
return c.sendWeComMessage(ctx, accessToken, msg)
|
||||
}
|
||||
|
||||
// WebhookPath returns the path for registering on the shared HTTP server.
|
||||
func (c *WeComAppChannel) WebhookPath() string {
|
||||
if c.config.WebhookPath != "" {
|
||||
@@ -722,63 +726,15 @@ func (c *WeComAppChannel) getAccessToken() string {
|
||||
return c.accessToken
|
||||
}
|
||||
|
||||
// sendTextMessage sends a text message to a user
|
||||
// sendTextMessage sends a text message to a user.
|
||||
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
|
||||
|
||||
msg := WeComTextMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "text",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Text.Content = content
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Use configurable timeout (default 5 seconds)
|
||||
timeout := c.config.ReplyTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body)))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var sendResp WeComSendMessageResponse
|
||||
if err := json.Unmarshal(body, &sendResp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if sendResp.ErrCode != 0 {
|
||||
return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.sendWeComMessage(ctx, accessToken, msg)
|
||||
}
|
||||
|
||||
// handleHealth handles health check requests
|
||||
|
||||
@@ -323,60 +323,6 @@ func TestWeComAppDecryptMessage(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComAppPKCS7Unpad(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte{},
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "valid padding 3 bytes",
|
||||
input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
|
||||
expected: []byte("hello"),
|
||||
},
|
||||
{
|
||||
name: "valid padding 16 bytes (full block)",
|
||||
input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
|
||||
expected: []byte("123456789012345"),
|
||||
},
|
||||
{
|
||||
name: "invalid padding larger than data",
|
||||
input: []byte{20},
|
||||
expected: nil, // should return error
|
||||
},
|
||||
{
|
||||
name: "invalid padding zero",
|
||||
input: append([]byte("test"), byte(0)),
|
||||
expected: nil, // should return error
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := pkcs7Unpad(tt.input)
|
||||
if tt.expected == nil {
|
||||
// This case should return an error
|
||||
if err == nil {
|
||||
t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("pkcs7Unpad() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(result, tt.expected) {
|
||||
t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComAppHandleVerification(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
aesKey := generateTestAESKeyApp()
|
||||
|
||||
@@ -412,22 +412,9 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("valid direct message callback", func(t *testing.T) {
|
||||
// Create JSON message for direct chat (single)
|
||||
jsonMsg := `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`
|
||||
|
||||
// Encrypt message
|
||||
runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
|
||||
|
||||
// Create encrypted XML wrapper
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
@@ -435,20 +422,29 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
Encrypt: encrypted,
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encrypted)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
t.Run("valid direct message callback", func(t *testing.T) {
|
||||
w := runBotMessageCallback(t, `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
@@ -458,8 +454,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("valid group message callback", func(t *testing.T) {
|
||||
// Create JSON message for group chat
|
||||
jsonMsg := `{
|
||||
w := runBotMessageCallback(t, `{
|
||||
"msgid": "test_msg_id_456",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chatid": "group_chat_id_123",
|
||||
@@ -468,33 +463,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello Group"}
|
||||
}`
|
||||
|
||||
// Encrypt message
|
||||
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
|
||||
|
||||
// Create encrypted XML wrapper
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: encrypted,
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encrypted)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
}`)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
+1
-19
@@ -747,25 +747,7 @@ func (c *Config) findMatches(modelName string) []ModelConfig {
|
||||
|
||||
// HasProvidersConfig checks if any provider in the old providers config has configuration.
|
||||
func (c *Config) HasProvidersConfig() bool {
|
||||
v := c.Providers
|
||||
return v.Anthropic.APIKey != "" || v.Anthropic.APIBase != "" ||
|
||||
v.OpenAI.APIKey != "" || v.OpenAI.APIBase != "" ||
|
||||
v.OpenRouter.APIKey != "" || v.OpenRouter.APIBase != "" ||
|
||||
v.Groq.APIKey != "" || v.Groq.APIBase != "" ||
|
||||
v.Zhipu.APIKey != "" || v.Zhipu.APIBase != "" ||
|
||||
v.VLLM.APIKey != "" || v.VLLM.APIBase != "" ||
|
||||
v.Gemini.APIKey != "" || v.Gemini.APIBase != "" ||
|
||||
v.Nvidia.APIKey != "" || v.Nvidia.APIBase != "" ||
|
||||
v.Ollama.APIKey != "" || v.Ollama.APIBase != "" ||
|
||||
v.Moonshot.APIKey != "" || v.Moonshot.APIBase != "" ||
|
||||
v.ShengSuanYun.APIKey != "" || v.ShengSuanYun.APIBase != "" ||
|
||||
v.DeepSeek.APIKey != "" || v.DeepSeek.APIBase != "" ||
|
||||
v.Cerebras.APIKey != "" || v.Cerebras.APIBase != "" ||
|
||||
v.VolcEngine.APIKey != "" || v.VolcEngine.APIBase != "" ||
|
||||
v.GitHubCopilot.APIKey != "" || v.GitHubCopilot.APIBase != "" ||
|
||||
v.Antigravity.APIKey != "" || v.Antigravity.APIBase != "" ||
|
||||
v.Qwen.APIKey != "" || v.Qwen.APIBase != "" ||
|
||||
v.Mistral.APIKey != "" || v.Mistral.APIBase != ""
|
||||
return !c.Providers.IsEmpty()
|
||||
}
|
||||
|
||||
// ValidateModelList validates all ModelConfig entries in the model_list.
|
||||
|
||||
@@ -47,79 +47,63 @@ func TestExecuteHeartbeat_Async(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteHeartbeat_Error(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
ForLLM: "Heartbeat failed: connection error",
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
}
|
||||
})
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
|
||||
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Check log file for error message
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
func TestExecuteHeartbeat_ResultLogging(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *tools.ToolResult
|
||||
wantLog string
|
||||
}{
|
||||
{
|
||||
name: "error result",
|
||||
result: &tools.ToolResult{
|
||||
ForLLM: "Heartbeat failed: connection error",
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
},
|
||||
wantLog: "error message",
|
||||
},
|
||||
{
|
||||
name: "silent result",
|
||||
result: &tools.ToolResult{
|
||||
ForLLM: "Heartbeat completed successfully",
|
||||
ForUser: "",
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
},
|
||||
wantLog: "completion message",
|
||||
},
|
||||
}
|
||||
|
||||
logContent := string(data)
|
||||
if logContent == "" {
|
||||
t.Error("Expected log file to contain error message")
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
func TestExecuteHeartbeat_Silent(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return tt.result
|
||||
})
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
ForLLM: "Heartbeat completed successfully",
|
||||
ForUser: "",
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
})
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
|
||||
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Check log file for completion message
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(data)
|
||||
if logContent == "" {
|
||||
t.Error("Expected log file to contain completion message")
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
}
|
||||
if string(data) == "" {
|
||||
t.Errorf("Expected log file to contain %s", tt.wantLog)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -118,64 +118,55 @@ func TestPlanWorkspaceMigration(t *testing.T) {
|
||||
assert.GreaterOrEqual(t, len(actions), 1)
|
||||
}
|
||||
|
||||
func TestPlanWorkspaceMigrationWithExistingDestination(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
func TestPlanWorkspaceMigrationExistingFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
force bool
|
||||
wantActionType ActionType
|
||||
}{
|
||||
{
|
||||
name: "backup when not forced",
|
||||
force: false,
|
||||
wantActionType: ActionBackup,
|
||||
},
|
||||
{
|
||||
name: "copy when forced",
|
||||
force: true,
|
||||
wantActionType: ActionCopy,
|
||||
},
|
||||
}
|
||||
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
|
||||
err = os.MkdirAll(dstWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
|
||||
require.NoError(t, err)
|
||||
err = os.MkdirAll(dstWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, len(actions), 1)
|
||||
assert.Equal(t, ActionBackup, actions[0].Type)
|
||||
}
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
tt.force,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
func TestPlanWorkspaceMigrationForce(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(dstWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, len(actions), 1)
|
||||
assert.Equal(t, ActionCopy, actions[0].Type)
|
||||
require.GreaterOrEqual(t, len(actions), 1)
|
||||
assert.Equal(t, tt.wantActionType, actions[0].Type)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanWorkspaceMigrationNonExistentSource(t *testing.T) {
|
||||
|
||||
@@ -100,44 +100,12 @@ func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDe
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
parts = append(parts, p.buildToolsPrompt(tools))
|
||||
parts = append(parts, buildCLIToolsPrompt(tools))
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// buildToolsPrompt creates the tool definitions section for the system prompt.
|
||||
func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(
|
||||
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
|
||||
)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// parseClaudeCliResponse parses the JSON output from the claude CLI.
|
||||
func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) {
|
||||
var resp claudeCliJSONResponse
|
||||
|
||||
@@ -660,12 +660,11 @@ func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) {
|
||||
// --- buildToolsPrompt tests ---
|
||||
|
||||
func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
tools := []ToolDefinition{
|
||||
{Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}},
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}},
|
||||
}
|
||||
got := p.buildToolsPrompt(tools)
|
||||
got := buildCLIToolsPrompt(tools)
|
||||
if strings.Contains(got, "skip_me") {
|
||||
t.Error("buildToolsPrompt() should skip non-function tools")
|
||||
}
|
||||
@@ -675,11 +674,10 @@ func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildToolsPrompt_NoDescription(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}},
|
||||
}
|
||||
got := p.buildToolsPrompt(tools)
|
||||
got := buildCLIToolsPrompt(tools)
|
||||
if !strings.Contains(got, "bare_tool") {
|
||||
t.Error("should include tool name")
|
||||
}
|
||||
@@ -689,14 +687,13 @@ func TestBuildToolsPrompt_NoDescription(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildToolsPrompt_NoParameters(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{
|
||||
Name: "no_params_tool",
|
||||
Description: "A tool with no parameters",
|
||||
}},
|
||||
}
|
||||
got := p.buildToolsPrompt(tools)
|
||||
got := buildCLIToolsPrompt(tools)
|
||||
if strings.Contains(got, "Parameters:") {
|
||||
t.Error("should not include Parameters: section when nil")
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString(p.buildToolsPrompt(tools))
|
||||
sb.WriteString(buildCLIToolsPrompt(tools))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
@@ -128,38 +128,6 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// buildToolsPrompt creates a tool definitions section for the prompt.
|
||||
func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(
|
||||
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
|
||||
)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// codexEvent represents a single JSONL event from `codex exec --json`.
|
||||
type codexEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
@@ -5,7 +5,43 @@
|
||||
|
||||
package providers
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// buildCLIToolsPrompt creates the tool definitions section for a CLI provider system prompt.
|
||||
func buildCLIToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(
|
||||
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
|
||||
)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated.
|
||||
// It handles cases where Name/Arguments might be in different locations (top-level vs Function)
|
||||
|
||||
Reference in New Issue
Block a user