merge: sync upstream/main into feat/multi-agent-routing

Resolve conflicts:
- pkg/agent/loop.go: integrate context compression, command handling,
  utf8 token estimation, and summarization notification into
  multi-agent routing architecture
- pkg/config/config_test.go: merge imports from both branches
- pkg/agent/loop_test.go: update test to use registry-based sessions
This commit is contained in:
Leandro Barbosa
2026-02-16 10:34:55 -03:00
69 changed files with 4550 additions and 559 deletions
+242 -30
View File
@@ -14,8 +14,10 @@ import (
"sync"
"sync/atomic"
"time"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -27,13 +29,14 @@ import (
)
type AgentLoop struct {
bus *bus.MessageBus
cfg *config.Config
registry *AgentRegistry
state *state.Manager
running atomic.Bool
summarizing sync.Map
fallback *providers.FallbackChain
bus *bus.MessageBus
cfg *config.Config
registry *AgentRegistry
state *state.Manager
running atomic.Bool
summarizing sync.Map
fallback *providers.FallbackChain
channelManager *channels.Manager
}
// processOptions configures how a message is processed
@@ -183,6 +186,10 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) {
}
}
func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
al.channelManager = cm
}
// RecordLastChannel records the last active channel for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChannel(channel string) error {
@@ -254,6 +261,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
return al.processSystemMessage(ctx, msg)
}
// Check for commands
if response, handled := al.handleCommand(ctx, msg); handled {
return response, nil
}
// Route to determine agent and session key
route := al.registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
@@ -404,7 +416,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
// 7. Optional: summarization
if opts.EnableSummary {
al.maybeSummarize(agent, opts.SessionKey)
al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID)
}
// 8. Optional: send response via bus
@@ -472,32 +484,72 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
var response *providers.LLMResponse
var err error
if len(agent.Candidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
},
)
if fbErr != nil {
err = fbErr
} else {
response = fbResult.Response
callLLM := func() (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
},
)
if fbErr != nil {
return nil, fbErr
}
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
map[string]interface{}{"agent_id": agent.ID, "iteration": iteration})
}
return fbResult.Response, nil
}
} else {
response, err = agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
}
// Retry loop for context/token errors
maxRetries := 2
for retry := 0; retry <= maxRetries; retry++ {
response, err = callLLM()
if err == nil {
break
}
errMsg := strings.ToLower(err.Error())
isContextError := strings.Contains(errMsg, "token") ||
strings.Contains(errMsg, "context") ||
strings.Contains(errMsg, "invalidparameter") ||
strings.Contains(errMsg, "length")
if isContextError && retry < maxRetries {
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]interface{}{
"error": err.Error(),
"retry": retry,
})
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: "Context window exceeded. Compressing history and retrying...",
})
}
al.forceCompression(agent, opts.SessionKey)
newHistory := agent.Sessions.GetHistory(opts.SessionKey)
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
messages = agent.ContextBuilder.BuildMessages(
newHistory, newSummary, "",
nil, opts.Channel, opts.ChatID,
)
continue
}
break
}
if err != nil {
logger.ErrorCF("agent", "LLM call failed",
map[string]interface{}{
@@ -505,7 +557,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
"iteration": iteration,
"error": err.Error(),
})
return "", iteration, fmt.Errorf("LLM call failed: %w", err)
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
}
// Check if no tool calls - we're done
@@ -639,7 +691,7 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string) {
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
newHistory := agent.Sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
threshold := agent.ContextWindow * 75 / 100
@@ -649,12 +701,79 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string) {
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
go func() {
defer al.summarizing.Delete(summarizeKey)
if !constants.IsInternalChannel(channel) {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: "Memory threshold reached. Optimizing conversation history...",
})
}
al.summarizeSession(agent, sessionKey)
}()
}
}
}
// forceCompression aggressively reduces context when the limit is hit.
// It drops the oldest 50% of messages (keeping system prompt and last user message).
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
history := agent.Sessions.GetHistory(sessionKey)
if len(history) <= 4 {
return
}
// Keep system prompt (usually [0]) and the very last message (user's trigger)
// We want to drop the oldest half of the *conversation*
// Assuming [0] is system, [1:] is conversation
conversation := history[1 : len(history)-1]
if len(conversation) == 0 {
return
}
// Helper to find the mid-point of the conversation
mid := len(conversation) / 2
// New history structure:
// 1. System Prompt
// 2. [Summary of dropped part] - synthesized
// 3. Second half of conversation
// 4. Last message
// Simplified approach for emergency: Drop first half of conversation
// and rely on existing summary if present, or create a placeholder.
droppedCount := mid
keptConversation := conversation[mid:]
newHistory := make([]providers.Message, 0)
newHistory = append(newHistory, history[0]) // System prompt
// Add a note about compression
compressionNote := fmt.Sprintf("[System: Emergency compression dropped %d oldest messages due to context limit]", droppedCount)
// If there was an existing summary, we might lose it if it was in the dropped part (which is just messages).
// The summary is stored separately in session.Summary, so it persists!
// We just need to ensure the user knows there's a gap.
// We only modify the messages list here
newHistory = append(newHistory, providers.Message{
Role: "system",
Content: compressionNote,
})
newHistory = append(newHistory, keptConversation...)
newHistory = append(newHistory, history[len(history)-1]) // Last message
// Update session
agent.Sessions.SetHistory(sessionKey, newHistory)
agent.Sessions.Save(sessionKey)
logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{
"session_key": sessionKey,
"dropped_msgs": droppedCount,
"new_count": len(newHistory),
})
}
// GetStartupInfo returns information about loaded tools and skills for logging.
func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
info := make(map[string]interface{})
@@ -693,7 +812,7 @@ func formatMessagesForLog(messages []providers.Message) string {
result += "[\n"
for i, msg := range messages {
result += fmt.Sprintf(" [%d] Role: %s\n", i, msg.Role)
if msg.ToolCalls != nil && len(msg.ToolCalls) > 0 {
if len(msg.ToolCalls) > 0 {
result += " ToolCalls:\n"
for _, tc := range msg.ToolCalls {
result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
@@ -758,7 +877,7 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
if m.Role != "user" && m.Role != "assistant" {
continue
}
msgTokens := len(m.Content) / 4
msgTokens := len(m.Content) / 2
if msgTokens > maxMessageTokens {
omitted = true
continue
@@ -827,12 +946,105 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, agent *AgentInstance, b
}
// estimateTokens estimates the number of tokens in a message list.
// Uses a safe heuristic of 2.5 characters per token to account for CJK and other
// overheads better than the previous 3 chars/token.
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
total := 0
totalChars := 0
for _, m := range messages {
total += len(m.Content) / 4
totalChars += utf8.RuneCountInString(m.Content)
}
return total
// 2.5 chars per token = totalChars * 2 / 5
return totalChars * 2 / 5
}
func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) {
content := strings.TrimSpace(msg.Content)
if !strings.HasPrefix(content, "/") {
return "", false
}
parts := strings.Fields(content)
if len(parts) == 0 {
return "", false
}
cmd := parts[0]
args := parts[1:]
switch cmd {
case "/show":
if len(args) < 1 {
return "Usage: /show [model|channel|agents]", true
}
switch args[0] {
case "model":
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
return "No default agent configured", true
}
return fmt.Sprintf("Current model: %s", defaultAgent.Model), true
case "channel":
return fmt.Sprintf("Current channel: %s", msg.Channel), true
case "agents":
agentIDs := al.registry.ListAgentIDs()
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
default:
return fmt.Sprintf("Unknown show target: %s", args[0]), true
}
case "/list":
if len(args) < 1 {
return "Usage: /list [models|channels|agents]", true
}
switch args[0] {
case "models":
return "Available models: configured in config.json per agent", true
case "channels":
if al.channelManager == nil {
return "Channel manager not initialized", true
}
channels := al.channelManager.GetEnabledChannels()
if len(channels) == 0 {
return "No channels enabled", true
}
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
case "agents":
agentIDs := al.registry.ListAgentIDs()
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
default:
return fmt.Sprintf("Unknown list target: %s", args[0]), true
}
case "/switch":
if len(args) < 3 || args[1] != "to" {
return "Usage: /switch [model|channel] to <name>", true
}
target := args[0]
value := args[2]
switch target {
case "model":
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
return "No default agent configured", true
}
oldModel := defaultAgent.Model
defaultAgent.Model = value
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
case "channel":
if al.channelManager == nil {
return "Channel manager not initialized", true
}
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
}
return fmt.Sprintf("Switched target channel to %s", value), true
default:
return fmt.Sprintf("Unknown switch target: %s", target), true
}
}
return "", false
}
// extractPeer extracts the routing peer from inbound message metadata.
+101
View File
@@ -2,6 +2,7 @@ package agent
import (
"context"
"fmt"
"os"
"path/filepath"
"testing"
@@ -527,3 +528,103 @@ func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
t.Errorf("Expected 'Command output: hello world', got: %s", response)
}
}
// failFirstMockProvider fails on the first N calls with a specific error
type failFirstMockProvider struct {
failures int
currentCall int
failError error
successResp string
}
func (m *failFirstMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
m.currentCall++
if m.currentCall <= m.failures {
return nil, m.failError
}
return &providers.LLMResponse{
Content: m.successResp,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *failFirstMockProvider) GetDefaultModel() string {
return "mock-fail-model"
}
// TestAgentLoop_ContextExhaustionRetry verify that the agent retries on context errors
func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
// Create a provider that fails once with a context error
contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens")
provider := &failFirstMockProvider{
failures: 1,
failError: contextErr,
successResp: "Recovered from context error",
}
al := NewAgentLoop(cfg, msgBus, provider)
// Inject some history to simulate a full context
sessionKey := "test-session-context"
// Create dummy history
history := []providers.Message{
{Role: "system", Content: "System prompt"},
{Role: "user", Content: "Old message 1"},
{Role: "assistant", Content: "Old response 1"},
{Role: "user", Content: "Old message 2"},
{Role: "assistant", Content: "Old response 2"},
{Role: "user", Content: "Trigger message"},
}
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("No default agent found")
}
defaultAgent.Sessions.SetHistory(sessionKey, history)
// Call ProcessDirectWithChannel
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
response, err := al.ProcessDirectWithChannel(context.Background(), "Trigger message", sessionKey, "test", "test-chat")
if err != nil {
t.Fatalf("Expected success after retry, got error: %v", err)
}
if response != "Recovered from context error" {
t.Errorf("Expected 'Recovered from context error', got '%s'", response)
}
// We expect 2 calls: 1st failed, 2nd succeeded
if provider.currentCall != 2 {
t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall)
}
// Check final history length
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
// We verify that the history has been modified (compressed)
// Original length: 6
// Expected behavior: compression drops ~50% of history (mid slice)
// We can assert that the length is NOT what it would be without compression.
// Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8
if len(finalHistory) >= 8 {
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
}
}
+82 -30
View File
@@ -19,18 +19,20 @@ import (
)
type OAuthProviderConfig struct {
Issuer string
ClientID string
Scopes string
Port int
Issuer string
ClientID string
Scopes string
Originator string
Port int
}
func OpenAIOAuthConfig() OAuthProviderConfig {
return OAuthProviderConfig{
Issuer: "https://auth.openai.com",
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
Scopes: "openid profile email offline_access",
Port: 1455,
Issuer: "https://auth.openai.com",
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
Scopes: "openid profile email offline_access",
Originator: "codex_cli_rs",
Port: 1455,
}
}
@@ -279,7 +281,17 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
return parseTokenResponse(body, cred.Provider)
refreshed, err := parseTokenResponse(body, cred.Provider)
if err != nil {
return nil, err
}
if refreshed.RefreshToken == "" {
refreshed.RefreshToken = cred.RefreshToken
}
if refreshed.AccountID == "" {
refreshed.AccountID = cred.AccountID
}
return refreshed, nil
}
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
@@ -288,15 +300,23 @@ func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
params := url.Values{
"response_type": {"code"},
"client_id": {cfg.ClientID},
"redirect_uri": {redirectURI},
"scope": {cfg.Scopes},
"code_challenge": {pkce.CodeChallenge},
"code_challenge_method": {"S256"},
"state": {state},
"response_type": {"code"},
"client_id": {cfg.ClientID},
"redirect_uri": {redirectURI},
"scope": {cfg.Scopes},
"code_challenge": {pkce.CodeChallenge},
"code_challenge_method": {"S256"},
"id_token_add_organizations": {"true"},
"codex_cli_simplified_flow": {"true"},
"state": {state},
}
return cfg.Issuer + "/authorize?" + params.Encode()
if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") {
params.Set("originator", "picoclaw")
}
if cfg.Originator != "" {
params.Set("originator", cfg.Originator)
}
return cfg.Issuer + "/oauth/authorize?" + params.Encode()
}
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
@@ -350,19 +370,57 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
AuthMethod: "oauth",
}
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
cred.AccountID = accountID
} else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
cred.AccountID = accountID
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
cred.AccountID = accountID
}
return cred, nil
}
func extractAccountID(accessToken string) string {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
func extractAccountID(token string) string {
claims, err := parseJWTClaims(token)
if err != nil {
return ""
}
if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
}
if orgs, ok := claims["organizations"].([]interface{}); ok {
for _, org := range orgs {
if orgMap, ok := org.(map[string]interface{}); ok {
if accountID, ok := orgMap["id"].(string); ok && accountID != "" {
return accountID
}
}
}
}
return ""
}
func parseJWTClaims(token string) (map[string]interface{}, error) {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("token is not a JWT")
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
@@ -373,21 +431,15 @@ func extractAccountID(accessToken string) string {
decoded, err := base64URLDecode(payload)
if err != nil {
return ""
return nil, err
}
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
return nil, err
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
return accountID
}
}
return ""
return claims, nil
}
func base64URLDecode(s string) ([]byte, error) {
+139 -5
View File
@@ -1,19 +1,34 @@
package auth
import (
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string {
t.Helper()
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
payloadJSON, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
return header + "." + payload + ".sig"
}
func TestBuildAuthorizeURL(t *testing.T) {
cfg := OAuthProviderConfig{
Issuer: "https://auth.example.com",
ClientID: "test-client-id",
Scopes: "openid profile",
Port: 1455,
Issuer: "https://auth.example.com",
ClientID: "test-client-id",
Scopes: "openid profile",
Originator: "codex_cli_rs",
Port: 1455,
}
pkce := PKCECodes{
CodeVerifier: "test-verifier",
@@ -22,7 +37,7 @@ func TestBuildAuthorizeURL(t *testing.T) {
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
if !strings.HasPrefix(u, "https://auth.example.com/oauth/authorize?") {
t.Errorf("URL does not start with expected prefix: %s", u)
}
if !strings.Contains(u, "client_id=test-client-id") {
@@ -40,6 +55,37 @@ func TestBuildAuthorizeURL(t *testing.T) {
if !strings.Contains(u, "response_type=code") {
t.Error("URL missing response_type")
}
if !strings.Contains(u, "id_token_add_organizations=true") {
t.Error("URL missing id_token_add_organizations")
}
if !strings.Contains(u, "codex_cli_simplified_flow=true") {
t.Error("URL missing codex_cli_simplified_flow")
}
if !strings.Contains(u, "originator=codex_cli_rs") {
t.Error("URL missing originator")
}
}
func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) {
cfg := OpenAIOAuthConfig()
pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"}
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
parsed, err := url.Parse(u)
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
q := parsed.Query()
if q.Get("id_token_add_organizations") != "true" {
t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations"))
}
if q.Get("codex_cli_simplified_flow") != "true" {
t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow"))
}
if q.Get("originator") != "codex_cli_rs" {
t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator"))
}
}
func TestParseTokenResponse(t *testing.T) {
@@ -73,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) {
}
}
func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) {
idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"})
resp := map[string]interface{}{
"access_token": "opaque-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"id_token": idToken,
}
body, _ := json.Marshal(resp)
cred, err := parseTokenResponse(body, "openai")
if err != nil {
t.Fatalf("parseTokenResponse() error: %v", err)
}
if cred.AccountID != "acc-id-from-id-token" {
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token")
}
}
func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) {
token := makeJWTForClaims(t, map[string]interface{}{
"organizations": []interface{}{
map[string]interface{}{"id": "org_from_orgs"},
},
})
if got := extractAccountID(token); got != "org_from_orgs" {
t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs")
}
}
func TestParseTokenResponseNoAccessToken(t *testing.T) {
body := []byte(`{"refresh_token": "test"}`)
_, err := parseTokenResponse(body, "openai")
@@ -81,6 +158,32 @@ func TestParseTokenResponseNoAccessToken(t *testing.T) {
}
}
func TestParseTokenResponseAccountIDFromIDToken(t *testing.T) {
idToken := makeJWTWithAccountID("acc-from-id")
resp := map[string]interface{}{
"access_token": "not-a-jwt",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"id_token": idToken,
}
body, _ := json.Marshal(resp)
cred, err := parseTokenResponse(body, "openai")
if err != nil {
t.Fatalf("parseTokenResponse() error: %v", err)
}
if cred.AccountID != "acc-from-id" {
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-from-id")
}
}
func makeJWTWithAccountID(accountID string) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
payload := base64.RawURLEncoding.EncodeToString([]byte(`{"https://api.openai.com/auth":{"chatgpt_account_id":"` + accountID + `"}}`))
return header + "." + payload + ".sig"
}
func TestExchangeCodeForTokens(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
@@ -185,6 +288,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
}
}
func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]interface{}{
"access_token": "new-access-token-only",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"}
cred := &AuthCredential{
AccessToken: "old-access",
RefreshToken: "existing-refresh",
AccountID: "acc_existing",
Provider: "openai",
AuthMethod: "oauth",
}
refreshed, err := RefreshAccessToken(cred, cfg)
if err != nil {
t.Fatalf("RefreshAccessToken() error: %v", err)
}
if refreshed.RefreshToken != "existing-refresh" {
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh")
}
if refreshed.AccountID != "acc_existing" {
t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing")
}
}
func TestOpenAIOAuthConfig(t *testing.T) {
cfg := OpenAIOAuthConfig()
if cfg.Issuer != "https://auth.openai.com" {
+17
View File
@@ -9,6 +9,7 @@ type MessageBus struct {
inbound chan InboundMessage
outbound chan OutboundMessage
handlers map[string]MessageHandler
closed bool
mu sync.RWMutex
}
@@ -21,6 +22,11 @@ func NewMessageBus() *MessageBus {
}
func (mb *MessageBus) PublishInbound(msg InboundMessage) {
mb.mu.RLock()
defer mb.mu.RUnlock()
if mb.closed {
return
}
mb.inbound <- msg
}
@@ -34,6 +40,11 @@ func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool)
}
func (mb *MessageBus) PublishOutbound(msg OutboundMessage) {
mb.mu.RLock()
defer mb.mu.RUnlock()
if mb.closed {
return
}
mb.outbound <- msg
}
@@ -60,6 +71,12 @@ func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) {
}
func (mb *MessageBus) Close() {
mb.mu.Lock()
defer mb.mu.Unlock()
if mb.closed {
return
}
mb.closed = true
close(mb.inbound)
close(mb.outbound)
}
+150 -2
View File
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"strings"
"time"
"github.com/bwmarrin/discordgo"
@@ -100,15 +101,156 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
return fmt.Errorf("channel ID is empty")
}
message := msg.Content
runes := []rune(msg.Content)
if len(runes) == 0 {
return nil
}
chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks
for _, chunk := range chunks {
if err := c.sendChunk(ctx, channelID, chunk); err != nil {
return err
}
}
return nil
}
// splitMessage splits long messages into chunks, preserving code block integrity
// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks
func splitMessage(content string, limit int) []string {
var messages []string
for len(content) > 0 {
if len(content) <= limit {
messages = append(messages, content)
break
}
msgEnd := limit
// Find natural split point within the limit
msgEnd = findLastNewline(content[:limit], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:limit], 100)
}
if msgEnd <= 0 {
msgEnd = limit
}
// Check if this would end with an incomplete code block
candidate := content[:msgEnd]
unclosedIdx := findLastUnclosedCodeBlock(candidate)
if unclosedIdx >= 0 {
// Message would end with incomplete code block
// Try to extend to include the closing ``` (with some buffer)
extendedLimit := limit + 500 // Allow 500 char buffer for code blocks
if len(content) > extendedLimit {
closingIdx := findNextClosingCodeBlock(content, msgEnd)
if closingIdx > 0 && closingIdx <= extendedLimit {
// Extend to include the closing ```
msgEnd = closingIdx
} else {
// Can't find closing, split before the code block
msgEnd = findLastNewline(content[:unclosedIdx], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:unclosedIdx], 100)
}
if msgEnd <= 0 {
msgEnd = unclosedIdx
}
}
} else {
// Remaining content fits within extended limit
msgEnd = len(content)
}
}
if msgEnd <= 0 {
msgEnd = limit
}
messages = append(messages, content[:msgEnd])
content = strings.TrimSpace(content[msgEnd:])
}
return messages
}
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
// Returns the position of the opening ``` or -1 if all code blocks are complete
func findLastUnclosedCodeBlock(text string) int {
count := 0
lastOpenIdx := -1
for i := 0; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
if count == 0 {
lastOpenIdx = i
}
count++
i += 2
}
}
// If odd number of ``` markers, last one is unclosed
if count%2 == 1 {
return lastOpenIdx
}
return -1
}
// findNextClosingCodeBlock finds the next closing ``` starting from a position
// Returns the position after the closing ``` or -1 if not found
func findNextClosingCodeBlock(text string, startIdx int) int {
for i := startIdx; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
return i + 3
}
}
return -1
}
// findLastNewline finds the last newline character within the last N characters
// Returns the position of the newline or -1 if not found
func findLastNewline(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == '\n' {
return i
}
}
return -1
}
// findLastSpace finds the last space character within the last N characters
// Returns the position of the space or -1 if not found
func findLastSpace(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == ' ' || s[i] == '\t' {
return i
}
}
return -1
}
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
// 使用传入的 ctx 进行超时控制
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
_, err := c.session.ChannelMessageSend(channelID, message)
_, err := c.session.ChannelMessageSend(channelID, content)
done <- err
}()
@@ -140,6 +282,12 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
if err := c.session.ChannelTyping(m.ChannelID); err != nil {
logger.ErrorCF("discord", "Failed to send typing indicator", map[string]any{
"error": err.Error(),
})
}
// 检查白名单,避免为被拒绝的用户下载附件和转录
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
+14 -1
View File
@@ -48,7 +48,7 @@ func (m *Manager) initChannels() error {
if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" {
logger.DebugC("channels", "Attempting to initialize Telegram channel")
telegram, err := NewTelegramChannel(m.config.Channels.Telegram, m.bus)
telegram, err := NewTelegramChannel(m.config, m.bus)
if err != nil {
logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]interface{}{
"error": err.Error(),
@@ -163,6 +163,19 @@ func (m *Manager) initChannels() error {
}
}
if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" {
logger.DebugC("channels", "Attempting to initialize OneBot channel")
onebot, err := NewOneBotChannel(m.config.Channels.OneBot, m.bus)
if err != nil {
logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]interface{}{
"error": err.Error(),
})
} else {
m.channels["onebot"] = onebot
logger.InfoC("channels", "OneBot channel enabled successfully")
}
}
logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
"enabled_channels": len(m.channels),
})
+686
View File
@@ -0,0 +1,686 @@
package channels
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
type OneBotChannel struct {
*BaseChannel
config config.OneBotConfig
conn *websocket.Conn
ctx context.Context
cancel context.CancelFunc
dedup map[string]struct{}
dedupRing []string
dedupIdx int
mu sync.Mutex
writeMu sync.Mutex
echoCounter int64
}
type oneBotRawEvent struct {
PostType string `json:"post_type"`
MessageType string `json:"message_type"`
SubType string `json:"sub_type"`
MessageID json.RawMessage `json:"message_id"`
UserID json.RawMessage `json:"user_id"`
GroupID json.RawMessage `json:"group_id"`
RawMessage string `json:"raw_message"`
Message json.RawMessage `json:"message"`
Sender json.RawMessage `json:"sender"`
SelfID json.RawMessage `json:"self_id"`
Time json.RawMessage `json:"time"`
MetaEventType string `json:"meta_event_type"`
Echo string `json:"echo"`
RetCode json.RawMessage `json:"retcode"`
Status BotStatus `json:"status"`
}
type BotStatus struct {
Online bool `json:"online"`
Good bool `json:"good"`
}
type oneBotSender struct {
UserID json.RawMessage `json:"user_id"`
Nickname string `json:"nickname"`
Card string `json:"card"`
}
type oneBotEvent struct {
PostType string
MessageType string
SubType string
MessageID string
UserID int64
GroupID int64
Content string
RawContent string
IsBotMentioned bool
Sender oneBotSender
SelfID int64
Time int64
MetaEventType string
}
type oneBotAPIRequest struct {
Action string `json:"action"`
Params interface{} `json:"params"`
Echo string `json:"echo,omitempty"`
}
type oneBotSendPrivateMsgParams struct {
UserID int64 `json:"user_id"`
Message string `json:"message"`
}
type oneBotSendGroupMsgParams struct {
GroupID int64 `json:"group_id"`
Message string `json:"message"`
}
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
base := NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom)
const dedupSize = 1024
return &OneBotChannel{
BaseChannel: base,
config: cfg,
dedup: make(map[string]struct{}, dedupSize),
dedupRing: make([]string, dedupSize),
dedupIdx: 0,
}, nil
}
func (c *OneBotChannel) Start(ctx context.Context) error {
if c.config.WSUrl == "" {
return fmt.Errorf("OneBot ws_url not configured")
}
logger.InfoCF("onebot", "Starting OneBot channel", map[string]interface{}{
"ws_url": c.config.WSUrl,
})
c.ctx, c.cancel = context.WithCancel(ctx)
if err := c.connect(); err != nil {
logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]interface{}{
"error": err.Error(),
})
} else {
go c.listen()
}
if c.config.ReconnectInterval > 0 {
go c.reconnectLoop()
} else {
// If reconnect is disabled but initial connection failed, we cannot recover
if c.conn == nil {
return fmt.Errorf("failed to connect to OneBot and reconnect is disabled")
}
}
c.setRunning(true)
logger.InfoC("onebot", "OneBot channel started successfully")
return nil
}
func (c *OneBotChannel) connect() error {
dialer := websocket.DefaultDialer
dialer.HandshakeTimeout = 10 * time.Second
header := make(map[string][]string)
if c.config.AccessToken != "" {
header["Authorization"] = []string{"Bearer " + c.config.AccessToken}
}
conn, _, err := dialer.Dial(c.config.WSUrl, header)
if err != nil {
return err
}
c.mu.Lock()
c.conn = conn
c.mu.Unlock()
logger.InfoC("onebot", "WebSocket connected")
return nil
}
func (c *OneBotChannel) reconnectLoop() {
interval := time.Duration(c.config.ReconnectInterval) * time.Second
if interval < 5*time.Second {
interval = 5 * time.Second
}
for {
select {
case <-c.ctx.Done():
return
case <-time.After(interval):
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
logger.InfoC("onebot", "Attempting to reconnect...")
if err := c.connect(); err != nil {
logger.ErrorCF("onebot", "Reconnect failed", map[string]interface{}{
"error": err.Error(),
})
} else {
go c.listen()
}
}
}
}
}
func (c *OneBotChannel) Stop(ctx context.Context) error {
logger.InfoC("onebot", "Stopping OneBot channel")
c.setRunning(false)
if c.cancel != nil {
c.cancel()
}
c.mu.Lock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.mu.Unlock()
return nil
}
func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return fmt.Errorf("OneBot channel not running")
}
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
return fmt.Errorf("OneBot WebSocket not connected")
}
action, params, err := c.buildSendRequest(msg)
if err != nil {
return err
}
c.writeMu.Lock()
c.echoCounter++
echo := fmt.Sprintf("send_%d", c.echoCounter)
c.writeMu.Unlock()
req := oneBotAPIRequest{
Action: action,
Params: params,
Echo: echo,
}
data, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal OneBot request: %w", err)
}
c.writeMu.Lock()
err = conn.WriteMessage(websocket.TextMessage, data)
c.writeMu.Unlock()
if err != nil {
logger.ErrorCF("onebot", "Failed to send message", map[string]interface{}{
"error": err.Error(),
})
return err
}
return nil
}
func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) {
chatID := msg.ChatID
if len(chatID) > 6 && chatID[:6] == "group:" {
groupID, err := strconv.ParseInt(chatID[6:], 10, 64)
if err != nil {
return "", nil, fmt.Errorf("invalid group ID in chatID: %s", chatID)
}
return "send_group_msg", oneBotSendGroupMsgParams{
GroupID: groupID,
Message: msg.Content,
}, nil
}
if len(chatID) > 8 && chatID[:8] == "private:" {
userID, err := strconv.ParseInt(chatID[8:], 10, 64)
if err != nil {
return "", nil, fmt.Errorf("invalid user ID in chatID: %s", chatID)
}
return "send_private_msg", oneBotSendPrivateMsgParams{
UserID: userID,
Message: msg.Content,
}, nil
}
userID, err := strconv.ParseInt(chatID, 10, 64)
if err != nil {
return "", nil, fmt.Errorf("invalid chatID for OneBot: %s", chatID)
}
return "send_private_msg", oneBotSendPrivateMsgParams{
UserID: userID,
Message: msg.Content,
}, nil
}
func (c *OneBotChannel) listen() {
for {
select {
case <-c.ctx.Done():
return
default:
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
return
}
_, message, err := conn.ReadMessage()
if err != nil {
logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{
"error": err.Error(),
})
c.mu.Lock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.mu.Unlock()
return
}
logger.DebugCF("onebot", "Raw WebSocket message received", map[string]interface{}{
"length": len(message),
"payload": string(message),
})
var raw oneBotRawEvent
if err := json.Unmarshal(message, &raw); err != nil {
logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]interface{}{
"error": err.Error(),
"payload": string(message),
})
continue
}
if raw.Echo != "" || raw.Status.Online || raw.Status.Good {
logger.DebugCF("onebot", "Received API response, skipping", map[string]interface{}{
"echo": raw.Echo,
"status": raw.Status,
})
continue
}
logger.DebugCF("onebot", "Parsed raw event", map[string]interface{}{
"post_type": raw.PostType,
"message_type": raw.MessageType,
"sub_type": raw.SubType,
"meta_event_type": raw.MetaEventType,
})
c.handleRawEvent(&raw)
}
}
}
func parseJSONInt64(raw json.RawMessage) (int64, error) {
if len(raw) == 0 {
return 0, nil
}
var n int64
if err := json.Unmarshal(raw, &n); err == nil {
return n, nil
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return strconv.ParseInt(s, 10, 64)
}
return 0, fmt.Errorf("cannot parse as int64: %s", string(raw))
}
func parseJSONString(raw json.RawMessage) string {
if len(raw) == 0 {
return ""
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s
}
return string(raw)
}
type parseMessageResult struct {
Text string
IsBotMentioned bool
}
func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult {
if len(raw) == 0 {
return parseMessageResult{}
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
mentioned := false
if selfID > 0 {
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
if strings.Contains(s, cqAt) {
mentioned = true
s = strings.ReplaceAll(s, cqAt, "")
s = strings.TrimSpace(s)
}
}
return parseMessageResult{Text: s, IsBotMentioned: mentioned}
}
var segments []map[string]interface{}
if err := json.Unmarshal(raw, &segments); err == nil {
var text string
mentioned := false
selfIDStr := strconv.FormatInt(selfID, 10)
for _, seg := range segments {
segType, _ := seg["type"].(string)
data, _ := seg["data"].(map[string]interface{})
switch segType {
case "text":
if data != nil {
if t, ok := data["text"].(string); ok {
text += t
}
}
case "at":
if data != nil && selfID > 0 {
qqVal := fmt.Sprintf("%v", data["qq"])
if qqVal == selfIDStr || qqVal == "all" {
mentioned = true
}
}
}
}
return parseMessageResult{Text: strings.TrimSpace(text), IsBotMentioned: mentioned}
}
return parseMessageResult{}
}
func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
switch raw.PostType {
case "message":
evt, err := c.normalizeMessageEvent(raw)
if err != nil {
logger.WarnCF("onebot", "Failed to normalize message event", map[string]interface{}{
"error": err.Error(),
})
return
}
c.handleMessage(evt)
case "meta_event":
c.handleMetaEvent(raw)
case "notice":
logger.DebugCF("onebot", "Notice event received", map[string]interface{}{
"sub_type": raw.SubType,
})
case "request":
logger.DebugCF("onebot", "Request event received", map[string]interface{}{
"sub_type": raw.SubType,
})
case "":
logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{
"echo": raw.Echo,
"status": raw.Status,
})
default:
logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{
"post_type": raw.PostType,
})
}
}
func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent, error) {
userID, err := parseJSONInt64(raw.UserID)
if err != nil {
return nil, fmt.Errorf("parse user_id: %w (raw: %s)", err, string(raw.UserID))
}
groupID, _ := parseJSONInt64(raw.GroupID)
selfID, _ := parseJSONInt64(raw.SelfID)
ts, _ := parseJSONInt64(raw.Time)
messageID := parseJSONString(raw.MessageID)
parsed := parseMessageContentEx(raw.Message, selfID)
isBotMentioned := parsed.IsBotMentioned
content := raw.RawMessage
if content == "" {
content = parsed.Text
} else if selfID > 0 {
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
if strings.Contains(content, cqAt) {
isBotMentioned = true
content = strings.ReplaceAll(content, cqAt, "")
content = strings.TrimSpace(content)
}
}
var sender oneBotSender
if len(raw.Sender) > 0 {
if err := json.Unmarshal(raw.Sender, &sender); err != nil {
logger.WarnCF("onebot", "Failed to parse sender", map[string]interface{}{
"error": err.Error(),
"sender": string(raw.Sender),
})
}
}
logger.DebugCF("onebot", "Normalized message event", map[string]interface{}{
"message_type": raw.MessageType,
"user_id": userID,
"group_id": groupID,
"message_id": messageID,
"content_len": len(content),
"nickname": sender.Nickname,
})
return &oneBotEvent{
PostType: raw.PostType,
MessageType: raw.MessageType,
SubType: raw.SubType,
MessageID: messageID,
UserID: userID,
GroupID: groupID,
Content: content,
RawContent: raw.RawMessage,
IsBotMentioned: isBotMentioned,
Sender: sender,
SelfID: selfID,
Time: ts,
MetaEventType: raw.MetaEventType,
}, nil
}
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
switch raw.MetaEventType {
case "lifecycle":
logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{
"sub_type": raw.SubType,
})
case "heartbeat":
logger.DebugC("onebot", "Heartbeat received")
default:
logger.DebugCF("onebot", "Unknown meta_event_type", map[string]interface{}{
"meta_event_type": raw.MetaEventType,
})
}
}
func (c *OneBotChannel) handleMessage(evt *oneBotEvent) {
if c.isDuplicate(evt.MessageID) {
logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{
"message_id": evt.MessageID,
})
return
}
content := evt.Content
if content == "" {
logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{
"message_id": evt.MessageID,
})
return
}
senderID := strconv.FormatInt(evt.UserID, 10)
var chatID string
metadata := map[string]string{
"message_id": evt.MessageID,
}
switch evt.MessageType {
case "private":
chatID = "private:" + senderID
logger.InfoCF("onebot", "Received private message", map[string]interface{}{
"sender": senderID,
"message_id": evt.MessageID,
"length": len(content),
"content": truncate(content, 100),
})
case "group":
groupIDStr := strconv.FormatInt(evt.GroupID, 10)
chatID = "group:" + groupIDStr
metadata["group_id"] = groupIDStr
senderUserID, _ := parseJSONInt64(evt.Sender.UserID)
if senderUserID > 0 {
metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10)
}
if evt.Sender.Card != "" {
metadata["sender_name"] = evt.Sender.Card
} else if evt.Sender.Nickname != "" {
metadata["sender_name"] = evt.Sender.Nickname
}
triggered, strippedContent := c.checkGroupTrigger(content, evt.IsBotMentioned)
if !triggered {
logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{
"sender": senderID,
"group": groupIDStr,
"is_mentioned": evt.IsBotMentioned,
"content": truncate(content, 100),
})
return
}
content = strippedContent
logger.InfoCF("onebot", "Received group message", map[string]interface{}{
"sender": senderID,
"group": groupIDStr,
"message_id": evt.MessageID,
"is_mentioned": evt.IsBotMentioned,
"length": len(content),
"content": truncate(content, 100),
})
default:
logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{
"type": evt.MessageType,
"message_id": evt.MessageID,
"user_id": evt.UserID,
})
return
}
if evt.Sender.Nickname != "" {
metadata["nickname"] = evt.Sender.Nickname
}
logger.DebugCF("onebot", "Forwarding message to bus", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"content": truncate(content, 100),
})
c.HandleMessage(senderID, chatID, content, []string{}, metadata)
}
func (c *OneBotChannel) isDuplicate(messageID string) bool {
if messageID == "" || messageID == "0" {
return false
}
c.mu.Lock()
defer c.mu.Unlock()
if _, exists := c.dedup[messageID]; exists {
return true
}
if old := c.dedupRing[c.dedupIdx]; old != "" {
delete(c.dedup, old)
}
c.dedupRing[c.dedupIdx] = messageID
c.dedup[messageID] = struct{}{}
c.dedupIdx = (c.dedupIdx + 1) % len(c.dedupRing)
return false
}
func truncate(s string, n int) string {
runes := []rune(s)
if len(runes) <= n {
return s
}
return string(runes[:n]) + "..."
}
func (c *OneBotChannel) checkGroupTrigger(content string, isBotMentioned bool) (triggered bool, strippedContent string) {
if isBotMentioned {
return true, strings.TrimSpace(content)
}
for _, prefix := range c.config.GroupTriggerPrefix {
if prefix == "" {
continue
}
if strings.HasPrefix(content, prefix) {
return true, strings.TrimSpace(strings.TrimPrefix(content, prefix))
}
}
return false, content
}
+14
View File
@@ -308,6 +308,13 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
return
}
if !c.IsAllowed(ev.User) {
logger.DebugCF("slack", "Mention rejected by allowlist", map[string]interface{}{
"user_id": ev.User,
})
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
@@ -367,6 +374,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
c.socketClient.Ack(*event.Request)
}
if !c.IsAllowed(cmd.UserID) {
logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]interface{}{
"user_id": cmd.UserID,
})
return
}
senderID := cmd.UserID
channelID := cmd.ChannelID
chatID := channelID
+60 -65
View File
@@ -11,7 +11,10 @@ import (
"sync"
"time"
th "github.com/mymmrac/telego/telegohandler"
"github.com/mymmrac/telego"
"github.com/mymmrac/telego/telegohandler"
tu "github.com/mymmrac/telego/telegoutil"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -24,7 +27,8 @@ import (
type TelegramChannel struct {
*BaseChannel
bot *telego.Bot
config config.TelegramConfig
commands TelegramCommander
config *config.Config
chatIDs map[string]int64
transcriber *voice.GroqTranscriber
placeholders sync.Map // chatID -> messageID
@@ -41,13 +45,14 @@ func (c *thinkingCancel) Cancel() {
}
}
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
var opts []telego.BotOption
telegramCfg := cfg.Channels.Telegram
if cfg.Proxy != "" {
proxyURL, parseErr := url.Parse(cfg.Proxy)
if telegramCfg.Proxy != "" {
proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
if parseErr != nil {
return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr)
return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr)
}
opts = append(opts, telego.WithHTTPClient(&http.Client{
Transport: &http.Transport{
@@ -56,15 +61,16 @@ func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*Telegr
}))
}
bot, err := telego.NewBot(cfg.Token, opts...)
bot, err := telego.NewBot(telegramCfg.Token, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
}
base := NewBaseChannel("telegram", cfg, bus, cfg.AllowFrom)
base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom)
return &TelegramChannel{
BaseChannel: base,
commands: NewTelegramCommands(bot, cfg),
bot: bot,
config: cfg,
chatIDs: make(map[string]int64),
@@ -88,31 +94,45 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return fmt.Errorf("failed to start long polling: %w", err)
}
bh, err := telegohandler.NewBotHandler(c.bot, updates)
if err != nil {
return fmt.Errorf("failed to create bot handler: %w", err)
}
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
c.commands.Help(ctx, message)
return nil
}, th.CommandEqual("help"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Start(ctx, message)
}, th.CommandEqual("start"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Show(ctx, message)
}, th.CommandEqual("show"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.List(ctx, message)
}, th.CommandEqual("list"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.handleMessage(ctx, &message)
}, th.AnyMessage())
c.setRunning(true)
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
"username": c.bot.Username(),
})
go bh.Start()
go func() {
for {
select {
case <-ctx.Done():
return
case update, ok := <-updates:
if !ok {
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
return
}
if update.Message != nil {
c.handleMessage(ctx, update)
}
}
}
<-ctx.Done()
bh.Stop()
}()
return nil
}
func (c *TelegramChannel) Stop(ctx context.Context) error {
logger.InfoC("telegram", "Stopping Telegram bot...")
c.setRunning(false)
@@ -166,30 +186,27 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return nil
}
func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) {
message := update.Message
func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
if message == nil {
return
return fmt.Errorf("message is nil")
}
user := message.From
if user == nil {
return
return fmt.Errorf("message sender (user) is nil")
}
userID := fmt.Sprintf("%d", user.ID)
senderID := userID
senderID := fmt.Sprintf("%d", user.ID)
if user.Username != "" {
senderID = fmt.Sprintf("%s|%s", userID, user.Username)
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
}
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(userID) && !c.IsAllowed(senderID) {
if !c.IsAllowed(senderID) {
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
"user_id": userID,
"username": user.Username,
"user_id": senderID,
})
return
return nil
}
chatID := message.Chat.ID
@@ -222,7 +239,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
content += message.Caption
}
if message.Photo != nil && len(message.Photo) > 0 {
if len(message.Photo) > 0 {
photo := message.Photo[len(message.Photo)-1]
photoPath := c.downloadPhoto(ctx, photo.FileID)
if photoPath != "" {
@@ -231,7 +248,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[image: photo]")
content += "[image: photo]"
}
}
@@ -252,7 +269,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
"error": err.Error(),
"path": voicePath,
})
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
transcribedText = "[voice (transcription failed)]"
} else {
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
@@ -260,7 +277,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
})
}
} else {
transcribedText = fmt.Sprintf("[voice]")
transcribedText = "[voice]"
}
if content != "" {
@@ -278,7 +295,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[audio]")
content += "[audio]"
}
}
@@ -290,7 +307,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[file]")
content += "[file]"
}
}
@@ -320,37 +337,14 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
}
}
// Create new context for thinking animation with timeout
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
// Create cancel function for thinking state
_, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
if err == nil {
pID := pMsg.MessageID
c.placeholders.Store(chatIDStr, pID)
go func(cid int64, mid int) {
dots := []string{".", "..", "..."}
emotes := []string{"💭", "🤔", "☁️"}
i := 0
ticker := time.NewTicker(2000 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-thinkCtx.Done():
return
case <-ticker.C:
i++
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
_, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
if editErr != nil {
logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
"error": editErr.Error(),
})
}
}
}
}(chatID, pID)
}
peerKind := "direct"
@@ -370,7 +364,8 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
"peer_id": peerID,
}
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
return nil
}
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
+153
View File
@@ -0,0 +1,153 @@
package channels
import (
"context"
"fmt"
"strings"
"github.com/mymmrac/telego"
"github.com/sipeed/picoclaw/pkg/config"
)
type TelegramCommander interface {
Help(ctx context.Context, message telego.Message) error
Start(ctx context.Context, message telego.Message) error
Show(ctx context.Context, message telego.Message) error
List(ctx context.Context, message telego.Message) error
}
type cmd struct {
bot *telego.Bot
config *config.Config
}
func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander {
return &cmd{
bot: bot,
config: cfg,
}
}
func commandArgs(text string) string {
parts := strings.SplitN(text, " ", 2)
if len(parts) < 2 {
return ""
}
return strings.TrimSpace(parts[1])
}
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
msg := `/start - Start the bot
/help - Show this help message
/show [model|channel] - Show current configuration
/list [models|channels] - List available options
`
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: msg,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) Start(ctx context.Context, message telego.Message) error {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Hello! I am PicoClaw 🦞",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) Show(ctx context.Context, message telego.Message) error {
args := commandArgs(message.Text)
if args == "" {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Usage: /show [model|channel]",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
var response string
switch args {
case "model":
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
c.config.Agents.Defaults.Model,
c.config.Agents.Defaults.Provider)
case "channel":
response = "Current Channel: telegram"
default:
response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args)
}
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: response,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) List(ctx context.Context, message telego.Message) error {
args := commandArgs(message.Text)
if args == "" {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Usage: /list [models|channels]",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
var response string
switch args {
case "models":
provider := c.config.Agents.Defaults.Provider
if provider == "" {
provider = "configured default"
}
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml",
c.config.Agents.Defaults.Model, provider)
case "channels":
var enabled []string
if c.config.Channels.Telegram.Enabled {
enabled = append(enabled, "telegram")
}
if c.config.Channels.WhatsApp.Enabled {
enabled = append(enabled, "whatsapp")
}
if c.config.Channels.Feishu.Enabled {
enabled = append(enabled, "feishu")
}
if c.config.Channels.Discord.Enabled {
enabled = append(enabled, "discord")
}
if c.config.Channels.Slack.Enabled {
enabled = append(enabled, "slack")
}
response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))
default:
response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args)
}
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: response,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
+42 -21
View File
@@ -161,6 +161,7 @@ type ChannelsConfig struct {
DingTalk DingTalkConfig `json:"dingtalk"`
Slack SlackConfig `json:"slack"`
LINE LINEConfig `json:"line"`
OneBot OneBotConfig `json:"onebot"`
}
type WhatsAppConfig struct {
@@ -213,10 +214,10 @@ type DingTalkConfig struct {
}
type SlackConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
}
type LINEConfig struct {
@@ -229,6 +230,15 @@ type LINEConfig struct {
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"`
}
type OneBotConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"`
WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"`
AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"`
ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"`
GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"`
}
type HeartbeatConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
@@ -240,24 +250,27 @@ type DevicesConfig struct {
}
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
DeepSeek ProviderConfig `json:"deepseek"`
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Ollama ProviderConfig `json:"ollama"`
Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
DeepSeek ProviderConfig `json:"deepseek"`
GitHubCopilot ProviderConfig `json:"github_copilot"`
}
type ProviderConfig struct {
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc`
}
type GatewayConfig struct {
@@ -344,7 +357,7 @@ func DefaultConfig() *Config {
Enabled: false,
BotToken: "",
AppToken: "",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
LINE: LINEConfig{
Enabled: false,
@@ -355,6 +368,14 @@ func DefaultConfig() *Config {
WebhookPath: "/webhook/line",
AllowFrom: FlexibleStringSlice{},
},
OneBot: OneBotConfig{
Enabled: false,
WSUrl: "ws://127.0.0.1:3001",
AccessToken: "",
ReconnectInterval: 5,
GroupTriggerPrefix: []string{},
AllowFrom: FlexibleStringSlice{},
},
},
Providers: ProvidersConfig{
Anthropic: ProviderConfig{},
@@ -432,7 +453,7 @@ func SaveConfig(path string, cfg *Config) error {
return err
}
return os.WriteFile(path, data, 0644)
return os.WriteFile(path, data, 0600)
}
func (c *Config) WorkspacePath() string {
+27
View File
@@ -2,6 +2,9 @@ package config
import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"testing"
)
@@ -297,6 +300,30 @@ func TestDefaultConfig_WebTools(t *testing.T) {
}
}
func TestSaveConfig_FilePermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("file permission bits are not enforced on Windows")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "config.json")
cfg := DefaultConfig()
if err := SaveConfig(path, cfg); err != nil {
t.Fatalf("SaveConfig failed: %v", err)
}
info, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat failed: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("config file has permission %04o, want 0600", perm)
}
}
// TestConfig_Complete verifies all config fields are set
func TestConfig_Complete(t *testing.T) {
cfg := DefaultConfig()
+1 -1
View File
@@ -340,7 +340,7 @@ func (cs *CronService) saveStoreUnsafe() error {
return err
}
return os.WriteFile(cs.storePath, data, 0644)
return os.WriteFile(cs.storePath, data, 0600)
}
func (cs *CronService) AddJob(name string, schedule CronSchedule, message string, deliver bool, channel, to string) (*CronJob, error) {
+38
View File
@@ -0,0 +1,38 @@
package cron
import (
"os"
"path/filepath"
"runtime"
"testing"
)
func TestSaveStore_FilePermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("file permission bits are not enforced on Windows")
}
tmpDir := t.TempDir()
storePath := filepath.Join(tmpDir, "cron", "jobs.json")
cs := NewCronService(storePath, nil)
_, err := cs.AddJob("test", CronSchedule{Kind: "every", EveryMS: int64Ptr(60000)}, "hello", false, "cli", "direct")
if err != nil {
t.Fatalf("AddJob failed: %v", err)
}
info, err := os.Stat(storePath)
if err != nil {
t.Fatalf("Stat failed: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("cron store has permission %04o, want 0600", perm)
}
}
func int64Ptr(v int64) *int64 {
return &v
}
+164
View File
@@ -0,0 +1,164 @@
package health
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
)
type Server struct {
server *http.Server
mu sync.RWMutex
ready bool
checks map[string]Check
startTime time.Time
}
type Check struct {
Name string `json:"name"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
type StatusResponse struct {
Status string `json:"status"`
Uptime string `json:"uptime"`
Checks map[string]Check `json:"checks,omitempty"`
}
func NewServer(host string, port int) *Server {
mux := http.NewServeMux()
s := &Server{
ready: false,
checks: make(map[string]Check),
startTime: time.Now(),
}
mux.HandleFunc("/health", s.healthHandler)
mux.HandleFunc("/ready", s.readyHandler)
addr := fmt.Sprintf("%s:%d", host, port)
s.server = &http.Server{
Addr: addr,
Handler: mux,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
}
return s
}
func (s *Server) Start() error {
s.mu.Lock()
s.ready = true
s.mu.Unlock()
return s.server.ListenAndServe()
}
func (s *Server) StartContext(ctx context.Context) error {
s.mu.Lock()
s.ready = true
s.mu.Unlock()
errCh := make(chan error, 1)
go func() {
errCh <- s.server.ListenAndServe()
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
return s.server.Shutdown(context.Background())
}
}
func (s *Server) Stop(ctx context.Context) error {
s.mu.Lock()
s.ready = false
s.mu.Unlock()
return s.server.Shutdown(ctx)
}
func (s *Server) SetReady(ready bool) {
s.mu.Lock()
s.ready = ready
s.mu.Unlock()
}
func (s *Server) RegisterCheck(name string, checkFn func() (bool, string)) {
s.mu.Lock()
defer s.mu.Unlock()
status, msg := checkFn()
s.checks[name] = Check{
Name: name,
Status: statusString(status),
Message: msg,
Timestamp: time.Now(),
}
}
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
uptime := time.Since(s.startTime)
resp := StatusResponse{
Status: "ok",
Uptime: uptime.String(),
}
json.NewEncoder(w).Encode(resp)
}
func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
s.mu.RLock()
ready := s.ready
checks := make(map[string]Check)
for k, v := range s.checks {
checks[k] = v
}
s.mu.RUnlock()
if !ready {
w.WriteHeader(http.StatusServiceUnavailable)
json.NewEncoder(w).Encode(StatusResponse{
Status: "not ready",
Checks: checks,
})
return
}
for _, check := range checks {
if check.Status == "fail" {
w.WriteHeader(http.StatusServiceUnavailable)
json.NewEncoder(w).Encode(StatusResponse{
Status: "not ready",
Checks: checks,
})
return
}
}
w.WriteHeader(http.StatusOK)
uptime := time.Since(s.startTime)
json.NewEncoder(w).Encode(StatusResponse{
Status: "ready",
Uptime: uptime.String(),
Checks: checks,
})
}
func statusString(ok bool) string {
if ok {
return "ok"
}
return "fail"
}
+4 -58
View File
@@ -171,68 +171,14 @@ func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse,
}, nil
}
// extractToolCalls parses tool call JSON from the response text.
// extractToolCalls delegates to the shared extractToolCallsFromText function.
func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return nil
}
end := findMatchingBrace(text, start)
if end == start {
return nil
}
jsonStr := text[start:end]
var wrapper struct {
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
}
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
return nil
}
var result []ToolCall
for _, tc := range wrapper.ToolCalls {
var args map[string]interface{}
json.Unmarshal([]byte(tc.Function.Arguments), &args)
result = append(result, ToolCall{
ID: tc.ID,
Type: tc.Type,
Name: tc.Function.Name,
Arguments: args,
Function: &FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
return result
return extractToolCallsFromText(text)
}
// stripToolCallsJSON removes tool call JSON from response text.
// stripToolCallsJSON delegates to the shared stripToolCallsFromText function.
func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return text
}
end := findMatchingBrace(text, start)
if end == start {
return text
}
return strings.TrimSpace(text[:start] + text[end:])
return stripToolCallsFromText(text)
}
// findMatchingBrace finds the index after the closing brace matching the opening brace at pos.
+79
View File
@@ -0,0 +1,79 @@
package providers
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
// CodexCliAuth represents the ~/.codex/auth.json file structure.
type CodexCliAuth struct {
Tokens struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
AccountID string `json:"account_id"`
} `json:"tokens"`
}
// ReadCodexCliCredentials reads OAuth tokens from the Codex CLI's auth.json file.
// Expiry is estimated as file modification time + 1 hour (same approach as moltbot).
func ReadCodexCliCredentials() (accessToken, accountID string, expiresAt time.Time, err error) {
authPath, err := resolveCodexAuthPath()
if err != nil {
return "", "", time.Time{}, err
}
data, err := os.ReadFile(authPath)
if err != nil {
return "", "", time.Time{}, fmt.Errorf("reading %s: %w", authPath, err)
}
var auth CodexCliAuth
if err := json.Unmarshal(data, &auth); err != nil {
return "", "", time.Time{}, fmt.Errorf("parsing %s: %w", authPath, err)
}
if auth.Tokens.AccessToken == "" {
return "", "", time.Time{}, fmt.Errorf("no access_token in %s", authPath)
}
stat, err := os.Stat(authPath)
if err != nil {
expiresAt = time.Now().Add(time.Hour)
} else {
expiresAt = stat.ModTime().Add(time.Hour)
}
return auth.Tokens.AccessToken, auth.Tokens.AccountID, expiresAt, nil
}
// CreateCodexCliTokenSource creates a token source that reads from ~/.codex/auth.json.
// This allows the existing CodexProvider to reuse Codex CLI credentials.
func CreateCodexCliTokenSource() func() (string, string, error) {
return func() (string, string, error) {
token, accountID, expiresAt, err := ReadCodexCliCredentials()
if err != nil {
return "", "", fmt.Errorf("reading codex cli credentials: %w", err)
}
if time.Now().After(expiresAt) {
return "", "", fmt.Errorf("codex cli credentials expired (auth.json last modified > 1h ago). Run: codex login")
}
return token, accountID, nil
}
}
func resolveCodexAuthPath() (string, error) {
codexHome := os.Getenv("CODEX_HOME")
if codexHome == "" {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("getting home dir: %w", err)
}
codexHome = filepath.Join(home, ".codex")
}
return filepath.Join(codexHome, "auth.json"), nil
}
+181
View File
@@ -0,0 +1,181 @@
package providers
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestReadCodexCliCredentials_Valid(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{
"tokens": {
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"account_id": "org-test123"
}
}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
token, accountID, expiresAt, err := ReadCodexCliCredentials()
if err != nil {
t.Fatalf("ReadCodexCliCredentials() error: %v", err)
}
if token != "test-access-token" {
t.Errorf("token = %q, want %q", token, "test-access-token")
}
if accountID != "org-test123" {
t.Errorf("accountID = %q, want %q", accountID, "org-test123")
}
// Expiry should be within ~1 hour from now (file was just written)
if expiresAt.Before(time.Now()) {
t.Errorf("expiresAt = %v, should be in the future", expiresAt)
}
if expiresAt.After(time.Now().Add(2 * time.Hour)) {
t.Errorf("expiresAt = %v, should be within ~1 hour", expiresAt)
}
}
func TestReadCodexCliCredentials_MissingFile(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("CODEX_HOME", tmpDir)
_, _, _, err := ReadCodexCliCredentials()
if err == nil {
t.Fatal("expected error for missing auth.json")
}
}
func TestReadCodexCliCredentials_EmptyToken(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "", "refresh_token": "r", "account_id": "a"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
_, _, _, err := ReadCodexCliCredentials()
if err == nil {
t.Fatal("expected error for empty access_token")
}
}
func TestReadCodexCliCredentials_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
if err := os.WriteFile(authPath, []byte("not json"), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
_, _, _, err := ReadCodexCliCredentials()
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
func TestReadCodexCliCredentials_NoAccountID(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "tok123", "refresh_token": "ref456"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
token, accountID, _, err := ReadCodexCliCredentials()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if token != "tok123" {
t.Errorf("token = %q, want %q", token, "tok123")
}
if accountID != "" {
t.Errorf("accountID = %q, want empty", accountID)
}
}
func TestReadCodexCliCredentials_CodexHomeEnv(t *testing.T) {
tmpDir := t.TempDir()
customDir := filepath.Join(tmpDir, "custom-codex")
if err := os.MkdirAll(customDir, 0755); err != nil {
t.Fatal(err)
}
authJSON := `{"tokens": {"access_token": "custom-token", "refresh_token": "r"}}`
if err := os.WriteFile(filepath.Join(customDir, "auth.json"), []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", customDir)
token, _, _, err := ReadCodexCliCredentials()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if token != "custom-token" {
t.Errorf("token = %q, want %q", token, "custom-token")
}
}
func TestCreateCodexCliTokenSource_Valid(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "fresh-token", "refresh_token": "r", "account_id": "acc"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
source := CreateCodexCliTokenSource()
token, accountID, err := source()
if err != nil {
t.Fatalf("token source error: %v", err)
}
if token != "fresh-token" {
t.Errorf("token = %q, want %q", token, "fresh-token")
}
if accountID != "acc" {
t.Errorf("accountID = %q, want %q", accountID, "acc")
}
}
func TestCreateCodexCliTokenSource_Expired(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "old-token", "refresh_token": "r"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
// Set file modification time to 2 hours ago
oldTime := time.Now().Add(-2 * time.Hour)
if err := os.Chtimes(authPath, oldTime, oldTime); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
source := CreateCodexCliTokenSource()
_, _, err := source()
if err == nil {
t.Fatal("expected error for expired credentials")
}
}
+251
View File
@@ -0,0 +1,251 @@
package providers
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"os/exec"
"strings"
)
// CodexCliProvider implements LLMProvider by wrapping the codex CLI as a subprocess.
type CodexCliProvider struct {
command string
workspace string
}
// NewCodexCliProvider creates a new Codex CLI provider.
func NewCodexCliProvider(workspace string) *CodexCliProvider {
return &CodexCliProvider{
command: "codex",
workspace: workspace,
}
}
// Chat implements LLMProvider.Chat by executing the codex CLI in non-interactive mode.
func (p *CodexCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p.command == "" {
return nil, fmt.Errorf("codex command not configured")
}
prompt := p.buildPrompt(messages, tools)
args := []string{
"exec",
"--json",
"--dangerously-bypass-approvals-and-sandbox",
"--skip-git-repo-check",
"--color", "never",
}
if model != "" && model != "codex-cli" {
args = append(args, "-m", model)
}
if p.workspace != "" {
args = append(args, "-C", p.workspace)
}
args = append(args, "-") // read prompt from stdin
cmd := exec.CommandContext(ctx, p.command, args...)
cmd.Stdin = bytes.NewReader([]byte(prompt))
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
// Parse JSONL from stdout even if exit code is non-zero,
// because codex writes diagnostic noise to stderr (e.g. rollout errors)
// but still produces valid JSONL output.
if stdoutStr := stdout.String(); stdoutStr != "" {
resp, parseErr := p.parseJSONLEvents(stdoutStr)
if parseErr == nil && resp != nil && (resp.Content != "" || len(resp.ToolCalls) > 0) {
return resp, nil
}
}
if err != nil {
if ctx.Err() == context.Canceled {
return nil, ctx.Err()
}
if stderrStr := stderr.String(); stderrStr != "" {
return nil, fmt.Errorf("codex cli error: %s", stderrStr)
}
return nil, fmt.Errorf("codex cli error: %w", err)
}
return p.parseJSONLEvents(stdout.String())
}
// GetDefaultModel returns the default model identifier.
func (p *CodexCliProvider) GetDefaultModel() string {
return "codex-cli"
}
// buildPrompt converts messages to a prompt string for the Codex CLI.
// System messages are prepended as instructions since Codex CLI has no --system-prompt flag.
func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinition) string {
var systemParts []string
var conversationParts []string
for _, msg := range messages {
switch msg.Role {
case "system":
systemParts = append(systemParts, msg.Content)
case "user":
conversationParts = append(conversationParts, msg.Content)
case "assistant":
conversationParts = append(conversationParts, "Assistant: "+msg.Content)
case "tool":
conversationParts = append(conversationParts,
fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content))
}
}
var sb strings.Builder
if len(systemParts) > 0 {
sb.WriteString("## System Instructions\n\n")
sb.WriteString(strings.Join(systemParts, "\n\n"))
sb.WriteString("\n\n## Task\n\n")
}
if len(tools) > 0 {
sb.WriteString(p.buildToolsPrompt(tools))
sb.WriteString("\n\n")
}
// Simplify single user message (no prefix)
if len(conversationParts) == 1 && len(systemParts) == 0 && len(tools) == 0 {
return conversationParts[0]
}
sb.WriteString(strings.Join(conversationParts, "\n"))
return sb.String()
}
// buildToolsPrompt creates a tool definitions section for the prompt.
func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
if tool.Type != "function" {
continue
}
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
if tool.Function.Description != "" {
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
}
if len(tool.Function.Parameters) > 0 {
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
}
sb.WriteString("\n")
}
return sb.String()
}
// codexEvent represents a single JSONL event from `codex exec --json`.
type codexEvent struct {
Type string `json:"type"`
ThreadID string `json:"thread_id,omitempty"`
Message string `json:"message,omitempty"`
Item *codexEventItem `json:"item,omitempty"`
Usage *codexUsage `json:"usage,omitempty"`
Error *codexEventErr `json:"error,omitempty"`
}
type codexEventItem struct {
ID string `json:"id"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
Command string `json:"command,omitempty"`
Status string `json:"status,omitempty"`
ExitCode *int `json:"exit_code,omitempty"`
Output string `json:"output,omitempty"`
}
type codexUsage struct {
InputTokens int `json:"input_tokens"`
CachedInputTokens int `json:"cached_input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type codexEventErr struct {
Message string `json:"message"`
}
// parseJSONLEvents processes the JSONL output from codex exec --json.
func (p *CodexCliProvider) parseJSONLEvents(output string) (*LLMResponse, error) {
var contentParts []string
var usage *UsageInfo
var lastError string
scanner := bufio.NewScanner(strings.NewReader(output))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
var event codexEvent
if err := json.Unmarshal([]byte(line), &event); err != nil {
continue // skip malformed lines
}
switch event.Type {
case "item.completed":
if event.Item != nil && event.Item.Type == "agent_message" && event.Item.Text != "" {
contentParts = append(contentParts, event.Item.Text)
}
case "turn.completed":
if event.Usage != nil {
promptTokens := event.Usage.InputTokens + event.Usage.CachedInputTokens
usage = &UsageInfo{
PromptTokens: promptTokens,
CompletionTokens: event.Usage.OutputTokens,
TotalTokens: promptTokens + event.Usage.OutputTokens,
}
}
case "error":
lastError = event.Message
case "turn.failed":
if event.Error != nil {
lastError = event.Error.Message
}
}
}
if lastError != "" && len(contentParts) == 0 {
return nil, fmt.Errorf("codex cli: %s", lastError)
}
content := strings.Join(contentParts, "\n")
// Extract tool calls from response text (same pattern as ClaudeCliProvider)
toolCalls := extractToolCallsFromText(content)
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
content = stripToolCallsFromText(content)
}
return &LLMResponse{
Content: strings.TrimSpace(content),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}, nil
}
+585
View File
@@ -0,0 +1,585 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
)
// --- JSONL Event Parsing Tests ---
func TestParseJSONLEvents_AgentMessage(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"thread.started","thread_id":"abc-123"}
{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Hello from Codex!"}}
{"type":"turn.completed","usage":{"input_tokens":100,"cached_input_tokens":50,"output_tokens":20}}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.Content != "Hello from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello from Codex!")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil {
t.Fatal("Usage should not be nil")
}
if resp.Usage.PromptTokens != 150 {
t.Errorf("PromptTokens = %d, want 150", resp.Usage.PromptTokens)
}
if resp.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens)
}
if resp.Usage.TotalTokens != 170 {
t.Errorf("TotalTokens = %d, want 170", resp.Usage.TotalTokens)
}
if len(resp.ToolCalls) != 0 {
t.Errorf("ToolCalls should be empty, got %d", len(resp.ToolCalls))
}
}
func TestParseJSONLEvents_ToolCallExtraction(t *testing.T) {
p := &CodexCliProvider{}
toolCallText := `Let me read that file.
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test.txt\"}"}}]}`
// Build valid JSONL by marshaling the event
item := codexEvent{
Type: "item.completed",
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
}
itemJSON, _ := json.Marshal(item)
usageEvt := `{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":0,"output_tokens":20}}`
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + usageEvt
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls count = %d, want 1", len(resp.ToolCalls))
}
if resp.ToolCalls[0].Name != "read_file" {
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
}
if resp.ToolCalls[0].ID != "call_1" {
t.Errorf("ToolCalls[0].ID = %q, want %q", resp.ToolCalls[0].ID, "call_1")
}
if resp.ToolCalls[0].Function.Arguments != `{"path":"/tmp/test.txt"}` {
t.Errorf("ToolCalls[0].Function.Arguments = %q", resp.ToolCalls[0].Function.Arguments)
}
// Content should have the tool call JSON stripped
if strings.Contains(resp.Content, "tool_calls") {
t.Errorf("Content should not contain tool_calls JSON, got: %q", resp.Content)
}
}
func TestParseJSONLEvents_MultipleToolCalls(t *testing.T) {
p := &CodexCliProvider{}
toolCallText := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"a.txt\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"b.txt\",\"content\":\"hello\"}"}}]}`
item := codexEvent{
Type: "item.completed",
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
}
itemJSON, _ := json.Marshal(item)
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + `{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if len(resp.ToolCalls) != 2 {
t.Fatalf("ToolCalls count = %d, want 2", len(resp.ToolCalls))
}
if resp.ToolCalls[0].Name != "read_file" {
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
}
if resp.ToolCalls[1].Name != "write_file" {
t.Errorf("ToolCalls[1].Name = %q, want %q", resp.ToolCalls[1].Name, "write_file")
}
if resp.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
}
func TestParseJSONLEvents_MultipleMessages(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"First part."}}
{"type":"item.completed","item":{"id":"item_2","type":"command_execution","command":"ls","status":"completed"}}
{"type":"item.completed","item":{"id":"item_3","type":"agent_message","text":"Second part."}}
{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.Content != "First part.\nSecond part." {
t.Errorf("Content = %q, want %q", resp.Content, "First part.\nSecond part.")
}
}
func TestParseJSONLEvents_ErrorEvent(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"thread.started","thread_id":"abc"}
{"type":"turn.started"}
{"type":"error","message":"token expired"}
{"type":"turn.failed","error":{"message":"token expired"}}`
_, err := p.parseJSONLEvents(events)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "token expired") {
t.Errorf("error = %q, want to contain 'token expired'", err.Error())
}
}
func TestParseJSONLEvents_TurnFailed(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"turn.failed","error":{"message":"rate limit exceeded"}}`
_, err := p.parseJSONLEvents(events)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "rate limit exceeded") {
t.Errorf("error = %q, want to contain 'rate limit exceeded'", err.Error())
}
}
func TestParseJSONLEvents_ErrorWithContent(t *testing.T) {
p := &CodexCliProvider{}
// If there's an error but also content, return the content (partial success)
events := `{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Partial result."}}
{"type":"error","message":"connection reset"}
{"type":"turn.failed","error":{"message":"connection reset"}}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("should not error when content exists: %v", err)
}
if resp.Content != "Partial result." {
t.Errorf("Content = %q, want %q", resp.Content, "Partial result.")
}
}
func TestParseJSONLEvents_EmptyOutput(t *testing.T) {
p := &CodexCliProvider{}
resp, err := p.parseJSONLEvents("")
if err != nil {
t.Fatalf("empty output should not error: %v", err)
}
if resp.Content != "" {
t.Errorf("Content = %q, want empty", resp.Content)
}
}
func TestParseJSONLEvents_MalformedLines(t *testing.T) {
p := &CodexCliProvider{}
events := `not json at all
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Good line."}}
another bad line
{"type":"turn.completed","usage":{"input_tokens":10,"output_tokens":5}}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("should skip malformed lines: %v", err)
}
if resp.Content != "Good line." {
t.Errorf("Content = %q, want %q", resp.Content, "Good line.")
}
if resp.Usage == nil || resp.Usage.TotalTokens != 15 {
t.Errorf("Usage.TotalTokens = %v, want 15", resp.Usage)
}
}
func TestParseJSONLEvents_CommandExecution(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"item.started","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"in_progress"}}
{"type":"item.completed","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"completed","exit_code":0,"output":"file1.go\nfile2.go"}}
{"type":"item.completed","item":{"id":"item_2","type":"agent_message","text":"Found 2 files."}}
{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
// command_execution items should be skipped; only agent_message text is returned
if resp.Content != "Found 2 files." {
t.Errorf("Content = %q, want %q", resp.Content, "Found 2 files.")
}
}
func TestParseJSONLEvents_NoUsage(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"No usage info."}}
{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.Usage != nil {
t.Errorf("Usage should be nil when turn.completed has no usage, got %+v", resp.Usage)
}
}
// --- Prompt Building Tests ---
func TestBuildPrompt_SystemAsInstructions(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hi there"},
}
prompt := p.buildPrompt(messages, nil)
if !strings.Contains(prompt, "## System Instructions") {
t.Error("prompt should contain '## System Instructions'")
}
if !strings.Contains(prompt, "You are helpful.") {
t.Error("prompt should contain system content")
}
if !strings.Contains(prompt, "## Task") {
t.Error("prompt should contain '## Task'")
}
if !strings.Contains(prompt, "Hi there") {
t.Error("prompt should contain user message")
}
}
func TestBuildPrompt_NoSystem(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Just a question"},
}
prompt := p.buildPrompt(messages, nil)
if strings.Contains(prompt, "## System Instructions") {
t.Error("prompt should not contain system instructions header")
}
if prompt != "Just a question" {
t.Errorf("prompt = %q, want %q", prompt, "Just a question")
}
}
func TestBuildPrompt_WithTools(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Get weather"},
}
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get current weather",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
},
},
},
}
prompt := p.buildPrompt(messages, tools)
if !strings.Contains(prompt, "## Available Tools") {
t.Error("prompt should contain tools section")
}
if !strings.Contains(prompt, "get_weather") {
t.Error("prompt should contain tool name")
}
if !strings.Contains(prompt, "Get current weather") {
t.Error("prompt should contain tool description")
}
}
func TestBuildPrompt_MultipleMessages(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi! How can I help?"},
{Role: "user", Content: "Tell me about Go"},
}
prompt := p.buildPrompt(messages, nil)
if !strings.Contains(prompt, "Hello") {
t.Error("prompt should contain first user message")
}
if !strings.Contains(prompt, "Assistant: Hi! How can I help?") {
t.Error("prompt should contain assistant message with prefix")
}
if !strings.Contains(prompt, "Tell me about Go") {
t.Error("prompt should contain second user message")
}
}
func TestBuildPrompt_ToolResults(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Weather?"},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
prompt := p.buildPrompt(messages, nil)
if !strings.Contains(prompt, "[Tool Result for call_1]") {
t.Error("prompt should contain tool result")
}
if !strings.Contains(prompt, `{"temp": 72}`) {
t.Error("prompt should contain tool result content")
}
}
func TestBuildPrompt_SystemAndTools(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "system", Content: "Be concise."},
{Role: "user", Content: "Do something"},
}
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "my_tool",
Description: "A tool",
},
},
}
prompt := p.buildPrompt(messages, tools)
// System instructions should come first
sysIdx := strings.Index(prompt, "## System Instructions")
toolIdx := strings.Index(prompt, "## Available Tools")
taskIdx := strings.Index(prompt, "## Task")
if sysIdx == -1 || toolIdx == -1 || taskIdx == -1 {
t.Fatal("prompt should contain all sections")
}
if sysIdx >= taskIdx {
t.Error("system instructions should come before task")
}
if taskIdx >= toolIdx {
t.Error("task section should come before tools in the output")
}
}
// --- CLI Argument Tests ---
func TestCodexCliProvider_GetDefaultModel(t *testing.T) {
p := NewCodexCliProvider("")
if got := p.GetDefaultModel(); got != "codex-cli" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "codex-cli")
}
}
// --- Mock CLI Integration Test ---
func createMockCodexCLI(t *testing.T, events []string) string {
t.Helper()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "codex")
var sb strings.Builder
sb.WriteString("#!/bin/bash\n")
for _, event := range events {
sb.WriteString(fmt.Sprintf("echo '%s'\n", event))
}
if err := os.WriteFile(scriptPath, []byte(sb.String()), 0755); err != nil {
t.Fatal(err)
}
return scriptPath
}
func TestCodexCliProvider_MockCLI_Success(t *testing.T) {
scriptPath := createMockCodexCLI(t, []string{
`{"type":"thread.started","thread_id":"test-123"}`,
`{"type":"turn.started"}`,
`{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Mock response from Codex CLI"}}`,
`{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":10,"output_tokens":15}}`,
})
p := &CodexCliProvider{
command: scriptPath,
workspace: "",
}
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Mock response from Codex CLI" {
t.Errorf("Content = %q, want %q", resp.Content, "Mock response from Codex CLI")
}
if resp.Usage == nil {
t.Fatal("Usage should not be nil")
}
if resp.Usage.PromptTokens != 60 {
t.Errorf("PromptTokens = %d, want 60", resp.Usage.PromptTokens)
}
if resp.Usage.CompletionTokens != 15 {
t.Errorf("CompletionTokens = %d, want 15", resp.Usage.CompletionTokens)
}
}
func TestCodexCliProvider_MockCLI_Error(t *testing.T) {
scriptPath := createMockCodexCLI(t, []string{
`{"type":"thread.started","thread_id":"test-err"}`,
`{"type":"turn.started"}`,
`{"type":"error","message":"auth token expired"}`,
`{"type":"turn.failed","error":{"message":"auth token expired"}}`,
})
p := &CodexCliProvider{
command: scriptPath,
workspace: "",
}
messages := []Message{{Role: "user", Content: "Hello"}}
_, err := p.Chat(context.Background(), messages, nil, "", nil)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "auth token expired") {
t.Errorf("error = %q, want to contain 'auth token expired'", err.Error())
}
}
func TestCodexCliProvider_MockCLI_WithModel(t *testing.T) {
// Mock script that captures args to verify model flag is passed
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "codex")
script := `#!/bin/bash
# Write args to a file for verification
echo "$@" > "` + filepath.Join(tmpDir, "args.txt") + `"
echo '{"type":"item.completed","item":{"id":"1","type":"agent_message","text":"ok"}}'
echo '{"type":"turn.completed"}'`
if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
t.Fatal(err)
}
p := &CodexCliProvider{
command: scriptPath,
workspace: "/tmp/test-workspace",
}
messages := []Message{{Role: "user", Content: "test"}}
_, err := p.Chat(context.Background(), messages, nil, "gpt-5.2-codex", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
// Verify the args
argsData, err := os.ReadFile(filepath.Join(tmpDir, "args.txt"))
if err != nil {
t.Fatalf("reading args: %v", err)
}
args := string(argsData)
if !strings.Contains(args, "-m gpt-5.2-codex") {
t.Errorf("args should contain model flag, got: %s", args)
}
if !strings.Contains(args, "-C /tmp/test-workspace") {
t.Errorf("args should contain workspace flag, got: %s", args)
}
if !strings.Contains(args, "--json") {
t.Errorf("args should contain --json, got: %s", args)
}
if !strings.Contains(args, "--dangerously-bypass-approvals-and-sandbox") {
t.Errorf("args should contain bypass flag, got: %s", args)
}
}
func TestCodexCliProvider_MockCLI_ContextCancel(t *testing.T) {
// Script that sleeps forever
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "codex")
script := "#!/bin/bash\nsleep 60"
if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
t.Fatal(err)
}
p := &CodexCliProvider{
command: scriptPath,
workspace: "",
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
messages := []Message{{Role: "user", Content: "test"}}
_, err := p.Chat(ctx, messages, nil, "", nil)
if err == nil {
t.Fatal("expected error on canceled context")
}
}
func TestCodexCliProvider_EmptyCommand(t *testing.T) {
p := &CodexCliProvider{command: ""}
messages := []Message{{Role: "user", Content: "test"}}
_, err := p.Chat(context.Background(), messages, nil, "", nil)
if err == nil {
t.Fatal("expected error for empty command")
}
}
// --- Integration Test (requires real codex CLI with valid auth) ---
func TestCodexCliProvider_Integration(t *testing.T) {
if os.Getenv("PICOCLAW_INTEGRATION_TESTS") == "" {
t.Skip("skipping integration test (set PICOCLAW_INTEGRATION_TESTS=1 to enable)")
}
// Verify codex is available
codexPath, err := exec.LookPath("codex")
if err != nil {
t.Skip("codex CLI not found in PATH")
}
p := &CodexCliProvider{
command: codexPath,
workspace: "",
}
messages := []Message{
{Role: "user", Content: "Respond with just the word 'hello' and nothing else."},
}
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
lower := strings.ToLower(strings.TrimSpace(resp.Content))
if !strings.Contains(lower, "hello") {
t.Errorf("Content = %q, expected to contain 'hello'", resp.Content)
}
}
+128 -9
View File
@@ -3,6 +3,7 @@ package providers
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
@@ -10,18 +11,26 @@ import (
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/logger"
)
const codexDefaultModel = "gpt-5.2"
const codexDefaultInstructions = "You are Codex, a coding assistant."
type CodexProvider struct {
client *openai.Client
accountID string
tokenSource func() (string, string, error)
}
const defaultCodexInstructions = "You are Codex, a coding assistant."
func NewCodexProvider(token, accountID string) *CodexProvider {
opts := []option.RequestOption{
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
option.WithAPIKey(token),
option.WithHeader("originator", "codex_cli_rs"),
option.WithHeader("OpenAI-Beta", "responses=experimental"),
}
if accountID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
@@ -41,6 +50,15 @@ func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func()
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
accountID := p.accountID
resolvedModel, fallbackReason := resolveCodexModel(model)
if fallbackReason != "" {
logger.WarnCF("provider.codex", "Requested model is not compatible with Codex backend, using fallback", map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
"reason": fallbackReason,
})
}
if p.tokenSource != nil {
tok, accID, err := p.tokenSource()
if err != nil {
@@ -48,22 +66,120 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
}
opts = append(opts, option.WithAPIKey(tok))
if accID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
accountID = accID
}
}
if accountID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
} else {
logger.WarnCF("provider.codex", "No account id found for Codex request; backend may reject with 400", map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
})
}
params := buildCodexParams(messages, tools, model, options)
params := buildCodexParams(messages, tools, resolvedModel, options)
resp, err := p.client.Responses.New(ctx, params, opts...)
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
defer stream.Close()
var resp *responses.Response
for stream.Next() {
evt := stream.Current()
if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" {
evtResp := evt.Response
if evtResp.ID != "" {
copy := evtResp
resp = &copy
}
}
}
err := stream.Err()
if err != nil {
fields := map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
"messages_count": len(messages),
"tools_count": len(tools),
"account_id_present": accountID != "",
"error": err.Error(),
}
var apiErr *openai.Error
if errors.As(err, &apiErr) {
fields["status_code"] = apiErr.StatusCode
fields["api_type"] = apiErr.Type
fields["api_code"] = apiErr.Code
fields["api_param"] = apiErr.Param
fields["api_message"] = apiErr.Message
if apiErr.StatusCode == 400 {
fields["hint"] = "verify account id header and model compatibility for codex backend"
}
if apiErr.Response != nil {
fields["request_id"] = apiErr.Response.Header.Get("x-request-id")
}
}
logger.ErrorCF("provider.codex", "Codex API call failed", fields)
return nil, fmt.Errorf("codex API call: %w", err)
}
if resp == nil {
fields := map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
"messages_count": len(messages),
"tools_count": len(tools),
"account_id_present": accountID != "",
}
logger.ErrorCF("provider.codex", "Codex stream ended without completed response event", fields)
return nil, fmt.Errorf("codex API call: stream ended without completed response")
}
return parseCodexResponse(resp), nil
}
func (p *CodexProvider) GetDefaultModel() string {
return "gpt-4o"
return codexDefaultModel
}
func resolveCodexModel(model string) (string, string) {
m := strings.ToLower(strings.TrimSpace(model))
if m == "" {
return codexDefaultModel, "empty model"
}
if strings.HasPrefix(m, "openai/") {
m = strings.TrimPrefix(m, "openai/")
} else if strings.Contains(m, "/") {
return codexDefaultModel, "non-openai model namespace"
}
unsupportedPrefixes := []string{
"glm",
"claude",
"anthropic",
"gemini",
"google",
"moonshot",
"kimi",
"qwen",
"deepseek",
"llama",
"meta-llama",
"mistral",
"grok",
"xai",
"zhipu",
}
for _, prefix := range unsupportedPrefixes {
if strings.HasPrefix(m, prefix) {
return codexDefaultModel, "unsupported model prefix"
}
}
if strings.HasPrefix(m, "gpt-") || strings.HasPrefix(m, "o3") || strings.HasPrefix(m, "o4") {
return m, ""
}
return codexDefaultModel, "unsupported model family"
}
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
@@ -133,21 +249,21 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
Input: responses.ResponseNewParamsInputUnion{
OfInputItemList: inputItems,
},
Store: openai.Opt(false),
Instructions: openai.Opt(instructions),
Store: openai.Opt(false),
}
if instructions != "" {
params.Instructions = openai.Opt(instructions)
} else {
// ChatGPT Codex backend requires instructions to be present.
params.Instructions = openai.Opt(defaultCodexInstructions)
}
if maxTokens, ok := options["max_tokens"].(int); ok {
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = openai.Opt(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForCodex(tools)
}
@@ -237,6 +353,9 @@ func createCodexTokenSource() func() (string, string, error) {
if err != nil {
return "", "", fmt.Errorf("refreshing token: %w", err)
}
if refreshed.AccountID == "" {
refreshed.AccountID = cred.AccountID
}
if err := auth.SetCredential("openai", refreshed); err != nil {
return "", "", fmt.Errorf("saving refreshed token: %w", err)
}
+210 -5
View File
@@ -2,6 +2,7 @@ package providers
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
@@ -16,11 +17,18 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
{Role: "user", Content: "Hello"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
"max_tokens": 2048,
"max_tokens": 2048,
"temperature": 0.7,
})
if params.Model != "gpt-4o" {
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
}
if !params.Instructions.Valid() {
t.Fatal("Instructions should be set")
}
if params.Instructions.Or("") != defaultCodexInstructions {
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), defaultCodexInstructions)
}
}
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
@@ -197,6 +205,16 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
return
}
var reqBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if reqBody["stream"] != true {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
@@ -220,8 +238,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
writeCompletedSSE(w, resp)
}))
defer server.Close()
@@ -244,10 +261,185 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
}
}
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer refreshed-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
http.Error(w, "missing account id", http.StatusBadRequest)
return
}
var reqBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if _, ok := reqBody["instructions"]; !ok {
http.Error(w, "missing instructions", http.StatusBadRequest)
return
}
if reqBody["instructions"] == "" {
http.Error(w, "instructions must not be empty", http.StatusBadRequest)
return
}
if _, ok := reqBody["temperature"]; ok {
http.Error(w, "temperature is not supported", http.StatusBadRequest)
return
}
if reqBody["stream"] != true {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 8,
"output_tokens": 4,
"total_tokens": 12,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
}))
defer server.Close()
provider := NewCodexProvider("stale-token", "acc-123")
provider.client = createOpenAITestClient(server.URL, "stale-token", "")
provider.tokenSource = func() (string, string, error) {
return "refreshed-token", "", nil
}
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"temperature": 0.7})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
}
func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
var reqBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if reqBody["model"] != codexDefaultModel {
http.Error(w, "unsupported model", http.StatusBadRequest)
return
}
if reqBody["stream"] != true {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
}
if reqBody["instructions"] != codexDefaultInstructions {
http.Error(w, "missing default instructions", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 8,
"output_tokens": 4,
"total_tokens": 12,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
}))
defer server.Close()
provider := NewCodexProvider("test-token", "acc-123")
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-5.2", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
}
func TestCodexProvider_GetDefaultModel(t *testing.T) {
p := NewCodexProvider("test-token", "")
if got := p.GetDefaultModel(); got != "gpt-4o" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
if got := p.GetDefaultModel(); got != codexDefaultModel {
t.Errorf("GetDefaultModel() = %q, want %q", got, codexDefaultModel)
}
}
func TestResolveCodexModel(t *testing.T) {
tests := []struct {
name string
input string
wantModel string
wantFallback bool
}{
{name: "empty", input: "", wantModel: codexDefaultModel, wantFallback: true},
{name: "unsupported namespace", input: "anthropic/claude-3.5", wantModel: codexDefaultModel, wantFallback: true},
{name: "non-openai prefixed", input: "glm-4.7", wantModel: codexDefaultModel, wantFallback: true},
{name: "openai prefix", input: "openai/gpt-5.2", wantModel: "gpt-5.2", wantFallback: false},
{name: "direct gpt", input: "gpt-4o", wantModel: "gpt-4o", wantFallback: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotModel, reason := resolveCodexModel(tt.input)
if gotModel != tt.wantModel {
t.Fatalf("resolveCodexModel(%q) model = %q, want %q", tt.input, gotModel, tt.wantModel)
}
if tt.wantFallback && reason == "" {
t.Fatalf("resolveCodexModel(%q) expected fallback reason", tt.input)
}
if !tt.wantFallback && reason != "" {
t.Fatalf("resolveCodexModel(%q) unexpected fallback reason: %q", tt.input, reason)
}
})
}
}
@@ -262,3 +454,16 @@ func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
c := openai.NewClient(opts...)
return &c
}
func writeCompletedSSE(w http.ResponseWriter, response map[string]interface{}) {
event := map[string]interface{}{
"type": "response.completed",
"sequence_number": 1,
"response": response,
}
b, _ := json.Marshal(event)
w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprintf(w, "event: response.completed\n")
fmt.Fprintf(w, "data: %s\n\n", string(b))
fmt.Fprintf(w, "data: [DONE]\n\n")
}
+82
View File
@@ -0,0 +1,82 @@
package providers
import (
"context"
"fmt"
json "encoding/json"
copilot "github.com/github/copilot-sdk/go"
)
type GitHubCopilotProvider struct {
uri string
connectMode string // `stdio` or `grpc``
session *copilot.Session
}
func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*GitHubCopilotProvider, error) {
var session *copilot.Session
if connectMode == "" {
connectMode = "grpc"
}
switch connectMode {
case "stdio":
//todo
case "grpc":
client := copilot.NewClient(&copilot.ClientOptions{
CLIUrl: uri,
})
if err := client.Start(context.Background()); err != nil {
return nil, fmt.Errorf("Can't connect to Github Copilot, https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server for details")
}
defer client.Stop()
session, _ = client.CreateSession(context.Background(), &copilot.SessionConfig{
Model: model,
Hooks: &copilot.SessionHooks{},
})
}
return &GitHubCopilotProvider{
uri: uri,
connectMode: connectMode,
session: session,
}, nil
}
// Chat sends a chat request to GitHub Copilot
func (p *GitHubCopilotProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
type tempMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
out := make([]tempMessage, 0, len(messages))
for _, msg := range messages {
out = append(out, tempMessage{
Role: msg.Role,
Content: msg.Content,
})
}
fullcontent, _ := json.Marshal(out)
content, _ := p.session.Send(ctx, copilot.MessageOptions{
Prompt: string(fullcontent),
})
return &LLMResponse{
FinishReason: "stop",
Content: content,
}, nil
}
func (p *GitHubCopilotProvider) GetDefaultModel() string {
return "gpt-4.1"
}
+32 -5
View File
@@ -15,6 +15,7 @@ import (
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
@@ -28,7 +29,7 @@ type HTTPProvider struct {
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
client := &http.Client{
Timeout: 0,
Timeout: 120 * time.Second,
}
if proxy != "" {
@@ -52,10 +53,10 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
return nil, fmt.Errorf("API base not configured")
}
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5)
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b)
if idx := strings.Index(model, "/"); idx != -1 {
prefix := model[:idx]
if prefix == "moonshot" || prefix == "nvidia" {
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" {
model = model[idx+1:]
}
}
@@ -239,6 +240,9 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
@@ -298,11 +302,17 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
}
}
case "claude-cli", "claudecode", "claude-code":
workspace := cfg.Agents.Defaults.Workspace
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewClaudeCliProvider(workspace), nil
case "codex-cli", "codex-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewCodexCliProvider(workspace), nil
case "deepseek":
if cfg.Providers.DeepSeek.APIKey != "" {
apiKey = cfg.Providers.DeepSeek.APIKey
@@ -314,7 +324,16 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
model = "deepseek-chat"
}
}
case "github_copilot", "copilot":
if cfg.Providers.GitHubCopilot.APIBase != "" {
apiBase = cfg.Providers.GitHubCopilot.APIBase
} else {
apiBase = "localhost:4321"
}
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
}
}
// Fallback: detect provider from model name
@@ -390,7 +409,15 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
if apiBase == "" {
apiBase = "https://integrate.api.nvidia.com/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
fmt.Println("Ollama provider selected based on model name prefix")
apiKey = cfg.Providers.Ollama.APIKey
apiBase = cfg.Providers.Ollama.APIBase
proxy = cfg.Providers.Ollama.Proxy
if apiBase == "" {
apiBase = "http://localhost:11434/v1"
}
fmt.Println("Ollama apiBase:", apiBase)
case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
+72
View File
@@ -0,0 +1,72 @@
package providers
import (
"encoding/json"
"strings"
)
// extractToolCallsFromText parses tool call JSON from response text.
// Both ClaudeCliProvider and CodexCliProvider use this to extract
// tool calls that the model outputs in its response text.
func extractToolCallsFromText(text string) []ToolCall {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return nil
}
end := findMatchingBrace(text, start)
if end == start {
return nil
}
jsonStr := text[start:end]
var wrapper struct {
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
}
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
return nil
}
var result []ToolCall
for _, tc := range wrapper.ToolCalls {
var args map[string]interface{}
json.Unmarshal([]byte(tc.Function.Arguments), &args)
result = append(result, ToolCall{
ID: tc.ID,
Type: tc.Type,
Name: tc.Function.Name,
Arguments: args,
Function: &FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
return result
}
// stripToolCallsFromText removes tool call JSON from response text.
func stripToolCallsFromText(text string) string {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return text
}
end := findMatchingBrace(text, start)
if end == start {
return text
}
return strings.TrimSpace(text[:start] + text[end:])
}
+33 -3
View File
@@ -145,13 +145,27 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
session.Updated = time.Now()
}
// sanitizeFilename converts a session key into a cross-platform safe filename.
// Session keys use "channel:chatID" (e.g. "telegram:123456") but ':' is the
// volume separator on Windows, so filepath.Base would misinterpret the key.
// We replace it with '_'. The original key is preserved inside the JSON file,
// so loadSessions still maps back to the right in-memory key.
func sanitizeFilename(key string) string {
return strings.ReplaceAll(key, ":", "_")
}
func (sm *SessionManager) Save(key string) error {
if sm.storage == "" {
return nil
}
// Validate key to avoid invalid filenames and path traversal.
if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") {
filename := sanitizeFilename(key)
// filepath.IsLocal rejects empty names, "..", absolute paths, and
// OS-reserved device names (NUL, COM1 … on Windows).
// The extra checks reject "." and any directory separators so that
// the session file is always written directly inside sm.storage.
if filename == "." || !filepath.IsLocal(filename) || strings.ContainsAny(filename, `/\`) {
return os.ErrInvalid
}
@@ -182,7 +196,7 @@ func (sm *SessionManager) Save(key string) error {
return err
}
sessionPath := filepath.Join(sm.storage, key+".json")
sessionPath := filepath.Join(sm.storage, filename+".json")
tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp")
if err != nil {
return err
@@ -250,3 +264,19 @@ func (sm *SessionManager) loadSessions() error {
return nil
}
// SetHistory updates the messages of a session.
func (sm *SessionManager) SetHistory(key string, history []providers.Message) {
sm.mu.Lock()
defer sm.mu.Unlock()
session, ok := sm.sessions[key]
if ok {
// Create a deep copy to strictly isolate internal state
// from the caller's slice.
msgs := make([]providers.Message, len(history))
copy(msgs, history)
session.Messages = msgs
session.Updated = time.Now()
}
}
+74
View File
@@ -0,0 +1,74 @@
package session
import (
"os"
"path/filepath"
"testing"
)
func TestSanitizeFilename(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"simple", "simple"},
{"telegram:123456", "telegram_123456"},
{"discord:987654321", "discord_987654321"},
{"slack:C01234", "slack_C01234"},
{"no-colons-here", "no-colons-here"},
{"multiple:colons:here", "multiple_colons_here"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := sanitizeFilename(tt.input)
if got != tt.expected {
t.Errorf("sanitizeFilename(%q) = %q, want %q", tt.input, got, tt.expected)
}
})
}
}
func TestSave_WithColonInKey(t *testing.T) {
tmpDir := t.TempDir()
sm := NewSessionManager(tmpDir)
// Create a session with a key containing colon (typical channel session key).
key := "telegram:123456"
sm.GetOrCreate(key)
sm.AddMessage(key, "user", "hello")
// Save should succeed even though the key contains ':'
if err := sm.Save(key); err != nil {
t.Fatalf("Save(%q) failed: %v", key, err)
}
// The file on disk should use sanitized name.
expectedFile := filepath.Join(tmpDir, "telegram_123456.json")
if _, err := os.Stat(expectedFile); os.IsNotExist(err) {
t.Fatalf("expected session file %s to exist", expectedFile)
}
// Load into a fresh manager and verify the session round-trips.
sm2 := NewSessionManager(tmpDir)
history := sm2.GetHistory(key)
if len(history) != 1 {
t.Fatalf("expected 1 message after reload, got %d", len(history))
}
if history[0].Content != "hello" {
t.Errorf("expected message content %q, got %q", "hello", history[0].Content)
}
}
func TestSave_RejectsPathTraversal(t *testing.T) {
tmpDir := t.TempDir()
sm := NewSessionManager(tmpDir)
badKeys := []string{"", ".", "..", "foo/bar", "foo\\bar"}
for _, key := range badKeys {
sm.GetOrCreate(key)
if err := sm.Save(key); err == nil {
t.Errorf("Save(%q) should have failed but didn't", key)
}
}
}
+45
View File
@@ -2,13 +2,22 @@ package skills
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"regexp"
"strings"
)
var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`)
const (
MaxNameLength = 64
MaxDescriptionLength = 1024
)
type SkillMetadata struct {
Name string `json:"name"`
Description string `json:"description"`
@@ -21,6 +30,27 @@ type SkillInfo struct {
Description string `json:"description"`
}
func (info SkillInfo) validate() error {
var errs error
if info.Name == "" {
errs = errors.Join(errs, errors.New("name is required"))
} else {
if len(info.Name) > MaxNameLength {
errs = errors.Join(errs, fmt.Errorf("name exceeds %d characters", MaxNameLength))
}
if !namePattern.MatchString(info.Name) {
errs = errors.Join(errs, errors.New("name must be alphanumeric with hyphens"))
}
}
if info.Description == "" {
errs = errors.Join(errs, errors.New("description is required"))
} else if len(info.Description) > MaxDescriptionLength {
errs = errors.Join(errs, fmt.Errorf("description exceeds %d character", MaxDescriptionLength))
}
return errs
}
type SkillsLoader struct {
workspace string
workspaceSkills string // workspace skills (项目级别)
@@ -54,6 +84,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
metadata := sl.getSkillMetadata(skillFile)
if metadata != nil {
info.Description = metadata.Description
info.Name = metadata.Name
}
if err := info.validate(); err != nil {
slog.Warn("invalid skill from workspace", "name", info.Name, "error", err)
continue
}
skills = append(skills, info)
}
@@ -89,6 +124,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
metadata := sl.getSkillMetadata(skillFile)
if metadata != nil {
info.Description = metadata.Description
info.Name = metadata.Name
}
if err := info.validate(); err != nil {
slog.Warn("invalid skill from global", "name", info.Name, "error", err)
continue
}
skills = append(skills, info)
}
@@ -123,6 +163,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
metadata := sl.getSkillMetadata(skillFile)
if metadata != nil {
info.Description = metadata.Description
info.Name = metadata.Name
}
if err := info.validate(); err != nil {
slog.Warn("invalid skill from builtin", "name", info.Name, "error", err)
continue
}
skills = append(skills, info)
}
+77
View File
@@ -0,0 +1,77 @@
package skills
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSkillsInfoValidate(t *testing.T) {
testcases := []struct {
name string
skillName string
description string
wantErr bool
errContains []string
}{
{
name: "valid-skill",
skillName: "valid-skill",
description: "a valid skill description",
wantErr: false,
},
{
name: "empty-name",
skillName: "",
description: "description without name",
wantErr: true,
errContains: []string{"name is required"},
},
{
name: "empty-description",
skillName: "skill-without-description",
description: "",
wantErr: true,
errContains: []string{"description is required"},
},
{
name: "empty-both",
skillName: "",
description: "",
wantErr: true,
errContains: []string{"name is required", "description is required"},
},
{
name: "name-with-spaces",
skillName: "skill with spaces",
description: "invalid name with spaces",
wantErr: true,
errContains: []string{"name must be alphanumeric with hyphens"},
},
{
name: "name-with-underscore",
skillName: "skill_underscore",
description: "invalid name with underscore",
wantErr: true,
errContains: []string{"name must be alphanumeric with hyphens"},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
info := SkillInfo{
Name: tc.skillName,
Description: tc.description,
}
err := info.validate()
if tc.wantErr {
assert.Error(t, err)
for _, msg := range tc.errContains {
assert.ErrorContains(t, err, msg)
}
} else {
assert.NoError(t, err)
}
})
}
}
+2 -2
View File
@@ -28,12 +28,12 @@ type CronTool struct {
}
// NewCronTool creates a new CronTool
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string) *CronTool {
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool) *CronTool {
return &CronTool{
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: NewExecTool(workspace, false),
execTool: NewExecTool(workspace, restrict),
}
}
+43 -2
View File
@@ -29,13 +29,54 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
}
}
if restrict && !strings.HasPrefix(absPath, absWorkspace) {
return "", fmt.Errorf("access denied: path is outside the workspace")
if restrict {
if !isWithinWorkspace(absPath, absWorkspace) {
return "", fmt.Errorf("access denied: path is outside the workspace")
}
workspaceReal := absWorkspace
if resolved, err := filepath.EvalSymlinks(absWorkspace); err == nil {
workspaceReal = resolved
}
if resolved, err := filepath.EvalSymlinks(absPath); err == nil {
if !isWithinWorkspace(resolved, workspaceReal) {
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
}
} else if os.IsNotExist(err) {
if parentResolved, err := resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
if !isWithinWorkspace(parentResolved, workspaceReal) {
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
}
} else if !os.IsNotExist(err) {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
} else {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
}
return absPath, nil
}
func resolveExistingAncestor(path string) (string, error) {
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
if resolved, err := filepath.EvalSymlinks(current); err == nil {
return resolved, nil
} else if !os.IsNotExist(err) {
return "", err
}
if filepath.Dir(current) == current {
return "", os.ErrNotExist
}
}
}
func isWithinWorkspace(candidate, workspace string) bool {
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))
}
type ReadFileTool struct {
workspace string
restrict bool
+32
View File
@@ -247,3 +247,35 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
}
}
// Block paths that look inside workspace but point outside via symlink.
func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
root := t.TempDir()
workspace := filepath.Join(root, "workspace")
if err := os.MkdirAll(workspace, 0755); err != nil {
t.Fatalf("failed to create workspace: %v", err)
}
secret := filepath.Join(root, "secret.txt")
if err := os.WriteFile(secret, []byte("top secret"), 0644); err != nil {
t.Fatalf("failed to write secret file: %v", err)
}
link := filepath.Join(workspace, "leak.txt")
if err := os.Symlink(secret, link); err != nil {
t.Skipf("symlink not supported in this environment: %v", err)
}
tool := NewReadFileTool(workspace, true)
result := tool.Execute(context.Background(), map[string]interface{}{
"path": link,
})
if !result.IsError {
t.Fatalf("expected symlink escape to be blocked")
}
if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") {
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
}
}
+10 -6
View File
@@ -173,19 +173,23 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}
}
// TestWebTool_WebSearch_NoApiKey verifies that nil is returned when no provider is configured
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "", BraveMaxResults: 5})
// Should return nil when no provider is enabled
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
if tool != nil {
t.Errorf("Expected nil when no search provider is configured")
t.Errorf("Expected nil tool when Brave API key is empty")
}
// Also nil when nothing is enabled
tool = NewWebSearchTool(WebSearchToolOptions{})
if tool != nil {
t.Errorf("Expected nil tool when no provider is enabled")
}
}
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "test-key", BraveMaxResults: 5, BraveEnabled: true})
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
ctx := context.Background()
args := map[string]interface{}{}
+1 -2
View File
@@ -73,9 +73,8 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
}
// Generate unique filename with UUID prefix to prevent conflicts
ext := filepath.Ext(filename)
safeName := SanitizeFilename(filename)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)