enable dupl check

Signed-off-by: Kai Xia <kaix+github@fastmail.com>
This commit is contained in:
Kai Xia
2026-03-01 18:17:32 +11:00
parent 44a52c0cf6
commit 32c864c309
14 changed files with 347 additions and 739 deletions
-1
View File
@@ -7,7 +7,6 @@ linters:
- containedctx
- cyclop
- depguard
- dupl
- dupword
- err113
- exhaustruct
+47 -211
View File
@@ -10,8 +10,8 @@ import (
picoclawconfig "github.com/sipeed/picoclaw/pkg/config"
)
func (s *appState) channelMenu() tview.Primitive {
items := []MenuItem{
func (s *appState) buildChannelMenuItems() []MenuItem {
return []MenuItem{
{Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }},
channelItem(
"Telegram",
@@ -86,8 +86,10 @@ func (s *appState) channelMenu() tview.Primitive {
func() { s.push("channel-wecomapp", s.wecomAppForm()) },
),
}
}
menu := NewMenu("Channels", items)
func (s *appState) channelMenu() tview.Primitive {
menu := NewMenu("Channels", s.buildChannelMenuItems())
menu.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
if event.Key() == tcell.KeyEsc {
s.pop()
@@ -103,199 +105,72 @@ func (s *appState) channelMenu() tview.Primitive {
}
func refreshChannelMenuFromState(menu *Menu, s *appState) {
items := []MenuItem{
{Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }},
channelItem(
"Telegram",
"Telegram bot settings",
s.config.Channels.Telegram.Enabled,
func() { s.push("channel-telegram", s.telegramForm()) },
),
channelItem(
"Discord",
"Discord bot settings",
s.config.Channels.Discord.Enabled,
func() { s.push("channel-discord", s.discordForm()) },
),
channelItem(
"QQ",
"QQ bot settings",
s.config.Channels.QQ.Enabled,
func() { s.push("channel-qq", s.qqForm()) },
),
channelItem(
"MaixCam",
"MaixCam gateway",
s.config.Channels.MaixCam.Enabled,
func() { s.push("channel-maixcam", s.maixcamForm()) },
),
channelItem(
"WhatsApp",
"WhatsApp bridge",
s.config.Channels.WhatsApp.Enabled,
func() { s.push("channel-whatsapp", s.whatsappForm()) },
),
channelItem(
"Feishu",
"Feishu bot settings",
s.config.Channels.Feishu.Enabled,
func() { s.push("channel-feishu", s.feishuForm()) },
),
channelItem(
"DingTalk",
"DingTalk bot settings",
s.config.Channels.DingTalk.Enabled,
func() { s.push("channel-dingtalk", s.dingtalkForm()) },
),
channelItem(
"Slack",
"Slack bot settings",
s.config.Channels.Slack.Enabled,
func() { s.push("channel-slack", s.slackForm()) },
),
channelItem(
"LINE",
"LINE bot settings",
s.config.Channels.LINE.Enabled,
func() { s.push("channel-line", s.lineForm()) },
),
channelItem(
"OneBot",
"OneBot settings",
s.config.Channels.OneBot.Enabled,
func() { s.push("channel-onebot", s.onebotForm()) },
),
channelItem(
"WeCom",
"WeCom bot settings",
s.config.Channels.WeCom.Enabled,
func() { s.push("channel-wecom", s.wecomForm()) },
),
channelItem(
"WeCom App",
"WeCom App settings",
s.config.Channels.WeComApp.Enabled,
func() { s.push("channel-wecomapp", s.wecomAppForm()) },
),
}
menu.applyItems(items)
menu.applyItems(s.buildChannelMenuItems())
}
func (s *appState) telegramForm() tview.Primitive {
cfg := &s.config.Channels.Telegram
form := baseChannelForm("Telegram", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("Telegram", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
cfg.Token = strings.TrimSpace(text)
})
form.AddInputField("Proxy", cfg.Proxy, 128, nil, func(text string) {
cfg.Proxy = strings.TrimSpace(text)
})
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) discordForm() tview.Primitive {
cfg := &s.config.Channels.Discord
form := baseChannelForm("Discord", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("Discord", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
cfg.Token = strings.TrimSpace(text)
})
form.AddCheckbox("Mention Only", cfg.MentionOnly, func(checked bool) {
cfg.MentionOnly = checked
})
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) qqForm() tview.Primitive {
cfg := &s.config.Channels.QQ
form := baseChannelForm("QQ", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("QQ", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) {
cfg.AppID = strings.TrimSpace(text)
})
form.AddInputField("App Secret", cfg.AppSecret, 128, nil, func(text string) {
cfg.AppSecret = strings.TrimSpace(text)
})
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) maixcamForm() tview.Primitive {
cfg := &s.config.Channels.MaixCam
form := baseChannelForm("MaixCam", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("MaixCam", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Host", cfg.Host, 64, nil, func(text string) {
cfg.Host = strings.TrimSpace(text)
})
addIntField(form, "Port", cfg.Port, func(value int) { cfg.Port = value })
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) whatsappForm() tview.Primitive {
cfg := &s.config.Channels.WhatsApp
form := baseChannelForm("WhatsApp", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("WhatsApp", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Bridge URL", cfg.BridgeURL, 128, nil, func(text string) {
cfg.BridgeURL = strings.TrimSpace(text)
})
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) feishuForm() tview.Primitive {
cfg := &s.config.Channels.Feishu
form := baseChannelForm("Feishu", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("Feishu", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) {
cfg.AppID = strings.TrimSpace(text)
})
@@ -308,66 +183,39 @@ func (s *appState) feishuForm() tview.Primitive {
form.AddInputField("Verification Token", cfg.VerificationToken, 128, nil, func(text string) {
cfg.VerificationToken = strings.TrimSpace(text)
})
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) dingtalkForm() tview.Primitive {
cfg := &s.config.Channels.DingTalk
form := baseChannelForm("DingTalk", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("DingTalk", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Client ID", cfg.ClientID, 64, nil, func(text string) {
cfg.ClientID = strings.TrimSpace(text)
})
form.AddInputField("Client Secret", cfg.ClientSecret, 128, nil, func(text string) {
cfg.ClientSecret = strings.TrimSpace(text)
})
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) slackForm() tview.Primitive {
cfg := &s.config.Channels.Slack
form := baseChannelForm("Slack", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("Slack", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Bot Token", cfg.BotToken, 128, nil, func(text string) {
cfg.BotToken = strings.TrimSpace(text)
})
form.AddInputField("App Token", cfg.AppToken, 128, nil, func(text string) {
cfg.AppToken = strings.TrimSpace(text)
})
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) lineForm() tview.Primitive {
cfg := &s.config.Channels.LINE
form := baseChannelForm("LINE", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("LINE", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Channel Secret", cfg.ChannelSecret, 128, nil, func(text string) {
cfg.ChannelSecret = strings.TrimSpace(text)
})
@@ -381,22 +229,13 @@ func (s *appState) lineForm() tview.Primitive {
form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
cfg.WebhookPath = strings.TrimSpace(text)
})
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) onebotForm() tview.Primitive {
cfg := &s.config.Channels.OneBot
form := baseChannelForm("OneBot", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("OneBot", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("WS URL", cfg.WSUrl, 128, nil, func(text string) {
cfg.WSUrl = strings.TrimSpace(text)
})
@@ -418,22 +257,13 @@ func (s *appState) onebotForm() tview.Primitive {
cfg.GroupTriggerPrefix = splitCSV(text)
},
)
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) wecomForm() tview.Primitive {
cfg := &s.config.Channels.WeCom
form := baseChannelForm("WeCom", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("WeCom", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
cfg.Token = strings.TrimSpace(text)
})
@@ -450,9 +280,7 @@ func (s *appState) wecomForm() tview.Primitive {
form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
cfg.WebhookPath = strings.TrimSpace(text)
})
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
addIntField(
form,
"Reply Timeout",
@@ -464,14 +292,7 @@ func (s *appState) wecomForm() tview.Primitive {
func (s *appState) wecomAppForm() tview.Primitive {
cfg := &s.config.Channels.WeComApp
form := baseChannelForm("WeCom App", cfg.Enabled, func(v bool) {
cfg.Enabled = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
})
form := baseChannelForm("WeCom App", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Corp ID", cfg.CorpID, 64, nil, func(text string) {
cfg.CorpID = strings.TrimSpace(text)
})
@@ -492,9 +313,7 @@ func (s *appState) wecomAppForm() tview.Primitive {
form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
cfg.WebhookPath = strings.TrimSpace(text)
})
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
cfg.AllowFrom = splitCSV(text)
})
addAllowFromField(form, &cfg.AllowFrom)
addIntField(
form,
"Reply Timeout",
@@ -504,6 +323,23 @@ func (s *appState) wecomAppForm() tview.Primitive {
return wrapWithBack(form, s)
}
func (s *appState) makeChannelOnEnabled(enabledPtr *bool) func(bool) {
return func(v bool) {
*enabledPtr = v
s.dirty = true
refreshMainMenuIfPresent(s)
if menu, ok := s.menus["channel"]; ok {
refreshChannelMenuFromState(menu, s)
}
}
}
func addAllowFromField(form *tview.Form, allowFrom *picoclawconfig.FlexibleStringSlice) {
form.AddInputField("Allow From", strings.Join(*allowFrom, ","), 128, nil, func(text string) {
*allowFrom = splitCSV(text)
})
}
func baseChannelForm(title string, enabled bool, onEnabled func(bool)) *tview.Form {
form := tview.NewForm()
form.SetBorder(true).SetTitle(fmt.Sprintf("Channel: %s", title))
+58 -65
View File
@@ -95,75 +95,68 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) {
}
func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "step-3.5-flash",
},
tests := []struct {
name string
aliasName string
modelName string
apiBase string
wantProvider string
wantModel string
}{
{
name: "alias with provider prefix",
aliasName: "step-3.5-flash",
modelName: "openrouter/stepfun/step-3.5-flash:free",
apiBase: "https://openrouter.ai/api/v1",
wantProvider: "openrouter",
wantModel: "stepfun/step-3.5-flash:free",
},
ModelList: []config.ModelConfig{
{
ModelName: "step-3.5-flash",
Model: "openrouter/stepfun/step-3.5-flash:free",
APIBase: "https://openrouter.ai/api/v1",
},
{
name: "alias without provider prefix",
aliasName: "glm-5",
modelName: "glm-5",
apiBase: "https://api.z.ai/api/coding/paas/v4",
wantProvider: "openai",
wantModel: "glm-5",
},
}
provider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
if len(agent.Candidates) != 1 {
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
}
if agent.Candidates[0].Provider != "openrouter" {
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter")
}
if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" {
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free")
}
}
func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "glm-5",
},
},
ModelList: []config.ModelConfig{
{
ModelName: "glm-5",
Model: "glm-5",
APIBase: "https://api.z.ai/api/coding/paas/v4",
},
},
}
provider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
if len(agent.Candidates) != 1 {
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
}
if agent.Candidates[0].Provider != "openai" {
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai")
}
if agent.Candidates[0].Model != "glm-5" {
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5")
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: tt.aliasName,
},
},
ModelList: []config.ModelConfig{
{
ModelName: tt.aliasName,
Model: tt.modelName,
APIBase: tt.apiBase,
},
},
}
provider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
if len(agent.Candidates) != 1 {
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
}
if agent.Candidates[0].Provider != tt.wantProvider {
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider)
}
if agent.Candidates[0].Model != tt.wantModel {
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel)
}
})
}
}
+25 -57
View File
@@ -26,16 +26,15 @@ func (f *fakeChannel) IsAllowed(string) bool {
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
func TestRecordLastChannel(t *testing.T) {
// Create temp workspace
func newTestAgentLoop(
t *testing.T,
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
cfg = &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
@@ -45,74 +44,43 @@ func TestRecordLastChannel(t *testing.T) {
},
},
}
msgBus = bus.NewMessageBus()
provider = &mockProvider{}
al = NewAgentLoop(cfg, msgBus, provider)
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
func TestRecordLastChannel(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
// Test RecordLastChannel
testChannel := "test-channel"
err = al.RecordLastChannel(testChannel)
if err != nil {
if err := al.RecordLastChannel(testChannel); err != nil {
t.Fatalf("RecordLastChannel failed: %v", err)
}
// Verify channel was saved
lastChannel := al.state.GetLastChannel()
if lastChannel != testChannel {
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
if got := al.state.GetLastChannel(); got != testChannel {
t.Errorf("Expected channel '%s', got '%s'", testChannel, got)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChannel() != testChannel {
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
if got := al2.state.GetLastChannel(); got != testChannel {
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got)
}
}
func TestRecordLastChatID(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChatID
testChatID := "test-chat-id-123"
err = al.RecordLastChatID(testChatID)
if err != nil {
if err := al.RecordLastChatID(testChatID); err != nil {
t.Fatalf("RecordLastChatID failed: %v", err)
}
// Verify chat ID was saved
lastChatID := al.state.GetLastChatID()
if lastChatID != testChatID {
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
if got := al.state.GetLastChatID(); got != testChatID {
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChatID() != testChatID {
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
if got := al2.state.GetLastChatID(); got != testChatID {
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got)
}
}
+52 -50
View File
@@ -539,86 +539,88 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork
})
}
func (m *Manager) dispatchOutbound(ctx context.Context) {
logger.InfoC("channels", "Outbound dispatcher started")
func dispatchLoop[M any](
ctx context.Context,
m *Manager,
subscribe func(context.Context) (M, bool),
getChannel func(M) string,
enqueue func(context.Context, *channelWorker, M) bool,
startMsg, stopMsg, unknownMsg, noWorkerMsg string,
) {
logger.InfoC("channels", startMsg)
for {
msg, ok := m.bus.SubscribeOutbound(ctx)
msg, ok := subscribe(ctx)
if !ok {
logger.InfoC("channels", "Outbound dispatcher stopped")
logger.InfoC("channels", stopMsg)
return
}
channel := getChannel(msg)
// Silently skip internal channels
if constants.IsInternalChannel(msg.Channel) {
if constants.IsInternalChannel(channel) {
continue
}
m.mu.RLock()
_, exists := m.channels[msg.Channel]
w, wExists := m.workers[msg.Channel]
_, exists := m.channels[channel]
w, wExists := m.workers[channel]
m.mu.RUnlock()
if !exists {
logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{
"channel": msg.Channel,
})
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
continue
}
if wExists && w != nil {
select {
case w.queue <- msg:
case <-ctx.Done():
if !enqueue(ctx, w, msg) {
return
}
} else if exists {
logger.WarnCF("channels", "Channel has no active worker, skipping message", map[string]any{
"channel": msg.Channel,
})
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
}
}
}
func (m *Manager) dispatchOutbound(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.SubscribeOutbound,
func(msg bus.OutboundMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
select {
case w.queue <- msg:
return true
case <-ctx.Done():
return false
}
},
"Outbound dispatcher started",
"Outbound dispatcher stopped",
"Unknown channel for outbound message",
"Channel has no active worker, skipping message",
)
}
func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
logger.InfoC("channels", "Outbound media dispatcher started")
for {
msg, ok := m.bus.SubscribeOutboundMedia(ctx)
if !ok {
logger.InfoC("channels", "Outbound media dispatcher stopped")
return
}
// Silently skip internal channels
if constants.IsInternalChannel(msg.Channel) {
continue
}
m.mu.RLock()
_, exists := m.channels[msg.Channel]
w, wExists := m.workers[msg.Channel]
m.mu.RUnlock()
if !exists {
logger.WarnCF("channels", "Unknown channel for outbound media message", map[string]any{
"channel": msg.Channel,
})
continue
}
if wExists && w != nil {
dispatchLoop(
ctx, m,
m.bus.SubscribeOutboundMedia,
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
select {
case w.mediaQueue <- msg:
return true
case <-ctx.Done():
return
return false
}
} else if exists {
logger.WarnCF("channels", "Channel has no active worker, skipping media message", map[string]any{
"channel": msg.Channel,
})
}
}
},
"Outbound media dispatcher started",
"Outbound media dispatcher stopped",
"Unknown channel for outbound media message",
"Channel has no active worker, skipping media message",
)
}
// runMediaWorker processes outbound media messages for a single channel.
+16 -60
View File
@@ -342,18 +342,11 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
return result.MediaID, nil
}
// sendImageMessage sends an image message using a media_id.
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
// sendWeComMessage marshals payload and POSTs it to the WeCom message API.
func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error {
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
msg := WeComImageMessage{
ToUser: userID,
MsgType: "image",
AgentID: c.config.AgentID,
}
msg.Image.MediaID = mediaID
jsonData, err := json.Marshal(msg)
jsonData, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
@@ -400,6 +393,17 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use
return nil
}
// sendImageMessage sends an image message using a media_id.
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
msg := WeComImageMessage{
ToUser: userID,
MsgType: "image",
AgentID: c.config.AgentID,
}
msg.Image.MediaID = mediaID
return c.sendWeComMessage(ctx, accessToken, msg)
}
// WebhookPath returns the path for registering on the shared HTTP server.
func (c *WeComAppChannel) WebhookPath() string {
if c.config.WebhookPath != "" {
@@ -722,63 +726,15 @@ func (c *WeComAppChannel) getAccessToken() string {
return c.accessToken
}
// sendTextMessage sends a text message to a user
// sendTextMessage sends a text message to a user.
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
msg := WeComTextMessage{
ToUser: userID,
MsgType: "text",
AgentID: c.config.AgentID,
}
msg.Text.Content = content
jsonData, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
// Use configurable timeout (default 5 seconds)
timeout := c.config.ReplyTimeout
if timeout <= 0 {
timeout = 5
}
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body)))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
var sendResp WeComSendMessageResponse
if err := json.Unmarshal(body, &sendResp); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
if sendResp.ErrCode != 0 {
return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
}
return nil
return c.sendWeComMessage(ctx, accessToken, msg)
}
// handleHealth handles health check requests
-54
View File
@@ -323,60 +323,6 @@ func TestWeComAppDecryptMessage(t *testing.T) {
})
}
func TestWeComAppPKCS7Unpad(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
}{
{
name: "empty input",
input: []byte{},
expected: []byte{},
},
{
name: "valid padding 3 bytes",
input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
expected: []byte("hello"),
},
{
name: "valid padding 16 bytes (full block)",
input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
expected: []byte("123456789012345"),
},
{
name: "invalid padding larger than data",
input: []byte{20},
expected: nil, // should return error
},
{
name: "invalid padding zero",
input: append([]byte("test"), byte(0)),
expected: nil, // should return error
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := pkcs7Unpad(tt.input)
if tt.expected == nil {
// This case should return an error
if err == nil {
t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
}
return
}
if err != nil {
t.Errorf("pkcs7Unpad() unexpected error: %v", err)
return
}
if !bytes.Equal(result, tt.expected) {
t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
}
})
}
}
func TestWeComAppHandleVerification(t *testing.T) {
msgBus := bus.NewMessageBus()
aesKey := generateTestAESKeyApp()
+16 -47
View File
@@ -412,22 +412,9 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
}
ch, _ := NewWeComBotChannel(cfg, msgBus)
t.Run("valid direct message callback", func(t *testing.T) {
// Create JSON message for direct chat (single)
jsonMsg := `{
"msgid": "test_msg_id_123",
"aibotid": "test_aibot_id",
"chattype": "single",
"from": {"userid": "user123"},
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
"msgtype": "text",
"text": {"content": "Hello World"}
}`
// Encrypt message
runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder {
t.Helper()
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
// Create encrypted XML wrapper
encryptedWrapper := struct {
XMLName xml.Name `xml:"xml"`
Encrypt string `xml:"Encrypt"`
@@ -435,20 +422,29 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
Encrypt: encrypted,
}
wrapperData, _ := xml.Marshal(encryptedWrapper)
timestamp := "1234567890"
nonce := "test_nonce"
signature := generateSignature("test_token", timestamp, nonce, encrypted)
req := httptest.NewRequest(
http.MethodPost,
"/webhook/wecom?msg_signature="+signature+"&timestamp="+timestamp+"&nonce="+nonce,
bytes.NewReader(wrapperData),
)
w := httptest.NewRecorder()
ch.handleMessageCallback(context.Background(), w, req)
return w
}
t.Run("valid direct message callback", func(t *testing.T) {
w := runBotMessageCallback(t, `{
"msgid": "test_msg_id_123",
"aibotid": "test_aibot_id",
"chattype": "single",
"from": {"userid": "user123"},
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
"msgtype": "text",
"text": {"content": "Hello World"}
}`)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
@@ -458,8 +454,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
})
t.Run("valid group message callback", func(t *testing.T) {
// Create JSON message for group chat
jsonMsg := `{
w := runBotMessageCallback(t, `{
"msgid": "test_msg_id_456",
"aibotid": "test_aibot_id",
"chatid": "group_chat_id_123",
@@ -468,33 +463,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
"msgtype": "text",
"text": {"content": "Hello Group"}
}`
// Encrypt message
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
// Create encrypted XML wrapper
encryptedWrapper := struct {
XMLName xml.Name `xml:"xml"`
Encrypt string `xml:"Encrypt"`
}{
Encrypt: encrypted,
}
wrapperData, _ := xml.Marshal(encryptedWrapper)
timestamp := "1234567890"
nonce := "test_nonce"
signature := generateSignature("test_token", timestamp, nonce, encrypted)
req := httptest.NewRequest(
http.MethodPost,
"/webhook/wecom?msg_signature="+signature+"&timestamp="+timestamp+"&nonce="+nonce,
bytes.NewReader(wrapperData),
)
w := httptest.NewRecorder()
ch.handleMessageCallback(context.Background(), w, req)
}`)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
+1 -19
View File
@@ -742,25 +742,7 @@ func (c *Config) findMatches(modelName string) []ModelConfig {
// HasProvidersConfig checks if any provider in the old providers config has configuration.
func (c *Config) HasProvidersConfig() bool {
v := c.Providers
return v.Anthropic.APIKey != "" || v.Anthropic.APIBase != "" ||
v.OpenAI.APIKey != "" || v.OpenAI.APIBase != "" ||
v.OpenRouter.APIKey != "" || v.OpenRouter.APIBase != "" ||
v.Groq.APIKey != "" || v.Groq.APIBase != "" ||
v.Zhipu.APIKey != "" || v.Zhipu.APIBase != "" ||
v.VLLM.APIKey != "" || v.VLLM.APIBase != "" ||
v.Gemini.APIKey != "" || v.Gemini.APIBase != "" ||
v.Nvidia.APIKey != "" || v.Nvidia.APIBase != "" ||
v.Ollama.APIKey != "" || v.Ollama.APIBase != "" ||
v.Moonshot.APIKey != "" || v.Moonshot.APIBase != "" ||
v.ShengSuanYun.APIKey != "" || v.ShengSuanYun.APIBase != "" ||
v.DeepSeek.APIKey != "" || v.DeepSeek.APIBase != "" ||
v.Cerebras.APIKey != "" || v.Cerebras.APIBase != "" ||
v.VolcEngine.APIKey != "" || v.VolcEngine.APIBase != "" ||
v.GitHubCopilot.APIKey != "" || v.GitHubCopilot.APIBase != "" ||
v.Antigravity.APIKey != "" || v.Antigravity.APIBase != "" ||
v.Qwen.APIKey != "" || v.Qwen.APIBase != "" ||
v.Mistral.APIKey != "" || v.Mistral.APIBase != ""
return !c.Providers.IsEmpty()
}
// ValidateModelList validates all ModelConfig entries in the model_list.
+51 -67
View File
@@ -47,79 +47,63 @@ func TestExecuteHeartbeat_Async(t *testing.T) {
}
}
func TestExecuteHeartbeat_Error(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat failed: connection error",
ForUser: "",
Silent: false,
IsError: true,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
hs.executeHeartbeat()
// Check log file for error message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
func TestExecuteHeartbeat_ResultLogging(t *testing.T) {
tests := []struct {
name string
result *tools.ToolResult
wantLog string
}{
{
name: "error result",
result: &tools.ToolResult{
ForLLM: "Heartbeat failed: connection error",
ForUser: "",
Silent: false,
IsError: true,
Async: false,
},
wantLog: "error message",
},
{
name: "silent result",
result: &tools.ToolResult{
ForLLM: "Heartbeat completed successfully",
ForUser: "",
Silent: true,
IsError: false,
Async: false,
},
wantLog: "completion message",
},
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain error message")
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
func TestExecuteHeartbeat_Silent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return tt.result
})
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat completed successfully",
ForUser: "",
Silent: true,
IsError: false,
Async: false,
}
})
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
hs.executeHeartbeat()
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
hs.executeHeartbeat()
// Check log file for completion message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain completion message")
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
if string(data) == "" {
t.Errorf("Expected log file to contain %s", tt.wantLog)
}
})
}
}
+42 -51
View File
@@ -118,64 +118,55 @@ func TestPlanWorkspaceMigration(t *testing.T) {
assert.GreaterOrEqual(t, len(actions), 1)
}
func TestPlanWorkspaceMigrationWithExistingDestination(t *testing.T) {
tmpDir := t.TempDir()
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
func TestPlanWorkspaceMigrationExistingFile(t *testing.T) {
tests := []struct {
name string
force bool
wantActionType ActionType
}{
{
name: "backup when not forced",
force: false,
wantActionType: ActionBackup,
},
{
name: "copy when forced",
force: true,
wantActionType: ActionCopy,
},
}
err := os.MkdirAll(srcWorkspace, 0o755)
require.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
err = os.MkdirAll(dstWorkspace, 0o755)
require.NoError(t, err)
err := os.MkdirAll(srcWorkspace, 0o755)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
require.NoError(t, err)
err = os.MkdirAll(dstWorkspace, 0o755)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
require.NoError(t, err)
actions, err := PlanWorkspaceMigration(
srcWorkspace,
dstWorkspace,
[]string{"file1.txt"},
[]string{},
false,
)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
require.NoError(t, err)
require.GreaterOrEqual(t, len(actions), 1)
assert.Equal(t, ActionBackup, actions[0].Type)
}
actions, err := PlanWorkspaceMigration(
srcWorkspace,
dstWorkspace,
[]string{"file1.txt"},
[]string{},
tt.force,
)
require.NoError(t, err)
func TestPlanWorkspaceMigrationForce(t *testing.T) {
tmpDir := t.TempDir()
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
err := os.MkdirAll(srcWorkspace, 0o755)
require.NoError(t, err)
err = os.MkdirAll(dstWorkspace, 0o755)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
require.NoError(t, err)
actions, err := PlanWorkspaceMigration(
srcWorkspace,
dstWorkspace,
[]string{"file1.txt"},
[]string{},
true,
)
require.NoError(t, err)
require.GreaterOrEqual(t, len(actions), 1)
assert.Equal(t, ActionCopy, actions[0].Type)
require.GreaterOrEqual(t, len(actions), 1)
assert.Equal(t, tt.wantActionType, actions[0].Type)
})
}
}
func TestPlanWorkspaceMigrationNonExistentSource(t *testing.T) {
+1 -28
View File
@@ -108,34 +108,7 @@ func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDe
// buildToolsPrompt creates the tool definitions section for the system prompt.
func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
sb.WriteString(
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
if tool.Type != "function" {
continue
}
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
if tool.Function.Description != "" {
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
}
if len(tool.Function.Parameters) > 0 {
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
}
sb.WriteString("\n")
}
return sb.String()
return buildCLIToolsPrompt(tools)
}
// parseClaudeCliResponse parses the JSON output from the claude CLI.
+1 -28
View File
@@ -130,34 +130,7 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
// buildToolsPrompt creates a tool definitions section for the prompt.
func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
sb.WriteString(
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
if tool.Type != "function" {
continue
}
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
if tool.Function.Description != "" {
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
}
if len(tool.Function.Parameters) > 0 {
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
}
sb.WriteString("\n")
}
return sb.String()
return buildCLIToolsPrompt(tools)
}
// codexEvent represents a single JSONL event from `codex exec --json`.
+37 -1
View File
@@ -5,7 +5,43 @@
package providers
import "encoding/json"
import (
"encoding/json"
"fmt"
"strings"
)
// buildCLIToolsPrompt creates the tool definitions section for a CLI provider system prompt.
func buildCLIToolsPrompt(tools []ToolDefinition) string {
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
sb.WriteString(
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
if tool.Type != "function" {
continue
}
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
if tool.Function.Description != "" {
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
}
if len(tool.Function.Parameters) > 0 {
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
}
sb.WriteString("\n")
}
return sb.String()
}
// NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated.
// It handles cases where Name/Arguments might be in different locations (top-level vs Function)