mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge remote-tracking branch 'upstream/main' into feat/cron-exec-timeout-config
This commit is contained in:
+20
-6
@@ -16,6 +16,7 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -70,11 +71,21 @@ func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msg
|
||||
// Shell execution
|
||||
registry.Register(tools.NewExecTool(workspace, restrict))
|
||||
|
||||
// Web tools
|
||||
braveAPIKey := cfg.Tools.Web.Search.APIKey
|
||||
registry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults))
|
||||
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
}); searchTool != nil {
|
||||
registry.Register(searchTool)
|
||||
}
|
||||
registry.Register(tools.NewWebFetchTool(50000))
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
registry.Register(tools.NewI2CTool())
|
||||
registry.Register(tools.NewSPITool())
|
||||
|
||||
// Message tool - available to both agent and subagent
|
||||
// Subagent uses it to communicate directly with user
|
||||
messageTool := tools.NewMessageTool()
|
||||
@@ -368,7 +379,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
|
||||
|
||||
// 6. Save final assistant message to session
|
||||
al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
|
||||
al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey))
|
||||
al.sessions.Save(opts.SessionKey)
|
||||
|
||||
// 7. Optional: summarization
|
||||
if opts.EnableSummary {
|
||||
@@ -732,7 +743,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||
if finalSummary != "" {
|
||||
al.sessions.SetSummary(sessionKey, finalSummary)
|
||||
al.sessions.TruncateHistory(sessionKey, 4)
|
||||
al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
|
||||
al.sessions.Save(sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -758,10 +769,13 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa
|
||||
}
|
||||
|
||||
// estimateTokens estimates the number of tokens in a message list.
|
||||
// Uses rune count instead of byte length so that CJK and other multi-byte
|
||||
// characters are not over-counted (a Chinese character is 3 bytes but roughly
|
||||
// one token).
|
||||
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
|
||||
total := 0
|
||||
for _, m := range messages {
|
||||
total += len(m.Content) / 4 // Simple heuristic: 4 chars per token
|
||||
total += utf8.RuneCountInString(m.Content) / 3
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
+26
-16
@@ -19,18 +19,20 @@ import (
|
||||
)
|
||||
|
||||
type OAuthProviderConfig struct {
|
||||
Issuer string
|
||||
ClientID string
|
||||
Scopes string
|
||||
Port int
|
||||
Issuer string
|
||||
ClientID string
|
||||
Scopes string
|
||||
Originator string
|
||||
Port int
|
||||
}
|
||||
|
||||
func OpenAIOAuthConfig() OAuthProviderConfig {
|
||||
return OAuthProviderConfig{
|
||||
Issuer: "https://auth.openai.com",
|
||||
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
Scopes: "openid profile email offline_access",
|
||||
Port: 1455,
|
||||
Issuer: "https://auth.openai.com",
|
||||
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
Scopes: "openid profile email offline_access",
|
||||
Originator: "codex_cli_rs",
|
||||
Port: 1455,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,15 +290,20 @@ func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
|
||||
|
||||
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||
params := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {cfg.ClientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {cfg.Scopes},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
"response_type": {"code"},
|
||||
"client_id": {cfg.ClientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {cfg.Scopes},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"id_token_add_organizations": {"true"},
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
"state": {state},
|
||||
}
|
||||
return cfg.Issuer + "/authorize?" + params.Encode()
|
||||
if cfg.Originator != "" {
|
||||
params.Set("originator", cfg.Originator)
|
||||
}
|
||||
return cfg.Issuer + "/oauth/authorize?" + params.Encode()
|
||||
}
|
||||
|
||||
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
|
||||
@@ -352,6 +359,9 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||
|
||||
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
|
||||
cred.AccountID = accountID
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
|
||||
+42
-5
@@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -10,10 +11,11 @@ import (
|
||||
|
||||
func TestBuildAuthorizeURL(t *testing.T) {
|
||||
cfg := OAuthProviderConfig{
|
||||
Issuer: "https://auth.example.com",
|
||||
ClientID: "test-client-id",
|
||||
Scopes: "openid profile",
|
||||
Port: 1455,
|
||||
Issuer: "https://auth.example.com",
|
||||
ClientID: "test-client-id",
|
||||
Scopes: "openid profile",
|
||||
Originator: "codex_cli_rs",
|
||||
Port: 1455,
|
||||
}
|
||||
pkce := PKCECodes{
|
||||
CodeVerifier: "test-verifier",
|
||||
@@ -22,7 +24,7 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
||||
|
||||
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
|
||||
|
||||
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
|
||||
if !strings.HasPrefix(u, "https://auth.example.com/oauth/authorize?") {
|
||||
t.Errorf("URL does not start with expected prefix: %s", u)
|
||||
}
|
||||
if !strings.Contains(u, "client_id=test-client-id") {
|
||||
@@ -40,6 +42,15 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
||||
if !strings.Contains(u, "response_type=code") {
|
||||
t.Error("URL missing response_type")
|
||||
}
|
||||
if !strings.Contains(u, "id_token_add_organizations=true") {
|
||||
t.Error("URL missing id_token_add_organizations")
|
||||
}
|
||||
if !strings.Contains(u, "codex_cli_simplified_flow=true") {
|
||||
t.Error("URL missing codex_cli_simplified_flow")
|
||||
}
|
||||
if !strings.Contains(u, "originator=codex_cli_rs") {
|
||||
t.Error("URL missing originator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponse(t *testing.T) {
|
||||
@@ -81,6 +92,32 @@ func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseAccountIDFromIDToken(t *testing.T) {
|
||||
idToken := makeJWTWithAccountID("acc-from-id")
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "not-a-jwt",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"id_token": idToken,
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
|
||||
cred, err := parseTokenResponse(body, "openai")
|
||||
if err != nil {
|
||||
t.Fatalf("parseTokenResponse() error: %v", err)
|
||||
}
|
||||
|
||||
if cred.AccountID != "acc-from-id" {
|
||||
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-from-id")
|
||||
}
|
||||
}
|
||||
|
||||
func makeJWTWithAccountID(accountID string) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
|
||||
payload := base64.RawURLEncoding.EncodeToString([]byte(`{"https://api.openai.com/auth":{"chatgpt_account_id":"` + accountID + `"}}`))
|
||||
return header + "." + payload + ".sig"
|
||||
}
|
||||
|
||||
func TestExchangeCodeForTokens(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/oauth/token" {
|
||||
|
||||
+16
-1
@@ -59,7 +59,22 @@ func (c *BaseChannel) IsAllowed(senderID string) bool {
|
||||
for _, allowed := range c.allowList {
|
||||
// Strip leading "@" from allowed value for username matching
|
||||
trimmed := strings.TrimPrefix(allowed, "@")
|
||||
if senderID == allowed || idPart == allowed || senderID == trimmed || idPart == trimmed || (userPart != "" && (userPart == allowed || userPart == trimmed)) {
|
||||
allowedID := trimmed
|
||||
allowedUser := ""
|
||||
if idx := strings.Index(trimmed, "|"); idx > 0 {
|
||||
allowedID = trimmed[:idx]
|
||||
allowedUser = trimmed[idx+1:]
|
||||
}
|
||||
|
||||
// Support either side using "id|username" compound form.
|
||||
// This keeps backward compatibility with legacy Telegram allowlist entries.
|
||||
if senderID == allowed ||
|
||||
idPart == allowed ||
|
||||
senderID == trimmed ||
|
||||
idPart == trimmed ||
|
||||
idPart == allowedID ||
|
||||
(allowedUser != "" && senderID == allowedUser) ||
|
||||
(userPart != "" && (userPart == allowed || userPart == trimmed || userPart == allowedUser)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package channels
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBaseChannelIsAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowList []string
|
||||
senderID string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty allowlist allows all",
|
||||
allowList: nil,
|
||||
senderID: "anyone",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "compound sender matches numeric allowlist",
|
||||
allowList: []string{"123456"},
|
||||
senderID: "123456|alice",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "compound sender matches username allowlist",
|
||||
allowList: []string{"@alice"},
|
||||
senderID: "123456|alice",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "numeric sender matches legacy compound allowlist",
|
||||
allowList: []string{"123456|alice"},
|
||||
senderID: "123456",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non matching sender is denied",
|
||||
allowList: []string{"123456"},
|
||||
senderID: "654321|bob",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ch := NewBaseChannel("test", nil, nil, tt.allowList)
|
||||
if got := ch.IsAllowed(tt.senderID); got != tt.want {
|
||||
t.Fatalf("IsAllowed(%q) = %v, want %v", tt.senderID, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
//go:build !amd64 && !arm64 && !riscv64 && !mips64 && !ppc64
|
||||
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// FeishuChannel is a stub implementation for 32-bit architectures
|
||||
type FeishuChannel struct {
|
||||
*BaseChannel
|
||||
}
|
||||
|
||||
// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported
|
||||
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
|
||||
return nil, errors.New("feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config")
|
||||
}
|
||||
|
||||
// Start is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
return errors.New("feishu channel is not supported on 32-bit architectures")
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
|
||||
|
||||
package channels
|
||||
|
||||
import (
|
||||
@@ -0,0 +1,598 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
lineAPIBase = "https://api.line.me/v2/bot"
|
||||
lineDataAPIBase = "https://api-data.line.me/v2/bot"
|
||||
lineReplyEndpoint = lineAPIBase + "/message/reply"
|
||||
linePushEndpoint = lineAPIBase + "/message/push"
|
||||
lineContentEndpoint = lineDataAPIBase + "/message/%s/content"
|
||||
lineBotInfoEndpoint = lineAPIBase + "/info"
|
||||
lineLoadingEndpoint = lineAPIBase + "/chat/loading/start"
|
||||
lineReplyTokenMaxAge = 25 * time.Second
|
||||
)
|
||||
|
||||
type replyTokenEntry struct {
|
||||
token string
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// LINEChannel implements the Channel interface for LINE Official Account
|
||||
// using the LINE Messaging API with HTTP webhook for receiving messages
|
||||
// and REST API for sending messages.
|
||||
type LINEChannel struct {
|
||||
*BaseChannel
|
||||
config config.LINEConfig
|
||||
httpServer *http.Server
|
||||
botUserID string // Bot's user ID
|
||||
botBasicID string // Bot's basic ID (e.g. @216ru...)
|
||||
botDisplayName string // Bot's display name for text-based mention detection
|
||||
replyTokens sync.Map // chatID -> replyTokenEntry
|
||||
quoteTokens sync.Map // chatID -> quoteToken (string)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewLINEChannel creates a new LINE channel instance.
|
||||
func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINEChannel, error) {
|
||||
if cfg.ChannelSecret == "" || cfg.ChannelAccessToken == "" {
|
||||
return nil, fmt.Errorf("line channel_secret and channel_access_token are required")
|
||||
}
|
||||
|
||||
base := NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
return &LINEChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start launches the HTTP webhook server.
|
||||
func (c *LINEChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("line", "Starting LINE channel (Webhook Mode)")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Fetch bot profile to get bot's userId for mention detection
|
||||
if err := c.fetchBotInfo(); err != nil {
|
||||
logger.WarnCF("line", "Failed to fetch bot info (mention detection disabled)", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
logger.InfoCF("line", "Bot info fetched", map[string]interface{}{
|
||||
"bot_user_id": c.botUserID,
|
||||
"basic_id": c.botBasicID,
|
||||
"display_name": c.botDisplayName,
|
||||
})
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
path := c.config.WebhookPath
|
||||
if path == "" {
|
||||
path = "/webhook/line"
|
||||
}
|
||||
mux.HandleFunc(path, c.webhookHandler)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort)
|
||||
c.httpServer = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
logger.InfoCF("line", "LINE webhook server listening", map[string]interface{}{
|
||||
"addr": addr,
|
||||
"path": path,
|
||||
})
|
||||
if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.ErrorCF("line", "Webhook server error", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("line", "LINE channel started (Webhook Mode)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchBotInfo retrieves the bot's userId, basicId, and displayName from the LINE API.
|
||||
func (c *LINEChannel) fetchBotInfo() error {
|
||||
req, err := http.NewRequest(http.MethodGet, lineBotInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bot info API returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var info struct {
|
||||
UserID string `json:"userId"`
|
||||
BasicID string `json:"basicId"`
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.botUserID = info.UserID
|
||||
c.botBasicID = info.BasicID
|
||||
c.botDisplayName = info.DisplayName
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the HTTP server.
|
||||
func (c *LINEChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("line", "Stopping LINE channel")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
if c.httpServer != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
if err := c.httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
logger.ErrorCF("line", "Webhook server shutdown error", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.setRunning(false)
|
||||
logger.InfoC("line", "LINE channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// webhookHandler handles incoming LINE webhook requests.
|
||||
func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
logger.ErrorCF("line", "Failed to read request body", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
signature := r.Header.Get("X-Line-Signature")
|
||||
if !c.verifySignature(body, signature) {
|
||||
logger.WarnC("line", "Invalid webhook signature")
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Events []lineEvent `json:"events"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
logger.ErrorCF("line", "Failed to parse webhook payload", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Return 200 immediately, process events asynchronously
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
for _, event := range payload.Events {
|
||||
go c.processEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
// verifySignature validates the X-Line-Signature using HMAC-SHA256.
|
||||
func (c *LINEChannel) verifySignature(body []byte, signature string) bool {
|
||||
if signature == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(c.config.ChannelSecret))
|
||||
mac.Write(body)
|
||||
expected := base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
||||
|
||||
return hmac.Equal([]byte(expected), []byte(signature))
|
||||
}
|
||||
|
||||
// LINE webhook event types
|
||||
type lineEvent struct {
|
||||
Type string `json:"type"`
|
||||
ReplyToken string `json:"replyToken"`
|
||||
Source lineSource `json:"source"`
|
||||
Message json.RawMessage `json:"message"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type lineSource struct {
|
||||
Type string `json:"type"` // "user", "group", "room"
|
||||
UserID string `json:"userId"`
|
||||
GroupID string `json:"groupId"`
|
||||
RoomID string `json:"roomId"`
|
||||
}
|
||||
|
||||
type lineMessage struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "text", "image", "video", "audio", "file", "sticker"
|
||||
Text string `json:"text"`
|
||||
QuoteToken string `json:"quoteToken"`
|
||||
Mention *struct {
|
||||
Mentionees []lineMentionee `json:"mentionees"`
|
||||
} `json:"mention"`
|
||||
ContentProvider struct {
|
||||
Type string `json:"type"`
|
||||
} `json:"contentProvider"`
|
||||
}
|
||||
|
||||
type lineMentionee struct {
|
||||
Index int `json:"index"`
|
||||
Length int `json:"length"`
|
||||
Type string `json:"type"` // "user", "all"
|
||||
UserID string `json:"userId"`
|
||||
}
|
||||
|
||||
func (c *LINEChannel) processEvent(event lineEvent) {
|
||||
if event.Type != "message" {
|
||||
logger.DebugCF("line", "Ignoring non-message event", map[string]interface{}{
|
||||
"type": event.Type,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := event.Source.UserID
|
||||
chatID := c.resolveChatID(event.Source)
|
||||
isGroup := event.Source.Type == "group" || event.Source.Type == "room"
|
||||
|
||||
var msg lineMessage
|
||||
if err := json.Unmarshal(event.Message, &msg); err != nil {
|
||||
logger.ErrorCF("line", "Failed to parse message", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// In group chats, only respond when the bot is mentioned
|
||||
if isGroup && !c.isBotMentioned(msg) {
|
||||
logger.DebugCF("line", "Ignoring group message without mention", map[string]interface{}{
|
||||
"chat_id": chatID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Store reply token for later use
|
||||
if event.ReplyToken != "" {
|
||||
c.replyTokens.Store(chatID, replyTokenEntry{
|
||||
token: event.ReplyToken,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// Store quote token for quoting the original message in reply
|
||||
if msg.QuoteToken != "" {
|
||||
c.quoteTokens.Store(chatID, msg.QuoteToken)
|
||||
}
|
||||
|
||||
var content string
|
||||
var mediaPaths []string
|
||||
localFiles := []string{}
|
||||
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("line", "Failed to cleanup temp file", map[string]interface{}{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
switch msg.Type {
|
||||
case "text":
|
||||
content = msg.Text
|
||||
// Strip bot mention from text in group chats
|
||||
if isGroup {
|
||||
content = c.stripBotMention(content, msg)
|
||||
}
|
||||
case "image":
|
||||
localPath := c.downloadContent(msg.ID, "image.jpg")
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
content = "[image]"
|
||||
}
|
||||
case "audio":
|
||||
localPath := c.downloadContent(msg.ID, "audio.m4a")
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
content = "[audio]"
|
||||
}
|
||||
case "video":
|
||||
localPath := c.downloadContent(msg.ID, "video.mp4")
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
content = "[video]"
|
||||
}
|
||||
case "file":
|
||||
content = "[file]"
|
||||
case "sticker":
|
||||
content = "[sticker]"
|
||||
default:
|
||||
content = fmt.Sprintf("[%s]", msg.Type)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"platform": "line",
|
||||
"source_type": event.Source.Type,
|
||||
"message_id": msg.ID,
|
||||
}
|
||||
|
||||
logger.DebugCF("line", "Received message", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"message_type": msg.Type,
|
||||
"is_group": isGroup,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Show typing/loading indicator (requires user ID, not group ID)
|
||||
c.sendLoading(senderID)
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
// isBotMentioned checks if the bot is mentioned in the message.
|
||||
// It first checks the mention metadata (userId match), then falls back
|
||||
// to text-based detection using the bot's display name, since LINE may
|
||||
// not include userId in mentionees for Official Accounts.
|
||||
func (c *LINEChannel) isBotMentioned(msg lineMessage) bool {
|
||||
// Check mention metadata
|
||||
if msg.Mention != nil {
|
||||
for _, m := range msg.Mention.Mentionees {
|
||||
if m.Type == "all" {
|
||||
return true
|
||||
}
|
||||
if c.botUserID != "" && m.UserID == c.botUserID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Mention metadata exists with mentionees but bot not matched by userId.
|
||||
// The bot IS likely mentioned (LINE includes mention struct when bot is @-ed),
|
||||
// so check if any mentionee overlaps with bot display name in text.
|
||||
if c.botDisplayName != "" {
|
||||
for _, m := range msg.Mention.Mentionees {
|
||||
if m.Index >= 0 && m.Length > 0 {
|
||||
runes := []rune(msg.Text)
|
||||
end := m.Index + m.Length
|
||||
if end <= len(runes) {
|
||||
mentionText := string(runes[m.Index:end])
|
||||
if strings.Contains(mentionText, c.botDisplayName) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: text-based detection with display name
|
||||
if c.botDisplayName != "" && strings.Contains(msg.Text, "@"+c.botDisplayName) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// stripBotMention removes the @BotName mention text from the message.
|
||||
func (c *LINEChannel) stripBotMention(text string, msg lineMessage) string {
|
||||
stripped := false
|
||||
|
||||
// Try to strip using mention metadata indices
|
||||
if msg.Mention != nil {
|
||||
runes := []rune(text)
|
||||
for i := len(msg.Mention.Mentionees) - 1; i >= 0; i-- {
|
||||
m := msg.Mention.Mentionees[i]
|
||||
// Strip if userId matches OR if the mention text contains the bot display name
|
||||
shouldStrip := false
|
||||
if c.botUserID != "" && m.UserID == c.botUserID {
|
||||
shouldStrip = true
|
||||
} else if c.botDisplayName != "" && m.Index >= 0 && m.Length > 0 {
|
||||
end := m.Index + m.Length
|
||||
if end <= len(runes) {
|
||||
mentionText := string(runes[m.Index:end])
|
||||
if strings.Contains(mentionText, c.botDisplayName) {
|
||||
shouldStrip = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if shouldStrip {
|
||||
start := m.Index
|
||||
end := m.Index + m.Length
|
||||
if start >= 0 && end <= len(runes) {
|
||||
runes = append(runes[:start], runes[end:]...)
|
||||
stripped = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if stripped {
|
||||
return strings.TrimSpace(string(runes))
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: strip @DisplayName from text
|
||||
if c.botDisplayName != "" {
|
||||
text = strings.ReplaceAll(text, "@"+c.botDisplayName, "")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// resolveChatID determines the chat ID from the event source.
|
||||
// For group/room messages, use the group/room ID; for 1:1, use the user ID.
|
||||
func (c *LINEChannel) resolveChatID(source lineSource) string {
|
||||
switch source.Type {
|
||||
case "group":
|
||||
return source.GroupID
|
||||
case "room":
|
||||
return source.RoomID
|
||||
default:
|
||||
return source.UserID
|
||||
}
|
||||
}
|
||||
|
||||
// Send sends a message to LINE. It first tries the Reply API (free)
|
||||
// using a cached reply token, then falls back to the Push API.
|
||||
func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("line channel not running")
|
||||
}
|
||||
|
||||
// Load and consume quote token for this chat
|
||||
var quoteToken string
|
||||
if qt, ok := c.quoteTokens.LoadAndDelete(msg.ChatID); ok {
|
||||
quoteToken = qt.(string)
|
||||
}
|
||||
|
||||
// Try reply token first (free, valid for ~25 seconds)
|
||||
if entry, ok := c.replyTokens.LoadAndDelete(msg.ChatID); ok {
|
||||
tokenEntry := entry.(replyTokenEntry)
|
||||
if time.Since(tokenEntry.timestamp) < lineReplyTokenMaxAge {
|
||||
if err := c.sendReply(ctx, tokenEntry.token, msg.Content, quoteToken); err == nil {
|
||||
logger.DebugCF("line", "Message sent via Reply API", map[string]interface{}{
|
||||
"chat_id": msg.ChatID,
|
||||
"quoted": quoteToken != "",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
logger.DebugC("line", "Reply API failed, falling back to Push API")
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to Push API
|
||||
return c.sendPush(ctx, msg.ChatID, msg.Content, quoteToken)
|
||||
}
|
||||
|
||||
// buildTextMessage creates a text message object, optionally with quoteToken.
|
||||
func buildTextMessage(content, quoteToken string) map[string]string {
|
||||
msg := map[string]string{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
}
|
||||
if quoteToken != "" {
|
||||
msg["quoteToken"] = quoteToken
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// sendReply sends a message using the LINE Reply API.
|
||||
func (c *LINEChannel) sendReply(ctx context.Context, replyToken, content, quoteToken string) error {
|
||||
payload := map[string]interface{}{
|
||||
"replyToken": replyToken,
|
||||
"messages": []map[string]string{buildTextMessage(content, quoteToken)},
|
||||
}
|
||||
|
||||
return c.callAPI(ctx, lineReplyEndpoint, payload)
|
||||
}
|
||||
|
||||
// sendPush sends a message using the LINE Push API.
|
||||
func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken string) error {
|
||||
payload := map[string]interface{}{
|
||||
"to": to,
|
||||
"messages": []map[string]string{buildTextMessage(content, quoteToken)},
|
||||
}
|
||||
|
||||
return c.callAPI(ctx, linePushEndpoint, payload)
|
||||
}
|
||||
|
||||
// sendLoading sends a loading animation indicator to the chat.
|
||||
func (c *LINEChannel) sendLoading(chatID string) {
|
||||
payload := map[string]interface{}{
|
||||
"chatId": chatID,
|
||||
"loadingSeconds": 60,
|
||||
}
|
||||
if err := c.callAPI(c.ctx, lineLoadingEndpoint, payload); err != nil {
|
||||
logger.DebugCF("line", "Failed to send loading indicator", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// callAPI makes an authenticated POST request to the LINE API.
|
||||
func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload interface{}) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("API request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// downloadContent downloads media content from the LINE API.
|
||||
func (c *LINEChannel) downloadContent(messageID, filename string) string {
|
||||
url := fmt.Sprintf(lineContentEndpoint, messageID)
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "line",
|
||||
ExtraHeaders: map[string]string{
|
||||
"Authorization": "Bearer " + c.config.ChannelAccessToken,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -150,6 +150,32 @@ func (m *Manager) initChannels() error {
|
||||
}
|
||||
}
|
||||
|
||||
if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" {
|
||||
logger.DebugC("channels", "Attempting to initialize LINE channel")
|
||||
line, err := NewLINEChannel(m.config.Channels.LINE, m.bus)
|
||||
if err != nil {
|
||||
logger.ErrorCF("channels", "Failed to initialize LINE channel", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
m.channels["line"] = line
|
||||
logger.InfoC("channels", "LINE channel enabled successfully")
|
||||
}
|
||||
}
|
||||
|
||||
if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" {
|
||||
logger.DebugC("channels", "Attempting to initialize OneBot channel")
|
||||
onebot, err := NewOneBotChannel(m.config.Channels.OneBot, m.bus)
|
||||
if err != nil {
|
||||
logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
m.channels["onebot"] = onebot
|
||||
logger.InfoC("channels", "OneBot channel enabled successfully")
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
|
||||
"enabled_channels": len(m.channels),
|
||||
})
|
||||
|
||||
@@ -0,0 +1,686 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
type OneBotChannel struct {
|
||||
*BaseChannel
|
||||
config config.OneBotConfig
|
||||
conn *websocket.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
dedup map[string]struct{}
|
||||
dedupRing []string
|
||||
dedupIdx int
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
echoCounter int64
|
||||
}
|
||||
|
||||
type oneBotRawEvent struct {
|
||||
PostType string `json:"post_type"`
|
||||
MessageType string `json:"message_type"`
|
||||
SubType string `json:"sub_type"`
|
||||
MessageID json.RawMessage `json:"message_id"`
|
||||
UserID json.RawMessage `json:"user_id"`
|
||||
GroupID json.RawMessage `json:"group_id"`
|
||||
RawMessage string `json:"raw_message"`
|
||||
Message json.RawMessage `json:"message"`
|
||||
Sender json.RawMessage `json:"sender"`
|
||||
SelfID json.RawMessage `json:"self_id"`
|
||||
Time json.RawMessage `json:"time"`
|
||||
MetaEventType string `json:"meta_event_type"`
|
||||
Echo string `json:"echo"`
|
||||
RetCode json.RawMessage `json:"retcode"`
|
||||
Status BotStatus `json:"status"`
|
||||
}
|
||||
|
||||
type BotStatus struct {
|
||||
Online bool `json:"online"`
|
||||
Good bool `json:"good"`
|
||||
}
|
||||
|
||||
type oneBotSender struct {
|
||||
UserID json.RawMessage `json:"user_id"`
|
||||
Nickname string `json:"nickname"`
|
||||
Card string `json:"card"`
|
||||
}
|
||||
|
||||
type oneBotEvent struct {
|
||||
PostType string
|
||||
MessageType string
|
||||
SubType string
|
||||
MessageID string
|
||||
UserID int64
|
||||
GroupID int64
|
||||
Content string
|
||||
RawContent string
|
||||
IsBotMentioned bool
|
||||
Sender oneBotSender
|
||||
SelfID int64
|
||||
Time int64
|
||||
MetaEventType string
|
||||
}
|
||||
|
||||
type oneBotAPIRequest struct {
|
||||
Action string `json:"action"`
|
||||
Params interface{} `json:"params"`
|
||||
Echo string `json:"echo,omitempty"`
|
||||
}
|
||||
|
||||
type oneBotSendPrivateMsgParams struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type oneBotSendGroupMsgParams struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
|
||||
base := NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
const dedupSize = 1024
|
||||
return &OneBotChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
dedup: make(map[string]struct{}, dedupSize),
|
||||
dedupRing: make([]string, dedupSize),
|
||||
dedupIdx: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Start(ctx context.Context) error {
|
||||
if c.config.WSUrl == "" {
|
||||
return fmt.Errorf("OneBot ws_url not configured")
|
||||
}
|
||||
|
||||
logger.InfoCF("onebot", "Starting OneBot channel", map[string]interface{}{
|
||||
"ws_url": c.config.WSUrl,
|
||||
})
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
if err := c.connect(); err != nil {
|
||||
logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
go c.listen()
|
||||
}
|
||||
|
||||
if c.config.ReconnectInterval > 0 {
|
||||
go c.reconnectLoop()
|
||||
} else {
|
||||
// If reconnect is disabled but initial connection failed, we cannot recover
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("failed to connect to OneBot and reconnect is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("onebot", "OneBot channel started successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) connect() error {
|
||||
dialer := websocket.DefaultDialer
|
||||
dialer.HandshakeTimeout = 10 * time.Second
|
||||
|
||||
header := make(map[string][]string)
|
||||
if c.config.AccessToken != "" {
|
||||
header["Authorization"] = []string{"Bearer " + c.config.AccessToken}
|
||||
}
|
||||
|
||||
conn, _, err := dialer.Dial(c.config.WSUrl, header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.mu.Unlock()
|
||||
|
||||
logger.InfoC("onebot", "WebSocket connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) reconnectLoop() {
|
||||
interval := time.Duration(c.config.ReconnectInterval) * time.Second
|
||||
if interval < 5*time.Second {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-time.After(interval):
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
logger.InfoC("onebot", "Attempting to reconnect...")
|
||||
if err := c.connect(); err != nil {
|
||||
logger.ErrorCF("onebot", "Reconnect failed", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
go c.listen()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("onebot", "Stopping OneBot channel")
|
||||
c.setRunning(false)
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("OneBot channel not running")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
return fmt.Errorf("OneBot WebSocket not connected")
|
||||
}
|
||||
|
||||
action, params, err := c.buildSendRequest(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
c.echoCounter++
|
||||
echo := fmt.Sprintf("send_%d", c.echoCounter)
|
||||
c.writeMu.Unlock()
|
||||
|
||||
req := oneBotAPIRequest{
|
||||
Action: action,
|
||||
Params: params,
|
||||
Echo: echo,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal OneBot request: %w", err)
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
err = conn.WriteMessage(websocket.TextMessage, data)
|
||||
c.writeMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("onebot", "Failed to send message", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) {
|
||||
chatID := msg.ChatID
|
||||
|
||||
if len(chatID) > 6 && chatID[:6] == "group:" {
|
||||
groupID, err := strconv.ParseInt(chatID[6:], 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid group ID in chatID: %s", chatID)
|
||||
}
|
||||
return "send_group_msg", oneBotSendGroupMsgParams{
|
||||
GroupID: groupID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(chatID) > 8 && chatID[:8] == "private:" {
|
||||
userID, err := strconv.ParseInt(chatID[8:], 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid user ID in chatID: %s", chatID)
|
||||
}
|
||||
return "send_private_msg", oneBotSendPrivateMsgParams{
|
||||
UserID: userID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseInt(chatID, 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid chatID for OneBot: %s", chatID)
|
||||
}
|
||||
|
||||
return "send_private_msg", oneBotSendPrivateMsgParams{
|
||||
UserID: userID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) listen() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
|
||||
return
|
||||
}
|
||||
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Raw WebSocket message received", map[string]interface{}{
|
||||
"length": len(message),
|
||||
"payload": string(message),
|
||||
})
|
||||
|
||||
var raw oneBotRawEvent
|
||||
if err := json.Unmarshal(message, &raw); err != nil {
|
||||
logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"payload": string(message),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if raw.Echo != "" || raw.Status.Online || raw.Status.Good {
|
||||
logger.DebugCF("onebot", "Received API response, skipping", map[string]interface{}{
|
||||
"echo": raw.Echo,
|
||||
"status": raw.Status,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Parsed raw event", map[string]interface{}{
|
||||
"post_type": raw.PostType,
|
||||
"message_type": raw.MessageType,
|
||||
"sub_type": raw.SubType,
|
||||
"meta_event_type": raw.MetaEventType,
|
||||
})
|
||||
|
||||
c.handleRawEvent(&raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSONInt64(raw json.RawMessage) (int64, error) {
|
||||
if len(raw) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var n int64
|
||||
if err := json.Unmarshal(raw, &n); err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
return 0, fmt.Errorf("cannot parse as int64: %s", string(raw))
|
||||
}
|
||||
|
||||
func parseJSONString(raw json.RawMessage) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s
|
||||
}
|
||||
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
type parseMessageResult struct {
|
||||
Text string
|
||||
IsBotMentioned bool
|
||||
}
|
||||
|
||||
func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult {
|
||||
if len(raw) == 0 {
|
||||
return parseMessageResult{}
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
mentioned := false
|
||||
if selfID > 0 {
|
||||
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
|
||||
if strings.Contains(s, cqAt) {
|
||||
mentioned = true
|
||||
s = strings.ReplaceAll(s, cqAt, "")
|
||||
s = strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return parseMessageResult{Text: s, IsBotMentioned: mentioned}
|
||||
}
|
||||
|
||||
var segments []map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &segments); err == nil {
|
||||
var text string
|
||||
mentioned := false
|
||||
selfIDStr := strconv.FormatInt(selfID, 10)
|
||||
for _, seg := range segments {
|
||||
segType, _ := seg["type"].(string)
|
||||
data, _ := seg["data"].(map[string]interface{})
|
||||
switch segType {
|
||||
case "text":
|
||||
if data != nil {
|
||||
if t, ok := data["text"].(string); ok {
|
||||
text += t
|
||||
}
|
||||
}
|
||||
case "at":
|
||||
if data != nil && selfID > 0 {
|
||||
qqVal := fmt.Sprintf("%v", data["qq"])
|
||||
if qqVal == selfIDStr || qqVal == "all" {
|
||||
mentioned = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return parseMessageResult{Text: strings.TrimSpace(text), IsBotMentioned: mentioned}
|
||||
}
|
||||
return parseMessageResult{}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
|
||||
switch raw.PostType {
|
||||
case "message":
|
||||
evt, err := c.normalizeMessageEvent(raw)
|
||||
if err != nil {
|
||||
logger.WarnCF("onebot", "Failed to normalize message event", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.handleMessage(evt)
|
||||
case "meta_event":
|
||||
c.handleMetaEvent(raw)
|
||||
case "notice":
|
||||
logger.DebugCF("onebot", "Notice event received", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
case "request":
|
||||
logger.DebugCF("onebot", "Request event received", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
case "":
|
||||
logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{
|
||||
"echo": raw.Echo,
|
||||
"status": raw.Status,
|
||||
})
|
||||
default:
|
||||
logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{
|
||||
"post_type": raw.PostType,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent, error) {
|
||||
userID, err := parseJSONInt64(raw.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse user_id: %w (raw: %s)", err, string(raw.UserID))
|
||||
}
|
||||
|
||||
groupID, _ := parseJSONInt64(raw.GroupID)
|
||||
selfID, _ := parseJSONInt64(raw.SelfID)
|
||||
ts, _ := parseJSONInt64(raw.Time)
|
||||
messageID := parseJSONString(raw.MessageID)
|
||||
|
||||
parsed := parseMessageContentEx(raw.Message, selfID)
|
||||
isBotMentioned := parsed.IsBotMentioned
|
||||
|
||||
content := raw.RawMessage
|
||||
if content == "" {
|
||||
content = parsed.Text
|
||||
} else if selfID > 0 {
|
||||
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
|
||||
if strings.Contains(content, cqAt) {
|
||||
isBotMentioned = true
|
||||
content = strings.ReplaceAll(content, cqAt, "")
|
||||
content = strings.TrimSpace(content)
|
||||
}
|
||||
}
|
||||
|
||||
var sender oneBotSender
|
||||
if len(raw.Sender) > 0 {
|
||||
if err := json.Unmarshal(raw.Sender, &sender); err != nil {
|
||||
logger.WarnCF("onebot", "Failed to parse sender", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"sender": string(raw.Sender),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Normalized message event", map[string]interface{}{
|
||||
"message_type": raw.MessageType,
|
||||
"user_id": userID,
|
||||
"group_id": groupID,
|
||||
"message_id": messageID,
|
||||
"content_len": len(content),
|
||||
"nickname": sender.Nickname,
|
||||
})
|
||||
|
||||
return &oneBotEvent{
|
||||
PostType: raw.PostType,
|
||||
MessageType: raw.MessageType,
|
||||
SubType: raw.SubType,
|
||||
MessageID: messageID,
|
||||
UserID: userID,
|
||||
GroupID: groupID,
|
||||
Content: content,
|
||||
RawContent: raw.RawMessage,
|
||||
IsBotMentioned: isBotMentioned,
|
||||
Sender: sender,
|
||||
SelfID: selfID,
|
||||
Time: ts,
|
||||
MetaEventType: raw.MetaEventType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
|
||||
switch raw.MetaEventType {
|
||||
case "lifecycle":
|
||||
logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
case "heartbeat":
|
||||
logger.DebugC("onebot", "Heartbeat received")
|
||||
default:
|
||||
logger.DebugCF("onebot", "Unknown meta_event_type", map[string]interface{}{
|
||||
"meta_event_type": raw.MetaEventType,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleMessage(evt *oneBotEvent) {
|
||||
if c.isDuplicate(evt.MessageID) {
|
||||
logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{
|
||||
"message_id": evt.MessageID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
content := evt.Content
|
||||
if content == "" {
|
||||
logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{
|
||||
"message_id": evt.MessageID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := strconv.FormatInt(evt.UserID, 10)
|
||||
var chatID string
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": evt.MessageID,
|
||||
}
|
||||
|
||||
switch evt.MessageType {
|
||||
case "private":
|
||||
chatID = "private:" + senderID
|
||||
logger.InfoCF("onebot", "Received private message", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"message_id": evt.MessageID,
|
||||
"length": len(content),
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
|
||||
case "group":
|
||||
groupIDStr := strconv.FormatInt(evt.GroupID, 10)
|
||||
chatID = "group:" + groupIDStr
|
||||
metadata["group_id"] = groupIDStr
|
||||
|
||||
senderUserID, _ := parseJSONInt64(evt.Sender.UserID)
|
||||
if senderUserID > 0 {
|
||||
metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10)
|
||||
}
|
||||
|
||||
if evt.Sender.Card != "" {
|
||||
metadata["sender_name"] = evt.Sender.Card
|
||||
} else if evt.Sender.Nickname != "" {
|
||||
metadata["sender_name"] = evt.Sender.Nickname
|
||||
}
|
||||
|
||||
triggered, strippedContent := c.checkGroupTrigger(content, evt.IsBotMentioned)
|
||||
if !triggered {
|
||||
logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"group": groupIDStr,
|
||||
"is_mentioned": evt.IsBotMentioned,
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
return
|
||||
}
|
||||
content = strippedContent
|
||||
|
||||
logger.InfoCF("onebot", "Received group message", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"group": groupIDStr,
|
||||
"message_id": evt.MessageID,
|
||||
"is_mentioned": evt.IsBotMentioned,
|
||||
"length": len(content),
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
|
||||
default:
|
||||
logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{
|
||||
"type": evt.MessageType,
|
||||
"message_id": evt.MessageID,
|
||||
"user_id": evt.UserID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if evt.Sender.Nickname != "" {
|
||||
metadata["nickname"] = evt.Sender.Nickname
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Forwarding message to bus", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, []string{}, metadata)
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) isDuplicate(messageID string) bool {
|
||||
if messageID == "" || messageID == "0" {
|
||||
return false
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if _, exists := c.dedup[messageID]; exists {
|
||||
return true
|
||||
}
|
||||
|
||||
if old := c.dedupRing[c.dedupIdx]; old != "" {
|
||||
delete(c.dedup, old)
|
||||
}
|
||||
c.dedupRing[c.dedupIdx] = messageID
|
||||
c.dedup[messageID] = struct{}{}
|
||||
c.dedupIdx = (c.dedupIdx + 1) % len(c.dedupRing)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= n {
|
||||
return s
|
||||
}
|
||||
return string(runes[:n]) + "..."
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) checkGroupTrigger(content string, isBotMentioned bool) (triggered bool, strippedContent string) {
|
||||
if isBotMentioned {
|
||||
return true, strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
for _, prefix := range c.config.GroupTriggerPrefix {
|
||||
if prefix == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(content, prefix) {
|
||||
return true, strings.TrimSpace(strings.TrimPrefix(content, prefix))
|
||||
}
|
||||
}
|
||||
|
||||
return false, content
|
||||
}
|
||||
+17
-32
@@ -177,15 +177,17 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
return
|
||||
}
|
||||
|
||||
senderID := fmt.Sprintf("%d", user.ID)
|
||||
userID := fmt.Sprintf("%d", user.ID)
|
||||
senderID := userID
|
||||
if user.Username != "" {
|
||||
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
|
||||
senderID = fmt.Sprintf("%s|%s", userID, user.Username)
|
||||
}
|
||||
|
||||
// 检查白名单,避免为被拒绝的用户下载附件
|
||||
if !c.IsAllowed(senderID) {
|
||||
if !c.IsAllowed(userID) && !c.IsAllowed(senderID) {
|
||||
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
|
||||
"user_id": senderID,
|
||||
"user_id": userID,
|
||||
"username": user.Username,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -318,37 +320,14 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
}
|
||||
}
|
||||
|
||||
// Create new context for thinking animation with timeout
|
||||
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
// Create cancel function for thinking state
|
||||
_, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
|
||||
|
||||
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
|
||||
if err == nil {
|
||||
pID := pMsg.MessageID
|
||||
c.placeholders.Store(chatIDStr, pID)
|
||||
|
||||
go func(cid int64, mid int) {
|
||||
dots := []string{".", "..", "..."}
|
||||
emotes := []string{"💭", "🤔", "☁️"}
|
||||
i := 0
|
||||
ticker := time.NewTicker(2000 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-thinkCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
i++
|
||||
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
|
||||
_, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
|
||||
if editErr != nil {
|
||||
logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
|
||||
"error": editErr.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}(chatID, pID)
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -359,7 +338,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
|
||||
}
|
||||
|
||||
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
|
||||
@@ -470,8 +449,11 @@ func extractCodeBlocks(text string) codeBlockMatch {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = re.ReplaceAllStringFunc(text, func(m string) string {
|
||||
return fmt.Sprintf("\x00CB%d\x00", len(codes)-1)
|
||||
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return codeBlockMatch{text: text, codes: codes}
|
||||
@@ -491,8 +473,11 @@ func extractInlineCodes(text string) inlineCodeMatch {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = re.ReplaceAllStringFunc(text, func(m string) string {
|
||||
return fmt.Sprintf("\x00IC%d\x00", len(codes)-1)
|
||||
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return inlineCodeMatch{text: text, codes: codes}
|
||||
|
||||
+87
-24
@@ -50,6 +50,7 @@ type Config struct {
|
||||
Gateway GatewayConfig `json:"gateway"`
|
||||
Tools ToolsConfig `json:"tools"`
|
||||
Heartbeat HeartbeatConfig `json:"heartbeat"`
|
||||
Devices DevicesConfig `json:"devices"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -76,6 +77,8 @@ type ChannelsConfig struct {
|
||||
QQ QQConfig `json:"qq"`
|
||||
DingTalk DingTalkConfig `json:"dingtalk"`
|
||||
Slack SlackConfig `json:"slack"`
|
||||
LINE LINEConfig `json:"line"`
|
||||
OneBot OneBotConfig `json:"onebot"`
|
||||
}
|
||||
|
||||
type WhatsAppConfig struct {
|
||||
@@ -128,10 +131,29 @@ type DingTalkConfig struct {
|
||||
}
|
||||
|
||||
type SlackConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
|
||||
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
|
||||
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
|
||||
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
|
||||
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
|
||||
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type LINEConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"`
|
||||
ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"`
|
||||
ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"`
|
||||
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"`
|
||||
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type OneBotConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"`
|
||||
WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"`
|
||||
AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"`
|
||||
ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"`
|
||||
GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type HeartbeatConfig struct {
|
||||
@@ -139,24 +161,32 @@ type HeartbeatConfig struct {
|
||||
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
|
||||
}
|
||||
|
||||
type DevicesConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_DEVICES_ENABLED"`
|
||||
MonitorUSB bool `json:"monitor_usb" env:"PICOCLAW_DEVICES_MONITOR_USB"`
|
||||
}
|
||||
|
||||
type ProvidersConfig struct {
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI ProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI ProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
DeepSeek ProviderConfig `json:"deepseek"`
|
||||
GitHubCopilot ProviderConfig `json:"github_copilot"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
||||
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
|
||||
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
||||
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
|
||||
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
|
||||
ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc`
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
@@ -164,13 +194,20 @@ type GatewayConfig struct {
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
}
|
||||
|
||||
type WebSearchConfig struct {
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_SEARCH_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARCH_MAX_RESULTS"`
|
||||
type BraveConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type DuckDuckGoConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_ENABLED"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type WebToolsConfig struct {
|
||||
Search WebSearchConfig `json:"search"`
|
||||
Brave BraveConfig `json:"brave"`
|
||||
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
@@ -241,7 +278,24 @@ func DefaultConfig() *Config {
|
||||
Enabled: false,
|
||||
BotToken: "",
|
||||
AppToken: "",
|
||||
AllowFrom: []string{},
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
LINE: LINEConfig{
|
||||
Enabled: false,
|
||||
ChannelSecret: "",
|
||||
ChannelAccessToken: "",
|
||||
WebhookHost: "0.0.0.0",
|
||||
WebhookPort: 18791,
|
||||
WebhookPath: "/webhook/line",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
OneBot: OneBotConfig{
|
||||
Enabled: false,
|
||||
WSUrl: "ws://127.0.0.1:3001",
|
||||
AccessToken: "",
|
||||
ReconnectInterval: 5,
|
||||
GroupTriggerPrefix: []string{},
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
@@ -262,10 +316,15 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
Web: WebToolsConfig{
|
||||
Search: WebSearchConfig{
|
||||
Brave: BraveConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
DuckDuckGo: DuckDuckGoConfig{
|
||||
Enabled: true,
|
||||
MaxResults: 5,
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
|
||||
@@ -275,6 +334,10 @@ func DefaultConfig() *Config {
|
||||
Enabled: true,
|
||||
Interval: 30, // default 30 minutes
|
||||
},
|
||||
Devices: DevicesConfig{
|
||||
Enabled: false,
|
||||
MonitorUSB: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -136,11 +136,14 @@ func TestDefaultConfig_WebTools(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify web tools defaults
|
||||
if cfg.Tools.Web.Search.MaxResults != 5 {
|
||||
t.Error("Expected MaxResults 5, got ", cfg.Tools.Web.Search.MaxResults)
|
||||
if cfg.Tools.Web.Brave.MaxResults != 5 {
|
||||
t.Error("Expected Brave MaxResults 5, got ", cfg.Tools.Web.Brave.MaxResults)
|
||||
}
|
||||
if cfg.Tools.Web.Search.APIKey != "" {
|
||||
t.Error("Search API key should be empty by default")
|
||||
if cfg.Tools.Web.Brave.APIKey != "" {
|
||||
t.Error("Brave API key should be empty by default")
|
||||
}
|
||||
if cfg.Tools.Web.DuckDuckGo.MaxResults != 5 {
|
||||
t.Error("Expected DuckDuckGo MaxResults 5, got ", cfg.Tools.Web.DuckDuckGo.MaxResults)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+66
-45
@@ -71,7 +71,6 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
|
||||
cs := &CronService{
|
||||
storePath: storePath,
|
||||
onJob: onJob,
|
||||
stopChan: make(chan struct{}),
|
||||
gronx: gronx.New(),
|
||||
}
|
||||
// Initialize and load store on creation
|
||||
@@ -96,8 +95,9 @@ func (cs *CronService) Start() error {
|
||||
return fmt.Errorf("failed to save store: %w", err)
|
||||
}
|
||||
|
||||
cs.stopChan = make(chan struct{})
|
||||
cs.running = true
|
||||
go cs.runLoop()
|
||||
go cs.runLoop(cs.stopChan)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -111,16 +111,19 @@ func (cs *CronService) Stop() {
|
||||
}
|
||||
|
||||
cs.running = false
|
||||
close(cs.stopChan)
|
||||
if cs.stopChan != nil {
|
||||
close(cs.stopChan)
|
||||
cs.stopChan = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CronService) runLoop() {
|
||||
func (cs *CronService) runLoop(stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cs.stopChan:
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
cs.checkJobs()
|
||||
@@ -137,27 +140,23 @@ func (cs *CronService) checkJobs() {
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
var dueJobs []*CronJob
|
||||
var dueJobIDs []string
|
||||
|
||||
// Collect jobs that are due (we need to copy them to execute outside lock)
|
||||
for i := range cs.store.Jobs {
|
||||
job := &cs.store.Jobs[i]
|
||||
if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now {
|
||||
// Create a shallow copy of the job for execution
|
||||
jobCopy := *job
|
||||
dueJobs = append(dueJobs, &jobCopy)
|
||||
dueJobIDs = append(dueJobIDs, job.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Update next run times for due jobs immediately (before executing)
|
||||
// Use map for O(n) lookup instead of O(n²) nested loop
|
||||
dueMap := make(map[string]bool, len(dueJobs))
|
||||
for _, job := range dueJobs {
|
||||
dueMap[job.ID] = true
|
||||
// Reset next run for due jobs before unlocking to avoid duplicate execution.
|
||||
dueMap := make(map[string]bool, len(dueJobIDs))
|
||||
for _, jobID := range dueJobIDs {
|
||||
dueMap[jobID] = true
|
||||
}
|
||||
for i := range cs.store.Jobs {
|
||||
if dueMap[cs.store.Jobs[i].ID] {
|
||||
// Reset NextRunAtMS temporarily so we don't re-execute
|
||||
cs.store.Jobs[i].State.NextRunAtMS = nil
|
||||
}
|
||||
}
|
||||
@@ -168,53 +167,75 @@ func (cs *CronService) checkJobs() {
|
||||
|
||||
cs.mu.Unlock()
|
||||
|
||||
// Execute jobs outside the lock
|
||||
for _, job := range dueJobs {
|
||||
cs.executeJob(job)
|
||||
// Execute jobs outside lock.
|
||||
for _, jobID := range dueJobIDs {
|
||||
cs.executeJobByID(jobID)
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CronService) executeJob(job *CronJob) {
|
||||
func (cs *CronService) executeJobByID(jobID string) {
|
||||
startTime := time.Now().UnixMilli()
|
||||
|
||||
cs.mu.RLock()
|
||||
var callbackJob *CronJob
|
||||
for i := range cs.store.Jobs {
|
||||
job := &cs.store.Jobs[i]
|
||||
if job.ID == jobID {
|
||||
jobCopy := *job
|
||||
callbackJob = &jobCopy
|
||||
break
|
||||
}
|
||||
}
|
||||
cs.mu.RUnlock()
|
||||
|
||||
if callbackJob == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if cs.onJob != nil {
|
||||
_, err = cs.onJob(job)
|
||||
_, err = cs.onJob(callbackJob)
|
||||
}
|
||||
|
||||
// Now acquire lock to update state
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
|
||||
// Find the job in store and update it
|
||||
var job *CronJob
|
||||
for i := range cs.store.Jobs {
|
||||
if cs.store.Jobs[i].ID == job.ID {
|
||||
cs.store.Jobs[i].State.LastRunAtMS = &startTime
|
||||
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
|
||||
|
||||
if err != nil {
|
||||
cs.store.Jobs[i].State.LastStatus = "error"
|
||||
cs.store.Jobs[i].State.LastError = err.Error()
|
||||
} else {
|
||||
cs.store.Jobs[i].State.LastStatus = "ok"
|
||||
cs.store.Jobs[i].State.LastError = ""
|
||||
}
|
||||
|
||||
// Compute next run time
|
||||
if cs.store.Jobs[i].Schedule.Kind == "at" {
|
||||
if cs.store.Jobs[i].DeleteAfterRun {
|
||||
cs.removeJobUnsafe(job.ID)
|
||||
} else {
|
||||
cs.store.Jobs[i].Enabled = false
|
||||
cs.store.Jobs[i].State.NextRunAtMS = nil
|
||||
}
|
||||
} else {
|
||||
nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli())
|
||||
cs.store.Jobs[i].State.NextRunAtMS = nextRun
|
||||
}
|
||||
if cs.store.Jobs[i].ID == jobID {
|
||||
job = &cs.store.Jobs[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if job == nil {
|
||||
log.Printf("[cron] job %s disappeared before state update", jobID)
|
||||
return
|
||||
}
|
||||
|
||||
job.State.LastRunAtMS = &startTime
|
||||
job.UpdatedAtMS = time.Now().UnixMilli()
|
||||
|
||||
if err != nil {
|
||||
job.State.LastStatus = "error"
|
||||
job.State.LastError = err.Error()
|
||||
} else {
|
||||
job.State.LastStatus = "ok"
|
||||
job.State.LastError = ""
|
||||
}
|
||||
|
||||
// Compute next run time
|
||||
if job.Schedule.Kind == "at" {
|
||||
if job.DeleteAfterRun {
|
||||
cs.removeJobUnsafe(job.ID)
|
||||
} else {
|
||||
job.Enabled = false
|
||||
job.State.NextRunAtMS = nil
|
||||
}
|
||||
} else {
|
||||
nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli())
|
||||
job.State.NextRunAtMS = nextRun
|
||||
}
|
||||
|
||||
if err := cs.saveStoreUnsafe(); err != nil {
|
||||
log.Printf("[cron] failed to save store: %v", err)
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
package events
|
||||
|
||||
import "context"
|
||||
|
||||
type EventSource interface {
|
||||
Kind() Kind
|
||||
Start(ctx context.Context) (<-chan *DeviceEvent, error)
|
||||
Stop() error
|
||||
}
|
||||
|
||||
type Action string
|
||||
|
||||
const (
|
||||
ActionAdd Action = "add"
|
||||
ActionRemove Action = "remove"
|
||||
ActionChange Action = "change"
|
||||
)
|
||||
|
||||
type Kind string
|
||||
|
||||
const (
|
||||
KindUSB Kind = "usb"
|
||||
KindBluetooth Kind = "bluetooth"
|
||||
KindPCI Kind = "pci"
|
||||
KindGeneric Kind = "generic"
|
||||
)
|
||||
|
||||
type DeviceEvent struct {
|
||||
Action Action
|
||||
Kind Kind
|
||||
DeviceID string // e.g. "1-2" for USB bus 1 dev 2
|
||||
Vendor string // Vendor name or ID
|
||||
Product string // Product name or ID
|
||||
Serial string // Serial number if available
|
||||
Capabilities string // Human-readable capability description
|
||||
Raw map[string]string // Raw properties for extensibility
|
||||
}
|
||||
|
||||
func (e *DeviceEvent) FormatMessage() string {
|
||||
actionEmoji := "🔌"
|
||||
actionText := "Connected"
|
||||
if e.Action == ActionRemove {
|
||||
actionEmoji = "🔌"
|
||||
actionText = "Disconnected"
|
||||
}
|
||||
|
||||
msg := actionEmoji + " Device " + actionText + "\n\n"
|
||||
msg += "Type: " + string(e.Kind) + "\n"
|
||||
msg += "Device: " + e.Vendor + " " + e.Product + "\n"
|
||||
if e.Capabilities != "" {
|
||||
msg += "Capabilities: " + e.Capabilities + "\n"
|
||||
}
|
||||
if e.Serial != "" {
|
||||
msg += "Serial: " + e.Serial + "\n"
|
||||
}
|
||||
return msg
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package devices
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/devices/events"
|
||||
"github.com/sipeed/picoclaw/pkg/devices/sources"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
bus *bus.MessageBus
|
||||
state *state.Manager
|
||||
sources []events.EventSource
|
||||
enabled bool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
MonitorUSB bool // When true, monitor USB hotplug (Linux only)
|
||||
// Future: MonitorBluetooth, MonitorPCI, etc.
|
||||
}
|
||||
|
||||
func NewService(cfg Config, stateMgr *state.Manager) *Service {
|
||||
s := &Service{
|
||||
state: stateMgr,
|
||||
enabled: cfg.Enabled,
|
||||
sources: make([]EventSource, 0),
|
||||
}
|
||||
|
||||
if cfg.Enabled && cfg.MonitorUSB {
|
||||
s.sources = append(s.sources, sources.NewUSBMonitor())
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Service) SetBus(msgBus *bus.MessageBus) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.bus = msgBus
|
||||
}
|
||||
|
||||
func (s *Service) Start(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.enabled || len(s.sources) == 0 {
|
||||
logger.InfoC("devices", "Device event service disabled or no sources")
|
||||
return nil
|
||||
}
|
||||
|
||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||
|
||||
for _, src := range s.sources {
|
||||
eventCh, err := src.Start(s.ctx)
|
||||
if err != nil {
|
||||
logger.ErrorCF("devices", "Failed to start source", map[string]interface{}{
|
||||
"kind": src.Kind(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
go s.handleEvents(src.Kind(), eventCh)
|
||||
logger.InfoCF("devices", "Device source started", map[string]interface{}{
|
||||
"kind": src.Kind(),
|
||||
})
|
||||
}
|
||||
|
||||
logger.InfoC("devices", "Device event service started")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Stop() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
s.cancel = nil
|
||||
}
|
||||
|
||||
for _, src := range s.sources {
|
||||
src.Stop()
|
||||
}
|
||||
|
||||
logger.InfoC("devices", "Device event service stopped")
|
||||
}
|
||||
|
||||
func (s *Service) handleEvents(kind events.Kind, eventCh <-chan *events.DeviceEvent) {
|
||||
for ev := range eventCh {
|
||||
if ev == nil {
|
||||
continue
|
||||
}
|
||||
s.sendNotification(ev)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) sendNotification(ev *events.DeviceEvent) {
|
||||
s.mu.RLock()
|
||||
msgBus := s.bus
|
||||
s.mu.RUnlock()
|
||||
|
||||
if msgBus == nil {
|
||||
return
|
||||
}
|
||||
|
||||
lastChannel := s.state.GetLastChannel()
|
||||
if lastChannel == "" {
|
||||
logger.DebugCF("devices", "No last channel, skipping notification", map[string]interface{}{
|
||||
"event": ev.FormatMessage(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
platform, userID := parseLastChannel(lastChannel)
|
||||
if platform == "" || userID == "" || constants.IsInternalChannel(platform) {
|
||||
return
|
||||
}
|
||||
|
||||
msg := ev.FormatMessage()
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: platform,
|
||||
ChatID: userID,
|
||||
Content: msg,
|
||||
})
|
||||
|
||||
logger.InfoCF("devices", "Device notification sent", map[string]interface{}{
|
||||
"kind": ev.Kind,
|
||||
"action": ev.Action,
|
||||
"to": platform,
|
||||
})
|
||||
}
|
||||
|
||||
func parseLastChannel(lastChannel string) (platform, userID string) {
|
||||
if lastChannel == "" {
|
||||
return "", ""
|
||||
}
|
||||
parts := strings.SplitN(lastChannel, ":", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", ""
|
||||
}
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
package devices
|
||||
|
||||
import "github.com/sipeed/picoclaw/pkg/devices/events"
|
||||
|
||||
type EventSource = events.EventSource
|
||||
@@ -0,0 +1,198 @@
|
||||
//go:build linux
|
||||
|
||||
package sources
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/devices/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
var usbClassToCapability = map[string]string{
|
||||
"00": "Interface Definition (by interface)",
|
||||
"01": "Audio",
|
||||
"02": "CDC Communication (Network Card/Modem)",
|
||||
"03": "HID (Keyboard/Mouse/Gamepad)",
|
||||
"05": "Physical Interface",
|
||||
"06": "Image (Scanner/Camera)",
|
||||
"07": "Printer",
|
||||
"08": "Mass Storage (USB Flash Drive/Hard Disk)",
|
||||
"09": "USB Hub",
|
||||
"0a": "CDC Data",
|
||||
"0b": "Smart Card",
|
||||
"0e": "Video (Camera)",
|
||||
"dc": "Diagnostic Device",
|
||||
"e0": "Wireless Controller (Bluetooth)",
|
||||
"ef": "Miscellaneous",
|
||||
"fe": "Application Specific",
|
||||
"ff": "Vendor Specific",
|
||||
}
|
||||
|
||||
type USBMonitor struct {
|
||||
cmd *exec.Cmd
|
||||
cancel context.CancelFunc
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewUSBMonitor() *USBMonitor {
|
||||
return &USBMonitor{}
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Kind() events.Kind {
|
||||
return events.KindUSB
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Start(ctx context.Context) (<-chan *events.DeviceEvent, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// udevadm monitor outputs: UDEV/KERNEL [timestamp] action devpath (subsystem)
|
||||
// Followed by KEY=value lines, empty line separates events
|
||||
// Use -s/--subsystem-match (eudev) or --udev-subsystem-match (systemd udev)
|
||||
cmd := exec.CommandContext(ctx, "udevadm", "monitor", "--property", "--subsystem-match=usb")
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("udevadm stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("udevadm start: %w (is udevadm installed?)", err)
|
||||
}
|
||||
|
||||
m.cmd = cmd
|
||||
eventCh := make(chan *events.DeviceEvent, 16)
|
||||
|
||||
go func() {
|
||||
defer close(eventCh)
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
var props map[string]string
|
||||
var action string
|
||||
isUdev := false // Only UDEV events have complete info (ID_VENDOR, ID_MODEL); KERNEL events come first with less info
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
// End of event block - only process UDEV events (skip KERNEL to avoid duplicate/incomplete notifications)
|
||||
if isUdev && props != nil && (action == "add" || action == "remove") {
|
||||
if ev := parseUSBEvent(action, props); ev != nil {
|
||||
select {
|
||||
case eventCh <- ev:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
props = nil
|
||||
action = ""
|
||||
isUdev = false
|
||||
continue
|
||||
}
|
||||
|
||||
idx := strings.Index(line, "=")
|
||||
// First line of block: "UDEV [ts] action devpath" or "KERNEL[ts] action devpath" - no KEY=value
|
||||
if idx <= 0 {
|
||||
isUdev = strings.HasPrefix(strings.TrimSpace(line), "UDEV")
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse KEY=value
|
||||
key := line[:idx]
|
||||
val := line[idx+1:]
|
||||
if props == nil {
|
||||
props = make(map[string]string)
|
||||
}
|
||||
props[key] = val
|
||||
|
||||
if key == "ACTION" {
|
||||
action = val
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.ErrorCF("devices", "udevadm scan error", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
cmd.Wait()
|
||||
}()
|
||||
|
||||
return eventCh, nil
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Stop() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.cmd != nil && m.cmd.Process != nil {
|
||||
m.cmd.Process.Kill()
|
||||
m.cmd = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseUSBEvent(action string, props map[string]string) *events.DeviceEvent {
|
||||
// Only care about add/remove for physical devices (not interfaces)
|
||||
subsystem := props["SUBSYSTEM"]
|
||||
if subsystem != "usb" {
|
||||
return nil
|
||||
}
|
||||
// Skip interface events - we want device-level only to avoid duplicates
|
||||
devType := props["DEVTYPE"]
|
||||
if devType == "usb_interface" {
|
||||
return nil
|
||||
}
|
||||
// Prefer usb_device, but accept if DEVTYPE not set (varies by udev version)
|
||||
if devType != "" && devType != "usb_device" {
|
||||
return nil
|
||||
}
|
||||
|
||||
ev := &events.DeviceEvent{
|
||||
Raw: props,
|
||||
}
|
||||
switch action {
|
||||
case "add":
|
||||
ev.Action = events.ActionAdd
|
||||
case "remove":
|
||||
ev.Action = events.ActionRemove
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
ev.Kind = events.KindUSB
|
||||
|
||||
ev.Vendor = props["ID_VENDOR"]
|
||||
if ev.Vendor == "" {
|
||||
ev.Vendor = props["ID_VENDOR_ID"]
|
||||
}
|
||||
if ev.Vendor == "" {
|
||||
ev.Vendor = "Unknown Vendor"
|
||||
}
|
||||
|
||||
ev.Product = props["ID_MODEL"]
|
||||
if ev.Product == "" {
|
||||
ev.Product = props["ID_MODEL_ID"]
|
||||
}
|
||||
if ev.Product == "" {
|
||||
ev.Product = "Unknown Device"
|
||||
}
|
||||
|
||||
ev.Serial = props["ID_SERIAL_SHORT"]
|
||||
ev.DeviceID = props["DEVPATH"]
|
||||
if bus := props["BUSNUM"]; bus != "" {
|
||||
if dev := props["DEVNUM"]; dev != "" {
|
||||
ev.DeviceID = bus + ":" + dev
|
||||
}
|
||||
}
|
||||
|
||||
// Map USB class to capability
|
||||
if class := props["ID_USB_CLASS"]; class != "" {
|
||||
ev.Capabilities = usbClassToCapability[strings.ToLower(class)]
|
||||
}
|
||||
if ev.Capabilities == "" {
|
||||
ev.Capabilities = "USB Device"
|
||||
}
|
||||
|
||||
return ev
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
//go:build !linux
|
||||
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/devices/events"
|
||||
)
|
||||
|
||||
type USBMonitor struct{}
|
||||
|
||||
func NewUSBMonitor() *USBMonitor {
|
||||
return &USBMonitor{}
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Kind() events.Kind {
|
||||
return events.KindUSB
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Start(ctx context.Context) (<-chan *events.DeviceEvent, error) {
|
||||
ch := make(chan *events.DeviceEvent)
|
||||
close(ch) // Immediately close, no events
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Stop() error {
|
||||
return nil
|
||||
}
|
||||
+12
-12
@@ -40,7 +40,6 @@ type HeartbeatService struct {
|
||||
interval time.Duration
|
||||
enabled bool
|
||||
mu sync.RWMutex
|
||||
started bool
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
@@ -60,7 +59,6 @@ func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *H
|
||||
interval: time.Duration(intervalMinutes) * time.Minute,
|
||||
enabled: enabled,
|
||||
state: state.NewManager(workspace),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,7 +81,7 @@ func (hs *HeartbeatService) Start() error {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
|
||||
if hs.started {
|
||||
if hs.stopChan != nil {
|
||||
logger.InfoC("heartbeat", "Heartbeat service already running")
|
||||
return nil
|
||||
}
|
||||
@@ -93,10 +91,8 @@ func (hs *HeartbeatService) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
hs.started = true
|
||||
hs.stopChan = make(chan struct{})
|
||||
|
||||
go hs.runLoop()
|
||||
go hs.runLoop(hs.stopChan)
|
||||
|
||||
logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{
|
||||
"interval_minutes": hs.interval.Minutes(),
|
||||
@@ -110,24 +106,24 @@ func (hs *HeartbeatService) Stop() {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
|
||||
if !hs.started {
|
||||
if hs.stopChan == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoC("heartbeat", "Stopping heartbeat service")
|
||||
close(hs.stopChan)
|
||||
hs.started = false
|
||||
hs.stopChan = nil
|
||||
}
|
||||
|
||||
// IsRunning returns whether the service is running
|
||||
func (hs *HeartbeatService) IsRunning() bool {
|
||||
hs.mu.RLock()
|
||||
defer hs.mu.RUnlock()
|
||||
return hs.started
|
||||
return hs.stopChan != nil
|
||||
}
|
||||
|
||||
// runLoop runs the heartbeat ticker
|
||||
func (hs *HeartbeatService) runLoop() {
|
||||
func (hs *HeartbeatService) runLoop(stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(hs.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -138,7 +134,7 @@ func (hs *HeartbeatService) runLoop() {
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hs.stopChan:
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
hs.executeHeartbeat()
|
||||
@@ -149,8 +145,12 @@ func (hs *HeartbeatService) runLoop() {
|
||||
// executeHeartbeat performs a single heartbeat check
|
||||
func (hs *HeartbeatService) executeHeartbeat() {
|
||||
hs.mu.RLock()
|
||||
enabled := hs.enabled && hs.started
|
||||
enabled := hs.enabled
|
||||
handler := hs.handler
|
||||
if !hs.enabled || hs.stopChan == nil {
|
||||
hs.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
hs.mu.RUnlock()
|
||||
|
||||
if !enabled {
|
||||
|
||||
@@ -17,7 +17,7 @@ func TestExecuteHeartbeat_Async(t *testing.T) {
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
asyncCalled := false
|
||||
asyncResult := &tools.ToolResult{
|
||||
@@ -55,7 +55,7 @@ func TestExecuteHeartbeat_Error(t *testing.T) {
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
@@ -93,7 +93,7 @@ func TestExecuteHeartbeat_Silent(t *testing.T) {
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
@@ -167,7 +167,7 @@ func TestExecuteHeartbeat_NilResult(t *testing.T) {
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return nil
|
||||
|
||||
@@ -212,12 +212,17 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error
|
||||
|
||||
if tools, ok := getMap(data, "tools"); ok {
|
||||
if web, ok := getMap(tools, "web"); ok {
|
||||
// Migrate old "search" config to "brave" if api_key is present
|
||||
if search, ok := getMap(web, "search"); ok {
|
||||
if v, ok := getString(search, "api_key"); ok {
|
||||
cfg.Tools.Web.Search.APIKey = v
|
||||
cfg.Tools.Web.Brave.APIKey = v
|
||||
if v != "" {
|
||||
cfg.Tools.Web.Brave.Enabled = true
|
||||
}
|
||||
}
|
||||
if v, ok := getFloat(search, "max_results"); ok {
|
||||
cfg.Tools.Web.Search.MaxResults = int(v)
|
||||
cfg.Tools.Web.Brave.MaxResults = int(v)
|
||||
cfg.Tools.Web.DuckDuckGo.MaxResults = int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -271,8 +276,8 @@ func MergeConfig(existing, incoming *config.Config) *config.Config {
|
||||
existing.Channels.MaixCam = incoming.Channels.MaixCam
|
||||
}
|
||||
|
||||
if existing.Tools.Web.Search.APIKey == "" {
|
||||
existing.Tools.Web.Search = incoming.Tools.Web.Search
|
||||
if existing.Tools.Web.Brave.APIKey == "" {
|
||||
existing.Tools.Web.Brave = incoming.Tools.Web.Brave
|
||||
}
|
||||
|
||||
return existing
|
||||
|
||||
@@ -18,6 +18,8 @@ type CodexProvider struct {
|
||||
tokenSource func() (string, string, error)
|
||||
}
|
||||
|
||||
const defaultCodexInstructions = "You are Codex, a coding assistant."
|
||||
|
||||
func NewCodexProvider(token, accountID string) *CodexProvider {
|
||||
opts := []option.RequestOption{
|
||||
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
|
||||
@@ -138,6 +140,9 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
|
||||
if instructions != "" {
|
||||
params.Instructions = openai.Opt(instructions)
|
||||
} else {
|
||||
// ChatGPT Codex backend requires instructions to be present.
|
||||
params.Instructions = openai.Opt(defaultCodexInstructions)
|
||||
}
|
||||
|
||||
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||
|
||||
@@ -21,6 +21,12 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
if params.Model != "gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||
}
|
||||
if !params.Instructions.Valid() {
|
||||
t.Fatal("Instructions should be set")
|
||||
}
|
||||
if params.Instructions.Or("") != defaultCodexInstructions {
|
||||
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), defaultCodexInstructions)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
json "encoding/json"
|
||||
|
||||
copilot "github.com/github/copilot-sdk/go"
|
||||
)
|
||||
|
||||
type GitHubCopilotProvider struct {
|
||||
uri string
|
||||
connectMode string // `stdio` or `grpc``
|
||||
|
||||
session *copilot.Session
|
||||
}
|
||||
|
||||
func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*GitHubCopilotProvider, error) {
|
||||
|
||||
var session *copilot.Session
|
||||
if connectMode == "" {
|
||||
connectMode = "grpc"
|
||||
}
|
||||
switch connectMode {
|
||||
|
||||
case "stdio":
|
||||
//todo
|
||||
case "grpc":
|
||||
client := copilot.NewClient(&copilot.ClientOptions{
|
||||
CLIUrl: uri,
|
||||
})
|
||||
if err := client.Start(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("Can't connect to Github Copilot, https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server for details")
|
||||
}
|
||||
defer client.Stop()
|
||||
session, _ = client.CreateSession(context.Background(), &copilot.SessionConfig{
|
||||
Model: model,
|
||||
Hooks: &copilot.SessionHooks{},
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return &GitHubCopilotProvider{
|
||||
uri: uri,
|
||||
connectMode: connectMode,
|
||||
session: session,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Chat sends a chat request to GitHub Copilot
|
||||
func (p *GitHubCopilotProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
type tempMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
out := make([]tempMessage, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
out = append(out, tempMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
fullcontent, _ := json.Marshal(out)
|
||||
|
||||
content, _ := p.session.Send(ctx, copilot.MessageOptions{
|
||||
Prompt: string(fullcontent),
|
||||
})
|
||||
|
||||
return &LLMResponse{
|
||||
FinishReason: "stop",
|
||||
Content: content,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *GitHubCopilotProvider) GetDefaultModel() string {
|
||||
|
||||
return "gpt-4.1"
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -28,7 +29,7 @@ type HTTPProvider struct {
|
||||
|
||||
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
client := &http.Client{
|
||||
Timeout: 0,
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
@@ -42,7 +43,7 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
|
||||
return &HTTPProvider{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
@@ -116,7 +117,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error: %s", string(body))
|
||||
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return p.parseResponse(body)
|
||||
@@ -303,7 +304,27 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), nil
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
if model != "deepseek-chat" && model != "deepseek-reasoner" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
apiBase = "localhost:4321"
|
||||
}
|
||||
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Fallback: detect provider from model name
|
||||
|
||||
+96
-19
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -39,22 +40,22 @@ func NewSessionManager(storage string) *SessionManager {
|
||||
}
|
||||
|
||||
func (sm *SessionManager) GetOrCreate(key string) *Session {
|
||||
sm.mu.RLock()
|
||||
session, ok := sm.sessions[key]
|
||||
sm.mu.RUnlock()
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
sm.mu.Lock()
|
||||
session = &Session{
|
||||
Key: key,
|
||||
Messages: []providers.Message{},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
}
|
||||
sm.sessions[key] = session
|
||||
sm.mu.Unlock()
|
||||
session, ok := sm.sessions[key]
|
||||
if ok {
|
||||
return session
|
||||
}
|
||||
|
||||
session = &Session{
|
||||
Key: key,
|
||||
Messages: []providers.Message{},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
}
|
||||
sm.sessions[key] = session
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
@@ -130,6 +131,12 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
|
||||
return
|
||||
}
|
||||
|
||||
if keepLast <= 0 {
|
||||
session.Messages = []providers.Message{}
|
||||
session.Updated = time.Now()
|
||||
return
|
||||
}
|
||||
|
||||
if len(session.Messages) <= keepLast {
|
||||
return
|
||||
}
|
||||
@@ -138,22 +145,92 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
|
||||
session.Updated = time.Now()
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Save(session *Session) error {
|
||||
// sanitizeFilename converts a session key into a cross-platform safe filename.
|
||||
// Session keys use "channel:chatID" (e.g. "telegram:123456") but ':' is the
|
||||
// volume separator on Windows, so filepath.Base would misinterpret the key.
|
||||
// We replace it with '_'. The original key is preserved inside the JSON file,
|
||||
// so loadSessions still maps back to the right in-memory key.
|
||||
func sanitizeFilename(key string) string {
|
||||
return strings.ReplaceAll(key, ":", "_")
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Save(key string) error {
|
||||
if sm.storage == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
filename := sanitizeFilename(key)
|
||||
|
||||
sessionPath := filepath.Join(sm.storage, session.Key+".json")
|
||||
// filepath.IsLocal rejects empty names, "..", absolute paths, and
|
||||
// OS-reserved device names (NUL, COM1 … on Windows).
|
||||
// The extra checks reject "." and any directory separators so that
|
||||
// the session file is always written directly inside sm.storage.
|
||||
if filename == "." || !filepath.IsLocal(filename) || strings.ContainsAny(filename, `/\`) {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(session, "", " ")
|
||||
// Snapshot under read lock, then perform slow file I/O after unlock.
|
||||
sm.mu.RLock()
|
||||
stored, ok := sm.sessions[key]
|
||||
if !ok {
|
||||
sm.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
snapshot := Session{
|
||||
Key: stored.Key,
|
||||
Summary: stored.Summary,
|
||||
Created: stored.Created,
|
||||
Updated: stored.Updated,
|
||||
}
|
||||
if len(stored.Messages) > 0 {
|
||||
snapshot.Messages = make([]providers.Message, len(stored.Messages))
|
||||
copy(snapshot.Messages, stored.Messages)
|
||||
} else {
|
||||
snapshot.Messages = []providers.Message{}
|
||||
}
|
||||
sm.mu.RUnlock()
|
||||
|
||||
data, err := json.MarshalIndent(snapshot, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(sessionPath, data, 0644)
|
||||
sessionPath := filepath.Join(sm.storage, filename+".json")
|
||||
tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmpPath := tmpFile.Name()
|
||||
cleanup := true
|
||||
defer func() {
|
||||
if cleanup {
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := tmpFile.Write(data); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Chmod(0644); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Sync(); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, sessionPath); err != nil {
|
||||
return err
|
||||
}
|
||||
cleanup = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *SessionManager) loadSessions() error {
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"simple", "simple"},
|
||||
{"telegram:123456", "telegram_123456"},
|
||||
{"discord:987654321", "discord_987654321"},
|
||||
{"slack:C01234", "slack_C01234"},
|
||||
{"no-colons-here", "no-colons-here"},
|
||||
{"multiple:colons:here", "multiple_colons_here"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := sanitizeFilename(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("sanitizeFilename(%q) = %q, want %q", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_WithColonInKey(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
sm := NewSessionManager(tmpDir)
|
||||
|
||||
// Create a session with a key containing colon (typical channel session key).
|
||||
key := "telegram:123456"
|
||||
sm.GetOrCreate(key)
|
||||
sm.AddMessage(key, "user", "hello")
|
||||
|
||||
// Save should succeed even though the key contains ':'
|
||||
if err := sm.Save(key); err != nil {
|
||||
t.Fatalf("Save(%q) failed: %v", key, err)
|
||||
}
|
||||
|
||||
// The file on disk should use sanitized name.
|
||||
expectedFile := filepath.Join(tmpDir, "telegram_123456.json")
|
||||
if _, err := os.Stat(expectedFile); os.IsNotExist(err) {
|
||||
t.Fatalf("expected session file %s to exist", expectedFile)
|
||||
}
|
||||
|
||||
// Load into a fresh manager and verify the session round-trips.
|
||||
sm2 := NewSessionManager(tmpDir)
|
||||
history := sm2.GetHistory(key)
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1 message after reload, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "hello" {
|
||||
t.Errorf("expected message content %q, got %q", "hello", history[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_RejectsPathTraversal(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
sm := NewSessionManager(tmpDir)
|
||||
|
||||
badKeys := []string{"", ".", "..", "foo/bar", "foo\\bar"}
|
||||
for _, key := range badKeys {
|
||||
sm.GetOrCreate(key)
|
||||
if err := sm.Save(key); err == nil {
|
||||
t.Errorf("Save(%q) should have failed but didn't", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// I2CTool provides I2C bus interaction for reading sensors and controlling peripherals.
|
||||
type I2CTool struct{}
|
||||
|
||||
func NewI2CTool() *I2CTool {
|
||||
return &I2CTool{}
|
||||
}
|
||||
|
||||
func (t *I2CTool) Name() string {
|
||||
return "i2c"
|
||||
}
|
||||
|
||||
func (t *I2CTool) Description() string {
|
||||
return "Interact with I2C bus devices for reading sensors and controlling peripherals. Actions: detect (list buses), scan (find devices on a bus), read (read bytes from device), write (send bytes to device). Linux only."
|
||||
}
|
||||
|
||||
func (t *I2CTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"detect", "scan", "read", "write"},
|
||||
"description": "Action to perform: detect (list available I2C buses), scan (find devices on a bus), read (read bytes from a device), write (send bytes to a device)",
|
||||
},
|
||||
"bus": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "I2C bus number (e.g. \"1\" for /dev/i2c-1). Required for scan/read/write.",
|
||||
},
|
||||
"address": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "7-bit I2C device address (0x03-0x77). Required for read/write.",
|
||||
},
|
||||
"register": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Register address to read from or write to. If set, sends register byte before read/write.",
|
||||
},
|
||||
"data": map[string]interface{}{
|
||||
"type": "array",
|
||||
"items": map[string]interface{}{"type": "integer"},
|
||||
"description": "Bytes to write (0-255 each). Required for write action.",
|
||||
},
|
||||
"length": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Number of bytes to read (1-256). Default: 1. Used with read action.",
|
||||
},
|
||||
"confirm": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "Must be true for write operations. Safety guard to prevent accidental writes.",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *I2CTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
if runtime.GOOS != "linux" {
|
||||
return ErrorResult("I2C is only supported on Linux. This tool requires /dev/i2c-* device files.")
|
||||
}
|
||||
|
||||
action, ok := args["action"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("action is required")
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "detect":
|
||||
return t.detect()
|
||||
case "scan":
|
||||
return t.scan(args)
|
||||
case "read":
|
||||
return t.readDevice(args)
|
||||
case "write":
|
||||
return t.writeDevice(args)
|
||||
default:
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s (valid: detect, scan, read, write)", action))
|
||||
}
|
||||
}
|
||||
|
||||
// detect lists available I2C buses by globbing /dev/i2c-*
|
||||
func (t *I2CTool) detect() *ToolResult {
|
||||
matches, err := filepath.Glob("/dev/i2c-*")
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to scan for I2C buses: %v", err))
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return SilentResult("No I2C buses found. You may need to:\n1. Load the i2c-dev module: modprobe i2c-dev\n2. Check that I2C is enabled in device tree\n3. Configure pinmux for your board (see hardware skill)")
|
||||
}
|
||||
|
||||
type busInfo struct {
|
||||
Path string `json:"path"`
|
||||
Bus string `json:"bus"`
|
||||
}
|
||||
|
||||
buses := make([]busInfo, 0, len(matches))
|
||||
re := regexp.MustCompile(`/dev/i2c-(\d+)`)
|
||||
for _, m := range matches {
|
||||
if sub := re.FindStringSubmatch(m); sub != nil {
|
||||
buses = append(buses, busInfo{Path: m, Bus: sub[1]})
|
||||
}
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(buses, "", " ")
|
||||
return SilentResult(fmt.Sprintf("Found %d I2C bus(es):\n%s", len(buses), string(result)))
|
||||
}
|
||||
|
||||
// isValidBusID checks that a bus identifier is a simple number (prevents path injection)
|
||||
func isValidBusID(id string) bool {
|
||||
matched, _ := regexp.MatchString(`^\d+$`, id)
|
||||
return matched
|
||||
}
|
||||
|
||||
// parseI2CAddress extracts and validates an I2C address from args
|
||||
func parseI2CAddress(args map[string]interface{}) (int, *ToolResult) {
|
||||
addrFloat, ok := args["address"].(float64)
|
||||
if !ok {
|
||||
return 0, ErrorResult("address is required (e.g. 0x38 for AHT20)")
|
||||
}
|
||||
addr := int(addrFloat)
|
||||
if addr < 0x03 || addr > 0x77 {
|
||||
return 0, ErrorResult("address must be in valid 7-bit range (0x03-0x77)")
|
||||
}
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
// parseI2CBus extracts and validates an I2C bus from args
|
||||
func parseI2CBus(args map[string]interface{}) (string, *ToolResult) {
|
||||
bus, ok := args["bus"].(string)
|
||||
if !ok || bus == "" {
|
||||
return "", ErrorResult("bus is required (e.g. \"1\" for /dev/i2c-1)")
|
||||
}
|
||||
if !isValidBusID(bus) {
|
||||
return "", ErrorResult("invalid bus identifier: must be a number (e.g. \"1\")")
|
||||
}
|
||||
return bus, nil
|
||||
}
|
||||
@@ -0,0 +1,282 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// I2C ioctl constants from Linux kernel headers (<linux/i2c-dev.h>, <linux/i2c.h>)
|
||||
const (
|
||||
i2cSlave = 0x0703 // Set slave address (fails if in use by driver)
|
||||
i2cFuncs = 0x0705 // Query adapter functionality bitmask
|
||||
i2cSmbus = 0x0720 // Perform SMBus transaction
|
||||
|
||||
// I2C_FUNC capability bits
|
||||
i2cFuncSmbusQuick = 0x00010000
|
||||
i2cFuncSmbusReadByte = 0x00020000
|
||||
|
||||
// SMBus transaction types
|
||||
i2cSmbusRead = 0
|
||||
i2cSmbusWrite = 1
|
||||
|
||||
// SMBus protocol sizes
|
||||
i2cSmbusQuick = 0
|
||||
i2cSmbusByte = 1
|
||||
)
|
||||
|
||||
// i2cSmbusData matches the kernel union i2c_smbus_data (34 bytes max).
|
||||
// For quick and byte transactions only the first byte is used (if at all).
|
||||
type i2cSmbusData [34]byte
|
||||
|
||||
// i2cSmbusArgs matches the kernel struct i2c_smbus_ioctl_data.
|
||||
type i2cSmbusArgs struct {
|
||||
readWrite uint8
|
||||
command uint8
|
||||
size uint32
|
||||
data *i2cSmbusData
|
||||
}
|
||||
|
||||
// smbusProbe performs a single SMBus probe at the given address.
|
||||
// Uses SMBus Quick Write (safest) or falls back to SMBus Read Byte for
|
||||
// EEPROM address ranges where quick write can corrupt AT24RF08 chips.
|
||||
// This matches i2cdetect's MODE_AUTO behavior.
|
||||
func smbusProbe(fd int, addr int, hasQuick bool) bool {
|
||||
// EEPROM ranges: use read byte (quick write can corrupt AT24RF08)
|
||||
useReadByte := (addr >= 0x30 && addr <= 0x37) || (addr >= 0x50 && addr <= 0x5F)
|
||||
|
||||
if !useReadByte && hasQuick {
|
||||
// SMBus Quick Write: [START] [ADDR|W] [ACK/NACK] [STOP]
|
||||
// Safest probe — no data transferred
|
||||
args := i2cSmbusArgs{
|
||||
readWrite: i2cSmbusWrite,
|
||||
command: 0,
|
||||
size: i2cSmbusQuick,
|
||||
data: nil,
|
||||
}
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args)))
|
||||
return errno == 0
|
||||
}
|
||||
|
||||
// SMBus Read Byte: [START] [ADDR|R] [ACK/NACK] [DATA] [STOP]
|
||||
var data i2cSmbusData
|
||||
args := i2cSmbusArgs{
|
||||
readWrite: i2cSmbusRead,
|
||||
command: 0,
|
||||
size: i2cSmbusByte,
|
||||
data: &data,
|
||||
}
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args)))
|
||||
return errno == 0
|
||||
}
|
||||
|
||||
// scan probes valid 7-bit addresses on a bus for connected devices.
|
||||
// Uses the same hybrid probe strategy as i2cdetect's MODE_AUTO:
|
||||
// SMBus Quick Write for most addresses, SMBus Read Byte for EEPROM ranges.
|
||||
func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
|
||||
bus, errResult := parseI2CBus(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
devPath := fmt.Sprintf("/dev/i2c-%s", bus)
|
||||
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and i2c-dev module)", devPath, err))
|
||||
}
|
||||
defer syscall.Close(fd)
|
||||
|
||||
// Query adapter capabilities to determine available probe methods.
|
||||
// I2C_FUNCS writes an unsigned long, which is word-sized on Linux.
|
||||
var funcs uintptr
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cFuncs, uintptr(unsafe.Pointer(&funcs)))
|
||||
if errno != 0 {
|
||||
return ErrorResult(fmt.Sprintf("failed to query I2C adapter capabilities on %s: %v", devPath, errno))
|
||||
}
|
||||
|
||||
hasQuick := funcs&i2cFuncSmbusQuick != 0
|
||||
hasReadByte := funcs&i2cFuncSmbusReadByte != 0
|
||||
|
||||
if !hasQuick && !hasReadByte {
|
||||
return ErrorResult(fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath))
|
||||
}
|
||||
|
||||
type deviceEntry struct {
|
||||
Address string `json:"address"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
var found []deviceEntry
|
||||
// Scan 0x08-0x77, skipping I2C reserved addresses 0x00-0x07
|
||||
for addr := 0x08; addr <= 0x77; addr++ {
|
||||
// Set slave address — EBUSY means a kernel driver owns this address
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSlave, uintptr(addr))
|
||||
if errno != 0 {
|
||||
if errno == syscall.EBUSY {
|
||||
found = append(found, deviceEntry{
|
||||
Address: fmt.Sprintf("0x%02x", addr),
|
||||
Status: "busy (in use by kernel driver)",
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if smbusProbe(fd, addr, hasQuick) {
|
||||
found = append(found, deviceEntry{
|
||||
Address: fmt.Sprintf("0x%02x", addr),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(found) == 0 {
|
||||
return SilentResult(fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath))
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]interface{}{
|
||||
"bus": devPath,
|
||||
"devices": found,
|
||||
"count": len(found),
|
||||
}, "", " ")
|
||||
return SilentResult(fmt.Sprintf("Scan of %s:\n%s", devPath, string(result)))
|
||||
}
|
||||
|
||||
// readDevice reads bytes from an I2C device, optionally at a specific register
|
||||
func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
bus, errResult := parseI2CBus(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
addr, errResult := parseI2CAddress(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
length := 1
|
||||
if l, ok := args["length"].(float64); ok {
|
||||
length = int(l)
|
||||
}
|
||||
if length < 1 || length > 256 {
|
||||
return ErrorResult("length must be between 1 and 256")
|
||||
}
|
||||
|
||||
devPath := fmt.Sprintf("/dev/i2c-%s", bus)
|
||||
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to open %s: %v", devPath, err))
|
||||
}
|
||||
defer syscall.Close(fd)
|
||||
|
||||
// Set slave address
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSlave, uintptr(addr))
|
||||
if errno != 0 {
|
||||
return ErrorResult(fmt.Sprintf("failed to set I2C address 0x%02x: %v", addr, errno))
|
||||
}
|
||||
|
||||
// If register is specified, write it first
|
||||
if regFloat, ok := args["register"].(float64); ok {
|
||||
reg := int(regFloat)
|
||||
if reg < 0 || reg > 255 {
|
||||
return ErrorResult("register must be between 0x00 and 0xFF")
|
||||
}
|
||||
_, err := syscall.Write(fd, []byte{byte(reg)})
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to write register 0x%02x: %v", reg, err))
|
||||
}
|
||||
}
|
||||
|
||||
// Read data
|
||||
buf := make([]byte, length)
|
||||
n, err := syscall.Read(fd, buf)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to read from device 0x%02x: %v", addr, err))
|
||||
}
|
||||
|
||||
// Format as hex bytes
|
||||
hexBytes := make([]string, n)
|
||||
intBytes := make([]int, n)
|
||||
for i := 0; i < n; i++ {
|
||||
hexBytes[i] = fmt.Sprintf("0x%02x", buf[i])
|
||||
intBytes[i] = int(buf[i])
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]interface{}{
|
||||
"bus": devPath,
|
||||
"address": fmt.Sprintf("0x%02x", addr),
|
||||
"bytes": intBytes,
|
||||
"hex": hexBytes,
|
||||
"length": n,
|
||||
}, "", " ")
|
||||
return SilentResult(string(result))
|
||||
}
|
||||
|
||||
// writeDevice writes bytes to an I2C device, optionally at a specific register
|
||||
func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
|
||||
confirm, _ := args["confirm"].(bool)
|
||||
if !confirm {
|
||||
return ErrorResult("write operations require confirm: true. Please confirm with the user before writing to I2C devices, as incorrect writes can misconfigure hardware.")
|
||||
}
|
||||
|
||||
bus, errResult := parseI2CBus(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
addr, errResult := parseI2CAddress(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
dataRaw, ok := args["data"].([]interface{})
|
||||
if !ok || len(dataRaw) == 0 {
|
||||
return ErrorResult("data is required for write (array of byte values 0-255)")
|
||||
}
|
||||
if len(dataRaw) > 256 {
|
||||
return ErrorResult("data too long: maximum 256 bytes per I2C transaction")
|
||||
}
|
||||
|
||||
data := make([]byte, 0, len(dataRaw)+1)
|
||||
|
||||
// If register is specified, prepend it to the data
|
||||
if regFloat, ok := args["register"].(float64); ok {
|
||||
reg := int(regFloat)
|
||||
if reg < 0 || reg > 255 {
|
||||
return ErrorResult("register must be between 0x00 and 0xFF")
|
||||
}
|
||||
data = append(data, byte(reg))
|
||||
}
|
||||
|
||||
for i, v := range dataRaw {
|
||||
f, ok := v.(float64)
|
||||
if !ok {
|
||||
return ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i))
|
||||
}
|
||||
b := int(f)
|
||||
if b < 0 || b > 255 {
|
||||
return ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b))
|
||||
}
|
||||
data = append(data, byte(b))
|
||||
}
|
||||
|
||||
devPath := fmt.Sprintf("/dev/i2c-%s", bus)
|
||||
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to open %s: %v", devPath, err))
|
||||
}
|
||||
defer syscall.Close(fd)
|
||||
|
||||
// Set slave address
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSlave, uintptr(addr))
|
||||
if errno != 0 {
|
||||
return ErrorResult(fmt.Sprintf("failed to set I2C address 0x%02x: %v", addr, errno))
|
||||
}
|
||||
|
||||
// Write data
|
||||
n, err := syscall.Write(fd, data)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to write to device 0x%02x: %v", addr, err))
|
||||
}
|
||||
|
||||
return SilentResult(fmt.Sprintf("Wrote %d byte(s) to device 0x%02x on %s", n, addr, devPath))
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
//go:build !linux
|
||||
|
||||
package tools
|
||||
|
||||
// scan is a stub for non-Linux platforms.
|
||||
func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
|
||||
return ErrorResult("I2C is only supported on Linux")
|
||||
}
|
||||
|
||||
// readDevice is a stub for non-Linux platforms.
|
||||
func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
return ErrorResult("I2C is only supported on Linux")
|
||||
}
|
||||
|
||||
// writeDevice is a stub for non-Linux platforms.
|
||||
func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
|
||||
return ErrorResult("I2C is only supported on Linux")
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// SPITool provides SPI bus interaction for high-speed peripheral communication.
|
||||
type SPITool struct{}
|
||||
|
||||
func NewSPITool() *SPITool {
|
||||
return &SPITool{}
|
||||
}
|
||||
|
||||
func (t *SPITool) Name() string {
|
||||
return "spi"
|
||||
}
|
||||
|
||||
func (t *SPITool) Description() string {
|
||||
return "Interact with SPI bus devices for high-speed peripheral communication. Actions: list (find SPI devices), transfer (full-duplex send/receive), read (receive bytes). Linux only."
|
||||
}
|
||||
|
||||
func (t *SPITool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"list", "transfer", "read"},
|
||||
"description": "Action to perform: list (find available SPI devices), transfer (full-duplex send/receive), read (receive bytes by sending zeros)",
|
||||
},
|
||||
"device": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "SPI device identifier (e.g. \"2.0\" for /dev/spidev2.0). Required for transfer/read.",
|
||||
},
|
||||
"speed": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "SPI clock speed in Hz. Default: 1000000 (1 MHz).",
|
||||
},
|
||||
"mode": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "SPI mode (0-3). Default: 0. Mode sets CPOL and CPHA: 0=0,0 1=0,1 2=1,0 3=1,1.",
|
||||
},
|
||||
"bits": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Bits per word. Default: 8.",
|
||||
},
|
||||
"data": map[string]interface{}{
|
||||
"type": "array",
|
||||
"items": map[string]interface{}{"type": "integer"},
|
||||
"description": "Bytes to send (0-255 each). Required for transfer action.",
|
||||
},
|
||||
"length": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Number of bytes to read (1-4096). Required for read action.",
|
||||
},
|
||||
"confirm": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "Must be true for transfer operations. Safety guard to prevent accidental writes.",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SPITool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
if runtime.GOOS != "linux" {
|
||||
return ErrorResult("SPI is only supported on Linux. This tool requires /dev/spidev* device files.")
|
||||
}
|
||||
|
||||
action, ok := args["action"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("action is required")
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "list":
|
||||
return t.list()
|
||||
case "transfer":
|
||||
return t.transfer(args)
|
||||
case "read":
|
||||
return t.readDevice(args)
|
||||
default:
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s (valid: list, transfer, read)", action))
|
||||
}
|
||||
}
|
||||
|
||||
// list finds available SPI devices by globbing /dev/spidev*
|
||||
func (t *SPITool) list() *ToolResult {
|
||||
matches, err := filepath.Glob("/dev/spidev*")
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to scan for SPI devices: %v", err))
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return SilentResult("No SPI devices found. You may need to:\n1. Enable SPI in device tree\n2. Configure pinmux for your board (see hardware skill)\n3. Check that spidev module is loaded")
|
||||
}
|
||||
|
||||
type devInfo struct {
|
||||
Path string `json:"path"`
|
||||
Device string `json:"device"`
|
||||
}
|
||||
|
||||
devices := make([]devInfo, 0, len(matches))
|
||||
re := regexp.MustCompile(`/dev/spidev(\d+\.\d+)`)
|
||||
for _, m := range matches {
|
||||
if sub := re.FindStringSubmatch(m); sub != nil {
|
||||
devices = append(devices, devInfo{Path: m, Device: sub[1]})
|
||||
}
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(devices, "", " ")
|
||||
return SilentResult(fmt.Sprintf("Found %d SPI device(s):\n%s", len(devices), string(result)))
|
||||
}
|
||||
|
||||
// parseSPIArgs extracts and validates common SPI parameters
|
||||
func parseSPIArgs(args map[string]interface{}) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
|
||||
dev, ok := args["device"].(string)
|
||||
if !ok || dev == "" {
|
||||
return "", 0, 0, 0, "device is required (e.g. \"2.0\" for /dev/spidev2.0)"
|
||||
}
|
||||
matched, _ := regexp.MatchString(`^\d+\.\d+$`, dev)
|
||||
if !matched {
|
||||
return "", 0, 0, 0, "invalid device identifier: must be in format \"X.Y\" (e.g. \"2.0\")"
|
||||
}
|
||||
|
||||
speed = 1000000 // default 1 MHz
|
||||
if s, ok := args["speed"].(float64); ok {
|
||||
if s < 1 || s > 125000000 {
|
||||
return "", 0, 0, 0, "speed must be between 1 Hz and 125 MHz"
|
||||
}
|
||||
speed = uint32(s)
|
||||
}
|
||||
|
||||
mode = 0
|
||||
if m, ok := args["mode"].(float64); ok {
|
||||
if int(m) < 0 || int(m) > 3 {
|
||||
return "", 0, 0, 0, "mode must be 0-3"
|
||||
}
|
||||
mode = uint8(m)
|
||||
}
|
||||
|
||||
bits = 8
|
||||
if b, ok := args["bits"].(float64); ok {
|
||||
if int(b) < 1 || int(b) > 32 {
|
||||
return "", 0, 0, 0, "bits must be between 1 and 32"
|
||||
}
|
||||
bits = uint8(b)
|
||||
}
|
||||
|
||||
return dev, speed, mode, bits, ""
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// SPI ioctl constants from Linux kernel headers.
|
||||
// Calculated from _IOW('k', nr, size) macro:
|
||||
//
|
||||
// direction(1)<<30 | size<<16 | type(0x6B)<<8 | nr
|
||||
const (
|
||||
spiIocWrMode = 0x40016B01 // _IOW('k', 1, __u8)
|
||||
spiIocWrBitsPerWord = 0x40016B03 // _IOW('k', 3, __u8)
|
||||
spiIocWrMaxSpeedHz = 0x40046B04 // _IOW('k', 4, __u32)
|
||||
spiIocMessage1 = 0x40206B00 // _IOW('k', 0, struct spi_ioc_transfer) — 32 bytes
|
||||
)
|
||||
|
||||
// spiTransfer matches Linux kernel struct spi_ioc_transfer (32 bytes on all architectures).
|
||||
type spiTransfer struct {
|
||||
txBuf uint64
|
||||
rxBuf uint64
|
||||
length uint32
|
||||
speedHz uint32
|
||||
delayUsecs uint16
|
||||
bitsPerWord uint8
|
||||
csChange uint8
|
||||
txNbits uint8
|
||||
rxNbits uint8
|
||||
wordDelay uint8
|
||||
pad uint8
|
||||
}
|
||||
|
||||
// configureSPI opens an SPI device and sets mode, bits per word, and speed
|
||||
func configureSPI(devPath string, mode uint8, bits uint8, speed uint32) (int, *ToolResult) {
|
||||
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and spidev module)", devPath, err))
|
||||
}
|
||||
|
||||
// Set SPI mode
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMode, uintptr(unsafe.Pointer(&mode)))
|
||||
if errno != 0 {
|
||||
syscall.Close(fd)
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to set SPI mode %d: %v", mode, errno))
|
||||
}
|
||||
|
||||
// Set bits per word
|
||||
_, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrBitsPerWord, uintptr(unsafe.Pointer(&bits)))
|
||||
if errno != 0 {
|
||||
syscall.Close(fd)
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to set bits per word %d: %v", bits, errno))
|
||||
}
|
||||
|
||||
// Set max speed
|
||||
_, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMaxSpeedHz, uintptr(unsafe.Pointer(&speed)))
|
||||
if errno != 0 {
|
||||
syscall.Close(fd)
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to set SPI speed %d Hz: %v", speed, errno))
|
||||
}
|
||||
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
// transfer performs a full-duplex SPI transfer
|
||||
func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
|
||||
confirm, _ := args["confirm"].(bool)
|
||||
if !confirm {
|
||||
return ErrorResult("transfer operations require confirm: true. Please confirm with the user before sending data to SPI devices.")
|
||||
}
|
||||
|
||||
dev, speed, mode, bits, errMsg := parseSPIArgs(args)
|
||||
if errMsg != "" {
|
||||
return ErrorResult(errMsg)
|
||||
}
|
||||
|
||||
dataRaw, ok := args["data"].([]interface{})
|
||||
if !ok || len(dataRaw) == 0 {
|
||||
return ErrorResult("data is required for transfer (array of byte values 0-255)")
|
||||
}
|
||||
if len(dataRaw) > 4096 {
|
||||
return ErrorResult("data too long: maximum 4096 bytes per SPI transfer")
|
||||
}
|
||||
|
||||
txBuf := make([]byte, len(dataRaw))
|
||||
for i, v := range dataRaw {
|
||||
f, ok := v.(float64)
|
||||
if !ok {
|
||||
return ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i))
|
||||
}
|
||||
b := int(f)
|
||||
if b < 0 || b > 255 {
|
||||
return ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b))
|
||||
}
|
||||
txBuf[i] = byte(b)
|
||||
}
|
||||
|
||||
devPath := fmt.Sprintf("/dev/spidev%s", dev)
|
||||
fd, errResult := configureSPI(devPath, mode, bits, speed)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
defer syscall.Close(fd)
|
||||
|
||||
rxBuf := make([]byte, len(txBuf))
|
||||
|
||||
xfer := spiTransfer{
|
||||
txBuf: uint64(uintptr(unsafe.Pointer(&txBuf[0]))),
|
||||
rxBuf: uint64(uintptr(unsafe.Pointer(&rxBuf[0]))),
|
||||
length: uint32(len(txBuf)),
|
||||
speedHz: speed,
|
||||
bitsPerWord: bits,
|
||||
}
|
||||
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocMessage1, uintptr(unsafe.Pointer(&xfer)))
|
||||
runtime.KeepAlive(txBuf)
|
||||
runtime.KeepAlive(rxBuf)
|
||||
if errno != 0 {
|
||||
return ErrorResult(fmt.Sprintf("SPI transfer failed: %v", errno))
|
||||
}
|
||||
|
||||
// Format received bytes
|
||||
hexBytes := make([]string, len(rxBuf))
|
||||
intBytes := make([]int, len(rxBuf))
|
||||
for i, b := range rxBuf {
|
||||
hexBytes[i] = fmt.Sprintf("0x%02x", b)
|
||||
intBytes[i] = int(b)
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]interface{}{
|
||||
"device": devPath,
|
||||
"sent": len(txBuf),
|
||||
"received": intBytes,
|
||||
"hex": hexBytes,
|
||||
}, "", " ")
|
||||
return SilentResult(string(result))
|
||||
}
|
||||
|
||||
// readDevice reads bytes from SPI by sending zeros (read-only, no confirm needed)
|
||||
func (t *SPITool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
dev, speed, mode, bits, errMsg := parseSPIArgs(args)
|
||||
if errMsg != "" {
|
||||
return ErrorResult(errMsg)
|
||||
}
|
||||
|
||||
length := 0
|
||||
if l, ok := args["length"].(float64); ok {
|
||||
length = int(l)
|
||||
}
|
||||
if length < 1 || length > 4096 {
|
||||
return ErrorResult("length is required for read (1-4096)")
|
||||
}
|
||||
|
||||
devPath := fmt.Sprintf("/dev/spidev%s", dev)
|
||||
fd, errResult := configureSPI(devPath, mode, bits, speed)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
defer syscall.Close(fd)
|
||||
|
||||
txBuf := make([]byte, length) // zeros
|
||||
rxBuf := make([]byte, length)
|
||||
|
||||
xfer := spiTransfer{
|
||||
txBuf: uint64(uintptr(unsafe.Pointer(&txBuf[0]))),
|
||||
rxBuf: uint64(uintptr(unsafe.Pointer(&rxBuf[0]))),
|
||||
length: uint32(length),
|
||||
speedHz: speed,
|
||||
bitsPerWord: bits,
|
||||
}
|
||||
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocMessage1, uintptr(unsafe.Pointer(&xfer)))
|
||||
runtime.KeepAlive(txBuf)
|
||||
runtime.KeepAlive(rxBuf)
|
||||
if errno != 0 {
|
||||
return ErrorResult(fmt.Sprintf("SPI read failed: %v", errno))
|
||||
}
|
||||
|
||||
hexBytes := make([]string, len(rxBuf))
|
||||
intBytes := make([]int, len(rxBuf))
|
||||
for i, b := range rxBuf {
|
||||
hexBytes[i] = fmt.Sprintf("0x%02x", b)
|
||||
intBytes[i] = int(b)
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]interface{}{
|
||||
"device": devPath,
|
||||
"bytes": intBytes,
|
||||
"hex": hexBytes,
|
||||
"length": len(rxBuf),
|
||||
}, "", " ")
|
||||
return SilentResult(string(result))
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
//go:build !linux
|
||||
|
||||
package tools
|
||||
|
||||
// transfer is a stub for non-Linux platforms.
|
||||
func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
|
||||
return ErrorResult("SPI is only supported on Linux")
|
||||
}
|
||||
|
||||
// readDevice is a stub for non-Linux platforms.
|
||||
func (t *SPITool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
return ErrorResult("SPI is only supported on Linux")
|
||||
}
|
||||
+193
-68
@@ -13,20 +13,203 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
userAgent = "Mozilla/5.0 (compatible; picoclaw/1.0)"
|
||||
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
type SearchProvider interface {
|
||||
Search(ctx context.Context, query string, count int) (string, error)
|
||||
}
|
||||
|
||||
type BraveSearchProvider struct {
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
|
||||
url.QueryEscape(query), count)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", p.apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Web struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
} `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
// Log error body for debugging
|
||||
fmt.Printf("Brave API Error Body: %s\n", string(body))
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.Web.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Description != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Description))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
type DuckDuckGoSearchProvider struct{}
|
||||
|
||||
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := fmt.Sprintf("https://html.duckduckgo.com/html/?q=%s", url.QueryEscape(query))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
return p.extractResults(string(body), count, query)
|
||||
}
|
||||
|
||||
func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query string) (string, error) {
|
||||
// Simple regex based extraction for DDG HTML
|
||||
// Strategy: Find all result containers or key anchors directly
|
||||
|
||||
// Try finding the result links directly first, as they are the most critical
|
||||
// Pattern: <a class="result__a" href="...">Title</a>
|
||||
// The previous regex was a bit strict. Let's make it more flexible for attributes order/content
|
||||
reLink := regexp.MustCompile(`<a[^>]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)</a>`)
|
||||
matches := reLink.FindAllStringSubmatch(html, count+5)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return fmt.Sprintf("No results found or extraction failed. Query: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s (via DuckDuckGo)", query))
|
||||
|
||||
// Pre-compile snippet regex to run inside the loop
|
||||
// We'll search for snippets relative to the link position or just globally if needed
|
||||
// But simple global search for snippets might mismatch order.
|
||||
// Since we only have the raw HTML string, let's just extract snippets globally and assume order matches (risky but simple for regex)
|
||||
// Or better: Let's assume the snippet follows the link in the HTML
|
||||
|
||||
// A better regex approach: iterate through text and find matches in order
|
||||
// But for now, let's grab all snippets too
|
||||
reSnippet := regexp.MustCompile(`<a class="result__snippet[^"]*".*?>([\s\S]*?)</a>`)
|
||||
snippetMatches := reSnippet.FindAllStringSubmatch(html, count+5)
|
||||
|
||||
maxItems := min(len(matches), count)
|
||||
|
||||
for i := 0; i < maxItems; i++ {
|
||||
urlStr := matches[i][1]
|
||||
title := stripTags(matches[i][2])
|
||||
title = strings.TrimSpace(title)
|
||||
|
||||
// URL decoding if needed
|
||||
if strings.Contains(urlStr, "uddg=") {
|
||||
if u, err := url.QueryUnescape(urlStr); err == nil {
|
||||
idx := strings.Index(u, "uddg=")
|
||||
if idx != -1 {
|
||||
urlStr = u[idx+5:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, title, urlStr))
|
||||
|
||||
// Attempt to attach snippet if available and index aligns
|
||||
if i < len(snippetMatches) {
|
||||
snippet := stripTags(snippetMatches[i][1])
|
||||
snippet = strings.TrimSpace(snippet)
|
||||
if snippet != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", snippet))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
func stripTags(content string) string {
|
||||
re := regexp.MustCompile(`<[^>]+>`)
|
||||
return re.ReplaceAllString(content, "")
|
||||
}
|
||||
|
||||
type WebSearchTool struct {
|
||||
apiKey string
|
||||
provider SearchProvider
|
||||
maxResults int
|
||||
}
|
||||
|
||||
func NewWebSearchTool(apiKey string, maxResults int) *WebSearchTool {
|
||||
if maxResults <= 0 || maxResults > 10 {
|
||||
maxResults = 5
|
||||
type WebSearchToolOptions struct {
|
||||
BraveAPIKey string
|
||||
BraveMaxResults int
|
||||
BraveEnabled bool
|
||||
DuckDuckGoMaxResults int
|
||||
DuckDuckGoEnabled bool
|
||||
}
|
||||
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Brave > DuckDuckGo
|
||||
if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey}
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
}
|
||||
} else if opts.DuckDuckGoEnabled {
|
||||
provider = &DuckDuckGoSearchProvider{}
|
||||
if opts.DuckDuckGoMaxResults > 0 {
|
||||
maxResults = opts.DuckDuckGoMaxResults
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WebSearchTool{
|
||||
apiKey: apiKey,
|
||||
provider: provider,
|
||||
maxResults: maxResults,
|
||||
}
|
||||
}
|
||||
@@ -59,10 +242,6 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
if t.apiKey == "" {
|
||||
return ErrorResult("BRAVE_API_KEY not configured")
|
||||
}
|
||||
|
||||
query, ok := args["query"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("query is required")
|
||||
@@ -75,68 +254,14 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
}
|
||||
}
|
||||
|
||||
searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
|
||||
url.QueryEscape(query), count)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
|
||||
result, err := t.provider.Search(ctx, query, count)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
|
||||
return ErrorResult(fmt.Sprintf("search failed: %v", err))
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", t.apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Web struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
} `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to parse response: %v", err))
|
||||
}
|
||||
|
||||
results := searchResp.Web.Results
|
||||
if len(results) == 0 {
|
||||
msg := fmt.Sprintf("No results for: %s", query)
|
||||
return &ToolResult{
|
||||
ForLLM: msg,
|
||||
ForUser: msg,
|
||||
}
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Description != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Description))
|
||||
}
|
||||
}
|
||||
|
||||
output := strings.Join(lines, "\n")
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Found %d results for: %s", len(results), query),
|
||||
ForUser: output,
|
||||
ForLLM: result,
|
||||
ForUser: result,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+6
-17
@@ -173,30 +173,19 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing
|
||||
// TestWebTool_WebSearch_NoApiKey verifies that nil is returned when no provider is configured
|
||||
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
tool := NewWebSearchTool("", 5)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"query": "test",
|
||||
}
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "", BraveMaxResults: 5})
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when API key is missing")
|
||||
}
|
||||
|
||||
// Should mention missing API key
|
||||
if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") {
|
||||
t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM)
|
||||
// Should return nil when no provider is enabled
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil when no search provider is configured")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
||||
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
tool := NewWebSearchTool("test-key", 5)
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "test-key", BraveMaxResults: 5, BraveEnabled: true})
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user