diff --git a/config/config.example.json b/config/config.example.json index 69ac062ac..81c9014ec 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -84,7 +84,10 @@ "proxy": "", "allow_from": ["YOUR_USER_ID"], "use_markdown_v2": false, - "reasoning_channel_id": "" + "reasoning_channel_id": "", + "streaming": { + "enabled": true + } }, "discord": { "enabled": false, diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index ef2b9e28f..1ca5db5b8 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1026,6 +1026,7 @@ func (al *AgentLoop) handleReasoning( } // runLLMIteration executes the LLM call loop with tool handling. +// Returns (finalContent, iteration, error). func (al *AgentLoop) runLLMIteration( ctx context.Context, agent *AgentInstance, @@ -1035,6 +1036,13 @@ func (al *AgentLoop) runLLMIteration( iteration := 0 var finalContent string + // Check if both the provider and channel support streaming + streamProvider, providerCanStream := agent.Provider.(providers.StreamingProvider) + var streamer bus.Streamer + if providerCanStream && !opts.NoHistory && !constants.IsInternalChannel(opts.Channel) { + streamer, _ = al.bus.GetStreamer(ctx, opts.Channel, opts.ChatID) + } + // Determine effective model tier for this conversation turn. // selectCandidates evaluates routing once and the decision is sticky for // all tool-follow-up iterations within the same turn so that a multi-step @@ -1116,6 +1124,16 @@ func (al *AgentLoop) runLLMIteration( al.activeRequests.Add(1) defer al.activeRequests.Done() + // Use streaming when available (streamer obtained, provider supports it) + if streamer != nil && streamProvider != nil { + return streamProvider.ChatStream( + ctx, messages, providerToolDefs, activeModel, llmOpts, + func(accumulated string) { + streamer.Update(ctx, accumulated) + }, + ) + } + if len(activeCandidates) > 1 && al.fallback != nil { fbResult, fbErr := al.fallback.Execute( ctx, @@ -1243,15 +1261,31 @@ func (al *AgentLoop) runLLMIteration( if finalContent == "" && response.ReasoningContent != "" { finalContent = response.ReasoningContent } + + // If we were streaming, finalize the message (sends the permanent message) + if streamer != nil { + if err := streamer.Finalize(ctx, finalContent); err != nil { + logger.WarnCF("agent", "Stream finalize failed", map[string]any{ + "error": err.Error(), + }) + } + } + logger.InfoCF("agent", "LLM response without tool calls (direct answer)", map[string]any{ "agent_id": agent.ID, "iteration": iteration, "content_chars": len(finalContent), + "streamed": streamer != nil, }) break } + // Tool calls detected — cancel any active stream (draft auto-expires) + if streamer != nil { + streamer.Cancel(ctx) + } + normalizedToolCalls := make([]providers.ToolCall, 0, len(response.ToolCalls)) for _, tc := range response.ToolCalls { normalizedToolCalls = append(normalizedToolCalls, providers.NormalizeToolCall(tc)) diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 3d08bda4f..37fcb74c5 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -14,15 +14,32 @@ var ErrBusClosed = errors.New("message bus closed") const defaultBusBufferSize = 64 +// StreamDelegate is implemented by the channel Manager to provide streaming +// capabilities to the agent loop without tight coupling. +type StreamDelegate interface { + // GetStreamer returns a Streamer for the given channel+chatID if the channel + // supports streaming. Returns nil, false if streaming is unavailable. + GetStreamer(ctx context.Context, channel, chatID string) (Streamer, bool) +} + +// Streamer pushes incremental content to a streaming-capable channel. +// Defined here so the agent loop can use it without importing pkg/channels. +type Streamer interface { + Update(ctx context.Context, content string) error + Finalize(ctx context.Context, content string) error + Cancel(ctx context.Context) +} + type MessageBus struct { inbound chan InboundMessage outbound chan OutboundMessage outboundMedia chan OutboundMediaMessage - closeOnce sync.Once - done chan struct{} - closed atomic.Bool - wg sync.WaitGroup + closeOnce sync.Once + done chan struct{} + closed atomic.Bool + wg sync.WaitGroup + streamDelegate atomic.Value // stores StreamDelegate } func NewMessageBus() *MessageBus { @@ -86,6 +103,19 @@ func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage { return mb.outboundMedia } +// SetStreamDelegate registers a StreamDelegate (typically the channel Manager). +func (mb *MessageBus) SetStreamDelegate(d StreamDelegate) { + mb.streamDelegate.Store(d) +} + +// GetStreamer returns a Streamer for the given channel+chatID via the delegate. +func (mb *MessageBus) GetStreamer(ctx context.Context, channel, chatID string) (Streamer, bool) { + if d, ok := mb.streamDelegate.Load().(StreamDelegate); ok && d != nil { + return d.GetStreamer(ctx, channel, chatID) + } + return nil, false +} + func (mb *MessageBus) Close() { mb.closeOnce.Do(func() { // notify all blocked publishers to exit diff --git a/pkg/channels/base.go b/pkg/channels/base.go index edb5b6f08..882e72d08 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -275,14 +275,18 @@ func (c *BaseChannel) HandleMessage( // Auto-trigger typing indicator, message reaction, and placeholder before publishing. // Each capability is independent — all three may fire for the same message. + // Note: even when streaming is available, we still show typing + placeholder on inbound. + // If streaming actually activates, preSend will skip the placeholder edit (streamActive map) + // and the typing stop will still be called. This avoids the problem of compile-time interface + // checks incorrectly skipping indicators when streaming may not work at runtime. if c.owner != nil && c.placeholderRecorder != nil { - // Typing — independent pipeline + // Typing if tc, ok := c.owner.(TypingCapable); ok { if stop, err := tc.StartTyping(ctx, chatID); err == nil { c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop) } } - // Reaction — independent pipeline + // Reaction if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" { if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil { c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo) diff --git a/pkg/channels/interfaces.go b/pkg/channels/interfaces.go index b3a493761..0cfd435b0 100644 --- a/pkg/channels/interfaces.go +++ b/pkg/channels/interfaces.go @@ -3,6 +3,7 @@ package channels import ( "context" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/commands" ) @@ -19,6 +20,11 @@ type MessageEditor interface { EditMessage(ctx context.Context, chatID string, messageID string, content string) error } +// MessageDeleter — channels that can delete a message by ID. +type MessageDeleter interface { + DeleteMessage(ctx context.Context, chatID string, messageID string) error +} + // ReactionCapable — channels that can add a reaction (e.g. 👀) to an inbound message. // ReactToMessage adds a reaction and returns an undo function to remove it. // The undo function MUST be idempotent and safe to call multiple times. @@ -35,6 +41,18 @@ type PlaceholderCapable interface { SendPlaceholder(ctx context.Context, chatID string) (messageID string, err error) } +// StreamingCapable — channels that can show partial LLM output in real-time. +// The channel SHOULD gracefully degrade if the platform rejects streaming +// (e.g. Telegram bot without forum mode). In that case, Update becomes a no-op +// and Finalize still delivers the final message. +type StreamingCapable interface { + BeginStream(ctx context.Context, chatID string) (Streamer, error) +} + +// Streamer is defined in pkg/bus to avoid circular imports. +// This alias keeps channel implementations using channels.Streamer unchanged. +type Streamer = bus.Streamer + // PlaceholderRecorder is injected into channels by Manager. // Channels call these methods on inbound to register typing/placeholder state. // Manager uses the registered state on outbound to stop typing and edit placeholders. diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 741fad53e..ff3fa399c 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -89,6 +89,7 @@ type Manager struct { placeholders sync.Map // "channel:chatID" → placeholderID (string) typingStops sync.Map // "channel:chatID" → func() reactionUndos sync.Map // "channel:chatID" → reactionEntry + streamActive sync.Map // "channel:chatID" → true (set when streamer.Finalize sent the message) channelHashes map[string]string // channel name → config hash } @@ -157,7 +158,7 @@ func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) { } // preSend handles typing stop, reaction undo, and placeholder editing before sending a message. -// Returns true if the message was edited into a placeholder (skip Send). +// Returns true if the message was already delivered (skip Send). func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) bool { key := name + ":" + msg.ChatID @@ -175,7 +176,22 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess } } - // 3. Try editing placeholder + // 3. If a stream already finalized this message, delete the placeholder and skip send + if _, loaded := m.streamActive.LoadAndDelete(key); loaded { + if v, loaded := m.placeholders.LoadAndDelete(key); loaded { + if entry, ok := v.(placeholderEntry); ok && entry.id != "" { + // Prefer deleting the placeholder (cleaner UX than editing to same content) + if deleter, ok := ch.(MessageDeleter); ok { + deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort + } else if editor, ok := ch.(MessageEditor); ok { + editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content) // fallback + } + } + } + return true + } + + // 4. Try editing placeholder if v, loaded := m.placeholders.LoadAndDelete(key); loaded { if entry, ok := v.(placeholderEntry); ok && entry.id != "" { if editor, ok := ch.(MessageEditor); ok { @@ -200,6 +216,9 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.Medi channelHashes: make(map[string]string), } + // Register as streaming delegate so the agent loop can obtain streamers + messageBus.SetStreamDelegate(m) + if err := m.initChannels(&cfg.Channels); err != nil { return nil, err } @@ -210,6 +229,53 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.Medi return m, nil } +// GetStreamer implements bus.StreamDelegate. +// It checks if the named channel supports streaming and returns a Streamer. +func (m *Manager) GetStreamer(ctx context.Context, channelName, chatID string) (bus.Streamer, bool) { + m.mu.RLock() + ch, exists := m.channels[channelName] + m.mu.RUnlock() + + if !exists { + return nil, false + } + + sc, ok := ch.(StreamingCapable) + if !ok { + return nil, false + } + + streamer, err := sc.BeginStream(ctx, chatID) + if err != nil { + logger.DebugCF("channels", "Streaming unavailable, falling back to placeholder", map[string]any{ + "channel": channelName, + "error": err.Error(), + }) + return nil, false + } + + // Mark streamActive on Finalize so preSend knows to clean up the placeholder + key := channelName + ":" + chatID + return &finalizeHookStreamer{ + Streamer: streamer, + onFinalize: func() { m.streamActive.Store(key, true) }, + }, true +} + +// finalizeHookStreamer wraps a Streamer to run a hook on Finalize. +type finalizeHookStreamer struct { + Streamer + onFinalize func() +} + +func (s *finalizeHookStreamer) Finalize(ctx context.Context, content string) error { + if err := s.Streamer.Finalize(ctx, content); err != nil { + return err + } + s.onFinalize() + return nil +} + // initChannel is a helper that looks up a factory by name and creates the channel. func (m *Manager) initChannel(name, displayName string) { f, ok := getFactory(name) diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 2797bdf4a..3eb89c636 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -2,6 +2,8 @@ package telegram import ( "context" + "crypto/rand" + "encoding/binary" "fmt" "io" "net/http" @@ -10,6 +12,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" "github.com/mymmrac/telego" @@ -374,6 +377,22 @@ func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messag return err } +// DeleteMessage implements channels.MessageDeleter. +func (c *TelegramChannel) DeleteMessage(ctx context.Context, chatID string, messageID string) error { + cid, _, err := parseTelegramChatID(chatID) + if err != nil { + return err + } + mid, err := strconv.Atoi(messageID) + if err != nil { + return err + } + return c.bot.DeleteMessage(ctx, &telego.DeleteMessageParams{ + ChatID: tu.ID(cid), + MessageID: mid, + }) +} + // SendPlaceholder implements channels.PlaceholderCapable. // It sends a placeholder message (e.g. "Thinking... 💭") that will later be // edited to the actual response via EditMessage (channels.MessageEditor). @@ -847,3 +866,107 @@ func (c *TelegramChannel) stripBotMention(content string) string { content = re.ReplaceAllString(content, "") return strings.TrimSpace(content) } + +// BeginStream implements channels.StreamingCapable. +func (c *TelegramChannel) BeginStream(ctx context.Context, chatID string) (channels.Streamer, error) { + if !c.config.Channels.Telegram.Streaming.Enabled { + return nil, fmt.Errorf("streaming disabled in config") + } + + cid, _, err := parseTelegramChatID(chatID) + if err != nil { + return nil, err + } + + streamCfg := c.config.Channels.Telegram.Streaming + return &telegramStreamer{ + bot: c.bot, + chatID: cid, + draftID: cryptoRandInt(), + throttleInterval: time.Duration(streamCfg.ThrottleSeconds) * time.Second, + minGrowth: streamCfg.MinGrowthChars, + }, nil +} + +// telegramStreamer streams partial LLM output via Telegram's sendMessageDraft API. +// On first API error (e.g. bot lacks forum mode), it silently degrades: Update +// becomes a no-op, while Finalize still delivers the final message. +type telegramStreamer struct { + bot *telego.Bot + chatID int64 + draftID int + throttleInterval time.Duration + minGrowth int + lastLen int + lastAt time.Time + failed bool + mu sync.Mutex +} + +func (s *telegramStreamer) Update(ctx context.Context, content string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.failed { + return nil + } + + // Throttle: skip if not enough time or content has passed + now := time.Now() + growth := len(content) - s.lastLen + if s.lastLen > 0 && now.Sub(s.lastAt) < s.throttleInterval && growth < s.minGrowth { + return nil + } + + htmlContent := markdownToTelegramHTML(content) + + err := s.bot.SendMessageDraft(ctx, &telego.SendMessageDraftParams{ + ChatID: s.chatID, + DraftID: s.draftID, + Text: htmlContent, + ParseMode: telego.ModeHTML, + }) + if err != nil { + // First error → degrade silently (e.g. no forum mode) + logger.WarnCF("telegram", "sendMessageDraft failed, disabling streaming", map[string]any{ + "error": err.Error(), + }) + s.failed = true + return nil // don't propagate — Finalize will still deliver + } + + s.lastLen = len(content) + s.lastAt = now + return nil +} + +func (s *telegramStreamer) Finalize(ctx context.Context, content string) error { + htmlContent := markdownToTelegramHTML(content) + tgMsg := tu.Message(tu.ID(s.chatID), htmlContent) + tgMsg.ParseMode = telego.ModeHTML + + if _, err := s.bot.SendMessage(ctx, tgMsg); err != nil { + // Fallback to plain text + tgMsg.ParseMode = "" + if _, err = s.bot.SendMessage(ctx, tgMsg); err != nil { + logger.ErrorCF("telegram", "Finalize failed after HTML and plain-text attempts", map[string]any{ + "chat_id": s.chatID, + "error": err.Error(), + "len": len(content), + }) + return fmt.Errorf("telegram finalize: %w", err) + } + } + return nil +} + +func (s *telegramStreamer) Cancel(ctx context.Context) { + // Draft auto-expires on Telegram's side; nothing to clean up. +} + +// cryptoRandInt returns a non-zero random int using crypto/rand. +func cryptoRandInt() int { + var b [4]byte + _, _ = rand.Read(b[:]) + return int(binary.BigEndian.Uint32(b[:])) | 1 // ensure non-zero +} diff --git a/pkg/config/config.go b/pkg/config/config.go index f524e952a..235cb0641 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -318,6 +318,12 @@ type PlaceholderConfig struct { Text string `json:"text,omitempty"` } +type StreamingConfig struct { + Enabled bool `json:"enabled,omitempty" env:"PICOCLAW_CHANNELS_TELEGRAM_STREAMING_ENABLED"` + ThrottleSeconds int `json:"throttle_seconds,omitempty" env:"PICOCLAW_CHANNELS_TELEGRAM_STREAMING_THROTTLE_SECONDS"` + MinGrowthChars int `json:"min_growth_chars,omitempty" env:"PICOCLAW_CHANNELS_TELEGRAM_STREAMING_MIN_GROWTH_CHARS"` +} + type WhatsAppConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` @@ -336,6 +342,7 @@ type TelegramConfig struct { GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` Typing TypingConfig `json:"typing,omitempty"` Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + Streaming StreamingConfig `json:"streaming,omitempty"` ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_TELEGRAM_REASONING_CHANNEL_ID"` UseMarkdownV2 bool `json:"use_markdown_v2" env:"PICOCLAW_CHANNELS_TELEGRAM_USE_MARKDOWN_V2"` } diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index d44c73577..0d2141ae1 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -62,6 +62,7 @@ func DefaultConfig() *Config { Enabled: true, Text: "Thinking... 💭", }, + Streaming: StreamingConfig{Enabled: true, ThrottleSeconds: 3, MinGrowthChars: 200}, UseMarkdownV2: false, }, Feishu: FeishuConfig{ diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 4d823630e..803165edb 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -52,6 +52,19 @@ func (p *HTTPProvider) Chat( return p.delegate.Chat(ctx, messages, tools, model, options) } +// ChatStream implements providers.StreamingProvider by delegating to the +// OpenAI-compatible streaming endpoint (SSE with stream: true). +func (p *HTTPProvider) ChatStream( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, + onChunk func(accumulated string), +) (*LLMResponse, error) { + return p.delegate.ChatStream(ctx, messages, tools, model, options, onChunk) +} + func (p *HTTPProvider) GetDefaultModel() string { return "" } diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 463db83c9..938e4ea8b 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -1,10 +1,13 @@ package openai_compat import ( + "bufio" "bytes" "context" "encoding/json" "fmt" + "io" + "log" "net/http" "net/url" "strings" @@ -85,17 +88,10 @@ func NewProviderWithMaxTokensFieldAndTimeout( ) } -func (p *Provider) Chat( - ctx context.Context, - messages []Message, - tools []ToolDefinition, - model string, - options map[string]any, -) (*LLMResponse, error) { - if p.apiBase == "" { - return nil, fmt.Errorf("API base not configured") - } - +// buildRequestBody constructs the common request body for Chat and ChatStream. +func (p *Provider) buildRequestBody( + messages []Message, tools []ToolDefinition, model string, options map[string]any, +) map[string]any { model = normalizeModel(model, p.apiBase) requestBody := map[string]any{ @@ -112,10 +108,8 @@ func (p *Provider) Chat( } if maxTokens, ok := common.AsInt(options["max_tokens"]); ok { - // Use configured maxTokensField if specified, otherwise fallback to model-based detection fieldName := p.maxTokensField if fieldName == "" { - // Fallback: detect from model name for backward compatibility lowerModel := strings.ToLower(model) if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") || strings.Contains(lowerModel, "gpt-5") { @@ -129,7 +123,6 @@ func (p *Provider) Chat( if temperature, ok := common.AsFloat(options["temperature"]); ok { lowerModel := strings.ToLower(model) - // Kimi k2 models only support temperature=1. if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { requestBody["temperature"] = 1.0 } else { @@ -139,17 +132,30 @@ func (p *Provider) Chat( // Prompt caching: pass a stable cache key so OpenAI can bucket requests // with the same key and reuse prefix KV cache across calls. - // The key is typically the agent ID — stable per agent, shared across requests. - // See: https://platform.openai.com/docs/guides/prompt-caching // Prompt caching is only supported by OpenAI-native endpoints. - // Non-OpenAI providers (Mistral, Gemini, DeepSeek, etc.) reject unknown - // fields with 422 errors, so only include it for OpenAI APIs. + // Non-OpenAI providers reject unknown fields with 422 errors. if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" { if supportsPromptCacheKey(p.apiBase) { requestBody["prompt_cache_key"] = cacheKey } } + return requestBody +} + +func (p *Provider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + requestBody := p.buildRequestBody(messages, tools, model, options) + jsonData, err := json.Marshal(requestBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) @@ -178,6 +184,195 @@ func (p *Provider) Chat( return common.ReadAndParseResponse(resp, p.apiBase) } +// ChatStream implements streaming via OpenAI-compatible SSE (stream: true). +// onChunk receives the accumulated text so far on each text delta. +func (p *Provider) ChatStream( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, + onChunk func(accumulated string), +) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + requestBody := p.buildRequestBody(messages, tools, model, options) + requestBody["stream"] = true + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + // Use a client without Timeout for streaming — the http.Client.Timeout covers + // the entire request lifecycle including body reads, which would kill long streams. + // Context cancellation still provides the safety net. + streamClient := &http.Client{Transport: p.httpClient.Transport} + resp, err := streamClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, common.HandleErrorResponse(resp, p.apiBase) + } + + return parseStreamResponse(ctx, resp.Body, onChunk) +} + +// parseStreamResponse parses an OpenAI-compatible SSE stream. +func parseStreamResponse( + ctx context.Context, + reader io.Reader, + onChunk func(accumulated string), +) (*LLMResponse, error) { + var textContent strings.Builder + var finishReason string + var usage *UsageInfo + + // Tool call assembly: OpenAI streams tool calls as incremental deltas + type toolAccum struct { + id string + name string + argsJSON strings.Builder + } + activeTools := map[int]*toolAccum{} + + scanner := bufio.NewScanner(reader) + scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024) // 1MB initial, 10MB max + for scanner.Scan() { + // Check for context cancellation between chunks + if err := ctx.Err(); err != nil { + return nil, err + } + + line := scanner.Text() + + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + var chunk struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + ToolCalls []struct { + Index int `json:"index"` + ID string `json:"id"` + Function *struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"delta"` + FinishReason *string `json:"finish_reason"` + } `json:"choices"` + Usage *UsageInfo `json:"usage"` + } + + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue // skip malformed chunks + } + + if chunk.Usage != nil { + usage = chunk.Usage + } + + if len(chunk.Choices) == 0 { + continue + } + + choice := chunk.Choices[0] + + // Accumulate text content + if choice.Delta.Content != "" { + textContent.WriteString(choice.Delta.Content) + if onChunk != nil { + onChunk(textContent.String()) + } + } + + // Accumulate tool call deltas + for _, tc := range choice.Delta.ToolCalls { + acc, ok := activeTools[tc.Index] + if !ok { + acc = &toolAccum{} + activeTools[tc.Index] = acc + } + if tc.ID != "" { + acc.id = tc.ID + } + if tc.Function != nil { + if tc.Function.Name != "" { + acc.name = tc.Function.Name + } + if tc.Function.Arguments != "" { + acc.argsJSON.WriteString(tc.Function.Arguments) + } + } + } + + if choice.FinishReason != nil { + finishReason = *choice.FinishReason + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("streaming read error: %w", err) + } + + // Assemble tool calls from accumulated deltas + var toolCalls []ToolCall + for i := 0; i < len(activeTools); i++ { + acc, ok := activeTools[i] + if !ok { + continue + } + args := make(map[string]any) + raw := acc.argsJSON.String() + if raw != "" { + if err := json.Unmarshal([]byte(raw), &args); err != nil { + log.Printf("openai_compat stream: failed to decode tool call arguments for %q: %v", acc.name, err) + args["raw"] = raw + } + } + toolCalls = append(toolCalls, ToolCall{ + ID: acc.id, + Name: acc.name, + Arguments: args, + }) + } + + if finishReason == "" { + finishReason = "stop" + } + + return &LLMResponse{ + Content: textContent.String(), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + }, nil +} + func normalizeModel(model, apiBase string) string { before, after, ok := strings.Cut(model, "/") if !ok { diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 1f28bc4ad..9a4d126a7 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -37,6 +37,20 @@ type StatefulProvider interface { Close() } +// StreamingProvider is an optional interface for providers that support token streaming. +// onChunk receives the accumulated text so far (not individual deltas). +// The returned LLMResponse is the same complete response for compatibility with tool-call handling. +type StreamingProvider interface { + ChatStream( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, + onChunk func(accumulated string), + ) (*LLMResponse, error) +} + // ThinkingCapable is an optional interface for providers that support // extended thinking (e.g. Anthropic). Used by the agent loop to warn // when thinking_level is configured but the active provider cannot use it.