diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index c3733ce3a..be3e77a43 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -270,6 +270,10 @@ func filterDeepSeekReasoningTurn(messages []Message) []Message { } cloned := msg + // DeepSeek thinking-mode replay only requires reasoning_content for + // turns that participate in a tool interaction round. For plain + // assistant turns between two user messages, the docs say the API will + // ignore reasoning_content on replay, so we strip it here. if cloned.Role == "assistant" && strings.TrimSpace(cloned.ReasoningContent) != "" && !hasToolInteraction { cloned.ReasoningContent = "" } diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 594048ea5..4f68fb393 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -526,6 +526,112 @@ func TestProviderChat_HistoryCanonicalizationMatrix(t *testing.T) { }) } +func TestProviderChat_DeepSeekDocsReplayRequirements(t *testing.T) { + // DeepSeek's thinking-mode and multi-round chat docs distinguish two cases: + // - for a plain assistant turn between two user messages without tool calls, + // reasoning_content does not need to be replayed and the API ignores it if sent; + // - for a turn that participates in a tool-interaction round, assistant + // reasoning_content must be replayed on subsequent requests. + // + // Keep this behavior explicit here so future changes do not "fix" the + // non-tool stripping based on issue reports that are broader than the + // vendor documentation. + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + p.SetProviderName("deepseek") + + messages := []Message{ + {Role: "user", Content: "Who wrote The Hobbit?"}, + {Role: "assistant", Content: "J.R.R. Tolkien.", ReasoningContent: "I know this from general knowledge."}, + {Role: "user", Content: "What's the weather tomorrow?"}, + { + Role: "assistant", + Content: "Let me check the date first.", + ReasoningContent: "I need tomorrow's date before checking the weather.", + ToolCalls: []ToolCall{{ + ID: "call_date", + Type: "function", + Function: &FunctionCall{ + Name: "get_date", + Arguments: "{}", + }, + }}, + }, + {Role: "tool", ToolCallID: "call_date", Content: "2026-04-29"}, + { + Role: "assistant", + Content: "Tomorrow is 2026-04-30.", + ReasoningContent: "Now I can continue with the weather request.", + }, + {Role: "user", Content: "What about Guangzhou?"}, + } + + _, err := p.Chat(t.Context(), messages, nil, "deepseek-v4-flash", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + reqMessages, ok := requestBody["messages"].([]any) + if !ok { + t.Fatalf("messages is not []any: %T", requestBody["messages"]) + } + if len(reqMessages) != len(messages) { + t.Fatalf("len(messages) = %d, want %d", len(reqMessages), len(messages)) + } + + plainAssistant, ok := reqMessages[1].(map[string]any) + if !ok { + t.Fatalf("plain assistant message is not map[string]any: %T", reqMessages[1]) + } + if _, exists := plainAssistant["reasoning_content"]; exists { + t.Fatalf( + "plain DeepSeek turn should omit reasoning_content on replay, got %v", + plainAssistant["reasoning_content"], + ) + } + + toolAssistant, ok := reqMessages[3].(map[string]any) + if !ok { + t.Fatalf("tool assistant message is not map[string]any: %T", reqMessages[3]) + } + if toolAssistant["reasoning_content"] != "I need tomorrow's date before checking the weather." { + t.Fatalf( + "tool assistant reasoning_content = %v, want preserved", + toolAssistant["reasoning_content"], + ) + } + + finalAssistant, ok := reqMessages[5].(map[string]any) + if !ok { + t.Fatalf("final assistant message is not map[string]any: %T", reqMessages[5]) + } + if finalAssistant["reasoning_content"] != "Now I can continue with the weather request." { + t.Fatalf( + "final assistant reasoning_content = %v, want preserved", + finalAssistant["reasoning_content"], + ) + } +} + func TestProviderChat_HTTPError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "bad request", http.StatusBadRequest) diff --git a/pkg/seahorse/schema.go b/pkg/seahorse/schema.go index aa829358b..5b67fe9e0 100644 --- a/pkg/seahorse/schema.go +++ b/pkg/seahorse/schema.go @@ -46,6 +46,7 @@ func runSchema(db *sql.DB) error { conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id), role TEXT NOT NULL, content TEXT NOT NULL DEFAULT '', + reasoning_content TEXT NOT NULL DEFAULT '', token_count INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL DEFAULT (datetime('now')) )`, @@ -157,9 +158,57 @@ func runSchema(db *sql.DB) error { return err } } + + if err := ensureMessagesReasoningContentColumn(db); err != nil { + return err + } return nil } +func ensureMessagesReasoningContentColumn(db *sql.DB) error { + hasColumn, err := tableHasColumn(db, "messages", "reasoning_content") + if err != nil { + return fmt.Errorf("check messages.reasoning_content: %w", err) + } + if hasColumn { + return nil + } + + if _, err := db.Exec(`ALTER TABLE messages ADD COLUMN reasoning_content TEXT NOT NULL DEFAULT ''`); err != nil { + return fmt.Errorf("add messages.reasoning_content: %w", err) + } + return nil +} + +func tableHasColumn(db *sql.DB, tableName, columnName string) (bool, error) { + rows, err := db.Query(fmt.Sprintf(`PRAGMA table_info(%s)`, tableName)) + if err != nil { + return false, err + } + defer rows.Close() + + for rows.Next() { + var ( + cid int + name string + columnType string + notNull int + defaultVal sql.NullString + pk int + ) + if err := rows.Scan(&cid, &name, &columnType, ¬Null, &defaultVal, &pk); err != nil { + return false, err + } + if name == columnName { + return true, nil + } + } + if err := rows.Err(); err != nil { + return false, err + } + return false, nil +} + // checkFTS5Support verifies that SQLite has FTS5 with trigram tokenizer enabled. // This is required for full-text search with CJK (Chinese, Japanese, Korean) support. func checkFTS5Support(db *sql.DB) error { diff --git a/pkg/seahorse/schema_test.go b/pkg/seahorse/schema_test.go index f3d6a3650..943b742b2 100644 --- a/pkg/seahorse/schema_test.go +++ b/pkg/seahorse/schema_test.go @@ -91,6 +91,53 @@ func TestRunMigrationsIdempotent(t *testing.T) { } } +func TestRunSchemaAddsMessagesReasoningContentColumn(t *testing.T) { + db := openTestDB(t) + + _, err := db.Exec(`CREATE TABLE messages ( + message_id INTEGER PRIMARY KEY AUTOINCREMENT, + conversation_id INTEGER NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL DEFAULT '', + token_count INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')) + )`) + if err != nil { + t.Fatalf("create legacy messages table: %v", err) + } + + err = runSchema(db) + if err != nil { + t.Fatalf("runSchema: %v", err) + } + + var count int + err = db.QueryRow(`SELECT count(*) FROM pragma_table_info('messages') WHERE name = 'reasoning_content'`). + Scan(&count) + if err != nil { + t.Fatalf("query pragma_table_info: %v", err) + } + if count != 1 { + t.Fatalf("reasoning_content column count = %d, want 1", count) + } + + _, err = db.Exec( + `INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))`, + "reasoning-column-test", + ) + if err != nil { + t.Fatalf("insert conversation: %v", err) + } + + _, err = db.Exec( + `INSERT INTO messages (conversation_id, role, content, reasoning_content, token_count) + VALUES (1, 'assistant', 'answer', 'thinking', 1)`, + ) + if err != nil { + t.Fatalf("insert message with reasoning_content: %v", err) + } +} + func TestMigrationConversationUnique(t *testing.T) { db := openTestDB(t) if err := runSchema(db); err != nil { diff --git a/pkg/seahorse/short_engine.go b/pkg/seahorse/short_engine.go index f584788ce..0a8175617 100644 --- a/pkg/seahorse/short_engine.go +++ b/pkg/seahorse/short_engine.go @@ -253,9 +253,23 @@ func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Messa var added *Message var err error if len(msg.Parts) > 0 { - added, err = e.store.AddMessageWithParts(ctx, conv.ConversationID, msg.Role, msg.Parts, msg.TokenCount) + added, err = e.store.AddMessageWithPartsAndReasoning( + ctx, + conv.ConversationID, + msg.Role, + msg.Parts, + msg.ReasoningContent, + msg.TokenCount, + ) } else { - added, err = e.store.AddMessage(ctx, conv.ConversationID, msg.Role, msg.Content, msg.TokenCount) + added, err = e.store.AddMessageWithReasoning( + ctx, + conv.ConversationID, + msg.Role, + msg.Content, + msg.ReasoningContent, + msg.TokenCount, + ) } if err != nil { return nil, fmt.Errorf("add message: %w", err) @@ -420,7 +434,7 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me // Fast path: DB has same count and exact match → no-op if len(dbMsgs) == len(messages) { matched := true - for i := 0; i < len(messages); i++ { + for i := range messages { if !messageMatches(dbMsgs[i], messages[i]) { matched = false break @@ -431,14 +445,21 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me } } - // Find longest matching prefix from the start - anchor := -1 - compareLen := len(dbMsgs) - if compareLen > len(messages) { - compareLen = len(messages) + // Migration repair path: old SeaHorse rows may be missing reasoning_content + // even though the canonical JSONL history already has it. Backfill those + // rows in place so we do not treat this as edited history and leave stale + // summaries/context behind after a partial raw-message rebuild. + if repaired, err := e.repairBootstrapReasoningContent(ctx, dbMsgs, messages); err != nil { + return fmt.Errorf("bootstrap: repair reasoning_content: %w", err) + } else if repaired && len(dbMsgs) == len(messages) { + return nil } - for i := 0; i < compareLen; i++ { + // Find longest matching prefix from the start + anchor := -1 + compareLen := min(len(dbMsgs), len(messages)) + + for i := range compareLen { if messageMatches(dbMsgs[i], messages[i]) { anchor = i } else { @@ -524,6 +545,57 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me return nil } +func (e *Engine) repairBootstrapReasoningContent(ctx context.Context, dbMsgs, messages []Message) (bool, error) { + if len(dbMsgs) == 0 || len(messages) == 0 { + return false, nil + } + + overlap := min(len(messages), len(dbMsgs)) + + var updates []struct { + index int + messageID int64 + reasoningContent string + } + + for i := range overlap { + if !messageMatchesIgnoringReasoning(dbMsgs[i], messages[i]) { + return false, nil + } + if dbMsgs[i].ReasoningContent == messages[i].ReasoningContent { + continue + } + if dbMsgs[i].ReasoningContent != "" || messages[i].ReasoningContent == "" { + return false, nil + } + updates = append(updates, struct { + index int + messageID int64 + reasoningContent string + }{ + index: i, + messageID: dbMsgs[i].ID, + reasoningContent: messages[i].ReasoningContent, + }) + } + + if len(updates) == 0 { + return false, nil + } + + for _, update := range updates { + if err := e.store.UpdateMessageReasoningContent(ctx, update.messageID, update.reasoningContent); err != nil { + return false, err + } + dbMsgs[update.index].ReasoningContent = update.reasoningContent + } + + logger.InfoCF("seahorse", "bootstrap: repaired missing reasoning_content", map[string]any{ + "messages": len(updates), + }) + return true, nil +} + // truncate shortens a string for logging. func truncate(s string, maxLen int) string { if len(s) <= maxLen { @@ -532,12 +604,19 @@ func truncate(s string, maxLen int) string { return s[:maxLen] + "..." } -// messageMatches compares two messages using (role, content) or (role, parts). -// TokenCount is NOT compared because it may be re-estimated differently -// during bootstrap (e.g., via tokenizer.EstimateMessageTokens). +// messageMatches compares two messages using role + reasoning_content and then +// either content or parts. TokenCount is NOT compared because it may be +// re-estimated differently during bootstrap (e.g., via tokenizer.EstimateMessageTokens). // For messages with Parts (tool_use, tool_result), compare Parts instead of Content -// since AddMessageWithParts stores empty Content in DB. +// because structured messages are matched by their parts payload. func messageMatches(a, b Message) bool { + if a.Role != b.Role || a.ReasoningContent != b.ReasoningContent { + return false + } + return messageMatchesIgnoringReasoning(a, b) +} + +func messageMatchesIgnoringReasoning(a, b Message) bool { if a.Role != b.Role { return false } diff --git a/pkg/seahorse/short_engine_test.go b/pkg/seahorse/short_engine_test.go index d64634fb7..2a5c6c5d8 100644 --- a/pkg/seahorse/short_engine_test.go +++ b/pkg/seahorse/short_engine_test.go @@ -320,6 +320,108 @@ func TestEngineIngestWithParts(t *testing.T) { } } +func TestEngineIngestPreservesReasoningContent(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + + msgs := []Message{ + { + Role: "assistant", + Content: "world", + ReasoningContent: "let me think this through", + TokenCount: 4, + }, + } + + _, err := eng.Ingest(ctx, "agent:reasoning", msgs) + if err != nil { + t.Fatalf("Ingest: %v", err) + } + + conv, _ := eng.store.GetOrCreateConversation(ctx, "agent:reasoning") + stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(stored) != 1 { + t.Fatalf("stored messages = %d, want 1", len(stored)) + } + if stored[0].ReasoningContent != "let me think this through" { + t.Errorf( + "stored[0].ReasoningContent = %q, want %q", + stored[0].ReasoningContent, + "let me think this through", + ) + } + + result, err := eng.Assemble(ctx, "agent:reasoning", AssembleInput{Budget: 1000}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + if len(result.Messages) != 1 { + t.Fatalf("assembled messages = %d, want 1", len(result.Messages)) + } + if result.Messages[0].ReasoningContent != "let me think this through" { + t.Errorf( + "assembled reasoning = %q, want %q", + result.Messages[0].ReasoningContent, + "let me think this through", + ) + } +} + +func TestEngineIngestWithPartsPreservesReasoningContent(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + + msgs := []Message{ + { + Role: "assistant", + ReasoningContent: "I need to inspect the file first", + TokenCount: 10, + Parts: []MessagePart{ + {Type: "tool_use", Name: "read_file", Arguments: `{"path":"/tmp/test"}`, ToolCallID: "tc_123"}, + }, + }, + } + + _, err := eng.Ingest(ctx, "agent:parts-reasoning", msgs) + if err != nil { + t.Fatalf("Ingest: %v", err) + } + + conv, _ := eng.store.GetOrCreateConversation(ctx, "agent:parts-reasoning") + stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(stored) != 1 { + t.Fatalf("stored messages = %d, want 1", len(stored)) + } + if stored[0].ReasoningContent != "I need to inspect the file first" { + t.Errorf( + "stored reasoning = %q, want %q", + stored[0].ReasoningContent, + "I need to inspect the file first", + ) + } + + result, err := eng.Assemble(ctx, "agent:parts-reasoning", AssembleInput{Budget: 1000}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + if len(result.Messages) != 1 { + t.Fatalf("assembled messages = %d, want 1", len(result.Messages)) + } + if result.Messages[0].ReasoningContent != "I need to inspect the file first" { + t.Errorf( + "assembled reasoning = %q, want %q", + result.Messages[0].ReasoningContent, + "I need to inspect the file first", + ) + } +} + func TestEngineIngestAssemblePreservesParts(t *testing.T) { eng := newTestEngine(t) ctx := context.Background() @@ -514,6 +616,216 @@ func TestEngineBootstrapIdempotent(t *testing.T) { } } +func TestBootstrapRepairsMissingReasoningContent(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + sessionKey := "agent:repair-reasoning" + + conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + + userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3) + if err != nil { + t.Fatalf("AddMessage user: %v", err) + } + + assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3) + if err != nil { + t.Fatalf("AddMessage assistant: %v", err) + } + + err = eng.store.AppendContextMessages( + ctx, + conv.ConversationID, + []int64{userMsg.ID, assistantMsg.ID}, + ) + if err != nil { + t.Fatalf("AppendContextMessages: %v", err) + } + + err = eng.Bootstrap(ctx, sessionKey, []Message{ + {Role: "user", Content: "hello", TokenCount: 3}, + {Role: "assistant", Content: "world", ReasoningContent: "let me think this through", TokenCount: 3}, + }) + if err != nil { + t.Fatalf("Bootstrap: %v", err) + } + + stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(stored) != 2 { + t.Fatalf("stored messages = %d, want 2", len(stored)) + } + if stored[1].ReasoningContent != "let me think this through" { + t.Errorf( + "stored[1].ReasoningContent = %q, want %q", + stored[1].ReasoningContent, + "let me think this through", + ) + } +} + +func TestBootstrapRepairsMissingReasoningContentWithoutDroppingSummaries(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + sessionKey := "agent:repair-reasoning-summary" + + conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + + userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3) + if err != nil { + t.Fatalf("AddMessage user: %v", err) + } + assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3) + if err != nil { + t.Fatalf("AddMessage assistant: %v", err) + } + + err = eng.store.AppendContextMessages( + ctx, + conv.ConversationID, + []int64{userMsg.ID, assistantMsg.ID}, + ) + if err != nil { + t.Fatalf("AppendContextMessages: %v", err) + } + + summary, err := eng.store.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "summary before repair", + TokenCount: 10, + }) + if err != nil { + t.Fatalf("CreateSummary: %v", err) + } + + err = eng.store.AppendContextSummary(ctx, conv.ConversationID, summary.SummaryID) + if err != nil { + t.Fatalf("AppendContextSummary: %v", err) + } + + err = eng.Bootstrap(ctx, sessionKey, []Message{ + {Role: "user", Content: "hello", TokenCount: 3}, + {Role: "assistant", Content: "world", ReasoningContent: "let me think this through", TokenCount: 3}, + }) + if err != nil { + t.Fatalf("Bootstrap: %v", err) + } + + stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(stored) != 2 { + t.Fatalf("stored messages = %d, want 2", len(stored)) + } + if stored[1].ReasoningContent != "let me think this through" { + t.Errorf( + "stored[1].ReasoningContent = %q, want %q", + stored[1].ReasoningContent, + "let me think this through", + ) + } + + summaries, err := eng.store.GetSummariesByConversation(ctx, conv.ConversationID) + if err != nil { + t.Fatalf("GetSummariesByConversation: %v", err) + } + if len(summaries) != 1 { + t.Fatalf("summaries = %d, want 1", len(summaries)) + } + if summaries[0].SummaryID != summary.SummaryID { + t.Errorf("SummaryID = %q, want %q", summaries[0].SummaryID, summary.SummaryID) + } + + items, err := eng.store.GetContextItems(ctx, conv.ConversationID) + if err != nil { + t.Fatalf("GetContextItems: %v", err) + } + if len(items) != 3 { + t.Fatalf("context items = %d, want 3", len(items)) + } + if items[2].ItemType != "summary" || items[2].SummaryID != summary.SummaryID { + t.Errorf("summary context item = %+v, want summary %q", items[2], summary.SummaryID) + } +} + +func TestBootstrapRepairsMissingReasoningContentOnPrefixBeforeAppendingDelta(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + sessionKey := "agent:repair-reasoning-prefix" + + conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + + userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3) + if err != nil { + t.Fatalf("AddMessage user: %v", err) + } + assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3) + if err != nil { + t.Fatalf("AddMessage assistant: %v", err) + } + + err = eng.store.AppendContextMessages( + ctx, + conv.ConversationID, + []int64{userMsg.ID, assistantMsg.ID}, + ) + if err != nil { + t.Fatalf("AppendContextMessages: %v", err) + } + + err = eng.Bootstrap(ctx, sessionKey, []Message{ + {Role: "user", Content: "hello", TokenCount: 3}, + {Role: "assistant", Content: "world", ReasoningContent: "let me think this through", TokenCount: 3}, + {Role: "user", Content: "follow-up", TokenCount: 2}, + }) + if err != nil { + t.Fatalf("Bootstrap: %v", err) + } + + stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(stored) != 3 { + t.Fatalf("stored messages = %d, want 3", len(stored)) + } + if stored[1].ReasoningContent != "let me think this through" { + t.Errorf( + "stored[1].ReasoningContent = %q, want %q", + stored[1].ReasoningContent, + "let me think this through", + ) + } + if stored[2].Content != "follow-up" { + t.Errorf("stored[2].Content = %q, want %q", stored[2].Content, "follow-up") + } + + items, err := eng.store.GetContextItems(ctx, conv.ConversationID) + if err != nil { + t.Fatalf("GetContextItems: %v", err) + } + if len(items) != 3 { + t.Fatalf("context items = %d, want 3", len(items)) + } + if items[2].ItemType != "message" || items[2].MessageID != stored[2].ID { + t.Errorf("last context item = %+v, want appended message %d", items[2], stored[2].ID) + } +} + func TestEngineBootstrapDelta(t *testing.T) { eng := newTestEngine(t) ctx := context.Background() diff --git a/pkg/seahorse/store.go b/pkg/seahorse/store.go index c84aaaf07..0edbbd128 100644 --- a/pkg/seahorse/store.go +++ b/pkg/seahorse/store.go @@ -162,20 +162,31 @@ func (s *Store) getMessageTimeRange(ctx context.Context, convID int64) (time.Tim // AddMessage appends a message to a conversation. func (s *Store) AddMessage(ctx context.Context, convID int64, role, content string, tokenCount int) (*Message, error) { + return s.AddMessageWithReasoning(ctx, convID, role, content, "", tokenCount) +} + +// AddMessageWithReasoning appends a message with reasoning content to a conversation. +func (s *Store) AddMessageWithReasoning( + ctx context.Context, + convID int64, + role, content, reasoningContent string, + tokenCount int, +) (*Message, error) { result, err := s.db.ExecContext(ctx, - "INSERT INTO messages (conversation_id, role, content, token_count) VALUES (?, ?, ?, ?)", - convID, role, content, tokenCount, + "INSERT INTO messages (conversation_id, role, content, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?)", + convID, role, content, reasoningContent, tokenCount, ) if err != nil { return nil, fmt.Errorf("add message: %w", err) } id, _ := result.LastInsertId() return &Message{ - ID: id, - ConversationID: convID, - Role: role, - Content: content, - TokenCount: tokenCount, + ID: id, + ConversationID: convID, + Role: role, + Content: content, + ReasoningContent: reasoningContent, + TokenCount: tokenCount, }, nil } @@ -212,6 +223,18 @@ func (s *Store) AddMessageWithParts( role string, parts []MessagePart, tokenCount int, +) (*Message, error) { + return s.AddMessageWithPartsAndReasoning(ctx, convID, role, parts, "", tokenCount) +} + +// AddMessageWithPartsAndReasoning adds a message with structured parts and reasoning content. +func (s *Store) AddMessageWithPartsAndReasoning( + ctx context.Context, + convID int64, + role string, + parts []MessagePart, + reasoningContent string, + tokenCount int, ) (*Message, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { @@ -223,8 +246,8 @@ func (s *Store) AddMessageWithParts( readableContent := partsToReadableContent(parts) result, err := tx.ExecContext(ctx, - "INSERT INTO messages (conversation_id, role, content, token_count) VALUES (?, ?, ?, ?)", - convID, role, readableContent, tokenCount, + "INSERT INTO messages (conversation_id, role, content, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?)", + convID, role, readableContent, reasoningContent, tokenCount, ) if err != nil { return nil, fmt.Errorf("add message: %w", err) @@ -256,11 +279,12 @@ func (s *Store) AddMessageWithParts( // Return message with parts msg := &Message{ - ID: msgID, - ConversationID: convID, - Role: role, - TokenCount: tokenCount, - Parts: make([]MessagePart, len(parts)), + ID: msgID, + ConversationID: convID, + Role: role, + ReasoningContent: reasoningContent, + TokenCount: tokenCount, + Parts: make([]MessagePart, len(parts)), } for i, p := range parts { p.MessageID = msgID @@ -271,7 +295,7 @@ func (s *Store) AddMessageWithParts( // GetMessages retrieves messages for a conversation. func (s *Store) GetMessages(ctx context.Context, convID int64, limit int, beforeID int64) ([]Message, error) { - query := "SELECT message_id, conversation_id, role, content, token_count, created_at FROM messages WHERE conversation_id = ?" + query := "SELECT message_id, conversation_id, role, content, reasoning_content, token_count, created_at FROM messages WHERE conversation_id = ?" args := []any{convID} if beforeID > 0 { query += " AND message_id < ?" @@ -298,6 +322,7 @@ func (s *Store) GetMessages(ctx context.Context, convID int64, limit int, before &msg.ConversationID, &msg.Role, &msg.Content, + &msg.ReasoningContent, &msg.TokenCount, &createdAt, ); err != nil { @@ -335,10 +360,11 @@ func (s *Store) GetMessageCount(ctx context.Context, convID int64) (int, error) func (s *Store) GetMessageByID(ctx context.Context, messageID int64) (*Message, error) { var msg Message var createdAt string - err := s.db.QueryRowContext(ctx, - "SELECT message_id, conversation_id, role, content, token_count, created_at FROM messages WHERE message_id = ?", + err := s.db.QueryRowContext( + ctx, + "SELECT message_id, conversation_id, role, content, reasoning_content, token_count, created_at FROM messages WHERE message_id = ?", messageID, - ).Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.TokenCount, &createdAt) + ).Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.ReasoningContent, &msg.TokenCount, &createdAt) if err == sql.ErrNoRows { return nil, fmt.Errorf("message %d not found", messageID) } @@ -350,6 +376,28 @@ func (s *Store) GetMessageByID(ctx context.Context, messageID int64) (*Message, return &msg, nil } +// UpdateMessageReasoningContent updates reasoning_content for an existing message. +func (s *Store) UpdateMessageReasoningContent(ctx context.Context, messageID int64, reasoningContent string) error { + result, err := s.db.ExecContext( + ctx, + "UPDATE messages SET reasoning_content = ? WHERE message_id = ?", + reasoningContent, + messageID, + ) + if err != nil { + return fmt.Errorf("update message reasoning_content: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("update message reasoning_content rows affected: %w", err) + } + if rowsAffected == 0 { + return fmt.Errorf("message %d not found", messageID) + } + return nil +} + func (s *Store) loadMessageParts(ctx context.Context, msgID int64) ([]MessagePart, error) { rows, err := s.db.QueryContext(ctx, `SELECT part_id, message_id, type, text, name, arguments, tool_call_id, media_uri, mime_type @@ -534,7 +582,7 @@ func (s *Store) LinkSummaryToMessages(ctx context.Context, summaryID string, mes // GetSummarySourceMessages retrieves source messages for a summary. func (s *Store) GetSummarySourceMessages(ctx context.Context, summaryID string) ([]Message, error) { rows, err := s.db.QueryContext(ctx, - `SELECT m.message_id, m.conversation_id, m.role, m.content, m.token_count, m.created_at + `SELECT m.message_id, m.conversation_id, m.role, m.content, m.reasoning_content, m.token_count, m.created_at FROM summary_messages sm JOIN messages m ON m.message_id = sm.message_id WHERE sm.summary_id = ? @@ -555,6 +603,7 @@ func (s *Store) GetSummarySourceMessages(ctx context.Context, summaryID string) &msg.ConversationID, &msg.Role, &msg.Content, + &msg.ReasoningContent, &msg.TokenCount, &createdAt, ); err != nil { diff --git a/pkg/seahorse/store_test.go b/pkg/seahorse/store_test.go index 89635cc9a..67bed1c11 100644 --- a/pkg/seahorse/store_test.go +++ b/pkg/seahorse/store_test.go @@ -199,6 +199,47 @@ func TestStoreAddAndGetMessages(t *testing.T) { } } +func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:reasoning") + + msg, err := s.AddMessageWithReasoning( + ctx, + conv.ConversationID, + "assistant", + "hello world", + "let me think", + 5, + ) + if err != nil { + t.Fatalf("AddMessageWithReasoning: %v", err) + } + if msg.ReasoningContent != "let me think" { + t.Fatalf("ReasoningContent = %q, want %q", msg.ReasoningContent, "let me think") + } + + msgs, err := s.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(msgs) != 1 { + t.Fatalf("got %d messages, want 1", len(msgs)) + } + if msgs[0].ReasoningContent != "let me think" { + t.Errorf("ReasoningContent = %q, want %q", msgs[0].ReasoningContent, "let me think") + } + + found, err := s.GetMessageByID(ctx, msg.ID) + if err != nil { + t.Fatalf("GetMessageByID: %v", err) + } + if found.ReasoningContent != "let me think" { + t.Errorf("GetMessageByID ReasoningContent = %q, want %q", found.ReasoningContent, "let me think") + } +} + func TestStoreAddMessageWithParts(t *testing.T) { s := openTestStore(t) ctx := context.Background() @@ -233,6 +274,43 @@ func TestStoreAddMessageWithParts(t *testing.T) { } } +func TestStoreAddMessageWithPartsAndReasoningContent(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:parts-reasoning") + + parts := []MessagePart{ + {Type: "tool_use", Name: "read_file", Arguments: `{"path":"/tmp/test"}`, ToolCallID: "tc_123"}, + } + _, err := s.AddMessageWithPartsAndReasoning( + ctx, + conv.ConversationID, + "assistant", + parts, + "need to inspect the file first", + 10, + ) + if err != nil { + t.Fatalf("AddMessageWithPartsAndReasoning: %v", err) + } + + msgs, err := s.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(msgs) != 1 { + t.Fatalf("expected 1 message, got %d", len(msgs)) + } + if msgs[0].ReasoningContent != "need to inspect the file first" { + t.Errorf( + "ReasoningContent = %q, want %q", + msgs[0].ReasoningContent, + "need to inspect the file first", + ) + } +} + func TestStoreGetMessageCount(t *testing.T) { s := openTestStore(t) ctx := context.Background() @@ -275,6 +353,31 @@ func TestStoreGetMessageByID(t *testing.T) { } } +func TestStoreUpdateMessageReasoningContent(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:update-reasoning") + + msg, err := s.AddMessage(ctx, conv.ConversationID, "assistant", "answer", 3) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + err = s.UpdateMessageReasoningContent(ctx, msg.ID, "thinking") + if err != nil { + t.Fatalf("UpdateMessageReasoningContent: %v", err) + } + + found, err := s.GetMessageByID(ctx, msg.ID) + if err != nil { + t.Fatalf("GetMessageByID: %v", err) + } + if found.ReasoningContent != "thinking" { + t.Errorf("ReasoningContent = %q, want %q", found.ReasoningContent, "thinking") + } +} + // --- Summary Operations --- func TestStoreCreateAndGetSummary(t *testing.T) {