mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
merge: sync upstream/main into feat/multi-agent-routing
Resolve conflicts: - pkg/agent/loop.go: integrate context compression, command handling, utf8 token estimation, and summarization notification into multi-agent routing architecture - pkg/config/config_test.go: merge imports from both branches - pkg/agent/loop_test.go: update test to use registry-based sessions
This commit is contained in:
+242
-30
@@ -14,8 +14,10 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -27,13 +29,14 @@ import (
|
||||
)
|
||||
|
||||
type AgentLoop struct {
|
||||
bus *bus.MessageBus
|
||||
cfg *config.Config
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
running atomic.Bool
|
||||
summarizing sync.Map
|
||||
fallback *providers.FallbackChain
|
||||
bus *bus.MessageBus
|
||||
cfg *config.Config
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
running atomic.Bool
|
||||
summarizing sync.Map
|
||||
fallback *providers.FallbackChain
|
||||
channelManager *channels.Manager
|
||||
}
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
@@ -183,6 +186,10 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
|
||||
al.channelManager = cm
|
||||
}
|
||||
|
||||
// RecordLastChannel records the last active channel for this workspace.
|
||||
// This uses the atomic state save mechanism to prevent data loss on crash.
|
||||
func (al *AgentLoop) RecordLastChannel(channel string) error {
|
||||
@@ -254,6 +261,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
return al.processSystemMessage(ctx, msg)
|
||||
}
|
||||
|
||||
// Check for commands
|
||||
if response, handled := al.handleCommand(ctx, msg); handled {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// Route to determine agent and session key
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
@@ -404,7 +416,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
|
||||
|
||||
// 7. Optional: summarization
|
||||
if opts.EnableSummary {
|
||||
al.maybeSummarize(agent, opts.SessionKey)
|
||||
al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID)
|
||||
}
|
||||
|
||||
// 8. Optional: send response via bus
|
||||
@@ -472,32 +484,72 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
err = fbErr
|
||||
} else {
|
||||
response = fbResult.Response
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
return nil, fbErr
|
||||
}
|
||||
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
|
||||
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]interface{}{"agent_id": agent.ID, "iteration": iteration})
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
} else {
|
||||
response, err = agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
}
|
||||
|
||||
// Retry loop for context/token errors
|
||||
maxRetries := 2
|
||||
for retry := 0; retry <= maxRetries; retry++ {
|
||||
response, err = callLLM()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
isContextError := strings.Contains(errMsg, "token") ||
|
||||
strings.Contains(errMsg, "context") ||
|
||||
strings.Contains(errMsg, "invalidparameter") ||
|
||||
strings.Contains(errMsg, "length")
|
||||
|
||||
if isContextError && retry < maxRetries {
|
||||
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
})
|
||||
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: "Context window exceeded. Compressing history and retrying...",
|
||||
})
|
||||
}
|
||||
|
||||
al.forceCompression(agent, opts.SessionKey)
|
||||
newHistory := agent.Sessions.GetHistory(opts.SessionKey)
|
||||
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
|
||||
messages = agent.ContextBuilder.BuildMessages(
|
||||
newHistory, newSummary, "",
|
||||
nil, opts.Channel, opts.ChatID,
|
||||
)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "LLM call failed",
|
||||
map[string]interface{}{
|
||||
@@ -505,7 +557,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
|
||||
"iteration": iteration,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return "", iteration, fmt.Errorf("LLM call failed: %w", err)
|
||||
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
// Check if no tool calls - we're done
|
||||
@@ -639,7 +691,7 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string) {
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := al.estimateTokens(newHistory)
|
||||
threshold := agent.ContextWindow * 75 / 100
|
||||
@@ -649,12 +701,79 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string) {
|
||||
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
|
||||
go func() {
|
||||
defer al.summarizing.Delete(summarizeKey)
|
||||
if !constants.IsInternalChannel(channel) {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: "Memory threshold reached. Optimizing conversation history...",
|
||||
})
|
||||
}
|
||||
al.summarizeSession(agent, sessionKey)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forceCompression aggressively reduces context when the limit is hit.
|
||||
// It drops the oldest 50% of messages (keeping system prompt and last user message).
|
||||
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 4 {
|
||||
return
|
||||
}
|
||||
|
||||
// Keep system prompt (usually [0]) and the very last message (user's trigger)
|
||||
// We want to drop the oldest half of the *conversation*
|
||||
// Assuming [0] is system, [1:] is conversation
|
||||
conversation := history[1 : len(history)-1]
|
||||
if len(conversation) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Helper to find the mid-point of the conversation
|
||||
mid := len(conversation) / 2
|
||||
|
||||
// New history structure:
|
||||
// 1. System Prompt
|
||||
// 2. [Summary of dropped part] - synthesized
|
||||
// 3. Second half of conversation
|
||||
// 4. Last message
|
||||
|
||||
// Simplified approach for emergency: Drop first half of conversation
|
||||
// and rely on existing summary if present, or create a placeholder.
|
||||
|
||||
droppedCount := mid
|
||||
keptConversation := conversation[mid:]
|
||||
|
||||
newHistory := make([]providers.Message, 0)
|
||||
newHistory = append(newHistory, history[0]) // System prompt
|
||||
|
||||
// Add a note about compression
|
||||
compressionNote := fmt.Sprintf("[System: Emergency compression dropped %d oldest messages due to context limit]", droppedCount)
|
||||
// If there was an existing summary, we might lose it if it was in the dropped part (which is just messages).
|
||||
// The summary is stored separately in session.Summary, so it persists!
|
||||
// We just need to ensure the user knows there's a gap.
|
||||
|
||||
// We only modify the messages list here
|
||||
newHistory = append(newHistory, providers.Message{
|
||||
Role: "system",
|
||||
Content: compressionNote,
|
||||
})
|
||||
|
||||
newHistory = append(newHistory, keptConversation...)
|
||||
newHistory = append(newHistory, history[len(history)-1]) // Last message
|
||||
|
||||
// Update session
|
||||
agent.Sessions.SetHistory(sessionKey, newHistory)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
|
||||
logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{
|
||||
"session_key": sessionKey,
|
||||
"dropped_msgs": droppedCount,
|
||||
"new_count": len(newHistory),
|
||||
})
|
||||
}
|
||||
|
||||
// GetStartupInfo returns information about loaded tools and skills for logging.
|
||||
func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
|
||||
info := make(map[string]interface{})
|
||||
@@ -693,7 +812,7 @@ func formatMessagesForLog(messages []providers.Message) string {
|
||||
result += "[\n"
|
||||
for i, msg := range messages {
|
||||
result += fmt.Sprintf(" [%d] Role: %s\n", i, msg.Role)
|
||||
if msg.ToolCalls != nil && len(msg.ToolCalls) > 0 {
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
result += " ToolCalls:\n"
|
||||
for _, tc := range msg.ToolCalls {
|
||||
result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
|
||||
@@ -758,7 +877,7 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
|
||||
if m.Role != "user" && m.Role != "assistant" {
|
||||
continue
|
||||
}
|
||||
msgTokens := len(m.Content) / 4
|
||||
msgTokens := len(m.Content) / 2
|
||||
if msgTokens > maxMessageTokens {
|
||||
omitted = true
|
||||
continue
|
||||
@@ -827,12 +946,105 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, agent *AgentInstance, b
|
||||
}
|
||||
|
||||
// estimateTokens estimates the number of tokens in a message list.
|
||||
// Uses a safe heuristic of 2.5 characters per token to account for CJK and other
|
||||
// overheads better than the previous 3 chars/token.
|
||||
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
|
||||
total := 0
|
||||
totalChars := 0
|
||||
for _, m := range messages {
|
||||
total += len(m.Content) / 4
|
||||
totalChars += utf8.RuneCountInString(m.Content)
|
||||
}
|
||||
return total
|
||||
// 2.5 chars per token = totalChars * 2 / 5
|
||||
return totalChars * 2 / 5
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) {
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
if !strings.HasPrefix(content, "/") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
parts := strings.Fields(content)
|
||||
if len(parts) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
cmd := parts[0]
|
||||
args := parts[1:]
|
||||
|
||||
switch cmd {
|
||||
case "/show":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /show [model|channel|agents]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "model":
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
return "No default agent configured", true
|
||||
}
|
||||
return fmt.Sprintf("Current model: %s", defaultAgent.Model), true
|
||||
case "channel":
|
||||
return fmt.Sprintf("Current channel: %s", msg.Channel), true
|
||||
case "agents":
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown show target: %s", args[0]), true
|
||||
}
|
||||
|
||||
case "/list":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /list [models|channels|agents]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "models":
|
||||
return "Available models: configured in config.json per agent", true
|
||||
case "channels":
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
}
|
||||
channels := al.channelManager.GetEnabledChannels()
|
||||
if len(channels) == 0 {
|
||||
return "No channels enabled", true
|
||||
}
|
||||
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
|
||||
case "agents":
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown list target: %s", args[0]), true
|
||||
}
|
||||
|
||||
case "/switch":
|
||||
if len(args) < 3 || args[1] != "to" {
|
||||
return "Usage: /switch [model|channel] to <name>", true
|
||||
}
|
||||
target := args[0]
|
||||
value := args[2]
|
||||
|
||||
switch target {
|
||||
case "model":
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
return "No default agent configured", true
|
||||
}
|
||||
oldModel := defaultAgent.Model
|
||||
defaultAgent.Model = value
|
||||
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
|
||||
case "channel":
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
}
|
||||
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
|
||||
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
|
||||
}
|
||||
return fmt.Sprintf("Switched target channel to %s", value), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown switch target: %s", target), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// extractPeer extracts the routing peer from inbound message metadata.
|
||||
|
||||
@@ -2,6 +2,7 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -527,3 +528,103 @@ func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
|
||||
t.Errorf("Expected 'Command output: hello world', got: %s", response)
|
||||
}
|
||||
}
|
||||
|
||||
// failFirstMockProvider fails on the first N calls with a specific error
|
||||
type failFirstMockProvider struct {
|
||||
failures int
|
||||
currentCall int
|
||||
failError error
|
||||
successResp string
|
||||
}
|
||||
|
||||
func (m *failFirstMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
m.currentCall++
|
||||
if m.currentCall <= m.failures {
|
||||
return nil, m.failError
|
||||
}
|
||||
return &providers.LLMResponse{
|
||||
Content: m.successResp,
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *failFirstMockProvider) GetDefaultModel() string {
|
||||
return "mock-fail-model"
|
||||
}
|
||||
|
||||
// TestAgentLoop_ContextExhaustionRetry verify that the agent retries on context errors
|
||||
func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
// Create a provider that fails once with a context error
|
||||
contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens")
|
||||
provider := &failFirstMockProvider{
|
||||
failures: 1,
|
||||
failError: contextErr,
|
||||
successResp: "Recovered from context error",
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Inject some history to simulate a full context
|
||||
sessionKey := "test-session-context"
|
||||
// Create dummy history
|
||||
history := []providers.Message{
|
||||
{Role: "system", Content: "System prompt"},
|
||||
{Role: "user", Content: "Old message 1"},
|
||||
{Role: "assistant", Content: "Old response 1"},
|
||||
{Role: "user", Content: "Old message 2"},
|
||||
{Role: "assistant", Content: "Old response 2"},
|
||||
{Role: "user", Content: "Trigger message"},
|
||||
}
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory(sessionKey, history)
|
||||
|
||||
// Call ProcessDirectWithChannel
|
||||
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "Trigger message", sessionKey, "test", "test-chat")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected success after retry, got error: %v", err)
|
||||
}
|
||||
|
||||
if response != "Recovered from context error" {
|
||||
t.Errorf("Expected 'Recovered from context error', got '%s'", response)
|
||||
}
|
||||
|
||||
// We expect 2 calls: 1st failed, 2nd succeeded
|
||||
if provider.currentCall != 2 {
|
||||
t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall)
|
||||
}
|
||||
|
||||
// Check final history length
|
||||
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
// We verify that the history has been modified (compressed)
|
||||
// Original length: 6
|
||||
// Expected behavior: compression drops ~50% of history (mid slice)
|
||||
// We can assert that the length is NOT what it would be without compression.
|
||||
// Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8
|
||||
if len(finalHistory) >= 8 {
|
||||
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
|
||||
}
|
||||
}
|
||||
|
||||
+82
-30
@@ -19,18 +19,20 @@ import (
|
||||
)
|
||||
|
||||
type OAuthProviderConfig struct {
|
||||
Issuer string
|
||||
ClientID string
|
||||
Scopes string
|
||||
Port int
|
||||
Issuer string
|
||||
ClientID string
|
||||
Scopes string
|
||||
Originator string
|
||||
Port int
|
||||
}
|
||||
|
||||
func OpenAIOAuthConfig() OAuthProviderConfig {
|
||||
return OAuthProviderConfig{
|
||||
Issuer: "https://auth.openai.com",
|
||||
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
Scopes: "openid profile email offline_access",
|
||||
Port: 1455,
|
||||
Issuer: "https://auth.openai.com",
|
||||
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
Scopes: "openid profile email offline_access",
|
||||
Originator: "codex_cli_rs",
|
||||
Port: 1455,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,7 +281,17 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
return parseTokenResponse(body, cred.Provider)
|
||||
refreshed, err := parseTokenResponse(body, cred.Provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refreshed.RefreshToken == "" {
|
||||
refreshed.RefreshToken = cred.RefreshToken
|
||||
}
|
||||
if refreshed.AccountID == "" {
|
||||
refreshed.AccountID = cred.AccountID
|
||||
}
|
||||
return refreshed, nil
|
||||
}
|
||||
|
||||
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||
@@ -288,15 +300,23 @@ func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
|
||||
|
||||
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||
params := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {cfg.ClientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {cfg.Scopes},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
"response_type": {"code"},
|
||||
"client_id": {cfg.ClientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {cfg.Scopes},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"id_token_add_organizations": {"true"},
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
"state": {state},
|
||||
}
|
||||
return cfg.Issuer + "/authorize?" + params.Encode()
|
||||
if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") {
|
||||
params.Set("originator", "picoclaw")
|
||||
}
|
||||
if cfg.Originator != "" {
|
||||
params.Set("originator", cfg.Originator)
|
||||
}
|
||||
return cfg.Issuer + "/oauth/authorize?" + params.Encode()
|
||||
}
|
||||
|
||||
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
|
||||
@@ -350,19 +370,57 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||
if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
} else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
|
||||
cred.AccountID = accountID
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func extractAccountID(accessToken string) string {
|
||||
parts := strings.Split(accessToken, ".")
|
||||
if len(parts) < 2 {
|
||||
func extractAccountID(token string) string {
|
||||
claims, err := parseJWTClaims(token)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
|
||||
if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
|
||||
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
|
||||
if orgs, ok := claims["organizations"].([]interface{}); ok {
|
||||
for _, org := range orgs {
|
||||
if orgMap, ok := org.(map[string]interface{}); ok {
|
||||
if accountID, ok := orgMap["id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseJWTClaims(token string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("token is not a JWT")
|
||||
}
|
||||
|
||||
payload := parts[1]
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
@@ -373,21 +431,15 @@ func extractAccountID(accessToken string) string {
|
||||
|
||||
decoded, err := base64URLDecode(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func base64URLDecode(s string) ([]byte, error) {
|
||||
|
||||
+139
-5
@@ -1,19 +1,34 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string {
|
||||
t.Helper()
|
||||
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
|
||||
payloadJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal claims: %v", err)
|
||||
}
|
||||
payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
return header + "." + payload + ".sig"
|
||||
}
|
||||
|
||||
func TestBuildAuthorizeURL(t *testing.T) {
|
||||
cfg := OAuthProviderConfig{
|
||||
Issuer: "https://auth.example.com",
|
||||
ClientID: "test-client-id",
|
||||
Scopes: "openid profile",
|
||||
Port: 1455,
|
||||
Issuer: "https://auth.example.com",
|
||||
ClientID: "test-client-id",
|
||||
Scopes: "openid profile",
|
||||
Originator: "codex_cli_rs",
|
||||
Port: 1455,
|
||||
}
|
||||
pkce := PKCECodes{
|
||||
CodeVerifier: "test-verifier",
|
||||
@@ -22,7 +37,7 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
||||
|
||||
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
|
||||
|
||||
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
|
||||
if !strings.HasPrefix(u, "https://auth.example.com/oauth/authorize?") {
|
||||
t.Errorf("URL does not start with expected prefix: %s", u)
|
||||
}
|
||||
if !strings.Contains(u, "client_id=test-client-id") {
|
||||
@@ -40,6 +55,37 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
||||
if !strings.Contains(u, "response_type=code") {
|
||||
t.Error("URL missing response_type")
|
||||
}
|
||||
if !strings.Contains(u, "id_token_add_organizations=true") {
|
||||
t.Error("URL missing id_token_add_organizations")
|
||||
}
|
||||
if !strings.Contains(u, "codex_cli_simplified_flow=true") {
|
||||
t.Error("URL missing codex_cli_simplified_flow")
|
||||
}
|
||||
if !strings.Contains(u, "originator=codex_cli_rs") {
|
||||
t.Error("URL missing originator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) {
|
||||
cfg := OpenAIOAuthConfig()
|
||||
pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"}
|
||||
|
||||
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
|
||||
if q.Get("id_token_add_organizations") != "true" {
|
||||
t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations"))
|
||||
}
|
||||
if q.Get("codex_cli_simplified_flow") != "true" {
|
||||
t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow"))
|
||||
}
|
||||
if q.Get("originator") != "codex_cli_rs" {
|
||||
t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponse(t *testing.T) {
|
||||
@@ -73,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) {
|
||||
idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"})
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "opaque-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"id_token": idToken,
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
|
||||
cred, err := parseTokenResponse(body, "openai")
|
||||
if err != nil {
|
||||
t.Fatalf("parseTokenResponse() error: %v", err)
|
||||
}
|
||||
if cred.AccountID != "acc-id-from-id-token" {
|
||||
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) {
|
||||
token := makeJWTForClaims(t, map[string]interface{}{
|
||||
"organizations": []interface{}{
|
||||
map[string]interface{}{"id": "org_from_orgs"},
|
||||
},
|
||||
})
|
||||
|
||||
if got := extractAccountID(token); got != "org_from_orgs" {
|
||||
t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
||||
body := []byte(`{"refresh_token": "test"}`)
|
||||
_, err := parseTokenResponse(body, "openai")
|
||||
@@ -81,6 +158,32 @@ func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseAccountIDFromIDToken(t *testing.T) {
|
||||
idToken := makeJWTWithAccountID("acc-from-id")
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "not-a-jwt",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"id_token": idToken,
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
|
||||
cred, err := parseTokenResponse(body, "openai")
|
||||
if err != nil {
|
||||
t.Fatalf("parseTokenResponse() error: %v", err)
|
||||
}
|
||||
|
||||
if cred.AccountID != "acc-from-id" {
|
||||
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-from-id")
|
||||
}
|
||||
}
|
||||
|
||||
func makeJWTWithAccountID(accountID string) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
|
||||
payload := base64.RawURLEncoding.EncodeToString([]byte(`{"https://api.openai.com/auth":{"chatgpt_account_id":"` + accountID + `"}}`))
|
||||
return header + "." + payload + ".sig"
|
||||
}
|
||||
|
||||
func TestExchangeCodeForTokens(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/oauth/token" {
|
||||
@@ -185,6 +288,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "new-access-token-only",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"}
|
||||
cred := &AuthCredential{
|
||||
AccessToken: "old-access",
|
||||
RefreshToken: "existing-refresh",
|
||||
AccountID: "acc_existing",
|
||||
Provider: "openai",
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
refreshed, err := RefreshAccessToken(cred, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshAccessToken() error: %v", err)
|
||||
}
|
||||
if refreshed.RefreshToken != "existing-refresh" {
|
||||
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh")
|
||||
}
|
||||
if refreshed.AccountID != "acc_existing" {
|
||||
t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthConfig(t *testing.T) {
|
||||
cfg := OpenAIOAuthConfig()
|
||||
if cfg.Issuer != "https://auth.openai.com" {
|
||||
|
||||
@@ -9,6 +9,7 @@ type MessageBus struct {
|
||||
inbound chan InboundMessage
|
||||
outbound chan OutboundMessage
|
||||
handlers map[string]MessageHandler
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -21,6 +22,11 @@ func NewMessageBus() *MessageBus {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishInbound(msg InboundMessage) {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
if mb.closed {
|
||||
return
|
||||
}
|
||||
mb.inbound <- msg
|
||||
}
|
||||
|
||||
@@ -34,6 +40,11 @@ func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool)
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutbound(msg OutboundMessage) {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
if mb.closed {
|
||||
return
|
||||
}
|
||||
mb.outbound <- msg
|
||||
}
|
||||
|
||||
@@ -60,6 +71,12 @@ func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) Close() {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
if mb.closed {
|
||||
return
|
||||
}
|
||||
mb.closed = true
|
||||
close(mb.inbound)
|
||||
close(mb.outbound)
|
||||
}
|
||||
|
||||
+150
-2
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
@@ -100,15 +101,156 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
return fmt.Errorf("channel ID is empty")
|
||||
}
|
||||
|
||||
message := msg.Content
|
||||
runes := []rune(msg.Content)
|
||||
if len(runes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks
|
||||
|
||||
for _, chunk := range chunks {
|
||||
if err := c.sendChunk(ctx, channelID, chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitMessage splits long messages into chunks, preserving code block integrity
|
||||
// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks
|
||||
func splitMessage(content string, limit int) []string {
|
||||
var messages []string
|
||||
|
||||
for len(content) > 0 {
|
||||
if len(content) <= limit {
|
||||
messages = append(messages, content)
|
||||
break
|
||||
}
|
||||
|
||||
msgEnd := limit
|
||||
|
||||
// Find natural split point within the limit
|
||||
msgEnd = findLastNewline(content[:limit], 200)
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = findLastSpace(content[:limit], 100)
|
||||
}
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = limit
|
||||
}
|
||||
|
||||
// Check if this would end with an incomplete code block
|
||||
candidate := content[:msgEnd]
|
||||
unclosedIdx := findLastUnclosedCodeBlock(candidate)
|
||||
|
||||
if unclosedIdx >= 0 {
|
||||
// Message would end with incomplete code block
|
||||
// Try to extend to include the closing ``` (with some buffer)
|
||||
extendedLimit := limit + 500 // Allow 500 char buffer for code blocks
|
||||
if len(content) > extendedLimit {
|
||||
closingIdx := findNextClosingCodeBlock(content, msgEnd)
|
||||
if closingIdx > 0 && closingIdx <= extendedLimit {
|
||||
// Extend to include the closing ```
|
||||
msgEnd = closingIdx
|
||||
} else {
|
||||
// Can't find closing, split before the code block
|
||||
msgEnd = findLastNewline(content[:unclosedIdx], 200)
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = findLastSpace(content[:unclosedIdx], 100)
|
||||
}
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = unclosedIdx
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Remaining content fits within extended limit
|
||||
msgEnd = len(content)
|
||||
}
|
||||
}
|
||||
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = limit
|
||||
}
|
||||
|
||||
messages = append(messages, content[:msgEnd])
|
||||
content = strings.TrimSpace(content[msgEnd:])
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
|
||||
// Returns the position of the opening ``` or -1 if all code blocks are complete
|
||||
func findLastUnclosedCodeBlock(text string) int {
|
||||
count := 0
|
||||
lastOpenIdx := -1
|
||||
|
||||
for i := 0; i < len(text); i++ {
|
||||
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
|
||||
if count == 0 {
|
||||
lastOpenIdx = i
|
||||
}
|
||||
count++
|
||||
i += 2
|
||||
}
|
||||
}
|
||||
|
||||
// If odd number of ``` markers, last one is unclosed
|
||||
if count%2 == 1 {
|
||||
return lastOpenIdx
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findNextClosingCodeBlock finds the next closing ``` starting from a position
|
||||
// Returns the position after the closing ``` or -1 if not found
|
||||
func findNextClosingCodeBlock(text string, startIdx int) int {
|
||||
for i := startIdx; i < len(text); i++ {
|
||||
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
|
||||
return i + 3
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findLastNewline finds the last newline character within the last N characters
|
||||
// Returns the position of the newline or -1 if not found
|
||||
func findLastNewline(s string, searchWindow int) int {
|
||||
searchStart := len(s) - searchWindow
|
||||
if searchStart < 0 {
|
||||
searchStart = 0
|
||||
}
|
||||
for i := len(s) - 1; i >= searchStart; i-- {
|
||||
if s[i] == '\n' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findLastSpace finds the last space character within the last N characters
|
||||
// Returns the position of the space or -1 if not found
|
||||
func findLastSpace(s string, searchWindow int) int {
|
||||
searchStart := len(s) - searchWindow
|
||||
if searchStart < 0 {
|
||||
searchStart = 0
|
||||
}
|
||||
for i := len(s) - 1; i >= searchStart; i-- {
|
||||
if s[i] == ' ' || s[i] == '\t' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
|
||||
// 使用传入的 ctx 进行超时控制
|
||||
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := c.session.ChannelMessageSend(channelID, message)
|
||||
_, err := c.session.ChannelMessageSend(channelID, content)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
@@ -140,6 +282,12 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.session.ChannelTyping(m.ChannelID); err != nil {
|
||||
logger.ErrorCF("discord", "Failed to send typing indicator", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// 检查白名单,避免为被拒绝的用户下载附件和转录
|
||||
if !c.IsAllowed(m.Author.ID) {
|
||||
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
|
||||
|
||||
+14
-1
@@ -48,7 +48,7 @@ func (m *Manager) initChannels() error {
|
||||
|
||||
if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" {
|
||||
logger.DebugC("channels", "Attempting to initialize Telegram channel")
|
||||
telegram, err := NewTelegramChannel(m.config.Channels.Telegram, m.bus)
|
||||
telegram, err := NewTelegramChannel(m.config, m.bus)
|
||||
if err != nil {
|
||||
logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
@@ -163,6 +163,19 @@ func (m *Manager) initChannels() error {
|
||||
}
|
||||
}
|
||||
|
||||
if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" {
|
||||
logger.DebugC("channels", "Attempting to initialize OneBot channel")
|
||||
onebot, err := NewOneBotChannel(m.config.Channels.OneBot, m.bus)
|
||||
if err != nil {
|
||||
logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
m.channels["onebot"] = onebot
|
||||
logger.InfoC("channels", "OneBot channel enabled successfully")
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
|
||||
"enabled_channels": len(m.channels),
|
||||
})
|
||||
|
||||
@@ -0,0 +1,686 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
type OneBotChannel struct {
|
||||
*BaseChannel
|
||||
config config.OneBotConfig
|
||||
conn *websocket.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
dedup map[string]struct{}
|
||||
dedupRing []string
|
||||
dedupIdx int
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
echoCounter int64
|
||||
}
|
||||
|
||||
type oneBotRawEvent struct {
|
||||
PostType string `json:"post_type"`
|
||||
MessageType string `json:"message_type"`
|
||||
SubType string `json:"sub_type"`
|
||||
MessageID json.RawMessage `json:"message_id"`
|
||||
UserID json.RawMessage `json:"user_id"`
|
||||
GroupID json.RawMessage `json:"group_id"`
|
||||
RawMessage string `json:"raw_message"`
|
||||
Message json.RawMessage `json:"message"`
|
||||
Sender json.RawMessage `json:"sender"`
|
||||
SelfID json.RawMessage `json:"self_id"`
|
||||
Time json.RawMessage `json:"time"`
|
||||
MetaEventType string `json:"meta_event_type"`
|
||||
Echo string `json:"echo"`
|
||||
RetCode json.RawMessage `json:"retcode"`
|
||||
Status BotStatus `json:"status"`
|
||||
}
|
||||
|
||||
type BotStatus struct {
|
||||
Online bool `json:"online"`
|
||||
Good bool `json:"good"`
|
||||
}
|
||||
|
||||
type oneBotSender struct {
|
||||
UserID json.RawMessage `json:"user_id"`
|
||||
Nickname string `json:"nickname"`
|
||||
Card string `json:"card"`
|
||||
}
|
||||
|
||||
type oneBotEvent struct {
|
||||
PostType string
|
||||
MessageType string
|
||||
SubType string
|
||||
MessageID string
|
||||
UserID int64
|
||||
GroupID int64
|
||||
Content string
|
||||
RawContent string
|
||||
IsBotMentioned bool
|
||||
Sender oneBotSender
|
||||
SelfID int64
|
||||
Time int64
|
||||
MetaEventType string
|
||||
}
|
||||
|
||||
type oneBotAPIRequest struct {
|
||||
Action string `json:"action"`
|
||||
Params interface{} `json:"params"`
|
||||
Echo string `json:"echo,omitempty"`
|
||||
}
|
||||
|
||||
type oneBotSendPrivateMsgParams struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type oneBotSendGroupMsgParams struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
|
||||
base := NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
const dedupSize = 1024
|
||||
return &OneBotChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
dedup: make(map[string]struct{}, dedupSize),
|
||||
dedupRing: make([]string, dedupSize),
|
||||
dedupIdx: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Start(ctx context.Context) error {
|
||||
if c.config.WSUrl == "" {
|
||||
return fmt.Errorf("OneBot ws_url not configured")
|
||||
}
|
||||
|
||||
logger.InfoCF("onebot", "Starting OneBot channel", map[string]interface{}{
|
||||
"ws_url": c.config.WSUrl,
|
||||
})
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
if err := c.connect(); err != nil {
|
||||
logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
go c.listen()
|
||||
}
|
||||
|
||||
if c.config.ReconnectInterval > 0 {
|
||||
go c.reconnectLoop()
|
||||
} else {
|
||||
// If reconnect is disabled but initial connection failed, we cannot recover
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("failed to connect to OneBot and reconnect is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("onebot", "OneBot channel started successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) connect() error {
|
||||
dialer := websocket.DefaultDialer
|
||||
dialer.HandshakeTimeout = 10 * time.Second
|
||||
|
||||
header := make(map[string][]string)
|
||||
if c.config.AccessToken != "" {
|
||||
header["Authorization"] = []string{"Bearer " + c.config.AccessToken}
|
||||
}
|
||||
|
||||
conn, _, err := dialer.Dial(c.config.WSUrl, header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.mu.Unlock()
|
||||
|
||||
logger.InfoC("onebot", "WebSocket connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) reconnectLoop() {
|
||||
interval := time.Duration(c.config.ReconnectInterval) * time.Second
|
||||
if interval < 5*time.Second {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-time.After(interval):
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
logger.InfoC("onebot", "Attempting to reconnect...")
|
||||
if err := c.connect(); err != nil {
|
||||
logger.ErrorCF("onebot", "Reconnect failed", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
go c.listen()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("onebot", "Stopping OneBot channel")
|
||||
c.setRunning(false)
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("OneBot channel not running")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
return fmt.Errorf("OneBot WebSocket not connected")
|
||||
}
|
||||
|
||||
action, params, err := c.buildSendRequest(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
c.echoCounter++
|
||||
echo := fmt.Sprintf("send_%d", c.echoCounter)
|
||||
c.writeMu.Unlock()
|
||||
|
||||
req := oneBotAPIRequest{
|
||||
Action: action,
|
||||
Params: params,
|
||||
Echo: echo,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal OneBot request: %w", err)
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
err = conn.WriteMessage(websocket.TextMessage, data)
|
||||
c.writeMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("onebot", "Failed to send message", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) {
|
||||
chatID := msg.ChatID
|
||||
|
||||
if len(chatID) > 6 && chatID[:6] == "group:" {
|
||||
groupID, err := strconv.ParseInt(chatID[6:], 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid group ID in chatID: %s", chatID)
|
||||
}
|
||||
return "send_group_msg", oneBotSendGroupMsgParams{
|
||||
GroupID: groupID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(chatID) > 8 && chatID[:8] == "private:" {
|
||||
userID, err := strconv.ParseInt(chatID[8:], 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid user ID in chatID: %s", chatID)
|
||||
}
|
||||
return "send_private_msg", oneBotSendPrivateMsgParams{
|
||||
UserID: userID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseInt(chatID, 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid chatID for OneBot: %s", chatID)
|
||||
}
|
||||
|
||||
return "send_private_msg", oneBotSendPrivateMsgParams{
|
||||
UserID: userID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) listen() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
|
||||
return
|
||||
}
|
||||
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Raw WebSocket message received", map[string]interface{}{
|
||||
"length": len(message),
|
||||
"payload": string(message),
|
||||
})
|
||||
|
||||
var raw oneBotRawEvent
|
||||
if err := json.Unmarshal(message, &raw); err != nil {
|
||||
logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"payload": string(message),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if raw.Echo != "" || raw.Status.Online || raw.Status.Good {
|
||||
logger.DebugCF("onebot", "Received API response, skipping", map[string]interface{}{
|
||||
"echo": raw.Echo,
|
||||
"status": raw.Status,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Parsed raw event", map[string]interface{}{
|
||||
"post_type": raw.PostType,
|
||||
"message_type": raw.MessageType,
|
||||
"sub_type": raw.SubType,
|
||||
"meta_event_type": raw.MetaEventType,
|
||||
})
|
||||
|
||||
c.handleRawEvent(&raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSONInt64(raw json.RawMessage) (int64, error) {
|
||||
if len(raw) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var n int64
|
||||
if err := json.Unmarshal(raw, &n); err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
return 0, fmt.Errorf("cannot parse as int64: %s", string(raw))
|
||||
}
|
||||
|
||||
func parseJSONString(raw json.RawMessage) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s
|
||||
}
|
||||
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
type parseMessageResult struct {
|
||||
Text string
|
||||
IsBotMentioned bool
|
||||
}
|
||||
|
||||
func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult {
|
||||
if len(raw) == 0 {
|
||||
return parseMessageResult{}
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
mentioned := false
|
||||
if selfID > 0 {
|
||||
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
|
||||
if strings.Contains(s, cqAt) {
|
||||
mentioned = true
|
||||
s = strings.ReplaceAll(s, cqAt, "")
|
||||
s = strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return parseMessageResult{Text: s, IsBotMentioned: mentioned}
|
||||
}
|
||||
|
||||
var segments []map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &segments); err == nil {
|
||||
var text string
|
||||
mentioned := false
|
||||
selfIDStr := strconv.FormatInt(selfID, 10)
|
||||
for _, seg := range segments {
|
||||
segType, _ := seg["type"].(string)
|
||||
data, _ := seg["data"].(map[string]interface{})
|
||||
switch segType {
|
||||
case "text":
|
||||
if data != nil {
|
||||
if t, ok := data["text"].(string); ok {
|
||||
text += t
|
||||
}
|
||||
}
|
||||
case "at":
|
||||
if data != nil && selfID > 0 {
|
||||
qqVal := fmt.Sprintf("%v", data["qq"])
|
||||
if qqVal == selfIDStr || qqVal == "all" {
|
||||
mentioned = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return parseMessageResult{Text: strings.TrimSpace(text), IsBotMentioned: mentioned}
|
||||
}
|
||||
return parseMessageResult{}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
|
||||
switch raw.PostType {
|
||||
case "message":
|
||||
evt, err := c.normalizeMessageEvent(raw)
|
||||
if err != nil {
|
||||
logger.WarnCF("onebot", "Failed to normalize message event", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.handleMessage(evt)
|
||||
case "meta_event":
|
||||
c.handleMetaEvent(raw)
|
||||
case "notice":
|
||||
logger.DebugCF("onebot", "Notice event received", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
case "request":
|
||||
logger.DebugCF("onebot", "Request event received", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
case "":
|
||||
logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{
|
||||
"echo": raw.Echo,
|
||||
"status": raw.Status,
|
||||
})
|
||||
default:
|
||||
logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{
|
||||
"post_type": raw.PostType,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent, error) {
|
||||
userID, err := parseJSONInt64(raw.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse user_id: %w (raw: %s)", err, string(raw.UserID))
|
||||
}
|
||||
|
||||
groupID, _ := parseJSONInt64(raw.GroupID)
|
||||
selfID, _ := parseJSONInt64(raw.SelfID)
|
||||
ts, _ := parseJSONInt64(raw.Time)
|
||||
messageID := parseJSONString(raw.MessageID)
|
||||
|
||||
parsed := parseMessageContentEx(raw.Message, selfID)
|
||||
isBotMentioned := parsed.IsBotMentioned
|
||||
|
||||
content := raw.RawMessage
|
||||
if content == "" {
|
||||
content = parsed.Text
|
||||
} else if selfID > 0 {
|
||||
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
|
||||
if strings.Contains(content, cqAt) {
|
||||
isBotMentioned = true
|
||||
content = strings.ReplaceAll(content, cqAt, "")
|
||||
content = strings.TrimSpace(content)
|
||||
}
|
||||
}
|
||||
|
||||
var sender oneBotSender
|
||||
if len(raw.Sender) > 0 {
|
||||
if err := json.Unmarshal(raw.Sender, &sender); err != nil {
|
||||
logger.WarnCF("onebot", "Failed to parse sender", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"sender": string(raw.Sender),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Normalized message event", map[string]interface{}{
|
||||
"message_type": raw.MessageType,
|
||||
"user_id": userID,
|
||||
"group_id": groupID,
|
||||
"message_id": messageID,
|
||||
"content_len": len(content),
|
||||
"nickname": sender.Nickname,
|
||||
})
|
||||
|
||||
return &oneBotEvent{
|
||||
PostType: raw.PostType,
|
||||
MessageType: raw.MessageType,
|
||||
SubType: raw.SubType,
|
||||
MessageID: messageID,
|
||||
UserID: userID,
|
||||
GroupID: groupID,
|
||||
Content: content,
|
||||
RawContent: raw.RawMessage,
|
||||
IsBotMentioned: isBotMentioned,
|
||||
Sender: sender,
|
||||
SelfID: selfID,
|
||||
Time: ts,
|
||||
MetaEventType: raw.MetaEventType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
|
||||
switch raw.MetaEventType {
|
||||
case "lifecycle":
|
||||
logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
case "heartbeat":
|
||||
logger.DebugC("onebot", "Heartbeat received")
|
||||
default:
|
||||
logger.DebugCF("onebot", "Unknown meta_event_type", map[string]interface{}{
|
||||
"meta_event_type": raw.MetaEventType,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleMessage(evt *oneBotEvent) {
|
||||
if c.isDuplicate(evt.MessageID) {
|
||||
logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{
|
||||
"message_id": evt.MessageID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
content := evt.Content
|
||||
if content == "" {
|
||||
logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{
|
||||
"message_id": evt.MessageID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := strconv.FormatInt(evt.UserID, 10)
|
||||
var chatID string
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": evt.MessageID,
|
||||
}
|
||||
|
||||
switch evt.MessageType {
|
||||
case "private":
|
||||
chatID = "private:" + senderID
|
||||
logger.InfoCF("onebot", "Received private message", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"message_id": evt.MessageID,
|
||||
"length": len(content),
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
|
||||
case "group":
|
||||
groupIDStr := strconv.FormatInt(evt.GroupID, 10)
|
||||
chatID = "group:" + groupIDStr
|
||||
metadata["group_id"] = groupIDStr
|
||||
|
||||
senderUserID, _ := parseJSONInt64(evt.Sender.UserID)
|
||||
if senderUserID > 0 {
|
||||
metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10)
|
||||
}
|
||||
|
||||
if evt.Sender.Card != "" {
|
||||
metadata["sender_name"] = evt.Sender.Card
|
||||
} else if evt.Sender.Nickname != "" {
|
||||
metadata["sender_name"] = evt.Sender.Nickname
|
||||
}
|
||||
|
||||
triggered, strippedContent := c.checkGroupTrigger(content, evt.IsBotMentioned)
|
||||
if !triggered {
|
||||
logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"group": groupIDStr,
|
||||
"is_mentioned": evt.IsBotMentioned,
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
return
|
||||
}
|
||||
content = strippedContent
|
||||
|
||||
logger.InfoCF("onebot", "Received group message", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"group": groupIDStr,
|
||||
"message_id": evt.MessageID,
|
||||
"is_mentioned": evt.IsBotMentioned,
|
||||
"length": len(content),
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
|
||||
default:
|
||||
logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{
|
||||
"type": evt.MessageType,
|
||||
"message_id": evt.MessageID,
|
||||
"user_id": evt.UserID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if evt.Sender.Nickname != "" {
|
||||
metadata["nickname"] = evt.Sender.Nickname
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Forwarding message to bus", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, []string{}, metadata)
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) isDuplicate(messageID string) bool {
|
||||
if messageID == "" || messageID == "0" {
|
||||
return false
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if _, exists := c.dedup[messageID]; exists {
|
||||
return true
|
||||
}
|
||||
|
||||
if old := c.dedupRing[c.dedupIdx]; old != "" {
|
||||
delete(c.dedup, old)
|
||||
}
|
||||
c.dedupRing[c.dedupIdx] = messageID
|
||||
c.dedup[messageID] = struct{}{}
|
||||
c.dedupIdx = (c.dedupIdx + 1) % len(c.dedupRing)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= n {
|
||||
return s
|
||||
}
|
||||
return string(runes[:n]) + "..."
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) checkGroupTrigger(content string, isBotMentioned bool) (triggered bool, strippedContent string) {
|
||||
if isBotMentioned {
|
||||
return true, strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
for _, prefix := range c.config.GroupTriggerPrefix {
|
||||
if prefix == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(content, prefix) {
|
||||
return true, strings.TrimSpace(strings.TrimPrefix(content, prefix))
|
||||
}
|
||||
}
|
||||
|
||||
return false, content
|
||||
}
|
||||
@@ -308,6 +308,13 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
if !c.IsAllowed(ev.User) {
|
||||
logger.DebugCF("slack", "Mention rejected by allowlist", map[string]interface{}{
|
||||
"user_id": ev.User,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := ev.User
|
||||
channelID := ev.Channel
|
||||
threadTS := ev.ThreadTimeStamp
|
||||
@@ -367,6 +374,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
c.socketClient.Ack(*event.Request)
|
||||
}
|
||||
|
||||
if !c.IsAllowed(cmd.UserID) {
|
||||
logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]interface{}{
|
||||
"user_id": cmd.UserID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := cmd.UserID
|
||||
channelID := cmd.ChannelID
|
||||
chatID := channelID
|
||||
|
||||
+60
-65
@@ -11,7 +11,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
th "github.com/mymmrac/telego/telegohandler"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
"github.com/mymmrac/telego/telegohandler"
|
||||
tu "github.com/mymmrac/telego/telegoutil"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -24,7 +27,8 @@ import (
|
||||
type TelegramChannel struct {
|
||||
*BaseChannel
|
||||
bot *telego.Bot
|
||||
config config.TelegramConfig
|
||||
commands TelegramCommander
|
||||
config *config.Config
|
||||
chatIDs map[string]int64
|
||||
transcriber *voice.GroqTranscriber
|
||||
placeholders sync.Map // chatID -> messageID
|
||||
@@ -41,13 +45,14 @@ func (c *thinkingCancel) Cancel() {
|
||||
}
|
||||
}
|
||||
|
||||
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
var opts []telego.BotOption
|
||||
telegramCfg := cfg.Channels.Telegram
|
||||
|
||||
if cfg.Proxy != "" {
|
||||
proxyURL, parseErr := url.Parse(cfg.Proxy)
|
||||
if telegramCfg.Proxy != "" {
|
||||
proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr)
|
||||
return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr)
|
||||
}
|
||||
opts = append(opts, telego.WithHTTPClient(&http.Client{
|
||||
Transport: &http.Transport{
|
||||
@@ -56,15 +61,16 @@ func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*Telegr
|
||||
}))
|
||||
}
|
||||
|
||||
bot, err := telego.NewBot(cfg.Token, opts...)
|
||||
bot, err := telego.NewBot(telegramCfg.Token, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
|
||||
}
|
||||
|
||||
base := NewBaseChannel("telegram", cfg, bus, cfg.AllowFrom)
|
||||
base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom)
|
||||
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
commands: NewTelegramCommands(bot, cfg),
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
chatIDs: make(map[string]int64),
|
||||
@@ -88,31 +94,45 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to start long polling: %w", err)
|
||||
}
|
||||
|
||||
bh, err := telegohandler.NewBotHandler(c.bot, updates)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create bot handler: %w", err)
|
||||
}
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
c.commands.Help(ctx, message)
|
||||
return nil
|
||||
}, th.CommandEqual("help"))
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Start(ctx, message)
|
||||
}, th.CommandEqual("start"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Show(ctx, message)
|
||||
}, th.CommandEqual("show"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.List(ctx, message)
|
||||
}, th.CommandEqual("list"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.handleMessage(ctx, &message)
|
||||
}, th.AnyMessage())
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
|
||||
"username": c.bot.Username(),
|
||||
})
|
||||
|
||||
go bh.Start()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case update, ok := <-updates:
|
||||
if !ok {
|
||||
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
|
||||
return
|
||||
}
|
||||
if update.Message != nil {
|
||||
c.handleMessage(ctx, update)
|
||||
}
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
bh.Stop()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("telegram", "Stopping Telegram bot...")
|
||||
c.setRunning(false)
|
||||
@@ -166,30 +186,27 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) {
|
||||
message := update.Message
|
||||
func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
|
||||
if message == nil {
|
||||
return
|
||||
return fmt.Errorf("message is nil")
|
||||
}
|
||||
|
||||
user := message.From
|
||||
if user == nil {
|
||||
return
|
||||
return fmt.Errorf("message sender (user) is nil")
|
||||
}
|
||||
|
||||
userID := fmt.Sprintf("%d", user.ID)
|
||||
senderID := userID
|
||||
senderID := fmt.Sprintf("%d", user.ID)
|
||||
if user.Username != "" {
|
||||
senderID = fmt.Sprintf("%s|%s", userID, user.Username)
|
||||
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
|
||||
}
|
||||
|
||||
// 检查白名单,避免为被拒绝的用户下载附件
|
||||
if !c.IsAllowed(userID) && !c.IsAllowed(senderID) {
|
||||
if !c.IsAllowed(senderID) {
|
||||
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"username": user.Username,
|
||||
"user_id": senderID,
|
||||
})
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
chatID := message.Chat.ID
|
||||
@@ -222,7 +239,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
content += message.Caption
|
||||
}
|
||||
|
||||
if message.Photo != nil && len(message.Photo) > 0 {
|
||||
if len(message.Photo) > 0 {
|
||||
photo := message.Photo[len(message.Photo)-1]
|
||||
photoPath := c.downloadPhoto(ctx, photo.FileID)
|
||||
if photoPath != "" {
|
||||
@@ -231,7 +248,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[image: photo]")
|
||||
content += "[image: photo]"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,7 +269,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
"error": err.Error(),
|
||||
"path": voicePath,
|
||||
})
|
||||
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
|
||||
transcribedText = "[voice (transcription failed)]"
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
|
||||
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
|
||||
@@ -260,7 +277,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
})
|
||||
}
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[voice]")
|
||||
transcribedText = "[voice]"
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
@@ -278,7 +295,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[audio]")
|
||||
content += "[audio]"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,7 +307,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[file]")
|
||||
content += "[file]"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,37 +337,14 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
}
|
||||
}
|
||||
|
||||
// Create new context for thinking animation with timeout
|
||||
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
// Create cancel function for thinking state
|
||||
_, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
|
||||
|
||||
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
|
||||
if err == nil {
|
||||
pID := pMsg.MessageID
|
||||
c.placeholders.Store(chatIDStr, pID)
|
||||
|
||||
go func(cid int64, mid int) {
|
||||
dots := []string{".", "..", "..."}
|
||||
emotes := []string{"💭", "🤔", "☁️"}
|
||||
i := 0
|
||||
ticker := time.NewTicker(2000 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-thinkCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
i++
|
||||
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
|
||||
_, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
|
||||
if editErr != nil {
|
||||
logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
|
||||
"error": editErr.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}(chatID, pID)
|
||||
}
|
||||
|
||||
peerKind := "direct"
|
||||
@@ -370,7 +364,8 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
"peer_id": peerID,
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type TelegramCommander interface {
|
||||
Help(ctx context.Context, message telego.Message) error
|
||||
Start(ctx context.Context, message telego.Message) error
|
||||
Show(ctx context.Context, message telego.Message) error
|
||||
List(ctx context.Context, message telego.Message) error
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
bot *telego.Bot
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander {
|
||||
return &cmd{
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func commandArgs(text string) string {
|
||||
parts := strings.SplitN(text, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
|
||||
msg := `/start - Start the bot
|
||||
/help - Show this help message
|
||||
/show [model|channel] - Show current configuration
|
||||
/list [models|channels] - List available options
|
||||
`
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: msg,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Start(ctx context.Context, message telego.Message) error {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Hello! I am PicoClaw 🦞",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Show(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /show [model|channel]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "model":
|
||||
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
|
||||
c.config.Agents.Defaults.Model,
|
||||
c.config.Agents.Defaults.Provider)
|
||||
case "channel":
|
||||
response = "Current Channel: telegram"
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
func (c *cmd) List(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /list [models|channels]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "models":
|
||||
provider := c.config.Agents.Defaults.Provider
|
||||
if provider == "" {
|
||||
provider = "configured default"
|
||||
}
|
||||
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml",
|
||||
c.config.Agents.Defaults.Model, provider)
|
||||
|
||||
case "channels":
|
||||
var enabled []string
|
||||
if c.config.Channels.Telegram.Enabled {
|
||||
enabled = append(enabled, "telegram")
|
||||
}
|
||||
if c.config.Channels.WhatsApp.Enabled {
|
||||
enabled = append(enabled, "whatsapp")
|
||||
}
|
||||
if c.config.Channels.Feishu.Enabled {
|
||||
enabled = append(enabled, "feishu")
|
||||
}
|
||||
if c.config.Channels.Discord.Enabled {
|
||||
enabled = append(enabled, "discord")
|
||||
}
|
||||
if c.config.Channels.Slack.Enabled {
|
||||
enabled = append(enabled, "slack")
|
||||
}
|
||||
response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))
|
||||
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
+42
-21
@@ -161,6 +161,7 @@ type ChannelsConfig struct {
|
||||
DingTalk DingTalkConfig `json:"dingtalk"`
|
||||
Slack SlackConfig `json:"slack"`
|
||||
LINE LINEConfig `json:"line"`
|
||||
OneBot OneBotConfig `json:"onebot"`
|
||||
}
|
||||
|
||||
type WhatsAppConfig struct {
|
||||
@@ -213,10 +214,10 @@ type DingTalkConfig struct {
|
||||
}
|
||||
|
||||
type SlackConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
|
||||
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
|
||||
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
|
||||
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
|
||||
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
|
||||
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type LINEConfig struct {
|
||||
@@ -229,6 +230,15 @@ type LINEConfig struct {
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type OneBotConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"`
|
||||
WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"`
|
||||
AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"`
|
||||
ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"`
|
||||
GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type HeartbeatConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
|
||||
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
|
||||
@@ -240,24 +250,27 @@ type DevicesConfig struct {
|
||||
}
|
||||
|
||||
type ProvidersConfig struct {
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI ProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
DeepSeek ProviderConfig `json:"deepseek"`
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI ProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Ollama ProviderConfig `json:"ollama"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
DeepSeek ProviderConfig `json:"deepseek"`
|
||||
GitHubCopilot ProviderConfig `json:"github_copilot"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
||||
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
|
||||
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
||||
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
|
||||
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
|
||||
ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc`
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
@@ -344,7 +357,7 @@ func DefaultConfig() *Config {
|
||||
Enabled: false,
|
||||
BotToken: "",
|
||||
AppToken: "",
|
||||
AllowFrom: []string{},
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
LINE: LINEConfig{
|
||||
Enabled: false,
|
||||
@@ -355,6 +368,14 @@ func DefaultConfig() *Config {
|
||||
WebhookPath: "/webhook/line",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
OneBot: OneBotConfig{
|
||||
Enabled: false,
|
||||
WSUrl: "ws://127.0.0.1:3001",
|
||||
AccessToken: "",
|
||||
ReconnectInterval: 5,
|
||||
GroupTriggerPrefix: []string{},
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: ProviderConfig{},
|
||||
@@ -432,7 +453,7 @@ func SaveConfig(path string, cfg *Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(path, data, 0644)
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
func (c *Config) WorkspacePath() string {
|
||||
|
||||
@@ -2,6 +2,9 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -297,6 +300,30 @@ func TestDefaultConfig_WebTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveConfig_FilePermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("file permission bits are not enforced on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
if err := SaveConfig(path, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig failed: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat failed: %v", err)
|
||||
}
|
||||
|
||||
perm := info.Mode().Perm()
|
||||
if perm != 0600 {
|
||||
t.Errorf("config file has permission %04o, want 0600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfig_Complete verifies all config fields are set
|
||||
func TestConfig_Complete(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
+1
-1
@@ -340,7 +340,7 @@ func (cs *CronService) saveStoreUnsafe() error {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(cs.storePath, data, 0644)
|
||||
return os.WriteFile(cs.storePath, data, 0600)
|
||||
}
|
||||
|
||||
func (cs *CronService) AddJob(name string, schedule CronSchedule, message string, deliver bool, channel, to string) (*CronJob, error) {
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSaveStore_FilePermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("file permission bits are not enforced on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
storePath := filepath.Join(tmpDir, "cron", "jobs.json")
|
||||
|
||||
cs := NewCronService(storePath, nil)
|
||||
|
||||
_, err := cs.AddJob("test", CronSchedule{Kind: "every", EveryMS: int64Ptr(60000)}, "hello", false, "cli", "direct")
|
||||
if err != nil {
|
||||
t.Fatalf("AddJob failed: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat failed: %v", err)
|
||||
}
|
||||
|
||||
perm := info.Mode().Perm()
|
||||
if perm != 0600 {
|
||||
t.Errorf("cron store has permission %04o, want 0600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
func int64Ptr(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
server *http.Server
|
||||
mu sync.RWMutex
|
||||
ready bool
|
||||
checks map[string]Check
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
type Check struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Uptime string `json:"uptime"`
|
||||
Checks map[string]Check `json:"checks,omitempty"`
|
||||
}
|
||||
|
||||
func NewServer(host string, port int) *Server {
|
||||
mux := http.NewServeMux()
|
||||
s := &Server{
|
||||
ready: false,
|
||||
checks: make(map[string]Check),
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
mux.HandleFunc("/health", s.healthHandler)
|
||||
mux.HandleFunc("/ready", s.readyHandler)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
s.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
s.ready = true
|
||||
s.mu.Unlock()
|
||||
return s.server.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *Server) StartContext(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
s.ready = true
|
||||
s.mu.Unlock()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.server.ListenAndServe()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return s.server.Shutdown(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
s.ready = false
|
||||
s.mu.Unlock()
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (s *Server) SetReady(ready bool) {
|
||||
s.mu.Lock()
|
||||
s.ready = ready
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) RegisterCheck(name string, checkFn func() (bool, string)) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
status, msg := checkFn()
|
||||
s.checks[name] = Check{
|
||||
Name: name,
|
||||
Status: statusString(status),
|
||||
Message: msg,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
uptime := time.Since(s.startTime)
|
||||
resp := StatusResponse{
|
||||
Status: "ok",
|
||||
Uptime: uptime.String(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
s.mu.RLock()
|
||||
ready := s.ready
|
||||
checks := make(map[string]Check)
|
||||
for k, v := range s.checks {
|
||||
checks[k] = v
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ready {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
json.NewEncoder(w).Encode(StatusResponse{
|
||||
Status: "not ready",
|
||||
Checks: checks,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
for _, check := range checks {
|
||||
if check.Status == "fail" {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
json.NewEncoder(w).Encode(StatusResponse{
|
||||
Status: "not ready",
|
||||
Checks: checks,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
uptime := time.Since(s.startTime)
|
||||
json.NewEncoder(w).Encode(StatusResponse{
|
||||
Status: "ready",
|
||||
Uptime: uptime.String(),
|
||||
Checks: checks,
|
||||
})
|
||||
}
|
||||
|
||||
func statusString(ok bool) string {
|
||||
if ok {
|
||||
return "ok"
|
||||
}
|
||||
return "fail"
|
||||
}
|
||||
@@ -171,68 +171,14 @@ func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractToolCalls parses tool call JSON from the response text.
|
||||
// extractToolCalls delegates to the shared extractToolCallsFromText function.
|
||||
func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall {
|
||||
start := strings.Index(text, `{"tool_calls"`)
|
||||
if start == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
end := findMatchingBrace(text, start)
|
||||
if end == start {
|
||||
return nil
|
||||
}
|
||||
|
||||
jsonStr := text[start:end]
|
||||
|
||||
var wrapper struct {
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []ToolCall
|
||||
for _, tc := range wrapper.ToolCalls {
|
||||
var args map[string]interface{}
|
||||
json.Unmarshal([]byte(tc.Function.Arguments), &args)
|
||||
|
||||
result = append(result, ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
Function: &FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
return extractToolCallsFromText(text)
|
||||
}
|
||||
|
||||
// stripToolCallsJSON removes tool call JSON from response text.
|
||||
// stripToolCallsJSON delegates to the shared stripToolCallsFromText function.
|
||||
func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string {
|
||||
start := strings.Index(text, `{"tool_calls"`)
|
||||
if start == -1 {
|
||||
return text
|
||||
}
|
||||
|
||||
end := findMatchingBrace(text, start)
|
||||
if end == start {
|
||||
return text
|
||||
}
|
||||
|
||||
return strings.TrimSpace(text[:start] + text[end:])
|
||||
return stripToolCallsFromText(text)
|
||||
}
|
||||
|
||||
// findMatchingBrace finds the index after the closing brace matching the opening brace at pos.
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CodexCliAuth represents the ~/.codex/auth.json file structure.
|
||||
type CodexCliAuth struct {
|
||||
Tokens struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccountID string `json:"account_id"`
|
||||
} `json:"tokens"`
|
||||
}
|
||||
|
||||
// ReadCodexCliCredentials reads OAuth tokens from the Codex CLI's auth.json file.
|
||||
// Expiry is estimated as file modification time + 1 hour (same approach as moltbot).
|
||||
func ReadCodexCliCredentials() (accessToken, accountID string, expiresAt time.Time, err error) {
|
||||
authPath, err := resolveCodexAuthPath()
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(authPath)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, fmt.Errorf("reading %s: %w", authPath, err)
|
||||
}
|
||||
|
||||
var auth CodexCliAuth
|
||||
if err := json.Unmarshal(data, &auth); err != nil {
|
||||
return "", "", time.Time{}, fmt.Errorf("parsing %s: %w", authPath, err)
|
||||
}
|
||||
|
||||
if auth.Tokens.AccessToken == "" {
|
||||
return "", "", time.Time{}, fmt.Errorf("no access_token in %s", authPath)
|
||||
}
|
||||
|
||||
stat, err := os.Stat(authPath)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(time.Hour)
|
||||
} else {
|
||||
expiresAt = stat.ModTime().Add(time.Hour)
|
||||
}
|
||||
|
||||
return auth.Tokens.AccessToken, auth.Tokens.AccountID, expiresAt, nil
|
||||
}
|
||||
|
||||
// CreateCodexCliTokenSource creates a token source that reads from ~/.codex/auth.json.
|
||||
// This allows the existing CodexProvider to reuse Codex CLI credentials.
|
||||
func CreateCodexCliTokenSource() func() (string, string, error) {
|
||||
return func() (string, string, error) {
|
||||
token, accountID, expiresAt, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("reading codex cli credentials: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().After(expiresAt) {
|
||||
return "", "", fmt.Errorf("codex cli credentials expired (auth.json last modified > 1h ago). Run: codex login")
|
||||
}
|
||||
|
||||
return token, accountID, nil
|
||||
}
|
||||
}
|
||||
|
||||
func resolveCodexAuthPath() (string, error) {
|
||||
codexHome := os.Getenv("CODEX_HOME")
|
||||
if codexHome == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting home dir: %w", err)
|
||||
}
|
||||
codexHome = filepath.Join(home, ".codex")
|
||||
}
|
||||
return filepath.Join(codexHome, "auth.json"), nil
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestReadCodexCliCredentials_Valid(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{
|
||||
"tokens": {
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"account_id": "org-test123"
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
token, accountID, expiresAt, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadCodexCliCredentials() error: %v", err)
|
||||
}
|
||||
if token != "test-access-token" {
|
||||
t.Errorf("token = %q, want %q", token, "test-access-token")
|
||||
}
|
||||
if accountID != "org-test123" {
|
||||
t.Errorf("accountID = %q, want %q", accountID, "org-test123")
|
||||
}
|
||||
// Expiry should be within ~1 hour from now (file was just written)
|
||||
if expiresAt.Before(time.Now()) {
|
||||
t.Errorf("expiresAt = %v, should be in the future", expiresAt)
|
||||
}
|
||||
if expiresAt.After(time.Now().Add(2 * time.Hour)) {
|
||||
t.Errorf("expiresAt = %v, should be within ~1 hour", expiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_MissingFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
_, _, _, err := ReadCodexCliCredentials()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing auth.json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_EmptyToken(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "", "refresh_token": "r", "account_id": "a"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
_, _, _, err := ReadCodexCliCredentials()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty access_token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_InvalidJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
if err := os.WriteFile(authPath, []byte("not json"), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
_, _, _, err := ReadCodexCliCredentials()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_NoAccountID(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "tok123", "refresh_token": "ref456"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
token, accountID, _, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if token != "tok123" {
|
||||
t.Errorf("token = %q, want %q", token, "tok123")
|
||||
}
|
||||
if accountID != "" {
|
||||
t.Errorf("accountID = %q, want empty", accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_CodexHomeEnv(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
customDir := filepath.Join(tmpDir, "custom-codex")
|
||||
if err := os.MkdirAll(customDir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "custom-token", "refresh_token": "r"}}`
|
||||
if err := os.WriteFile(filepath.Join(customDir, "auth.json"), []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", customDir)
|
||||
|
||||
token, _, _, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if token != "custom-token" {
|
||||
t.Errorf("token = %q, want %q", token, "custom-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCodexCliTokenSource_Valid(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "fresh-token", "refresh_token": "r", "account_id": "acc"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
source := CreateCodexCliTokenSource()
|
||||
token, accountID, err := source()
|
||||
if err != nil {
|
||||
t.Fatalf("token source error: %v", err)
|
||||
}
|
||||
if token != "fresh-token" {
|
||||
t.Errorf("token = %q, want %q", token, "fresh-token")
|
||||
}
|
||||
if accountID != "acc" {
|
||||
t.Errorf("accountID = %q, want %q", accountID, "acc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCodexCliTokenSource_Expired(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "old-token", "refresh_token": "r"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Set file modification time to 2 hours ago
|
||||
oldTime := time.Now().Add(-2 * time.Hour)
|
||||
if err := os.Chtimes(authPath, oldTime, oldTime); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
source := CreateCodexCliTokenSource()
|
||||
_, _, err := source()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for expired credentials")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,251 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CodexCliProvider implements LLMProvider by wrapping the codex CLI as a subprocess.
|
||||
type CodexCliProvider struct {
|
||||
command string
|
||||
workspace string
|
||||
}
|
||||
|
||||
// NewCodexCliProvider creates a new Codex CLI provider.
|
||||
func NewCodexCliProvider(workspace string) *CodexCliProvider {
|
||||
return &CodexCliProvider{
|
||||
command: "codex",
|
||||
workspace: workspace,
|
||||
}
|
||||
}
|
||||
|
||||
// Chat implements LLMProvider.Chat by executing the codex CLI in non-interactive mode.
|
||||
func (p *CodexCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
if p.command == "" {
|
||||
return nil, fmt.Errorf("codex command not configured")
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, tools)
|
||||
|
||||
args := []string{
|
||||
"exec",
|
||||
"--json",
|
||||
"--dangerously-bypass-approvals-and-sandbox",
|
||||
"--skip-git-repo-check",
|
||||
"--color", "never",
|
||||
}
|
||||
if model != "" && model != "codex-cli" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
if p.workspace != "" {
|
||||
args = append(args, "-C", p.workspace)
|
||||
}
|
||||
args = append(args, "-") // read prompt from stdin
|
||||
|
||||
cmd := exec.CommandContext(ctx, p.command, args...)
|
||||
cmd.Stdin = bytes.NewReader([]byte(prompt))
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
|
||||
// Parse JSONL from stdout even if exit code is non-zero,
|
||||
// because codex writes diagnostic noise to stderr (e.g. rollout errors)
|
||||
// but still produces valid JSONL output.
|
||||
if stdoutStr := stdout.String(); stdoutStr != "" {
|
||||
resp, parseErr := p.parseJSONLEvents(stdoutStr)
|
||||
if parseErr == nil && resp != nil && (resp.Content != "" || len(resp.ToolCalls) > 0) {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.Canceled {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
if stderrStr := stderr.String(); stderrStr != "" {
|
||||
return nil, fmt.Errorf("codex cli error: %s", stderrStr)
|
||||
}
|
||||
return nil, fmt.Errorf("codex cli error: %w", err)
|
||||
}
|
||||
|
||||
return p.parseJSONLEvents(stdout.String())
|
||||
}
|
||||
|
||||
// GetDefaultModel returns the default model identifier.
|
||||
func (p *CodexCliProvider) GetDefaultModel() string {
|
||||
return "codex-cli"
|
||||
}
|
||||
|
||||
// buildPrompt converts messages to a prompt string for the Codex CLI.
|
||||
// System messages are prepended as instructions since Codex CLI has no --system-prompt flag.
|
||||
func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinition) string {
|
||||
var systemParts []string
|
||||
var conversationParts []string
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
systemParts = append(systemParts, msg.Content)
|
||||
case "user":
|
||||
conversationParts = append(conversationParts, msg.Content)
|
||||
case "assistant":
|
||||
conversationParts = append(conversationParts, "Assistant: "+msg.Content)
|
||||
case "tool":
|
||||
conversationParts = append(conversationParts,
|
||||
fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content))
|
||||
}
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
if len(systemParts) > 0 {
|
||||
sb.WriteString("## System Instructions\n\n")
|
||||
sb.WriteString(strings.Join(systemParts, "\n\n"))
|
||||
sb.WriteString("\n\n## Task\n\n")
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString(p.buildToolsPrompt(tools))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
// Simplify single user message (no prefix)
|
||||
if len(conversationParts) == 1 && len(systemParts) == 0 && len(tools) == 0 {
|
||||
return conversationParts[0]
|
||||
}
|
||||
|
||||
sb.WriteString(strings.Join(conversationParts, "\n"))
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// buildToolsPrompt creates a tool definitions section for the prompt.
|
||||
func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// codexEvent represents a single JSONL event from `codex exec --json`.
|
||||
type codexEvent struct {
|
||||
Type string `json:"type"`
|
||||
ThreadID string `json:"thread_id,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Item *codexEventItem `json:"item,omitempty"`
|
||||
Usage *codexUsage `json:"usage,omitempty"`
|
||||
Error *codexEventErr `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type codexEventItem struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
ExitCode *int `json:"exit_code,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
type codexUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CachedInputTokens int `json:"cached_input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type codexEventErr struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// parseJSONLEvents processes the JSONL output from codex exec --json.
|
||||
func (p *CodexCliProvider) parseJSONLEvents(output string) (*LLMResponse, error) {
|
||||
var contentParts []string
|
||||
var usage *UsageInfo
|
||||
var lastError string
|
||||
|
||||
scanner := bufio.NewScanner(strings.NewReader(output))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event codexEvent
|
||||
if err := json.Unmarshal([]byte(line), &event); err != nil {
|
||||
continue // skip malformed lines
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "item.completed":
|
||||
if event.Item != nil && event.Item.Type == "agent_message" && event.Item.Text != "" {
|
||||
contentParts = append(contentParts, event.Item.Text)
|
||||
}
|
||||
case "turn.completed":
|
||||
if event.Usage != nil {
|
||||
promptTokens := event.Usage.InputTokens + event.Usage.CachedInputTokens
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: event.Usage.OutputTokens,
|
||||
TotalTokens: promptTokens + event.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
case "error":
|
||||
lastError = event.Message
|
||||
case "turn.failed":
|
||||
if event.Error != nil {
|
||||
lastError = event.Error.Message
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastError != "" && len(contentParts) == 0 {
|
||||
return nil, fmt.Errorf("codex cli: %s", lastError)
|
||||
}
|
||||
|
||||
content := strings.Join(contentParts, "\n")
|
||||
|
||||
// Extract tool calls from response text (same pattern as ClaudeCliProvider)
|
||||
toolCalls := extractToolCallsFromText(content)
|
||||
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
content = stripToolCallsFromText(content)
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: strings.TrimSpace(content),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,585 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- JSONL Event Parsing Tests ---
|
||||
|
||||
func TestParseJSONLEvents_AgentMessage(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"thread.started","thread_id":"abc-123"}
|
||||
{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Hello from Codex!"}}
|
||||
{"type":"turn.completed","usage":{"input_tokens":100,"cached_input_tokens":50,"output_tokens":20}}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hello from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hello from Codex!")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("Usage should not be nil")
|
||||
}
|
||||
if resp.Usage.PromptTokens != 150 {
|
||||
t.Errorf("PromptTokens = %d, want 150", resp.Usage.PromptTokens)
|
||||
}
|
||||
if resp.Usage.CompletionTokens != 20 {
|
||||
t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens)
|
||||
}
|
||||
if resp.Usage.TotalTokens != 170 {
|
||||
t.Errorf("TotalTokens = %d, want 170", resp.Usage.TotalTokens)
|
||||
}
|
||||
if len(resp.ToolCalls) != 0 {
|
||||
t.Errorf("ToolCalls should be empty, got %d", len(resp.ToolCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_ToolCallExtraction(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
toolCallText := `Let me read that file.
|
||||
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test.txt\"}"}}]}`
|
||||
// Build valid JSONL by marshaling the event
|
||||
item := codexEvent{
|
||||
Type: "item.completed",
|
||||
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
|
||||
}
|
||||
itemJSON, _ := json.Marshal(item)
|
||||
usageEvt := `{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":0,"output_tokens":20}}`
|
||||
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + usageEvt
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.FinishReason != "tool_calls" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("ToolCalls count = %d, want 1", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.ToolCalls[0].Name != "read_file" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
|
||||
}
|
||||
if resp.ToolCalls[0].ID != "call_1" {
|
||||
t.Errorf("ToolCalls[0].ID = %q, want %q", resp.ToolCalls[0].ID, "call_1")
|
||||
}
|
||||
if resp.ToolCalls[0].Function.Arguments != `{"path":"/tmp/test.txt"}` {
|
||||
t.Errorf("ToolCalls[0].Function.Arguments = %q", resp.ToolCalls[0].Function.Arguments)
|
||||
}
|
||||
// Content should have the tool call JSON stripped
|
||||
if strings.Contains(resp.Content, "tool_calls") {
|
||||
t.Errorf("Content should not contain tool_calls JSON, got: %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_MultipleToolCalls(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
toolCallText := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"a.txt\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"b.txt\",\"content\":\"hello\"}"}}]}`
|
||||
item := codexEvent{
|
||||
Type: "item.completed",
|
||||
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
|
||||
}
|
||||
itemJSON, _ := json.Marshal(item)
|
||||
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + `{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if len(resp.ToolCalls) != 2 {
|
||||
t.Fatalf("ToolCalls count = %d, want 2", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.ToolCalls[0].Name != "read_file" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
|
||||
}
|
||||
if resp.ToolCalls[1].Name != "write_file" {
|
||||
t.Errorf("ToolCalls[1].Name = %q, want %q", resp.ToolCalls[1].Name, "write_file")
|
||||
}
|
||||
if resp.FinishReason != "tool_calls" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_MultipleMessages(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"First part."}}
|
||||
{"type":"item.completed","item":{"id":"item_2","type":"command_execution","command":"ls","status":"completed"}}
|
||||
{"type":"item.completed","item":{"id":"item_3","type":"agent_message","text":"Second part."}}
|
||||
{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.Content != "First part.\nSecond part." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "First part.\nSecond part.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_ErrorEvent(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"thread.started","thread_id":"abc"}
|
||||
{"type":"turn.started"}
|
||||
{"type":"error","message":"token expired"}
|
||||
{"type":"turn.failed","error":{"message":"token expired"}}`
|
||||
|
||||
_, err := p.parseJSONLEvents(events)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "token expired") {
|
||||
t.Errorf("error = %q, want to contain 'token expired'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_TurnFailed(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"turn.failed","error":{"message":"rate limit exceeded"}}`
|
||||
|
||||
_, err := p.parseJSONLEvents(events)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "rate limit exceeded") {
|
||||
t.Errorf("error = %q, want to contain 'rate limit exceeded'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_ErrorWithContent(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
// If there's an error but also content, return the content (partial success)
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Partial result."}}
|
||||
{"type":"error","message":"connection reset"}
|
||||
{"type":"turn.failed","error":{"message":"connection reset"}}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("should not error when content exists: %v", err)
|
||||
}
|
||||
if resp.Content != "Partial result." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Partial result.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_EmptyOutput(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
resp, err := p.parseJSONLEvents("")
|
||||
if err != nil {
|
||||
t.Fatalf("empty output should not error: %v", err)
|
||||
}
|
||||
if resp.Content != "" {
|
||||
t.Errorf("Content = %q, want empty", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_MalformedLines(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `not json at all
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Good line."}}
|
||||
another bad line
|
||||
{"type":"turn.completed","usage":{"input_tokens":10,"output_tokens":5}}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("should skip malformed lines: %v", err)
|
||||
}
|
||||
if resp.Content != "Good line." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Good line.")
|
||||
}
|
||||
if resp.Usage == nil || resp.Usage.TotalTokens != 15 {
|
||||
t.Errorf("Usage.TotalTokens = %v, want 15", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_CommandExecution(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.started","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"in_progress"}}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"completed","exit_code":0,"output":"file1.go\nfile2.go"}}
|
||||
{"type":"item.completed","item":{"id":"item_2","type":"agent_message","text":"Found 2 files."}}
|
||||
{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
// command_execution items should be skipped; only agent_message text is returned
|
||||
if resp.Content != "Found 2 files." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Found 2 files.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_NoUsage(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"No usage info."}}
|
||||
{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
t.Errorf("Usage should be nil when turn.completed has no usage, got %+v", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Prompt Building Tests ---
|
||||
|
||||
func TestBuildPrompt_SystemAsInstructions(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hi there"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if !strings.Contains(prompt, "## System Instructions") {
|
||||
t.Error("prompt should contain '## System Instructions'")
|
||||
}
|
||||
if !strings.Contains(prompt, "You are helpful.") {
|
||||
t.Error("prompt should contain system content")
|
||||
}
|
||||
if !strings.Contains(prompt, "## Task") {
|
||||
t.Error("prompt should contain '## Task'")
|
||||
}
|
||||
if !strings.Contains(prompt, "Hi there") {
|
||||
t.Error("prompt should contain user message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_NoSystem(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Just a question"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if strings.Contains(prompt, "## System Instructions") {
|
||||
t.Error("prompt should not contain system instructions header")
|
||||
}
|
||||
if prompt != "Just a question" {
|
||||
t.Errorf("prompt = %q, want %q", prompt, "Just a question")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_WithTools(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Get weather"},
|
||||
}
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"city": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, tools)
|
||||
|
||||
if !strings.Contains(prompt, "## Available Tools") {
|
||||
t.Error("prompt should contain tools section")
|
||||
}
|
||||
if !strings.Contains(prompt, "get_weather") {
|
||||
t.Error("prompt should contain tool name")
|
||||
}
|
||||
if !strings.Contains(prompt, "Get current weather") {
|
||||
t.Error("prompt should contain tool description")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_MultipleMessages(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi! How can I help?"},
|
||||
{Role: "user", Content: "Tell me about Go"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if !strings.Contains(prompt, "Hello") {
|
||||
t.Error("prompt should contain first user message")
|
||||
}
|
||||
if !strings.Contains(prompt, "Assistant: Hi! How can I help?") {
|
||||
t.Error("prompt should contain assistant message with prefix")
|
||||
}
|
||||
if !strings.Contains(prompt, "Tell me about Go") {
|
||||
t.Error("prompt should contain second user message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_ToolResults(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if !strings.Contains(prompt, "[Tool Result for call_1]") {
|
||||
t.Error("prompt should contain tool result")
|
||||
}
|
||||
if !strings.Contains(prompt, `{"temp": 72}`) {
|
||||
t.Error("prompt should contain tool result content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_SystemAndTools(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "Be concise."},
|
||||
{Role: "user", Content: "Do something"},
|
||||
}
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "my_tool",
|
||||
Description: "A tool",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, tools)
|
||||
|
||||
// System instructions should come first
|
||||
sysIdx := strings.Index(prompt, "## System Instructions")
|
||||
toolIdx := strings.Index(prompt, "## Available Tools")
|
||||
taskIdx := strings.Index(prompt, "## Task")
|
||||
|
||||
if sysIdx == -1 || toolIdx == -1 || taskIdx == -1 {
|
||||
t.Fatal("prompt should contain all sections")
|
||||
}
|
||||
if sysIdx >= taskIdx {
|
||||
t.Error("system instructions should come before task")
|
||||
}
|
||||
if taskIdx >= toolIdx {
|
||||
t.Error("task section should come before tools in the output")
|
||||
}
|
||||
}
|
||||
|
||||
// --- CLI Argument Tests ---
|
||||
|
||||
func TestCodexCliProvider_GetDefaultModel(t *testing.T) {
|
||||
p := NewCodexCliProvider("")
|
||||
if got := p.GetDefaultModel(); got != "codex-cli" {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "codex-cli")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Mock CLI Integration Test ---
|
||||
|
||||
func createMockCodexCLI(t *testing.T, events []string) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "codex")
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("#!/bin/bash\n")
|
||||
for _, event := range events {
|
||||
sb.WriteString(fmt.Sprintf("echo '%s'\n", event))
|
||||
}
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(sb.String()), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return scriptPath
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_Success(t *testing.T) {
|
||||
scriptPath := createMockCodexCLI(t, []string{
|
||||
`{"type":"thread.started","thread_id":"test-123"}`,
|
||||
`{"type":"turn.started"}`,
|
||||
`{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Mock response from Codex CLI"}}`,
|
||||
`{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":10,"output_tokens":15}}`,
|
||||
})
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Mock response from Codex CLI" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Mock response from Codex CLI")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("Usage should not be nil")
|
||||
}
|
||||
if resp.Usage.PromptTokens != 60 {
|
||||
t.Errorf("PromptTokens = %d, want 60", resp.Usage.PromptTokens)
|
||||
}
|
||||
if resp.Usage.CompletionTokens != 15 {
|
||||
t.Errorf("CompletionTokens = %d, want 15", resp.Usage.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_Error(t *testing.T) {
|
||||
scriptPath := createMockCodexCLI(t, []string{
|
||||
`{"type":"thread.started","thread_id":"test-err"}`,
|
||||
`{"type":"turn.started"}`,
|
||||
`{"type":"error","message":"auth token expired"}`,
|
||||
`{"type":"turn.failed","error":{"message":"auth token expired"}}`,
|
||||
})
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
_, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "auth token expired") {
|
||||
t.Errorf("error = %q, want to contain 'auth token expired'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_WithModel(t *testing.T) {
|
||||
// Mock script that captures args to verify model flag is passed
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "codex")
|
||||
script := `#!/bin/bash
|
||||
# Write args to a file for verification
|
||||
echo "$@" > "` + filepath.Join(tmpDir, "args.txt") + `"
|
||||
echo '{"type":"item.completed","item":{"id":"1","type":"agent_message","text":"ok"}}'
|
||||
echo '{"type":"turn.completed"}'`
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "/tmp/test-workspace",
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "test"}}
|
||||
_, err := p.Chat(context.Background(), messages, nil, "gpt-5.2-codex", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the args
|
||||
argsData, err := os.ReadFile(filepath.Join(tmpDir, "args.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("reading args: %v", err)
|
||||
}
|
||||
args := string(argsData)
|
||||
|
||||
if !strings.Contains(args, "-m gpt-5.2-codex") {
|
||||
t.Errorf("args should contain model flag, got: %s", args)
|
||||
}
|
||||
if !strings.Contains(args, "-C /tmp/test-workspace") {
|
||||
t.Errorf("args should contain workspace flag, got: %s", args)
|
||||
}
|
||||
if !strings.Contains(args, "--json") {
|
||||
t.Errorf("args should contain --json, got: %s", args)
|
||||
}
|
||||
if !strings.Contains(args, "--dangerously-bypass-approvals-and-sandbox") {
|
||||
t.Errorf("args should contain bypass flag, got: %s", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_ContextCancel(t *testing.T) {
|
||||
// Script that sleeps forever
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "codex")
|
||||
script := "#!/bin/bash\nsleep 60"
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancel immediately
|
||||
|
||||
messages := []Message{{Role: "user", Content: "test"}}
|
||||
_, err := p.Chat(ctx, messages, nil, "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error on canceled context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_EmptyCommand(t *testing.T) {
|
||||
p := &CodexCliProvider{command: ""}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "test"}}
|
||||
_, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty command")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Integration Test (requires real codex CLI with valid auth) ---
|
||||
|
||||
func TestCodexCliProvider_Integration(t *testing.T) {
|
||||
if os.Getenv("PICOCLAW_INTEGRATION_TESTS") == "" {
|
||||
t.Skip("skipping integration test (set PICOCLAW_INTEGRATION_TESTS=1 to enable)")
|
||||
}
|
||||
|
||||
// Verify codex is available
|
||||
codexPath, err := exec.LookPath("codex")
|
||||
if err != nil {
|
||||
t.Skip("codex CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: codexPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Respond with just the word 'hello' and nothing else."},
|
||||
}
|
||||
|
||||
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
|
||||
lower := strings.ToLower(strings.TrimSpace(resp.Content))
|
||||
if !strings.Contains(lower, "hello") {
|
||||
t.Errorf("Content = %q, expected to contain 'hello'", resp.Content)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package providers
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -10,18 +11,26 @@ import (
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const codexDefaultModel = "gpt-5.2"
|
||||
const codexDefaultInstructions = "You are Codex, a coding assistant."
|
||||
|
||||
type CodexProvider struct {
|
||||
client *openai.Client
|
||||
accountID string
|
||||
tokenSource func() (string, string, error)
|
||||
}
|
||||
|
||||
const defaultCodexInstructions = "You are Codex, a coding assistant."
|
||||
|
||||
func NewCodexProvider(token, accountID string) *CodexProvider {
|
||||
opts := []option.RequestOption{
|
||||
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
|
||||
option.WithAPIKey(token),
|
||||
option.WithHeader("originator", "codex_cli_rs"),
|
||||
option.WithHeader("OpenAI-Beta", "responses=experimental"),
|
||||
}
|
||||
if accountID != "" {
|
||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||
@@ -41,6 +50,15 @@ func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func()
|
||||
|
||||
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
var opts []option.RequestOption
|
||||
accountID := p.accountID
|
||||
resolvedModel, fallbackReason := resolveCodexModel(model)
|
||||
if fallbackReason != "" {
|
||||
logger.WarnCF("provider.codex", "Requested model is not compatible with Codex backend, using fallback", map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
"reason": fallbackReason,
|
||||
})
|
||||
}
|
||||
if p.tokenSource != nil {
|
||||
tok, accID, err := p.tokenSource()
|
||||
if err != nil {
|
||||
@@ -48,22 +66,120 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
|
||||
}
|
||||
opts = append(opts, option.WithAPIKey(tok))
|
||||
if accID != "" {
|
||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
|
||||
accountID = accID
|
||||
}
|
||||
}
|
||||
if accountID != "" {
|
||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||
} else {
|
||||
logger.WarnCF("provider.codex", "No account id found for Codex request; backend may reject with 400", map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
})
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, tools, model, options)
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options)
|
||||
|
||||
resp, err := p.client.Responses.New(ctx, params, opts...)
|
||||
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
|
||||
defer stream.Close()
|
||||
|
||||
var resp *responses.Response
|
||||
for stream.Next() {
|
||||
evt := stream.Current()
|
||||
if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" {
|
||||
evtResp := evt.Response
|
||||
if evtResp.ID != "" {
|
||||
copy := evtResp
|
||||
resp = ©
|
||||
}
|
||||
}
|
||||
}
|
||||
err := stream.Err()
|
||||
if err != nil {
|
||||
fields := map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(tools),
|
||||
"account_id_present": accountID != "",
|
||||
"error": err.Error(),
|
||||
}
|
||||
var apiErr *openai.Error
|
||||
if errors.As(err, &apiErr) {
|
||||
fields["status_code"] = apiErr.StatusCode
|
||||
fields["api_type"] = apiErr.Type
|
||||
fields["api_code"] = apiErr.Code
|
||||
fields["api_param"] = apiErr.Param
|
||||
fields["api_message"] = apiErr.Message
|
||||
if apiErr.StatusCode == 400 {
|
||||
fields["hint"] = "verify account id header and model compatibility for codex backend"
|
||||
}
|
||||
if apiErr.Response != nil {
|
||||
fields["request_id"] = apiErr.Response.Header.Get("x-request-id")
|
||||
}
|
||||
}
|
||||
logger.ErrorCF("provider.codex", "Codex API call failed", fields)
|
||||
return nil, fmt.Errorf("codex API call: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
fields := map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(tools),
|
||||
"account_id_present": accountID != "",
|
||||
}
|
||||
logger.ErrorCF("provider.codex", "Codex stream ended without completed response event", fields)
|
||||
return nil, fmt.Errorf("codex API call: stream ended without completed response")
|
||||
}
|
||||
|
||||
return parseCodexResponse(resp), nil
|
||||
}
|
||||
|
||||
func (p *CodexProvider) GetDefaultModel() string {
|
||||
return "gpt-4o"
|
||||
return codexDefaultModel
|
||||
}
|
||||
|
||||
func resolveCodexModel(model string) (string, string) {
|
||||
m := strings.ToLower(strings.TrimSpace(model))
|
||||
if m == "" {
|
||||
return codexDefaultModel, "empty model"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(m, "openai/") {
|
||||
m = strings.TrimPrefix(m, "openai/")
|
||||
} else if strings.Contains(m, "/") {
|
||||
return codexDefaultModel, "non-openai model namespace"
|
||||
}
|
||||
|
||||
unsupportedPrefixes := []string{
|
||||
"glm",
|
||||
"claude",
|
||||
"anthropic",
|
||||
"gemini",
|
||||
"google",
|
||||
"moonshot",
|
||||
"kimi",
|
||||
"qwen",
|
||||
"deepseek",
|
||||
"llama",
|
||||
"meta-llama",
|
||||
"mistral",
|
||||
"grok",
|
||||
"xai",
|
||||
"zhipu",
|
||||
}
|
||||
for _, prefix := range unsupportedPrefixes {
|
||||
if strings.HasPrefix(m, prefix) {
|
||||
return codexDefaultModel, "unsupported model prefix"
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(m, "gpt-") || strings.HasPrefix(m, "o3") || strings.HasPrefix(m, "o4") {
|
||||
return m, ""
|
||||
}
|
||||
|
||||
return codexDefaultModel, "unsupported model family"
|
||||
}
|
||||
|
||||
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
||||
@@ -133,21 +249,21 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
Input: responses.ResponseNewParamsInputUnion{
|
||||
OfInputItemList: inputItems,
|
||||
},
|
||||
Store: openai.Opt(false),
|
||||
Instructions: openai.Opt(instructions),
|
||||
Store: openai.Opt(false),
|
||||
}
|
||||
|
||||
if instructions != "" {
|
||||
params.Instructions = openai.Opt(instructions)
|
||||
} else {
|
||||
// ChatGPT Codex backend requires instructions to be present.
|
||||
params.Instructions = openai.Opt(defaultCodexInstructions)
|
||||
}
|
||||
|
||||
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||
}
|
||||
|
||||
if temp, ok := options["temperature"].(float64); ok {
|
||||
params.Temperature = openai.Opt(temp)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateToolsForCodex(tools)
|
||||
}
|
||||
@@ -237,6 +353,9 @@ func createCodexTokenSource() func() (string, string, error) {
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
if refreshed.AccountID == "" {
|
||||
refreshed.AccountID = cred.AccountID
|
||||
}
|
||||
if err := auth.SetCredential("openai", refreshed); err != nil {
|
||||
return "", "", fmt.Errorf("saving refreshed token: %w", err)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -16,11 +17,18 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
||||
"max_tokens": 2048,
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
if params.Model != "gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||
}
|
||||
if !params.Instructions.Valid() {
|
||||
t.Fatal("Instructions should be set")
|
||||
}
|
||||
if params.Instructions.Or("") != defaultCodexInstructions {
|
||||
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), defaultCodexInstructions)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||
@@ -197,6 +205,16 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
@@ -220,8 +238,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
@@ -244,10 +261,185 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer refreshed-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
|
||||
http.Error(w, "missing account id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["instructions"]; !ok {
|
||||
http.Error(w, "missing instructions", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["instructions"] == "" {
|
||||
http.Error(w, "instructions must not be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["temperature"]; ok {
|
||||
http.Error(w, "temperature is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]interface{}{
|
||||
{
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "output_text", "text": "Hi from Codex!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 8,
|
||||
"output_tokens": 4,
|
||||
"total_tokens": 12,
|
||||
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("stale-token", "acc-123")
|
||||
provider.client = createOpenAITestClient(server.URL, "stale-token", "")
|
||||
provider.tokenSource = func() (string, string, error) {
|
||||
return "refreshed-token", "", nil
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"temperature": 0.7})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hi from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["model"] != codexDefaultModel {
|
||||
http.Error(w, "unsupported model", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["instructions"] != codexDefaultInstructions {
|
||||
http.Error(w, "missing default instructions", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]interface{}{
|
||||
{
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "output_text", "text": "Hi from Codex!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 8,
|
||||
"output_tokens": 4,
|
||||
"total_tokens": 12,
|
||||
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("test-token", "acc-123")
|
||||
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-5.2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hi from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_GetDefaultModel(t *testing.T) {
|
||||
p := NewCodexProvider("test-token", "")
|
||||
if got := p.GetDefaultModel(); got != "gpt-4o" {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
|
||||
if got := p.GetDefaultModel(); got != codexDefaultModel {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, codexDefaultModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCodexModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantModel string
|
||||
wantFallback bool
|
||||
}{
|
||||
{name: "empty", input: "", wantModel: codexDefaultModel, wantFallback: true},
|
||||
{name: "unsupported namespace", input: "anthropic/claude-3.5", wantModel: codexDefaultModel, wantFallback: true},
|
||||
{name: "non-openai prefixed", input: "glm-4.7", wantModel: codexDefaultModel, wantFallback: true},
|
||||
{name: "openai prefix", input: "openai/gpt-5.2", wantModel: "gpt-5.2", wantFallback: false},
|
||||
{name: "direct gpt", input: "gpt-4o", wantModel: "gpt-4o", wantFallback: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotModel, reason := resolveCodexModel(tt.input)
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("resolveCodexModel(%q) model = %q, want %q", tt.input, gotModel, tt.wantModel)
|
||||
}
|
||||
if tt.wantFallback && reason == "" {
|
||||
t.Fatalf("resolveCodexModel(%q) expected fallback reason", tt.input)
|
||||
}
|
||||
if !tt.wantFallback && reason != "" {
|
||||
t.Fatalf("resolveCodexModel(%q) unexpected fallback reason: %q", tt.input, reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,3 +454,16 @@ func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
|
||||
c := openai.NewClient(opts...)
|
||||
return &c
|
||||
}
|
||||
|
||||
func writeCompletedSSE(w http.ResponseWriter, response map[string]interface{}) {
|
||||
event := map[string]interface{}{
|
||||
"type": "response.completed",
|
||||
"sequence_number": 1,
|
||||
"response": response,
|
||||
}
|
||||
b, _ := json.Marshal(event)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
fmt.Fprintf(w, "event: response.completed\n")
|
||||
fmt.Fprintf(w, "data: %s\n\n", string(b))
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
json "encoding/json"
|
||||
|
||||
copilot "github.com/github/copilot-sdk/go"
|
||||
)
|
||||
|
||||
type GitHubCopilotProvider struct {
|
||||
uri string
|
||||
connectMode string // `stdio` or `grpc``
|
||||
|
||||
session *copilot.Session
|
||||
}
|
||||
|
||||
func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*GitHubCopilotProvider, error) {
|
||||
|
||||
var session *copilot.Session
|
||||
if connectMode == "" {
|
||||
connectMode = "grpc"
|
||||
}
|
||||
switch connectMode {
|
||||
|
||||
case "stdio":
|
||||
//todo
|
||||
case "grpc":
|
||||
client := copilot.NewClient(&copilot.ClientOptions{
|
||||
CLIUrl: uri,
|
||||
})
|
||||
if err := client.Start(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("Can't connect to Github Copilot, https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server for details")
|
||||
}
|
||||
defer client.Stop()
|
||||
session, _ = client.CreateSession(context.Background(), &copilot.SessionConfig{
|
||||
Model: model,
|
||||
Hooks: &copilot.SessionHooks{},
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return &GitHubCopilotProvider{
|
||||
uri: uri,
|
||||
connectMode: connectMode,
|
||||
session: session,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Chat sends a chat request to GitHub Copilot
|
||||
func (p *GitHubCopilotProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
type tempMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
out := make([]tempMessage, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
out = append(out, tempMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
fullcontent, _ := json.Marshal(out)
|
||||
|
||||
content, _ := p.session.Send(ctx, copilot.MessageOptions{
|
||||
Prompt: string(fullcontent),
|
||||
})
|
||||
|
||||
return &LLMResponse{
|
||||
FinishReason: "stop",
|
||||
Content: content,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *GitHubCopilotProvider) GetDefaultModel() string {
|
||||
|
||||
return "gpt-4.1"
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -28,7 +29,7 @@ type HTTPProvider struct {
|
||||
|
||||
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
client := &http.Client{
|
||||
Timeout: 0,
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
@@ -52,10 +53,10 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5)
|
||||
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b)
|
||||
if idx := strings.Index(model, "/"); idx != -1 {
|
||||
prefix := model[:idx]
|
||||
if prefix == "moonshot" || prefix == "nvidia" {
|
||||
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" {
|
||||
model = model[idx+1:]
|
||||
}
|
||||
}
|
||||
@@ -239,6 +240,9 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
}
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
return createCodexAuthProvider()
|
||||
}
|
||||
@@ -298,11 +302,17 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claudecode", "claude-code":
|
||||
workspace := cfg.Agents.Defaults.Workspace
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), nil
|
||||
case "codex-cli", "codex-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewCodexCliProvider(workspace), nil
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
@@ -314,7 +324,16 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
apiBase = "localhost:4321"
|
||||
}
|
||||
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Fallback: detect provider from model name
|
||||
@@ -390,7 +409,15 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
if apiBase == "" {
|
||||
apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
fmt.Println("Ollama provider selected based on model name prefix")
|
||||
apiKey = cfg.Providers.Ollama.APIKey
|
||||
apiBase = cfg.Providers.Ollama.APIBase
|
||||
proxy = cfg.Providers.Ollama.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "http://localhost:11434/v1"
|
||||
}
|
||||
fmt.Println("Ollama apiBase:", apiBase)
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// extractToolCallsFromText parses tool call JSON from response text.
|
||||
// Both ClaudeCliProvider and CodexCliProvider use this to extract
|
||||
// tool calls that the model outputs in its response text.
|
||||
func extractToolCallsFromText(text string) []ToolCall {
|
||||
start := strings.Index(text, `{"tool_calls"`)
|
||||
if start == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
end := findMatchingBrace(text, start)
|
||||
if end == start {
|
||||
return nil
|
||||
}
|
||||
|
||||
jsonStr := text[start:end]
|
||||
|
||||
var wrapper struct {
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []ToolCall
|
||||
for _, tc := range wrapper.ToolCalls {
|
||||
var args map[string]interface{}
|
||||
json.Unmarshal([]byte(tc.Function.Arguments), &args)
|
||||
|
||||
result = append(result, ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
Function: &FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// stripToolCallsFromText removes tool call JSON from response text.
|
||||
func stripToolCallsFromText(text string) string {
|
||||
start := strings.Index(text, `{"tool_calls"`)
|
||||
if start == -1 {
|
||||
return text
|
||||
}
|
||||
|
||||
end := findMatchingBrace(text, start)
|
||||
if end == start {
|
||||
return text
|
||||
}
|
||||
|
||||
return strings.TrimSpace(text[:start] + text[end:])
|
||||
}
|
||||
+33
-3
@@ -145,13 +145,27 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
|
||||
session.Updated = time.Now()
|
||||
}
|
||||
|
||||
// sanitizeFilename converts a session key into a cross-platform safe filename.
|
||||
// Session keys use "channel:chatID" (e.g. "telegram:123456") but ':' is the
|
||||
// volume separator on Windows, so filepath.Base would misinterpret the key.
|
||||
// We replace it with '_'. The original key is preserved inside the JSON file,
|
||||
// so loadSessions still maps back to the right in-memory key.
|
||||
func sanitizeFilename(key string) string {
|
||||
return strings.ReplaceAll(key, ":", "_")
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Save(key string) error {
|
||||
if sm.storage == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate key to avoid invalid filenames and path traversal.
|
||||
if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") {
|
||||
filename := sanitizeFilename(key)
|
||||
|
||||
// filepath.IsLocal rejects empty names, "..", absolute paths, and
|
||||
// OS-reserved device names (NUL, COM1 … on Windows).
|
||||
// The extra checks reject "." and any directory separators so that
|
||||
// the session file is always written directly inside sm.storage.
|
||||
if filename == "." || !filepath.IsLocal(filename) || strings.ContainsAny(filename, `/\`) {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
@@ -182,7 +196,7 @@ func (sm *SessionManager) Save(key string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
sessionPath := filepath.Join(sm.storage, key+".json")
|
||||
sessionPath := filepath.Join(sm.storage, filename+".json")
|
||||
tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -250,3 +264,19 @@ func (sm *SessionManager) loadSessions() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetHistory updates the messages of a session.
|
||||
func (sm *SessionManager) SetHistory(key string, history []providers.Message) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
session, ok := sm.sessions[key]
|
||||
if ok {
|
||||
// Create a deep copy to strictly isolate internal state
|
||||
// from the caller's slice.
|
||||
msgs := make([]providers.Message, len(history))
|
||||
copy(msgs, history)
|
||||
session.Messages = msgs
|
||||
session.Updated = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"simple", "simple"},
|
||||
{"telegram:123456", "telegram_123456"},
|
||||
{"discord:987654321", "discord_987654321"},
|
||||
{"slack:C01234", "slack_C01234"},
|
||||
{"no-colons-here", "no-colons-here"},
|
||||
{"multiple:colons:here", "multiple_colons_here"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := sanitizeFilename(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("sanitizeFilename(%q) = %q, want %q", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_WithColonInKey(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
sm := NewSessionManager(tmpDir)
|
||||
|
||||
// Create a session with a key containing colon (typical channel session key).
|
||||
key := "telegram:123456"
|
||||
sm.GetOrCreate(key)
|
||||
sm.AddMessage(key, "user", "hello")
|
||||
|
||||
// Save should succeed even though the key contains ':'
|
||||
if err := sm.Save(key); err != nil {
|
||||
t.Fatalf("Save(%q) failed: %v", key, err)
|
||||
}
|
||||
|
||||
// The file on disk should use sanitized name.
|
||||
expectedFile := filepath.Join(tmpDir, "telegram_123456.json")
|
||||
if _, err := os.Stat(expectedFile); os.IsNotExist(err) {
|
||||
t.Fatalf("expected session file %s to exist", expectedFile)
|
||||
}
|
||||
|
||||
// Load into a fresh manager and verify the session round-trips.
|
||||
sm2 := NewSessionManager(tmpDir)
|
||||
history := sm2.GetHistory(key)
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1 message after reload, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "hello" {
|
||||
t.Errorf("expected message content %q, got %q", "hello", history[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_RejectsPathTraversal(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
sm := NewSessionManager(tmpDir)
|
||||
|
||||
badKeys := []string{"", ".", "..", "foo/bar", "foo\\bar"}
|
||||
for _, key := range badKeys {
|
||||
sm.GetOrCreate(key)
|
||||
if err := sm.Save(key); err == nil {
|
||||
t.Errorf("Save(%q) should have failed but didn't", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,13 +2,22 @@ package skills
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`)
|
||||
|
||||
const (
|
||||
MaxNameLength = 64
|
||||
MaxDescriptionLength = 1024
|
||||
)
|
||||
|
||||
type SkillMetadata struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
@@ -21,6 +30,27 @@ type SkillInfo struct {
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
func (info SkillInfo) validate() error {
|
||||
var errs error
|
||||
if info.Name == "" {
|
||||
errs = errors.Join(errs, errors.New("name is required"))
|
||||
} else {
|
||||
if len(info.Name) > MaxNameLength {
|
||||
errs = errors.Join(errs, fmt.Errorf("name exceeds %d characters", MaxNameLength))
|
||||
}
|
||||
if !namePattern.MatchString(info.Name) {
|
||||
errs = errors.Join(errs, errors.New("name must be alphanumeric with hyphens"))
|
||||
}
|
||||
}
|
||||
|
||||
if info.Description == "" {
|
||||
errs = errors.Join(errs, errors.New("description is required"))
|
||||
} else if len(info.Description) > MaxDescriptionLength {
|
||||
errs = errors.Join(errs, fmt.Errorf("description exceeds %d character", MaxDescriptionLength))
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
type SkillsLoader struct {
|
||||
workspace string
|
||||
workspaceSkills string // workspace skills (项目级别)
|
||||
@@ -54,6 +84,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
|
||||
metadata := sl.getSkillMetadata(skillFile)
|
||||
if metadata != nil {
|
||||
info.Description = metadata.Description
|
||||
info.Name = metadata.Name
|
||||
}
|
||||
if err := info.validate(); err != nil {
|
||||
slog.Warn("invalid skill from workspace", "name", info.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
skills = append(skills, info)
|
||||
}
|
||||
@@ -89,6 +124,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
|
||||
metadata := sl.getSkillMetadata(skillFile)
|
||||
if metadata != nil {
|
||||
info.Description = metadata.Description
|
||||
info.Name = metadata.Name
|
||||
}
|
||||
if err := info.validate(); err != nil {
|
||||
slog.Warn("invalid skill from global", "name", info.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
skills = append(skills, info)
|
||||
}
|
||||
@@ -123,6 +163,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
|
||||
metadata := sl.getSkillMetadata(skillFile)
|
||||
if metadata != nil {
|
||||
info.Description = metadata.Description
|
||||
info.Name = metadata.Name
|
||||
}
|
||||
if err := info.validate(); err != nil {
|
||||
slog.Warn("invalid skill from builtin", "name", info.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
skills = append(skills, info)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSkillsInfoValidate(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
skillName string
|
||||
description string
|
||||
wantErr bool
|
||||
errContains []string
|
||||
}{
|
||||
{
|
||||
name: "valid-skill",
|
||||
skillName: "valid-skill",
|
||||
description: "a valid skill description",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty-name",
|
||||
skillName: "",
|
||||
description: "description without name",
|
||||
wantErr: true,
|
||||
errContains: []string{"name is required"},
|
||||
},
|
||||
{
|
||||
name: "empty-description",
|
||||
skillName: "skill-without-description",
|
||||
description: "",
|
||||
wantErr: true,
|
||||
errContains: []string{"description is required"},
|
||||
},
|
||||
{
|
||||
name: "empty-both",
|
||||
skillName: "",
|
||||
description: "",
|
||||
wantErr: true,
|
||||
errContains: []string{"name is required", "description is required"},
|
||||
},
|
||||
{
|
||||
name: "name-with-spaces",
|
||||
skillName: "skill with spaces",
|
||||
description: "invalid name with spaces",
|
||||
wantErr: true,
|
||||
errContains: []string{"name must be alphanumeric with hyphens"},
|
||||
},
|
||||
{
|
||||
name: "name-with-underscore",
|
||||
skillName: "skill_underscore",
|
||||
description: "invalid name with underscore",
|
||||
wantErr: true,
|
||||
errContains: []string{"name must be alphanumeric with hyphens"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
info := SkillInfo{
|
||||
Name: tc.skillName,
|
||||
Description: tc.description,
|
||||
}
|
||||
err := info.validate()
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
for _, msg := range tc.errContains {
|
||||
assert.ErrorContains(t, err, msg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+2
-2
@@ -28,12 +28,12 @@ type CronTool struct {
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string) *CronTool {
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool) *CronTool {
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: NewExecTool(workspace, false),
|
||||
execTool: NewExecTool(workspace, restrict),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+43
-2
@@ -29,13 +29,54 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if restrict && !strings.HasPrefix(absPath, absWorkspace) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
if restrict {
|
||||
if !isWithinWorkspace(absPath, absWorkspace) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
}
|
||||
|
||||
workspaceReal := absWorkspace
|
||||
if resolved, err := filepath.EvalSymlinks(absWorkspace); err == nil {
|
||||
workspaceReal = resolved
|
||||
}
|
||||
|
||||
if resolved, err := filepath.EvalSymlinks(absPath); err == nil {
|
||||
if !isWithinWorkspace(resolved, workspaceReal) {
|
||||
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
|
||||
}
|
||||
} else if os.IsNotExist(err) {
|
||||
if parentResolved, err := resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
|
||||
if !isWithinWorkspace(parentResolved, workspaceReal) {
|
||||
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("failed to resolve path: %w", err)
|
||||
}
|
||||
} else {
|
||||
return "", fmt.Errorf("failed to resolve path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
func resolveExistingAncestor(path string) (string, error) {
|
||||
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
|
||||
if resolved, err := filepath.EvalSymlinks(current); err == nil {
|
||||
return resolved, nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return "", err
|
||||
}
|
||||
if filepath.Dir(current) == current {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isWithinWorkspace(candidate, workspace string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
|
||||
return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))
|
||||
}
|
||||
|
||||
type ReadFileTool struct {
|
||||
workspace string
|
||||
restrict bool
|
||||
|
||||
@@ -247,3 +247,35 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
|
||||
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Block paths that look inside workspace but point outside via symlink.
|
||||
func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
|
||||
|
||||
root := t.TempDir()
|
||||
workspace := filepath.Join(root, "workspace")
|
||||
if err := os.MkdirAll(workspace, 0755); err != nil {
|
||||
t.Fatalf("failed to create workspace: %v", err)
|
||||
}
|
||||
|
||||
secret := filepath.Join(root, "secret.txt")
|
||||
if err := os.WriteFile(secret, []byte("top secret"), 0644); err != nil {
|
||||
t.Fatalf("failed to write secret file: %v", err)
|
||||
}
|
||||
|
||||
link := filepath.Join(workspace, "leak.txt")
|
||||
if err := os.Symlink(secret, link); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileTool(workspace, true)
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"path": link,
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape to be blocked")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") {
|
||||
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
+10
-6
@@ -173,19 +173,23 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebSearch_NoApiKey verifies that nil is returned when no provider is configured
|
||||
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
|
||||
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "", BraveMaxResults: 5})
|
||||
|
||||
// Should return nil when no provider is enabled
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil when no search provider is configured")
|
||||
t.Errorf("Expected nil tool when Brave API key is empty")
|
||||
}
|
||||
|
||||
// Also nil when nothing is enabled
|
||||
tool = NewWebSearchTool(WebSearchToolOptions{})
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil tool when no provider is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
||||
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "test-key", BraveMaxResults: 5, BraveEnabled: true})
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
|
||||
+1
-2
@@ -73,9 +73,8 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
}
|
||||
|
||||
// Generate unique filename with UUID prefix to prevent conflicts
|
||||
ext := filepath.Ext(filename)
|
||||
safeName := SanitizeFilename(filename)
|
||||
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
|
||||
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName)
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
|
||||
Reference in New Issue
Block a user