mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'upstream-main' into feat/subturn-poc
This commit is contained in:
+41
-2
@@ -76,7 +76,8 @@ type processOptions struct {
|
||||
}
|
||||
|
||||
const (
|
||||
defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
|
||||
defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit."
|
||||
toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps."
|
||||
sessionKeyAgentPrefix = "agent:"
|
||||
metadataKeyAccountID = "account_id"
|
||||
metadataKeyGuildID = "guild_id"
|
||||
@@ -1130,7 +1131,11 @@ func (al *AgentLoop) runAgentLoop(
|
||||
|
||||
// 4. Handle empty response
|
||||
if finalContent == "" {
|
||||
finalContent = opts.DefaultResponse
|
||||
if iteration >= agent.MaxIterations && agent.MaxIterations > 0 {
|
||||
finalContent = toolLimitResponse
|
||||
} else {
|
||||
finalContent = opts.DefaultResponse
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Save final assistant message to session
|
||||
@@ -1221,6 +1226,7 @@ func (al *AgentLoop) handleReasoning(
|
||||
}
|
||||
|
||||
// runLLMIteration executes the LLM call loop with tool handling.
|
||||
// Returns (finalContent, iteration, error).
|
||||
func (al *AgentLoop) runLLMIteration(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
@@ -1248,6 +1254,13 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
}
|
||||
|
||||
// Check if both the provider and channel support streaming
|
||||
streamProvider, providerCanStream := agent.Provider.(providers.StreamingProvider)
|
||||
var streamer bus.Streamer
|
||||
if providerCanStream && !opts.NoHistory && !constants.IsInternalChannel(opts.Channel) {
|
||||
streamer, _ = al.bus.GetStreamer(ctx, opts.Channel, opts.ChatID)
|
||||
}
|
||||
|
||||
// Determine effective model tier for this conversation turn.
|
||||
// selectCandidates evaluates routing once and the decision is sticky for
|
||||
// all tool-follow-up iterations within the same turn so that a multi-step
|
||||
@@ -1364,6 +1377,16 @@ func (al *AgentLoop) runLLMIteration(
|
||||
al.activeRequests.Add(1)
|
||||
defer al.activeRequests.Done()
|
||||
|
||||
// Use streaming when available (streamer obtained, provider supports it)
|
||||
if streamer != nil && streamProvider != nil {
|
||||
return streamProvider.ChatStream(
|
||||
ctx, messages, providerToolDefs, activeModel, llmOpts,
|
||||
func(accumulated string) {
|
||||
streamer.Update(ctx, accumulated)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if len(activeCandidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(
|
||||
ctx,
|
||||
@@ -1500,15 +1523,31 @@ func (al *AgentLoop) runLLMIteration(
|
||||
if finalContent == "" && response.ReasoningContent != "" {
|
||||
finalContent = response.ReasoningContent
|
||||
}
|
||||
|
||||
// If we were streaming, finalize the message (sends the permanent message)
|
||||
if streamer != nil {
|
||||
if err := streamer.Finalize(ctx, finalContent); err != nil {
|
||||
logger.WarnCF("agent", "Stream finalize failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"content_chars": len(finalContent),
|
||||
"streamed": streamer != nil,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
// Tool calls detected — cancel any active stream (draft auto-expires)
|
||||
if streamer != nil {
|
||||
streamer.Cancel(ctx)
|
||||
}
|
||||
|
||||
normalizedToolCalls := make([]providers.ToolCall, 0, len(response.ToolCalls))
|
||||
for _, tc := range response.ToolCalls {
|
||||
normalizedToolCalls = append(normalizedToolCalls, providers.NormalizeToolCall(tc))
|
||||
|
||||
@@ -420,6 +420,29 @@ func (m *countingMockProvider) GetDefaultModel() string {
|
||||
return "counting-mock-model"
|
||||
}
|
||||
|
||||
type toolLimitOnlyProvider struct{}
|
||||
|
||||
func (m *toolLimitOnlyProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_tool_limit_test",
|
||||
Type: "function",
|
||||
Name: "tool_limit_test_tool",
|
||||
Arguments: map[string]any{"value": "x"},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *toolLimitOnlyProvider) GetDefaultModel() string {
|
||||
return "tool-limit-only-model"
|
||||
}
|
||||
|
||||
// mockCustomTool is a simple mock tool for registration testing
|
||||
type mockCustomTool struct{}
|
||||
|
||||
@@ -442,6 +465,29 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool
|
||||
return tools.SilentResult("Custom tool executed")
|
||||
}
|
||||
|
||||
type toolLimitTestTool struct{}
|
||||
|
||||
func (m *toolLimitTestTool) Name() string {
|
||||
return "tool_limit_test_tool"
|
||||
}
|
||||
|
||||
func (m *toolLimitTestTool) Description() string {
|
||||
return "Tool used to exhaust the iteration budget in tests"
|
||||
}
|
||||
|
||||
func (m *toolLimitTestTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"value": map[string]any{"type": "string"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *toolLimitTestTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
return tools.SilentResult("tool limit test result")
|
||||
}
|
||||
|
||||
// testHelper executes a message and returns the response
|
||||
type testHelper struct {
|
||||
al *AgentLoop
|
||||
@@ -1083,6 +1129,89 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmptyModelResponseUsesAccurateFallback(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: ""}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "empty-response", "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if response != defaultResponse {
|
||||
t.Fatalf("response = %q, want %q", response, defaultResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &toolLimitOnlyProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(&toolLimitTestTool{})
|
||||
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "tool-limit", "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if response != toolLimitResponse {
|
||||
t.Fatalf("response = %q, want %q", response, toolLimitResponse)
|
||||
}
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: "test",
|
||||
Peer: &routing.RoutePeer{
|
||||
Kind: "direct",
|
||||
ID: "cron",
|
||||
},
|
||||
})
|
||||
history := defaultAgent.Sessions.GetHistory(route.SessionKey)
|
||||
if len(history) != 4 {
|
||||
t.Fatalf("history len = %d, want 4", len(history))
|
||||
}
|
||||
assertRoles(t, history, "user", "assistant", "tool", "assistant")
|
||||
if history[3].Content != toolLimitResponse {
|
||||
t.Fatalf("final assistant content = %q, want %q", history[3].Content, toolLimitResponse)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessDirectWithChannel_TriggersMCPInitialization verifies that
|
||||
// ProcessDirectWithChannel triggers MCP initialization when MCP is enabled.
|
||||
// Note: Manager is only initialized when at least one MCP server is configured
|
||||
|
||||
+34
-4
@@ -14,15 +14,32 @@ var ErrBusClosed = errors.New("message bus closed")
|
||||
|
||||
const defaultBusBufferSize = 64
|
||||
|
||||
// StreamDelegate is implemented by the channel Manager to provide streaming
|
||||
// capabilities to the agent loop without tight coupling.
|
||||
type StreamDelegate interface {
|
||||
// GetStreamer returns a Streamer for the given channel+chatID if the channel
|
||||
// supports streaming. Returns nil, false if streaming is unavailable.
|
||||
GetStreamer(ctx context.Context, channel, chatID string) (Streamer, bool)
|
||||
}
|
||||
|
||||
// Streamer pushes incremental content to a streaming-capable channel.
|
||||
// Defined here so the agent loop can use it without importing pkg/channels.
|
||||
type Streamer interface {
|
||||
Update(ctx context.Context, content string) error
|
||||
Finalize(ctx context.Context, content string) error
|
||||
Cancel(ctx context.Context)
|
||||
}
|
||||
|
||||
type MessageBus struct {
|
||||
inbound chan InboundMessage
|
||||
outbound chan OutboundMessage
|
||||
outboundMedia chan OutboundMediaMessage
|
||||
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
streamDelegate atomic.Value // stores StreamDelegate
|
||||
}
|
||||
|
||||
func NewMessageBus() *MessageBus {
|
||||
@@ -86,6 +103,19 @@ func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
|
||||
return mb.outboundMedia
|
||||
}
|
||||
|
||||
// SetStreamDelegate registers a StreamDelegate (typically the channel Manager).
|
||||
func (mb *MessageBus) SetStreamDelegate(d StreamDelegate) {
|
||||
mb.streamDelegate.Store(d)
|
||||
}
|
||||
|
||||
// GetStreamer returns a Streamer for the given channel+chatID via the delegate.
|
||||
func (mb *MessageBus) GetStreamer(ctx context.Context, channel, chatID string) (Streamer, bool) {
|
||||
if d, ok := mb.streamDelegate.Load().(StreamDelegate); ok && d != nil {
|
||||
return d.GetStreamer(ctx, channel, chatID)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (mb *MessageBus) Close() {
|
||||
mb.closeOnce.Do(func() {
|
||||
// notify all blocked publishers to exit
|
||||
|
||||
@@ -275,14 +275,18 @@ func (c *BaseChannel) HandleMessage(
|
||||
|
||||
// Auto-trigger typing indicator, message reaction, and placeholder before publishing.
|
||||
// Each capability is independent — all three may fire for the same message.
|
||||
// Note: even when streaming is available, we still show typing + placeholder on inbound.
|
||||
// If streaming actually activates, preSend will skip the placeholder edit (streamActive map)
|
||||
// and the typing stop will still be called. This avoids the problem of compile-time interface
|
||||
// checks incorrectly skipping indicators when streaming may not work at runtime.
|
||||
if c.owner != nil && c.placeholderRecorder != nil {
|
||||
// Typing — independent pipeline
|
||||
// Typing
|
||||
if tc, ok := c.owner.(TypingCapable); ok {
|
||||
if stop, err := tc.StartTyping(ctx, chatID); err == nil {
|
||||
c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop)
|
||||
}
|
||||
}
|
||||
// Reaction — independent pipeline
|
||||
// Reaction
|
||||
if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" {
|
||||
if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil {
|
||||
c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo)
|
||||
|
||||
@@ -84,3 +84,64 @@ func stripMentionPlaceholders(content string, mentions []*larkim.MentionEvent) s
|
||||
content = mentionPlaceholderRegex.ReplaceAllString(content, "")
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
// extractCardImageKeys recursively extracts all image keys from a Feishu interactive card.
|
||||
// Image keys are used to download images from Feishu API.
|
||||
// Returns two slices: Feishu-hosted keys and external URLs.
|
||||
func extractCardImageKeys(rawContent string) (feishuKeys []string, externalURLs []string) {
|
||||
if rawContent == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var card map[string]any
|
||||
if err := json.Unmarshal([]byte(rawContent), &card); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
extractImageKeysRecursive(card, &feishuKeys, &externalURLs)
|
||||
return feishuKeys, externalURLs
|
||||
}
|
||||
|
||||
// isExternalURL returns true if the string is an external HTTP/HTTPS URL.
|
||||
func isExternalURL(s string) bool {
|
||||
return strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://")
|
||||
}
|
||||
|
||||
// extractImageKeysRecursive traverses card structure to find all image keys.
|
||||
// Collects both Feishu-hosted keys and external URLs separately.
|
||||
func extractImageKeysRecursive(v any, feishuKeys, externalURLs *[]string) {
|
||||
switch val := v.(type) {
|
||||
case map[string]any:
|
||||
// Check if this is an img element
|
||||
if tag, ok := val["tag"].(string); ok {
|
||||
switch tag {
|
||||
case "img":
|
||||
// Try img_key first (always Feishu-hosted)
|
||||
if imgKey, ok := val["img_key"].(string); ok && imgKey != "" {
|
||||
*feishuKeys = append(*feishuKeys, imgKey)
|
||||
}
|
||||
// Check src - could be Feishu key or external URL
|
||||
if src, ok := val["src"].(string); ok && src != "" {
|
||||
if isExternalURL(src) {
|
||||
*externalURLs = append(*externalURLs, src)
|
||||
} else {
|
||||
*feishuKeys = append(*feishuKeys, src)
|
||||
}
|
||||
}
|
||||
case "icon":
|
||||
// Icon elements use icon_key
|
||||
if iconKey, ok := val["icon_key"].(string); ok && iconKey != "" {
|
||||
*feishuKeys = append(*feishuKeys, iconKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Recurse into all nested structures
|
||||
for _, child := range val {
|
||||
extractImageKeysRecursive(child, feishuKeys, externalURLs)
|
||||
}
|
||||
case []any:
|
||||
for _, item := range val {
|
||||
extractImageKeysRecursive(item, feishuKeys, externalURLs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,3 +290,119 @@ func TestStripMentionPlaceholders(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCardImageKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
wantFeishuKeys []string
|
||||
wantExternalURLs []string
|
||||
}{
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
wantFeishuKeys: nil,
|
||||
wantExternalURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
content: "not json",
|
||||
wantFeishuKeys: nil,
|
||||
wantExternalURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "card with no images",
|
||||
content: `{"schema":"2.0","body":{"elements":[{"tag":"markdown","content":"text"}]}}`,
|
||||
wantFeishuKeys: nil,
|
||||
wantExternalURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "single image with img_key",
|
||||
content: `{"elements":[{"tag":"img","img_key":"img_abc123"}]}`,
|
||||
wantFeishuKeys: []string{"img_abc123"},
|
||||
wantExternalURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "single image with src as Feishu key",
|
||||
content: `{"elements":[{"tag":"img","src":"img_xyz789"}]}`,
|
||||
wantFeishuKeys: []string{"img_xyz789"},
|
||||
wantExternalURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "multiple images",
|
||||
content: `{"elements":[{"tag":"img","img_key":"img_1"},{"tag":"div","text":{"content":"text"}},{"tag":"img","img_key":"img_2"}]}`,
|
||||
wantFeishuKeys: []string{"img_1", "img_2"},
|
||||
wantExternalURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "nested image in columns",
|
||||
content: `{"elements":[{"tag":"div","columns":[{"tag":"img","img_key":"img_col1"},{"tag":"img","img_key":"img_col2"}]}]}`,
|
||||
wantFeishuKeys: []string{"img_col1", "img_col2"},
|
||||
wantExternalURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "image in action",
|
||||
content: `{"elements":[{"tag":"action","actions":[{"tag":"img","img_key":"img_action"}]}]}`,
|
||||
wantFeishuKeys: []string{"img_action"},
|
||||
wantExternalURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "icon element",
|
||||
content: `{"elements":[{"tag":"icon","icon_key":"icon_123"}]}`,
|
||||
wantFeishuKeys: []string{"icon_123"},
|
||||
wantExternalURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "complex card with text and images",
|
||||
content: `{"header":{"title":{"content":"Title"}},"elements":[{"tag":"div","text":{"content":"Description"}},{"tag":"img","img_key":"img_main"}]}`,
|
||||
wantFeishuKeys: []string{"img_main"},
|
||||
wantExternalURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "external URL in src",
|
||||
content: `{"elements":[{"tag":"img","src":"https://example.com/image.png"}]}`,
|
||||
wantFeishuKeys: nil,
|
||||
wantExternalURLs: []string{"https://example.com/image.png"},
|
||||
},
|
||||
{
|
||||
name: "mixed Feishu keys and external URLs",
|
||||
content: `{"elements":[{"tag":"img","img_key":"img_feishu"},{"tag":"img","src":"https://cdn.example.com/external.jpg"},{"tag":"img","src":"img_another"}]}`,
|
||||
wantFeishuKeys: []string{"img_feishu", "img_another"},
|
||||
wantExternalURLs: []string{"https://cdn.example.com/external.jpg"},
|
||||
},
|
||||
{
|
||||
name: "multiple external URLs",
|
||||
content: `{"elements":[{"tag":"img","src":"https://a.com/1.png"},{"tag":"img","src":"http://b.com/2.jpg"}]}`,
|
||||
wantFeishuKeys: nil,
|
||||
wantExternalURLs: []string{"https://a.com/1.png", "http://b.com/2.jpg"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotFeishuKeys, gotExternalURLs := extractCardImageKeys(tt.content)
|
||||
|
||||
// Compare Feishu keys
|
||||
if len(gotFeishuKeys) != len(tt.wantFeishuKeys) {
|
||||
t.Errorf("extractCardImageKeys() feishuKeys = %v, want %v", gotFeishuKeys, tt.wantFeishuKeys)
|
||||
return
|
||||
}
|
||||
for i, v := range gotFeishuKeys {
|
||||
if v != tt.wantFeishuKeys[i] {
|
||||
t.Errorf("extractCardImageKeys() feishuKeys[%d] = %q, want %q", i, v, tt.wantFeishuKeys[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Compare external URLs
|
||||
if len(gotExternalURLs) != len(tt.wantExternalURLs) {
|
||||
t.Errorf("extractCardImageKeys() externalURLs = %v, want %v", gotExternalURLs, tt.wantExternalURLs)
|
||||
return
|
||||
}
|
||||
for i, v := range gotExternalURLs {
|
||||
if v != tt.wantExternalURLs[i] {
|
||||
t.Errorf("extractCardImageKeys() externalURLs[%d] = %q, want %q", i, v, tt.wantExternalURLs[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -424,6 +424,15 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
mediaRefs = c.downloadInboundMedia(ctx, chatID, messageID, messageType, rawContent, store)
|
||||
}
|
||||
|
||||
// For interactive cards, pass external image URLs via media refs.
|
||||
// Keep content as valid raw JSON for downstream parsing.
|
||||
if messageType == larkim.MsgTypeInteractive {
|
||||
_, externalURLs := extractCardImageKeys(rawContent)
|
||||
if len(externalURLs) > 0 {
|
||||
mediaRefs = append(mediaRefs, externalURLs...)
|
||||
}
|
||||
}
|
||||
|
||||
// Append media tags to content (like Telegram does)
|
||||
content = appendMediaTags(content, messageType, mediaRefs)
|
||||
|
||||
@@ -559,6 +568,10 @@ func extractContent(messageType, rawContent string) string {
|
||||
// Pass raw JSON to LLM — structured rich text is more informative than flattened plain text
|
||||
return rawContent
|
||||
|
||||
case larkim.MsgTypeInteractive:
|
||||
// Pass raw JSON to LLM — structured card is more informative than flattened text
|
||||
return rawContent
|
||||
|
||||
case larkim.MsgTypeImage:
|
||||
// Image messages don't have text content
|
||||
return ""
|
||||
@@ -596,6 +609,18 @@ func (c *FeishuChannel) downloadInboundMedia(
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
|
||||
case larkim.MsgTypeInteractive:
|
||||
// Extract and download images embedded in interactive cards
|
||||
feishuKeys, _ := extractCardImageKeys(rawContent)
|
||||
// Download Feishu-hosted images via API
|
||||
for _, imageKey := range feishuKeys {
|
||||
ref := c.downloadResource(ctx, messageID, imageKey, "image", ".jpg", store, scope)
|
||||
if ref != "" {
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
}
|
||||
// External URLs are passed directly to LLM, not downloaded
|
||||
|
||||
case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia:
|
||||
fileKey := extractFileKey(rawContent)
|
||||
if fileKey == "" {
|
||||
@@ -716,11 +741,18 @@ func (c *FeishuChannel) downloadResource(
|
||||
}
|
||||
|
||||
// appendMediaTags appends media type tags to content (like Telegram's "[image: photo]").
|
||||
// For interactive cards, media tags are not appended because content is raw JSON
|
||||
// and appending would produce invalid JSON format.
|
||||
func appendMediaTags(content, messageType string, mediaRefs []string) string {
|
||||
if len(mediaRefs) == 0 {
|
||||
return content
|
||||
}
|
||||
|
||||
// Don't append tags to JSON content (interactive cards) - would produce invalid JSON
|
||||
if messageType == larkim.MsgTypeInteractive {
|
||||
return content
|
||||
}
|
||||
|
||||
var tag string
|
||||
switch messageType {
|
||||
case larkim.MsgTypeImage:
|
||||
|
||||
@@ -75,6 +75,24 @@ func TestExtractContent(t *testing.T) {
|
||||
rawContent: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "interactive card returns raw JSON",
|
||||
messageType: "interactive",
|
||||
rawContent: `{"schema":"2.0","body":{"elements":[{"tag":"markdown","content":"Hello from card"}]}}`,
|
||||
want: `{"schema":"2.0","body":{"elements":[{"tag":"markdown","content":"Hello from card"}]}}`,
|
||||
},
|
||||
{
|
||||
name: "interactive card with complex structure returns raw JSON",
|
||||
messageType: "interactive",
|
||||
rawContent: `{"header":{"title":{"tag":"plain_text","content":"Title"}},"elements":[{"tag":"div","text":{"tag":"lark_md","content":"Card content"}}]}`,
|
||||
want: `{"header":{"title":{"tag":"plain_text","content":"Title"}},"elements":[{"tag":"div","text":{"tag":"lark_md","content":"Card content"}}]}`,
|
||||
},
|
||||
{
|
||||
name: "interactive card invalid JSON returns as-is",
|
||||
messageType: "interactive",
|
||||
rawContent: `not valid json`,
|
||||
want: `not valid json`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -151,6 +169,13 @@ func TestAppendMediaTags(t *testing.T) {
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "something [attachment]",
|
||||
},
|
||||
{
|
||||
name: "interactive card with images returns content unchanged",
|
||||
content: `{"schema":"2.0","body":{"elements":[{"tag":"img","img_key":"img_123"}]}}`,
|
||||
messageType: "interactive",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: `{"schema":"2.0","body":{"elements":[{"tag":"img","img_key":"img_123"}]}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -3,6 +3,7 @@ package channels
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
)
|
||||
|
||||
@@ -19,6 +20,11 @@ type MessageEditor interface {
|
||||
EditMessage(ctx context.Context, chatID string, messageID string, content string) error
|
||||
}
|
||||
|
||||
// MessageDeleter — channels that can delete a message by ID.
|
||||
type MessageDeleter interface {
|
||||
DeleteMessage(ctx context.Context, chatID string, messageID string) error
|
||||
}
|
||||
|
||||
// ReactionCapable — channels that can add a reaction (e.g. 👀) to an inbound message.
|
||||
// ReactToMessage adds a reaction and returns an undo function to remove it.
|
||||
// The undo function MUST be idempotent and safe to call multiple times.
|
||||
@@ -35,6 +41,18 @@ type PlaceholderCapable interface {
|
||||
SendPlaceholder(ctx context.Context, chatID string) (messageID string, err error)
|
||||
}
|
||||
|
||||
// StreamingCapable — channels that can show partial LLM output in real-time.
|
||||
// The channel SHOULD gracefully degrade if the platform rejects streaming
|
||||
// (e.g. Telegram bot without forum mode). In that case, Update becomes a no-op
|
||||
// and Finalize still delivers the final message.
|
||||
type StreamingCapable interface {
|
||||
BeginStream(ctx context.Context, chatID string) (Streamer, error)
|
||||
}
|
||||
|
||||
// Streamer is defined in pkg/bus to avoid circular imports.
|
||||
// This alias keeps channel implementations using channels.Streamer unchanged.
|
||||
type Streamer = bus.Streamer
|
||||
|
||||
// PlaceholderRecorder is injected into channels by Manager.
|
||||
// Channels call these methods on inbound to register typing/placeholder state.
|
||||
// Manager uses the registered state on outbound to stop typing and edit placeholders.
|
||||
|
||||
+72
-2
@@ -89,6 +89,7 @@ type Manager struct {
|
||||
placeholders sync.Map // "channel:chatID" → placeholderID (string)
|
||||
typingStops sync.Map // "channel:chatID" → func()
|
||||
reactionUndos sync.Map // "channel:chatID" → reactionEntry
|
||||
streamActive sync.Map // "channel:chatID" → true (set when streamer.Finalize sent the message)
|
||||
channelHashes map[string]string // channel name → config hash
|
||||
}
|
||||
|
||||
@@ -157,7 +158,7 @@ func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) {
|
||||
}
|
||||
|
||||
// preSend handles typing stop, reaction undo, and placeholder editing before sending a message.
|
||||
// Returns true if the message was edited into a placeholder (skip Send).
|
||||
// Returns true if the message was already delivered (skip Send).
|
||||
func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) bool {
|
||||
key := name + ":" + msg.ChatID
|
||||
|
||||
@@ -175,7 +176,22 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Try editing placeholder
|
||||
// 3. If a stream already finalized this message, delete the placeholder and skip send
|
||||
if _, loaded := m.streamActive.LoadAndDelete(key); loaded {
|
||||
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
// Prefer deleting the placeholder (cleaner UX than editing to same content)
|
||||
if deleter, ok := ch.(MessageDeleter); ok {
|
||||
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
|
||||
} else if editor, ok := ch.(MessageEditor); ok {
|
||||
editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content) // fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 4. Try editing placeholder
|
||||
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
if editor, ok := ch.(MessageEditor); ok {
|
||||
@@ -200,6 +216,9 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.Medi
|
||||
channelHashes: make(map[string]string),
|
||||
}
|
||||
|
||||
// Register as streaming delegate so the agent loop can obtain streamers
|
||||
messageBus.SetStreamDelegate(m)
|
||||
|
||||
if err := m.initChannels(&cfg.Channels); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -210,6 +229,53 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.Medi
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// GetStreamer implements bus.StreamDelegate.
|
||||
// It checks if the named channel supports streaming and returns a Streamer.
|
||||
func (m *Manager) GetStreamer(ctx context.Context, channelName, chatID string) (bus.Streamer, bool) {
|
||||
m.mu.RLock()
|
||||
ch, exists := m.channels[channelName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
sc, ok := ch.(StreamingCapable)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
streamer, err := sc.BeginStream(ctx, chatID)
|
||||
if err != nil {
|
||||
logger.DebugCF("channels", "Streaming unavailable, falling back to placeholder", map[string]any{
|
||||
"channel": channelName,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Mark streamActive on Finalize so preSend knows to clean up the placeholder
|
||||
key := channelName + ":" + chatID
|
||||
return &finalizeHookStreamer{
|
||||
Streamer: streamer,
|
||||
onFinalize: func() { m.streamActive.Store(key, true) },
|
||||
}, true
|
||||
}
|
||||
|
||||
// finalizeHookStreamer wraps a Streamer to run a hook on Finalize.
|
||||
type finalizeHookStreamer struct {
|
||||
Streamer
|
||||
onFinalize func()
|
||||
}
|
||||
|
||||
func (s *finalizeHookStreamer) Finalize(ctx context.Context, content string) error {
|
||||
if err := s.Streamer.Finalize(ctx, content); err != nil {
|
||||
return err
|
||||
}
|
||||
s.onFinalize()
|
||||
return nil
|
||||
}
|
||||
|
||||
// initChannel is a helper that looks up a factory by name and creates the channel.
|
||||
func (m *Manager) initChannel(name, displayName string) {
|
||||
f, ok := getFactory(name)
|
||||
@@ -323,6 +389,10 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
m.initChannel("pico", "Pico")
|
||||
}
|
||||
|
||||
if channels.PicoClient.Enabled && channels.PicoClient.URL != "" {
|
||||
m.initChannel("pico_client", "Pico Client")
|
||||
}
|
||||
|
||||
if channels.IRC.Enabled && channels.IRC.Server != "" {
|
||||
m.initChannel("irc", "IRC")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,319 @@
|
||||
package pico
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// PicoClientChannel connects to a remote Pico Protocol WebSocket server.
|
||||
type PicoClientChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.PicoClientConfig
|
||||
conn *picoConn
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewPicoClientChannel creates a new Pico Protocol client channel.
|
||||
func NewPicoClientChannel(
|
||||
cfg config.PicoClientConfig,
|
||||
messageBus *bus.MessageBus,
|
||||
) (*PicoClientChannel, error) {
|
||||
if cfg.URL == "" {
|
||||
return nil, fmt.Errorf("pico_client url is required")
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel("pico_client", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
return &PicoClientChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start dials the remote server and begins reading.
|
||||
func (c *PicoClientChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("pico_client", "Starting Pico Client channel")
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
if err := c.dial(); err != nil {
|
||||
c.cancel()
|
||||
return fmt.Errorf("pico_client initial connect: %w", err)
|
||||
}
|
||||
|
||||
c.SetRunning(true)
|
||||
go c.reconnectLoop()
|
||||
|
||||
logger.InfoCF("pico_client", "Connected", map[string]any{"url": c.config.URL})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the connection.
|
||||
func (c *PicoClientChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("pico_client", "Stopping Pico Client channel")
|
||||
c.SetRunning(false)
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
logger.InfoC("pico_client", "Pico Client channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PicoClientChannel) dial() error {
|
||||
header := http.Header{}
|
||||
if c.config.Token != "" {
|
||||
header.Set("Authorization", "Bearer "+c.config.Token)
|
||||
}
|
||||
|
||||
ws, resp, err := websocket.DefaultDialer.DialContext(c.ctx, c.config.URL, header)
|
||||
if resp != nil && resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connCtx, connCancel := context.WithCancel(c.ctx)
|
||||
|
||||
pc := &picoConn{
|
||||
id: uuid.New().String(),
|
||||
conn: ws,
|
||||
sessionID: c.config.SessionID,
|
||||
cancel: connCancel,
|
||||
}
|
||||
if pc.sessionID == "" {
|
||||
pc.sessionID = uuid.New().String()
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = pc
|
||||
c.mu.Unlock()
|
||||
|
||||
go c.readLoop(connCtx, pc)
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconnectLoop re-dials when the connection drops.
|
||||
func (c *PicoClientChannel) reconnectLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
pc := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if pc == nil || pc.closed.Load() {
|
||||
backoff := 5 * time.Second
|
||||
logger.InfoC("pico_client", "Reconnecting...")
|
||||
if err := c.dial(); err != nil {
|
||||
logger.WarnCF("pico_client", "Reconnect failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
continue
|
||||
}
|
||||
logger.InfoC("pico_client", "Reconnected")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PicoClientChannel) readLoop(connCtx context.Context, pc *picoConn) {
|
||||
defer pc.close()
|
||||
|
||||
readTimeout := time.Duration(c.config.ReadTimeout) * time.Second
|
||||
if readTimeout <= 0 {
|
||||
readTimeout = 60 * time.Second
|
||||
}
|
||||
|
||||
_ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
pc.conn.SetPongHandler(func(string) error {
|
||||
return pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
})
|
||||
|
||||
pingInterval := time.Duration(c.config.PingInterval) * time.Second
|
||||
if pingInterval <= 0 {
|
||||
pingInterval = 30 * time.Second
|
||||
}
|
||||
go c.pingLoop(connCtx, pc, pingInterval)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
_, raw, err := pc.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(
|
||||
err,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNormalClosure,
|
||||
) {
|
||||
logger.DebugCF("pico_client", "Read error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
_ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
|
||||
var msg PicoMessage
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
c.handleInbound(pc, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PicoClientChannel) pingLoop(connCtx context.Context, pc *picoConn, interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-connCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if pc.closed.Load() {
|
||||
return
|
||||
}
|
||||
pc.writeMu.Lock()
|
||||
err := pc.conn.WriteMessage(websocket.PingMessage, nil)
|
||||
pc.writeMu.Unlock()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleInbound processes messages from the remote server.
|
||||
// In client mode the server sends message.create (responses) and the client
|
||||
// sends message.send (user input). We treat message.create from the server
|
||||
// as inbound user messages to feed into the agent loop.
|
||||
func (c *PicoClientChannel) handleInbound(pc *picoConn, msg PicoMessage) {
|
||||
switch msg.Type {
|
||||
case TypePong:
|
||||
// response to our ping, ignore
|
||||
case TypeMessageCreate:
|
||||
// Server sent us a message — treat as inbound
|
||||
c.handleServerMessage(pc, msg)
|
||||
default:
|
||||
logger.DebugCF("pico_client", "Ignoring message type", map[string]any{
|
||||
"type": msg.Type,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
|
||||
content, _ := msg.Payload["content"].(string)
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := msg.SessionID
|
||||
if sessionID == "" {
|
||||
sessionID = pc.sessionID
|
||||
}
|
||||
|
||||
chatID := "pico_client:" + sessionID
|
||||
senderID := "pico-remote"
|
||||
peer := bus.Peer{Kind: "direct", ID: chatID}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "pico_client",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("pico_client", senderID),
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(sender) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, map[string]string{
|
||||
"platform": "pico_client",
|
||||
"session_id": sessionID,
|
||||
}, sender)
|
||||
}
|
||||
|
||||
// Send sends a message to the remote server.
|
||||
func (c *PicoClientChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
c.mu.Lock()
|
||||
pc := c.conn
|
||||
c.mu.Unlock()
|
||||
if pc == nil || pc.closed.Load() {
|
||||
return channels.ErrSendFailed
|
||||
}
|
||||
|
||||
outMsg := newMessage(TypeMessageSend, map[string]any{
|
||||
"content": msg.Content,
|
||||
})
|
||||
outMsg.SessionID = strings.TrimPrefix(msg.ChatID, "pico_client:")
|
||||
return pc.writeJSON(outMsg)
|
||||
}
|
||||
|
||||
// StartTyping implements channels.TypingCapable.
|
||||
func (c *PicoClientChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
|
||||
c.mu.Lock()
|
||||
pc := c.conn
|
||||
c.mu.Unlock()
|
||||
if pc == nil || pc.closed.Load() {
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
startMsg := newMessage(TypeTypingStart, nil)
|
||||
startMsg.SessionID = strings.TrimPrefix(chatID, "pico_client:")
|
||||
if err := pc.writeJSON(startMsg); err != nil {
|
||||
return func() {}, err
|
||||
}
|
||||
return func() {
|
||||
c.mu.Lock()
|
||||
currentPC := c.conn
|
||||
c.mu.Unlock()
|
||||
if currentPC == nil {
|
||||
return
|
||||
}
|
||||
stopMsg := newMessage(TypeTypingStop, nil)
|
||||
stopMsg.SessionID = strings.TrimPrefix(chatID, "pico_client:")
|
||||
currentPC.writeJSON(stopMsg)
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,264 @@
|
||||
package pico
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestNewPicoClientChannel_MissingURL(t *testing.T) {
|
||||
_, err := NewPicoClientChannel(config.PicoClientConfig{}, bus.NewMessageBus())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing URL")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "url is required") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPicoClientChannel_OK(t *testing.T) {
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
URL: "ws://localhost:9999/ws",
|
||||
}, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ch.Name() != "pico_client" {
|
||||
t.Fatalf("name = %q, want pico_client", ch.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_NotRunning(t *testing.T) {
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
URL: "ws://localhost:9999/ws",
|
||||
}, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = ch.Send(context.Background(), bus.OutboundMessage{Content: "hi"})
|
||||
if !errors.Is(err, channels.ErrNotRunning) {
|
||||
t.Fatalf("expected ErrNotRunning, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// testServer starts a WS server that echoes message.send back as message.create.
|
||||
func testServer(t *testing.T, token string) *httptest.Server {
|
||||
t.Helper()
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if token != "" {
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth != "Bearer "+token {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Logf("upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
_, raw, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var msg PicoMessage
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.Type == TypeMessageSend {
|
||||
reply := newMessage(TypeMessageCreate, msg.Payload)
|
||||
reply.SessionID = msg.SessionID
|
||||
if err := conn.WriteJSON(reply); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func wsURL(httpURL string) string {
|
||||
return "ws" + strings.TrimPrefix(httpURL, "http")
|
||||
}
|
||||
|
||||
func TestClientChannel_ConnectAndSend(t *testing.T) {
|
||||
srv := testServer(t, "test-token")
|
||||
defer srv.Close()
|
||||
|
||||
mb := bus.NewMessageBus()
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
URL: wsURL(srv.URL),
|
||||
Token: "test-token",
|
||||
SessionID: "sess-1",
|
||||
PingInterval: 60,
|
||||
ReadTimeout: 10,
|
||||
}, mb)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err = ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
defer ch.Stop(ctx)
|
||||
|
||||
// Send a message
|
||||
err = ch.Send(ctx, bus.OutboundMessage{
|
||||
ChatID: "pico_client:sess-1",
|
||||
Content: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientChannel_AuthFailure(t *testing.T) {
|
||||
srv := testServer(t, "correct-token")
|
||||
defer srv.Close()
|
||||
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
URL: wsURL(srv.URL),
|
||||
Token: "wrong-token",
|
||||
}, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = ch.Start(ctx)
|
||||
if err == nil {
|
||||
ch.Stop(ctx)
|
||||
t.Fatal("expected auth failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientChannel_ReceivesServerMessage(t *testing.T) {
|
||||
srv := testServer(t, "")
|
||||
defer srv.Close()
|
||||
|
||||
mb := bus.NewMessageBus()
|
||||
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
URL: wsURL(srv.URL),
|
||||
SessionID: "sess-echo",
|
||||
ReadTimeout: 10,
|
||||
}, mb)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err = ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
defer ch.Stop(ctx)
|
||||
|
||||
// Send a message; the echo server replies with message.create
|
||||
err = ch.Send(ctx, bus.OutboundMessage{
|
||||
ChatID: "pico_client:sess-echo",
|
||||
Content: "ping",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
|
||||
// The echoed message.create is processed by handleServerMessage which
|
||||
// calls HandleMessage → PublishInbound. Consume it from the bus.
|
||||
select {
|
||||
case msg := <-mb.InboundChan():
|
||||
if msg.Content != "ping" {
|
||||
t.Fatalf("received = %q, want %q", msg.Content, "ping")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for echoed message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientChannel_StartTyping(t *testing.T) {
|
||||
srv := testServer(t, "")
|
||||
defer srv.Close()
|
||||
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
URL: wsURL(srv.URL),
|
||||
SessionID: "sess-type",
|
||||
ReadTimeout: 10,
|
||||
}, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err = ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
defer ch.Stop(ctx)
|
||||
|
||||
stop, err := ch.StartTyping(ctx, "pico_client:sess-type")
|
||||
if err != nil {
|
||||
t.Fatalf("StartTyping: %v", err)
|
||||
}
|
||||
stop() // should not panic
|
||||
}
|
||||
|
||||
func TestSend_ClosedConnection(t *testing.T) {
|
||||
srv := testServer(t, "")
|
||||
defer srv.Close()
|
||||
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
URL: wsURL(srv.URL),
|
||||
SessionID: "sess-close",
|
||||
ReadTimeout: 10,
|
||||
}, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err = ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Start: %v", err)
|
||||
}
|
||||
|
||||
// Force close the underlying connection
|
||||
ch.mu.Lock()
|
||||
ch.conn.close()
|
||||
ch.mu.Unlock()
|
||||
|
||||
err = ch.Send(ctx, bus.OutboundMessage{
|
||||
ChatID: "pico_client:sess-close",
|
||||
Content: "should fail",
|
||||
})
|
||||
if !errors.Is(err, channels.ErrSendFailed) {
|
||||
t.Fatalf("expected ErrSendFailed, got %v", err)
|
||||
}
|
||||
|
||||
ch.Stop(ctx)
|
||||
}
|
||||
@@ -10,4 +10,7 @@ func init() {
|
||||
channels.RegisterFactory("pico", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewPicoChannel(cfg.Channels.Pico, b)
|
||||
})
|
||||
channels.RegisterFactory("pico_client", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewPicoClientChannel(cfg.Channels.PicoClient, b)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ type picoConn struct {
|
||||
sessionID string
|
||||
writeMu sync.Mutex
|
||||
closed atomic.Bool
|
||||
cancel context.CancelFunc // cancels per-connection goroutines (e.g. pingLoop)
|
||||
}
|
||||
|
||||
// writeJSON sends a JSON message to the connection with write locking.
|
||||
@@ -42,6 +43,9 @@ func (pc *picoConn) writeJSON(v any) error {
|
||||
// close closes the connection.
|
||||
func (pc *picoConn) close() {
|
||||
if pc.closed.CompareAndSwap(false, true) {
|
||||
if pc.cancel != nil {
|
||||
pc.cancel()
|
||||
}
|
||||
pc.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -10,6 +12,7 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
@@ -374,6 +377,22 @@ func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messag
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMessage implements channels.MessageDeleter.
|
||||
func (c *TelegramChannel) DeleteMessage(ctx context.Context, chatID string, messageID string) error {
|
||||
cid, _, err := parseTelegramChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mid, err := strconv.Atoi(messageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.bot.DeleteMessage(ctx, &telego.DeleteMessageParams{
|
||||
ChatID: tu.ID(cid),
|
||||
MessageID: mid,
|
||||
})
|
||||
}
|
||||
|
||||
// SendPlaceholder implements channels.PlaceholderCapable.
|
||||
// It sends a placeholder message (e.g. "Thinking... 💭") that will later be
|
||||
// edited to the actual response via EditMessage (channels.MessageEditor).
|
||||
@@ -847,3 +866,107 @@ func (c *TelegramChannel) stripBotMention(content string) string {
|
||||
content = re.ReplaceAllString(content, "")
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
// BeginStream implements channels.StreamingCapable.
|
||||
func (c *TelegramChannel) BeginStream(ctx context.Context, chatID string) (channels.Streamer, error) {
|
||||
if !c.config.Channels.Telegram.Streaming.Enabled {
|
||||
return nil, fmt.Errorf("streaming disabled in config")
|
||||
}
|
||||
|
||||
cid, _, err := parseTelegramChatID(chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
streamCfg := c.config.Channels.Telegram.Streaming
|
||||
return &telegramStreamer{
|
||||
bot: c.bot,
|
||||
chatID: cid,
|
||||
draftID: cryptoRandInt(),
|
||||
throttleInterval: time.Duration(streamCfg.ThrottleSeconds) * time.Second,
|
||||
minGrowth: streamCfg.MinGrowthChars,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// telegramStreamer streams partial LLM output via Telegram's sendMessageDraft API.
|
||||
// On first API error (e.g. bot lacks forum mode), it silently degrades: Update
|
||||
// becomes a no-op, while Finalize still delivers the final message.
|
||||
type telegramStreamer struct {
|
||||
bot *telego.Bot
|
||||
chatID int64
|
||||
draftID int
|
||||
throttleInterval time.Duration
|
||||
minGrowth int
|
||||
lastLen int
|
||||
lastAt time.Time
|
||||
failed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *telegramStreamer) Update(ctx context.Context, content string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.failed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Throttle: skip if not enough time or content has passed
|
||||
now := time.Now()
|
||||
growth := len(content) - s.lastLen
|
||||
if s.lastLen > 0 && now.Sub(s.lastAt) < s.throttleInterval && growth < s.minGrowth {
|
||||
return nil
|
||||
}
|
||||
|
||||
htmlContent := markdownToTelegramHTML(content)
|
||||
|
||||
err := s.bot.SendMessageDraft(ctx, &telego.SendMessageDraftParams{
|
||||
ChatID: s.chatID,
|
||||
DraftID: s.draftID,
|
||||
Text: htmlContent,
|
||||
ParseMode: telego.ModeHTML,
|
||||
})
|
||||
if err != nil {
|
||||
// First error → degrade silently (e.g. no forum mode)
|
||||
logger.WarnCF("telegram", "sendMessageDraft failed, disabling streaming", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
s.failed = true
|
||||
return nil // don't propagate — Finalize will still deliver
|
||||
}
|
||||
|
||||
s.lastLen = len(content)
|
||||
s.lastAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *telegramStreamer) Finalize(ctx context.Context, content string) error {
|
||||
htmlContent := markdownToTelegramHTML(content)
|
||||
tgMsg := tu.Message(tu.ID(s.chatID), htmlContent)
|
||||
tgMsg.ParseMode = telego.ModeHTML
|
||||
|
||||
if _, err := s.bot.SendMessage(ctx, tgMsg); err != nil {
|
||||
// Fallback to plain text
|
||||
tgMsg.ParseMode = ""
|
||||
if _, err = s.bot.SendMessage(ctx, tgMsg); err != nil {
|
||||
logger.ErrorCF("telegram", "Finalize failed after HTML and plain-text attempts", map[string]any{
|
||||
"chat_id": s.chatID,
|
||||
"error": err.Error(),
|
||||
"len": len(content),
|
||||
})
|
||||
return fmt.Errorf("telegram finalize: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *telegramStreamer) Cancel(ctx context.Context) {
|
||||
// Draft auto-expires on Telegram's side; nothing to clean up.
|
||||
}
|
||||
|
||||
// cryptoRandInt returns a non-zero random int using crypto/rand.
|
||||
func cryptoRandInt() int {
|
||||
var b [4]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
return int(binary.BigEndian.Uint32(b[:])) | 1 // ensure non-zero
|
||||
}
|
||||
|
||||
@@ -305,6 +305,7 @@ type ChannelsConfig struct {
|
||||
WeComApp WeComAppConfig `json:"wecom_app"`
|
||||
WeComAIBot WeComAIBotConfig `json:"wecom_aibot"`
|
||||
Pico PicoConfig `json:"pico"`
|
||||
PicoClient PicoClientConfig `json:"pico_client"`
|
||||
IRC IRCConfig `json:"irc"`
|
||||
}
|
||||
|
||||
@@ -325,6 +326,12 @@ type PlaceholderConfig struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type StreamingConfig struct {
|
||||
Enabled bool `json:"enabled,omitempty" env:"PICOCLAW_CHANNELS_TELEGRAM_STREAMING_ENABLED"`
|
||||
ThrottleSeconds int `json:"throttle_seconds,omitempty" env:"PICOCLAW_CHANNELS_TELEGRAM_STREAMING_THROTTLE_SECONDS"`
|
||||
MinGrowthChars int `json:"min_growth_chars,omitempty" env:"PICOCLAW_CHANNELS_TELEGRAM_STREAMING_MIN_GROWTH_CHARS"`
|
||||
}
|
||||
|
||||
type WhatsAppConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"`
|
||||
BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"`
|
||||
@@ -343,6 +350,7 @@ type TelegramConfig struct {
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
Typing TypingConfig `json:"typing,omitempty"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
Streaming StreamingConfig `json:"streaming,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_TELEGRAM_REASONING_CHANNEL_ID"`
|
||||
UseMarkdownV2 bool `json:"use_markdown_v2" env:"PICOCLAW_CHANNELS_TELEGRAM_USE_MARKDOWN_V2"`
|
||||
}
|
||||
@@ -512,6 +520,16 @@ type PicoConfig struct {
|
||||
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
}
|
||||
|
||||
type PicoClientConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_CLIENT_ENABLED"`
|
||||
URL string `json:"url" env:"PICOCLAW_CHANNELS_PICO_CLIENT_URL"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_CLIENT_TOKEN"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
PingInterval int `json:"ping_interval,omitempty"`
|
||||
ReadTimeout int `json:"read_timeout,omitempty"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_CLIENT_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type IRCConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_IRC_ENABLED"`
|
||||
Server string `json:"server" env:"PICOCLAW_CHANNELS_IRC_SERVER"`
|
||||
|
||||
@@ -63,6 +63,7 @@ func DefaultConfig() *Config {
|
||||
Enabled: true,
|
||||
Text: "Thinking... 💭",
|
||||
},
|
||||
Streaming: StreamingConfig{Enabled: true, ThrottleSeconds: 3, MinGrowthChars: 200},
|
||||
UseMarkdownV2: false,
|
||||
},
|
||||
Feishu: FeishuConfig{
|
||||
|
||||
@@ -52,6 +52,19 @@ func (p *HTTPProvider) Chat(
|
||||
return p.delegate.Chat(ctx, messages, tools, model, options)
|
||||
}
|
||||
|
||||
// ChatStream implements providers.StreamingProvider by delegating to the
|
||||
// OpenAI-compatible streaming endpoint (SSE with stream: true).
|
||||
func (p *HTTPProvider) ChatStream(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
onChunk func(accumulated string),
|
||||
) (*LLMResponse, error) {
|
||||
return p.delegate.ChatStream(ctx, messages, tools, model, options, onChunk)
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package openai_compat
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -85,17 +88,10 @@ func NewProviderWithMaxTokensFieldAndTimeout(
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
// buildRequestBody constructs the common request body for Chat and ChatStream.
|
||||
func (p *Provider) buildRequestBody(
|
||||
messages []Message, tools []ToolDefinition, model string, options map[string]any,
|
||||
) map[string]any {
|
||||
model = normalizeModel(model, p.apiBase)
|
||||
|
||||
requestBody := map[string]any{
|
||||
@@ -112,10 +108,8 @@ func (p *Provider) Chat(
|
||||
}
|
||||
|
||||
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
|
||||
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
|
||||
fieldName := p.maxTokensField
|
||||
if fieldName == "" {
|
||||
// Fallback: detect from model name for backward compatibility
|
||||
lowerModel := strings.ToLower(model)
|
||||
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") ||
|
||||
strings.Contains(lowerModel, "gpt-5") {
|
||||
@@ -129,7 +123,6 @@ func (p *Provider) Chat(
|
||||
|
||||
if temperature, ok := common.AsFloat(options["temperature"]); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
// Kimi k2 models only support temperature=1.
|
||||
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
|
||||
requestBody["temperature"] = 1.0
|
||||
} else {
|
||||
@@ -139,17 +132,30 @@ func (p *Provider) Chat(
|
||||
|
||||
// Prompt caching: pass a stable cache key so OpenAI can bucket requests
|
||||
// with the same key and reuse prefix KV cache across calls.
|
||||
// The key is typically the agent ID — stable per agent, shared across requests.
|
||||
// See: https://platform.openai.com/docs/guides/prompt-caching
|
||||
// Prompt caching is only supported by OpenAI-native endpoints.
|
||||
// Non-OpenAI providers (Mistral, Gemini, DeepSeek, etc.) reject unknown
|
||||
// fields with 422 errors, so only include it for OpenAI APIs.
|
||||
// Non-OpenAI providers reject unknown fields with 422 errors.
|
||||
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
|
||||
if supportsPromptCacheKey(p.apiBase) {
|
||||
requestBody["prompt_cache_key"] = cacheKey
|
||||
}
|
||||
}
|
||||
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
requestBody := p.buildRequestBody(messages, tools, model, options)
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
@@ -178,6 +184,195 @@ func (p *Provider) Chat(
|
||||
return common.ReadAndParseResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
// ChatStream implements streaming via OpenAI-compatible SSE (stream: true).
|
||||
// onChunk receives the accumulated text so far on each text delta.
|
||||
func (p *Provider) ChatStream(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
onChunk func(accumulated string),
|
||||
) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
requestBody := p.buildRequestBody(messages, tools, model, options)
|
||||
requestBody["stream"] = true
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
|
||||
// Use a client without Timeout for streaming — the http.Client.Timeout covers
|
||||
// the entire request lifecycle including body reads, which would kill long streams.
|
||||
// Context cancellation still provides the safety net.
|
||||
streamClient := &http.Client{Transport: p.httpClient.Transport}
|
||||
resp, err := streamClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, common.HandleErrorResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
return parseStreamResponse(ctx, resp.Body, onChunk)
|
||||
}
|
||||
|
||||
// parseStreamResponse parses an OpenAI-compatible SSE stream.
|
||||
func parseStreamResponse(
|
||||
ctx context.Context,
|
||||
reader io.Reader,
|
||||
onChunk func(accumulated string),
|
||||
) (*LLMResponse, error) {
|
||||
var textContent strings.Builder
|
||||
var finishReason string
|
||||
var usage *UsageInfo
|
||||
|
||||
// Tool call assembly: OpenAI streams tool calls as incremental deltas
|
||||
type toolAccum struct {
|
||||
id string
|
||||
name string
|
||||
argsJSON strings.Builder
|
||||
}
|
||||
activeTools := map[int]*toolAccum{}
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024) // 1MB initial, 10MB max
|
||||
for scanner.Scan() {
|
||||
// Check for context cancellation between chunks
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var chunk struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []struct {
|
||||
Index int `json:"index"`
|
||||
ID string `json:"id"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue // skip malformed chunks
|
||||
}
|
||||
|
||||
if chunk.Usage != nil {
|
||||
usage = chunk.Usage
|
||||
}
|
||||
|
||||
if len(chunk.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
choice := chunk.Choices[0]
|
||||
|
||||
// Accumulate text content
|
||||
if choice.Delta.Content != "" {
|
||||
textContent.WriteString(choice.Delta.Content)
|
||||
if onChunk != nil {
|
||||
onChunk(textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate tool call deltas
|
||||
for _, tc := range choice.Delta.ToolCalls {
|
||||
acc, ok := activeTools[tc.Index]
|
||||
if !ok {
|
||||
acc = &toolAccum{}
|
||||
activeTools[tc.Index] = acc
|
||||
}
|
||||
if tc.ID != "" {
|
||||
acc.id = tc.ID
|
||||
}
|
||||
if tc.Function != nil {
|
||||
if tc.Function.Name != "" {
|
||||
acc.name = tc.Function.Name
|
||||
}
|
||||
if tc.Function.Arguments != "" {
|
||||
acc.argsJSON.WriteString(tc.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if choice.FinishReason != nil {
|
||||
finishReason = *choice.FinishReason
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("streaming read error: %w", err)
|
||||
}
|
||||
|
||||
// Assemble tool calls from accumulated deltas
|
||||
var toolCalls []ToolCall
|
||||
for i := 0; i < len(activeTools); i++ {
|
||||
acc, ok := activeTools[i]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
args := make(map[string]any)
|
||||
raw := acc.argsJSON.String()
|
||||
if raw != "" {
|
||||
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||
log.Printf("openai_compat stream: failed to decode tool call arguments for %q: %v", acc.name, err)
|
||||
args["raw"] = raw
|
||||
}
|
||||
}
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: acc.id,
|
||||
Name: acc.name,
|
||||
Arguments: args,
|
||||
})
|
||||
}
|
||||
|
||||
if finishReason == "" {
|
||||
finishReason = "stop"
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: textContent.String(),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeModel(model, apiBase string) string {
|
||||
before, after, ok := strings.Cut(model, "/")
|
||||
if !ok {
|
||||
|
||||
@@ -37,6 +37,20 @@ type StatefulProvider interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
// StreamingProvider is an optional interface for providers that support token streaming.
|
||||
// onChunk receives the accumulated text so far (not individual deltas).
|
||||
// The returned LLMResponse is the same complete response for compatibility with tool-call handling.
|
||||
type StreamingProvider interface {
|
||||
ChatStream(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
onChunk func(accumulated string),
|
||||
) (*LLMResponse, error)
|
||||
}
|
||||
|
||||
// ThinkingCapable is an optional interface for providers that support
|
||||
// extended thinking (e.g. Anthropic). Used by the agent loop to warn
|
||||
// when thinking_level is configured but the active provider cannot use it.
|
||||
|
||||
Reference in New Issue
Block a user