mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
@@ -7,7 +7,6 @@ linters:
|
||||
- containedctx
|
||||
- cyclop
|
||||
- depguard
|
||||
- dupl
|
||||
- dupword
|
||||
- err113
|
||||
- exhaustruct
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
picoclawconfig "github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func (s *appState) channelMenu() tview.Primitive {
|
||||
items := []MenuItem{
|
||||
func (s *appState) buildChannelMenuItems() []MenuItem {
|
||||
return []MenuItem{
|
||||
{Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }},
|
||||
channelItem(
|
||||
"Telegram",
|
||||
@@ -86,8 +86,10 @@ func (s *appState) channelMenu() tview.Primitive {
|
||||
func() { s.push("channel-wecomapp", s.wecomAppForm()) },
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
menu := NewMenu("Channels", items)
|
||||
func (s *appState) channelMenu() tview.Primitive {
|
||||
menu := NewMenu("Channels", s.buildChannelMenuItems())
|
||||
menu.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEsc {
|
||||
s.pop()
|
||||
@@ -103,199 +105,72 @@ func (s *appState) channelMenu() tview.Primitive {
|
||||
}
|
||||
|
||||
func refreshChannelMenuFromState(menu *Menu, s *appState) {
|
||||
items := []MenuItem{
|
||||
{Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }},
|
||||
channelItem(
|
||||
"Telegram",
|
||||
"Telegram bot settings",
|
||||
s.config.Channels.Telegram.Enabled,
|
||||
func() { s.push("channel-telegram", s.telegramForm()) },
|
||||
),
|
||||
channelItem(
|
||||
"Discord",
|
||||
"Discord bot settings",
|
||||
s.config.Channels.Discord.Enabled,
|
||||
func() { s.push("channel-discord", s.discordForm()) },
|
||||
),
|
||||
channelItem(
|
||||
"QQ",
|
||||
"QQ bot settings",
|
||||
s.config.Channels.QQ.Enabled,
|
||||
func() { s.push("channel-qq", s.qqForm()) },
|
||||
),
|
||||
channelItem(
|
||||
"MaixCam",
|
||||
"MaixCam gateway",
|
||||
s.config.Channels.MaixCam.Enabled,
|
||||
func() { s.push("channel-maixcam", s.maixcamForm()) },
|
||||
),
|
||||
channelItem(
|
||||
"WhatsApp",
|
||||
"WhatsApp bridge",
|
||||
s.config.Channels.WhatsApp.Enabled,
|
||||
func() { s.push("channel-whatsapp", s.whatsappForm()) },
|
||||
),
|
||||
channelItem(
|
||||
"Feishu",
|
||||
"Feishu bot settings",
|
||||
s.config.Channels.Feishu.Enabled,
|
||||
func() { s.push("channel-feishu", s.feishuForm()) },
|
||||
),
|
||||
channelItem(
|
||||
"DingTalk",
|
||||
"DingTalk bot settings",
|
||||
s.config.Channels.DingTalk.Enabled,
|
||||
func() { s.push("channel-dingtalk", s.dingtalkForm()) },
|
||||
),
|
||||
channelItem(
|
||||
"Slack",
|
||||
"Slack bot settings",
|
||||
s.config.Channels.Slack.Enabled,
|
||||
func() { s.push("channel-slack", s.slackForm()) },
|
||||
),
|
||||
channelItem(
|
||||
"LINE",
|
||||
"LINE bot settings",
|
||||
s.config.Channels.LINE.Enabled,
|
||||
func() { s.push("channel-line", s.lineForm()) },
|
||||
),
|
||||
channelItem(
|
||||
"OneBot",
|
||||
"OneBot settings",
|
||||
s.config.Channels.OneBot.Enabled,
|
||||
func() { s.push("channel-onebot", s.onebotForm()) },
|
||||
),
|
||||
channelItem(
|
||||
"WeCom",
|
||||
"WeCom bot settings",
|
||||
s.config.Channels.WeCom.Enabled,
|
||||
func() { s.push("channel-wecom", s.wecomForm()) },
|
||||
),
|
||||
channelItem(
|
||||
"WeCom App",
|
||||
"WeCom App settings",
|
||||
s.config.Channels.WeComApp.Enabled,
|
||||
func() { s.push("channel-wecomapp", s.wecomAppForm()) },
|
||||
),
|
||||
}
|
||||
menu.applyItems(items)
|
||||
menu.applyItems(s.buildChannelMenuItems())
|
||||
}
|
||||
|
||||
func (s *appState) telegramForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.Telegram
|
||||
form := baseChannelForm("Telegram", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("Telegram", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
|
||||
cfg.Token = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("Proxy", cfg.Proxy, 128, nil, func(text string) {
|
||||
cfg.Proxy = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
return wrapWithBack(form, s)
|
||||
}
|
||||
|
||||
func (s *appState) discordForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.Discord
|
||||
form := baseChannelForm("Discord", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("Discord", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
|
||||
cfg.Token = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddCheckbox("Mention Only", cfg.MentionOnly, func(checked bool) {
|
||||
cfg.MentionOnly = checked
|
||||
})
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
return wrapWithBack(form, s)
|
||||
}
|
||||
|
||||
func (s *appState) qqForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.QQ
|
||||
form := baseChannelForm("QQ", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("QQ", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) {
|
||||
cfg.AppID = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("App Secret", cfg.AppSecret, 128, nil, func(text string) {
|
||||
cfg.AppSecret = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
return wrapWithBack(form, s)
|
||||
}
|
||||
|
||||
func (s *appState) maixcamForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.MaixCam
|
||||
form := baseChannelForm("MaixCam", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("MaixCam", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("Host", cfg.Host, 64, nil, func(text string) {
|
||||
cfg.Host = strings.TrimSpace(text)
|
||||
})
|
||||
addIntField(form, "Port", cfg.Port, func(value int) { cfg.Port = value })
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
return wrapWithBack(form, s)
|
||||
}
|
||||
|
||||
func (s *appState) whatsappForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.WhatsApp
|
||||
form := baseChannelForm("WhatsApp", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("WhatsApp", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("Bridge URL", cfg.BridgeURL, 128, nil, func(text string) {
|
||||
cfg.BridgeURL = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
return wrapWithBack(form, s)
|
||||
}
|
||||
|
||||
func (s *appState) feishuForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.Feishu
|
||||
form := baseChannelForm("Feishu", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("Feishu", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) {
|
||||
cfg.AppID = strings.TrimSpace(text)
|
||||
})
|
||||
@@ -308,66 +183,39 @@ func (s *appState) feishuForm() tview.Primitive {
|
||||
form.AddInputField("Verification Token", cfg.VerificationToken, 128, nil, func(text string) {
|
||||
cfg.VerificationToken = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
return wrapWithBack(form, s)
|
||||
}
|
||||
|
||||
func (s *appState) dingtalkForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.DingTalk
|
||||
form := baseChannelForm("DingTalk", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("DingTalk", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("Client ID", cfg.ClientID, 64, nil, func(text string) {
|
||||
cfg.ClientID = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("Client Secret", cfg.ClientSecret, 128, nil, func(text string) {
|
||||
cfg.ClientSecret = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
return wrapWithBack(form, s)
|
||||
}
|
||||
|
||||
func (s *appState) slackForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.Slack
|
||||
form := baseChannelForm("Slack", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("Slack", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("Bot Token", cfg.BotToken, 128, nil, func(text string) {
|
||||
cfg.BotToken = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("App Token", cfg.AppToken, 128, nil, func(text string) {
|
||||
cfg.AppToken = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
return wrapWithBack(form, s)
|
||||
}
|
||||
|
||||
func (s *appState) lineForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.LINE
|
||||
form := baseChannelForm("LINE", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("LINE", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("Channel Secret", cfg.ChannelSecret, 128, nil, func(text string) {
|
||||
cfg.ChannelSecret = strings.TrimSpace(text)
|
||||
})
|
||||
@@ -381,22 +229,13 @@ func (s *appState) lineForm() tview.Primitive {
|
||||
form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
|
||||
cfg.WebhookPath = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
return wrapWithBack(form, s)
|
||||
}
|
||||
|
||||
func (s *appState) onebotForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.OneBot
|
||||
form := baseChannelForm("OneBot", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("OneBot", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("WS URL", cfg.WSUrl, 128, nil, func(text string) {
|
||||
cfg.WSUrl = strings.TrimSpace(text)
|
||||
})
|
||||
@@ -418,22 +257,13 @@ func (s *appState) onebotForm() tview.Primitive {
|
||||
cfg.GroupTriggerPrefix = splitCSV(text)
|
||||
},
|
||||
)
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
return wrapWithBack(form, s)
|
||||
}
|
||||
|
||||
func (s *appState) wecomForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.WeCom
|
||||
form := baseChannelForm("WeCom", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("WeCom", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
|
||||
cfg.Token = strings.TrimSpace(text)
|
||||
})
|
||||
@@ -450,9 +280,7 @@ func (s *appState) wecomForm() tview.Primitive {
|
||||
form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
|
||||
cfg.WebhookPath = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
addIntField(
|
||||
form,
|
||||
"Reply Timeout",
|
||||
@@ -464,14 +292,7 @@ func (s *appState) wecomForm() tview.Primitive {
|
||||
|
||||
func (s *appState) wecomAppForm() tview.Primitive {
|
||||
cfg := &s.config.Channels.WeComApp
|
||||
form := baseChannelForm("WeCom App", cfg.Enabled, func(v bool) {
|
||||
cfg.Enabled = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
})
|
||||
form := baseChannelForm("WeCom App", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
|
||||
form.AddInputField("Corp ID", cfg.CorpID, 64, nil, func(text string) {
|
||||
cfg.CorpID = strings.TrimSpace(text)
|
||||
})
|
||||
@@ -492,9 +313,7 @@ func (s *appState) wecomAppForm() tview.Primitive {
|
||||
form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
|
||||
cfg.WebhookPath = strings.TrimSpace(text)
|
||||
})
|
||||
form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
|
||||
cfg.AllowFrom = splitCSV(text)
|
||||
})
|
||||
addAllowFromField(form, &cfg.AllowFrom)
|
||||
addIntField(
|
||||
form,
|
||||
"Reply Timeout",
|
||||
@@ -504,6 +323,23 @@ func (s *appState) wecomAppForm() tview.Primitive {
|
||||
return wrapWithBack(form, s)
|
||||
}
|
||||
|
||||
func (s *appState) makeChannelOnEnabled(enabledPtr *bool) func(bool) {
|
||||
return func(v bool) {
|
||||
*enabledPtr = v
|
||||
s.dirty = true
|
||||
refreshMainMenuIfPresent(s)
|
||||
if menu, ok := s.menus["channel"]; ok {
|
||||
refreshChannelMenuFromState(menu, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addAllowFromField(form *tview.Form, allowFrom *picoclawconfig.FlexibleStringSlice) {
|
||||
form.AddInputField("Allow From", strings.Join(*allowFrom, ","), 128, nil, func(text string) {
|
||||
*allowFrom = splitCSV(text)
|
||||
})
|
||||
}
|
||||
|
||||
func baseChannelForm(title string, enabled bool, onEnabled func(bool)) *tview.Form {
|
||||
form := tview.NewForm()
|
||||
form.SetBorder(true).SetTitle(fmt.Sprintf("Channel: %s", title))
|
||||
|
||||
+58
-65
@@ -95,75 +95,68 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "step-3.5-flash",
|
||||
},
|
||||
tests := []struct {
|
||||
name string
|
||||
aliasName string
|
||||
modelName string
|
||||
apiBase string
|
||||
wantProvider string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "alias with provider prefix",
|
||||
aliasName: "step-3.5-flash",
|
||||
modelName: "openrouter/stepfun/step-3.5-flash:free",
|
||||
apiBase: "https://openrouter.ai/api/v1",
|
||||
wantProvider: "openrouter",
|
||||
wantModel: "stepfun/step-3.5-flash:free",
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "step-3.5-flash",
|
||||
Model: "openrouter/stepfun/step-3.5-flash:free",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
{
|
||||
name: "alias without provider prefix",
|
||||
aliasName: "glm-5",
|
||||
modelName: "glm-5",
|
||||
apiBase: "https://api.z.ai/api/coding/paas/v4",
|
||||
wantProvider: "openai",
|
||||
wantModel: "glm-5",
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != "openrouter" {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter")
|
||||
}
|
||||
if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "glm-5",
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "glm-5",
|
||||
Model: "glm-5",
|
||||
APIBase: "https://api.z.ai/api/coding/paas/v4",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != "openai" {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai")
|
||||
}
|
||||
if agent.Candidates[0].Model != "glm-5" {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5")
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: tt.aliasName,
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: tt.aliasName,
|
||||
Model: tt.modelName,
|
||||
APIBase: tt.apiBase,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != tt.wantProvider {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider)
|
||||
}
|
||||
if agent.Candidates[0].Model != tt.wantModel {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+25
-57
@@ -26,16 +26,15 @@ func (f *fakeChannel) IsAllowed(string) bool {
|
||||
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
|
||||
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
// Create temp workspace
|
||||
func newTestAgentLoop(
|
||||
t *testing.T,
|
||||
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
|
||||
t.Helper()
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
cfg = &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
@@ -45,74 +44,43 @@ func TestRecordLastChannel(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
msgBus = bus.NewMessageBus()
|
||||
provider = &mockProvider{}
|
||||
al = NewAgentLoop(cfg, msgBus, provider)
|
||||
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
// Test RecordLastChannel
|
||||
testChannel := "test-channel"
|
||||
err = al.RecordLastChannel(testChannel)
|
||||
if err != nil {
|
||||
if err := al.RecordLastChannel(testChannel); err != nil {
|
||||
t.Fatalf("RecordLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify channel was saved
|
||||
lastChannel := al.state.GetLastChannel()
|
||||
if lastChannel != testChannel {
|
||||
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
|
||||
if got := al.state.GetLastChannel(); got != testChannel {
|
||||
t.Errorf("Expected channel '%s', got '%s'", testChannel, got)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChannel() != testChannel {
|
||||
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
|
||||
if got := al2.state.GetLastChannel(); got != testChannel {
|
||||
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordLastChatID(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Test RecordLastChatID
|
||||
testChatID := "test-chat-id-123"
|
||||
err = al.RecordLastChatID(testChatID)
|
||||
if err != nil {
|
||||
if err := al.RecordLastChatID(testChatID); err != nil {
|
||||
t.Fatalf("RecordLastChatID failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify chat ID was saved
|
||||
lastChatID := al.state.GetLastChatID()
|
||||
if lastChatID != testChatID {
|
||||
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
|
||||
if got := al.state.GetLastChatID(); got != testChatID {
|
||||
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChatID() != testChatID {
|
||||
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
|
||||
if got := al2.state.GetLastChatID(); got != testChatID {
|
||||
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+52
-50
@@ -539,86 +539,88 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
logger.InfoC("channels", "Outbound dispatcher started")
|
||||
func dispatchLoop[M any](
|
||||
ctx context.Context,
|
||||
m *Manager,
|
||||
subscribe func(context.Context) (M, bool),
|
||||
getChannel func(M) string,
|
||||
enqueue func(context.Context, *channelWorker, M) bool,
|
||||
startMsg, stopMsg, unknownMsg, noWorkerMsg string,
|
||||
) {
|
||||
logger.InfoC("channels", startMsg)
|
||||
|
||||
for {
|
||||
msg, ok := m.bus.SubscribeOutbound(ctx)
|
||||
msg, ok := subscribe(ctx)
|
||||
if !ok {
|
||||
logger.InfoC("channels", "Outbound dispatcher stopped")
|
||||
logger.InfoC("channels", stopMsg)
|
||||
return
|
||||
}
|
||||
|
||||
channel := getChannel(msg)
|
||||
|
||||
// Silently skip internal channels
|
||||
if constants.IsInternalChannel(msg.Channel) {
|
||||
if constants.IsInternalChannel(channel) {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
_, exists := m.channels[channel]
|
||||
w, wExists := m.workers[channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
|
||||
continue
|
||||
}
|
||||
|
||||
if wExists && w != nil {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
case <-ctx.Done():
|
||||
if !enqueue(ctx, w, msg) {
|
||||
return
|
||||
}
|
||||
} else if exists {
|
||||
logger.WarnCF("channels", "Channel has no active worker, skipping message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.SubscribeOutbound,
|
||||
func(msg bus.OutboundMessage) string { return msg.Channel },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
},
|
||||
"Outbound dispatcher started",
|
||||
"Outbound dispatcher stopped",
|
||||
"Unknown channel for outbound message",
|
||||
"Channel has no active worker, skipping message",
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
|
||||
logger.InfoC("channels", "Outbound media dispatcher started")
|
||||
|
||||
for {
|
||||
msg, ok := m.bus.SubscribeOutboundMedia(ctx)
|
||||
if !ok {
|
||||
logger.InfoC("channels", "Outbound media dispatcher stopped")
|
||||
return
|
||||
}
|
||||
|
||||
// Silently skip internal channels
|
||||
if constants.IsInternalChannel(msg.Channel) {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
logger.WarnCF("channels", "Unknown channel for outbound media message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if wExists && w != nil {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.SubscribeOutboundMedia,
|
||||
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
|
||||
select {
|
||||
case w.mediaQueue <- msg:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return
|
||||
return false
|
||||
}
|
||||
} else if exists {
|
||||
logger.WarnCF("channels", "Channel has no active worker, skipping media message", map[string]any{
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
}
|
||||
}
|
||||
},
|
||||
"Outbound media dispatcher started",
|
||||
"Outbound media dispatcher stopped",
|
||||
"Unknown channel for outbound media message",
|
||||
"Channel has no active worker, skipping media message",
|
||||
)
|
||||
}
|
||||
|
||||
// runMediaWorker processes outbound media messages for a single channel.
|
||||
|
||||
+16
-60
@@ -342,18 +342,11 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
|
||||
return result.MediaID, nil
|
||||
}
|
||||
|
||||
// sendImageMessage sends an image message using a media_id.
|
||||
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
|
||||
// sendWeComMessage marshals payload and POSTs it to the WeCom message API.
|
||||
func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
|
||||
|
||||
msg := WeComImageMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "image",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Image.MediaID = mediaID
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
@@ -400,6 +393,17 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendImageMessage sends an image message using a media_id.
|
||||
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
|
||||
msg := WeComImageMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "image",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Image.MediaID = mediaID
|
||||
return c.sendWeComMessage(ctx, accessToken, msg)
|
||||
}
|
||||
|
||||
// WebhookPath returns the path for registering on the shared HTTP server.
|
||||
func (c *WeComAppChannel) WebhookPath() string {
|
||||
if c.config.WebhookPath != "" {
|
||||
@@ -722,63 +726,15 @@ func (c *WeComAppChannel) getAccessToken() string {
|
||||
return c.accessToken
|
||||
}
|
||||
|
||||
// sendTextMessage sends a text message to a user
|
||||
// sendTextMessage sends a text message to a user.
|
||||
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
|
||||
|
||||
msg := WeComTextMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "text",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Text.Content = content
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Use configurable timeout (default 5 seconds)
|
||||
timeout := c.config.ReplyTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body)))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var sendResp WeComSendMessageResponse
|
||||
if err := json.Unmarshal(body, &sendResp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if sendResp.ErrCode != 0 {
|
||||
return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.sendWeComMessage(ctx, accessToken, msg)
|
||||
}
|
||||
|
||||
// handleHealth handles health check requests
|
||||
|
||||
@@ -323,60 +323,6 @@ func TestWeComAppDecryptMessage(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComAppPKCS7Unpad(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte{},
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "valid padding 3 bytes",
|
||||
input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
|
||||
expected: []byte("hello"),
|
||||
},
|
||||
{
|
||||
name: "valid padding 16 bytes (full block)",
|
||||
input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
|
||||
expected: []byte("123456789012345"),
|
||||
},
|
||||
{
|
||||
name: "invalid padding larger than data",
|
||||
input: []byte{20},
|
||||
expected: nil, // should return error
|
||||
},
|
||||
{
|
||||
name: "invalid padding zero",
|
||||
input: append([]byte("test"), byte(0)),
|
||||
expected: nil, // should return error
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := pkcs7Unpad(tt.input)
|
||||
if tt.expected == nil {
|
||||
// This case should return an error
|
||||
if err == nil {
|
||||
t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("pkcs7Unpad() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(result, tt.expected) {
|
||||
t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComAppHandleVerification(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
aesKey := generateTestAESKeyApp()
|
||||
|
||||
@@ -412,22 +412,9 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("valid direct message callback", func(t *testing.T) {
|
||||
// Create JSON message for direct chat (single)
|
||||
jsonMsg := `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`
|
||||
|
||||
// Encrypt message
|
||||
runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
|
||||
|
||||
// Create encrypted XML wrapper
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
@@ -435,20 +422,29 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
Encrypt: encrypted,
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encrypted)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
t.Run("valid direct message callback", func(t *testing.T) {
|
||||
w := runBotMessageCallback(t, `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
@@ -458,8 +454,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("valid group message callback", func(t *testing.T) {
|
||||
// Create JSON message for group chat
|
||||
jsonMsg := `{
|
||||
w := runBotMessageCallback(t, `{
|
||||
"msgid": "test_msg_id_456",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chatid": "group_chat_id_123",
|
||||
@@ -468,33 +463,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello Group"}
|
||||
}`
|
||||
|
||||
// Encrypt message
|
||||
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
|
||||
|
||||
// Create encrypted XML wrapper
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: encrypted,
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encrypted)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
}`)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
+1
-19
@@ -742,25 +742,7 @@ func (c *Config) findMatches(modelName string) []ModelConfig {
|
||||
|
||||
// HasProvidersConfig checks if any provider in the old providers config has configuration.
|
||||
func (c *Config) HasProvidersConfig() bool {
|
||||
v := c.Providers
|
||||
return v.Anthropic.APIKey != "" || v.Anthropic.APIBase != "" ||
|
||||
v.OpenAI.APIKey != "" || v.OpenAI.APIBase != "" ||
|
||||
v.OpenRouter.APIKey != "" || v.OpenRouter.APIBase != "" ||
|
||||
v.Groq.APIKey != "" || v.Groq.APIBase != "" ||
|
||||
v.Zhipu.APIKey != "" || v.Zhipu.APIBase != "" ||
|
||||
v.VLLM.APIKey != "" || v.VLLM.APIBase != "" ||
|
||||
v.Gemini.APIKey != "" || v.Gemini.APIBase != "" ||
|
||||
v.Nvidia.APIKey != "" || v.Nvidia.APIBase != "" ||
|
||||
v.Ollama.APIKey != "" || v.Ollama.APIBase != "" ||
|
||||
v.Moonshot.APIKey != "" || v.Moonshot.APIBase != "" ||
|
||||
v.ShengSuanYun.APIKey != "" || v.ShengSuanYun.APIBase != "" ||
|
||||
v.DeepSeek.APIKey != "" || v.DeepSeek.APIBase != "" ||
|
||||
v.Cerebras.APIKey != "" || v.Cerebras.APIBase != "" ||
|
||||
v.VolcEngine.APIKey != "" || v.VolcEngine.APIBase != "" ||
|
||||
v.GitHubCopilot.APIKey != "" || v.GitHubCopilot.APIBase != "" ||
|
||||
v.Antigravity.APIKey != "" || v.Antigravity.APIBase != "" ||
|
||||
v.Qwen.APIKey != "" || v.Qwen.APIBase != "" ||
|
||||
v.Mistral.APIKey != "" || v.Mistral.APIBase != ""
|
||||
return !c.Providers.IsEmpty()
|
||||
}
|
||||
|
||||
// ValidateModelList validates all ModelConfig entries in the model_list.
|
||||
|
||||
@@ -47,79 +47,63 @@ func TestExecuteHeartbeat_Async(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteHeartbeat_Error(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
ForLLM: "Heartbeat failed: connection error",
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
}
|
||||
})
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
|
||||
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Check log file for error message
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
func TestExecuteHeartbeat_ResultLogging(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *tools.ToolResult
|
||||
wantLog string
|
||||
}{
|
||||
{
|
||||
name: "error result",
|
||||
result: &tools.ToolResult{
|
||||
ForLLM: "Heartbeat failed: connection error",
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
},
|
||||
wantLog: "error message",
|
||||
},
|
||||
{
|
||||
name: "silent result",
|
||||
result: &tools.ToolResult{
|
||||
ForLLM: "Heartbeat completed successfully",
|
||||
ForUser: "",
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
},
|
||||
wantLog: "completion message",
|
||||
},
|
||||
}
|
||||
|
||||
logContent := string(data)
|
||||
if logContent == "" {
|
||||
t.Error("Expected log file to contain error message")
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
func TestExecuteHeartbeat_Silent(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return tt.result
|
||||
})
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
ForLLM: "Heartbeat completed successfully",
|
||||
ForUser: "",
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
})
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
|
||||
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Check log file for completion message
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(data)
|
||||
if logContent == "" {
|
||||
t.Error("Expected log file to contain completion message")
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
}
|
||||
if string(data) == "" {
|
||||
t.Errorf("Expected log file to contain %s", tt.wantLog)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -118,64 +118,55 @@ func TestPlanWorkspaceMigration(t *testing.T) {
|
||||
assert.GreaterOrEqual(t, len(actions), 1)
|
||||
}
|
||||
|
||||
func TestPlanWorkspaceMigrationWithExistingDestination(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
func TestPlanWorkspaceMigrationExistingFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
force bool
|
||||
wantActionType ActionType
|
||||
}{
|
||||
{
|
||||
name: "backup when not forced",
|
||||
force: false,
|
||||
wantActionType: ActionBackup,
|
||||
},
|
||||
{
|
||||
name: "copy when forced",
|
||||
force: true,
|
||||
wantActionType: ActionCopy,
|
||||
},
|
||||
}
|
||||
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
|
||||
err = os.MkdirAll(dstWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
|
||||
require.NoError(t, err)
|
||||
err = os.MkdirAll(dstWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, len(actions), 1)
|
||||
assert.Equal(t, ActionBackup, actions[0].Type)
|
||||
}
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
tt.force,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
func TestPlanWorkspaceMigrationForce(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(dstWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, len(actions), 1)
|
||||
assert.Equal(t, ActionCopy, actions[0].Type)
|
||||
require.GreaterOrEqual(t, len(actions), 1)
|
||||
assert.Equal(t, tt.wantActionType, actions[0].Type)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanWorkspaceMigrationNonExistentSource(t *testing.T) {
|
||||
|
||||
@@ -108,34 +108,7 @@ func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDe
|
||||
|
||||
// buildToolsPrompt creates the tool definitions section for the system prompt.
|
||||
func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(
|
||||
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
|
||||
)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
return buildCLIToolsPrompt(tools)
|
||||
}
|
||||
|
||||
// parseClaudeCliResponse parses the JSON output from the claude CLI.
|
||||
|
||||
@@ -130,34 +130,7 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
|
||||
|
||||
// buildToolsPrompt creates a tool definitions section for the prompt.
|
||||
func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(
|
||||
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
|
||||
)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
return buildCLIToolsPrompt(tools)
|
||||
}
|
||||
|
||||
// codexEvent represents a single JSONL event from `codex exec --json`.
|
||||
|
||||
@@ -5,7 +5,43 @@
|
||||
|
||||
package providers
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// buildCLIToolsPrompt creates the tool definitions section for a CLI provider system prompt.
|
||||
func buildCLIToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(
|
||||
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
|
||||
)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated.
|
||||
// It handles cases where Name/Arguments might be in different locations (top-level vs Function)
|
||||
|
||||
Reference in New Issue
Block a user