Merge remote-tracking branch 'upstream/main'

This commit is contained in:
afjcjsbx
2026-05-26 09:22:23 +02:00
18 changed files with 742 additions and 89 deletions
+9
View File
@@ -6,6 +6,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"time"
"github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/providers"
@@ -200,6 +201,7 @@ func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
ModelName: msg.ModelName, ModelName: msg.ModelName,
ReasoningContent: msg.ReasoningContent, ReasoningContent: msg.ReasoningContent,
TokenCount: tokenizer.EstimateMessageTokens(msg), TokenCount: tokenizer.EstimateMessageTokens(msg),
CreatedAt: normalizeSeahorseMessageCreatedAt(msg.CreatedAt),
} }
// Convert ToolCalls → MessageParts // Convert ToolCalls → MessageParts
@@ -235,6 +237,13 @@ func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
return result return result
} }
func normalizeSeahorseMessageCreatedAt(createdAt *time.Time) time.Time {
if createdAt == nil || createdAt.IsZero() {
return time.Time{}
}
return createdAt.UTC().Truncate(time.Second)
}
// seahorseToProviderMessages converts a seahorse.AssembleResult to []providers.Message. // seahorseToProviderMessages converts a seahorse.AssembleResult to []providers.Message.
func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes.Message { func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes.Message {
messages := make([]protocoltypes.Message, 0, len(result.Messages)) messages := make([]protocoltypes.Message, 0, len(result.Messages))
+5
View File
@@ -171,11 +171,13 @@ func TestProviderToSeahorseMessageWithMedia(t *testing.T) {
} }
func TestProviderToSeahorseMessageWithReasoning(t *testing.T) { func TestProviderToSeahorseMessageWithReasoning(t *testing.T) {
createdAt := time.Date(2026, 5, 6, 7, 8, 9, 123000000, time.UTC)
msg := protocoltypes.Message{ msg := protocoltypes.Message{
Role: "assistant", Role: "assistant",
Content: "response text", Content: "response text",
ModelName: "gpt-5.4-mini", ModelName: "gpt-5.4-mini",
ReasoningContent: "I thought about this carefully", ReasoningContent: "I thought about this carefully",
CreatedAt: &createdAt,
} }
result := providerToSeahorseMessage(msg) result := providerToSeahorseMessage(msg)
@@ -185,6 +187,9 @@ func TestProviderToSeahorseMessageWithReasoning(t *testing.T) {
if result.ModelName != "gpt-5.4-mini" { if result.ModelName != "gpt-5.4-mini" {
t.Errorf("ModelName = %q, want %q", result.ModelName, "gpt-5.4-mini") t.Errorf("ModelName = %q, want %q", result.ModelName, "gpt-5.4-mini")
} }
if !result.CreatedAt.Equal(time.Date(2026, 5, 6, 7, 8, 9, 0, time.UTC)) {
t.Errorf("CreatedAt = %v, want 2026-05-06 07:08:09 UTC", result.CreatedAt)
}
} }
func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) { func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) {
+108 -24
View File
@@ -9,6 +9,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
"time"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
@@ -261,6 +262,7 @@ func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Messa
msg.ModelName, msg.ModelName,
msg.ReasoningContent, msg.ReasoningContent,
msg.TokenCount, msg.TokenCount,
msg.CreatedAt,
) )
} else { } else {
added, err = e.store.AddMessageWithReasoning( added, err = e.store.AddMessageWithReasoning(
@@ -271,6 +273,7 @@ func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Messa
msg.ModelName, msg.ModelName,
msg.ReasoningContent, msg.ReasoningContent,
msg.TokenCount, msg.TokenCount,
msg.CreatedAt,
) )
} }
if err != nil { if err != nil {
@@ -445,10 +448,14 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me
if err != nil { if err != nil {
return fmt.Errorf("bootstrap: repair model_name: %w", err) return fmt.Errorf("bootstrap: repair model_name: %w", err)
} }
if (repairedReasoning || repairedModelName) && len(dbMsgs) == len(messages) { repairedCreatedAt, err := e.repairBootstrapCreatedAt(ctx, dbMsgs, messages)
if err != nil {
return fmt.Errorf("bootstrap: repair created_at: %w", err)
}
if (repairedReasoning || repairedModelName || repairedCreatedAt) && len(dbMsgs) == len(messages) {
matched := true matched := true
for i := range messages { for i := range messages {
if !messageMatches(dbMsgs[i], messages[i]) { if !messagesMatch(dbMsgs[i], messages[i], messageMatchOptions{}) {
matched = false matched = false
break break
} }
@@ -462,7 +469,7 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me
if len(dbMsgs) == len(messages) { if len(dbMsgs) == len(messages) {
matched := true matched := true
for i := range messages { for i := range messages {
if !messageMatches(dbMsgs[i], messages[i]) { if !messagesMatch(dbMsgs[i], messages[i], messageMatchOptions{}) {
matched = false matched = false
break break
} }
@@ -477,7 +484,7 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me
compareLen := min(len(dbMsgs), len(messages)) compareLen := min(len(dbMsgs), len(messages))
for i := range compareLen { for i := range compareLen {
if messageMatches(dbMsgs[i], messages[i]) { if messagesMatch(dbMsgs[i], messages[i], messageMatchOptions{}) {
anchor = i anchor = i
} else { } else {
// Mismatch detected - log details and rebuild // Mismatch detected - log details and rebuild
@@ -578,7 +585,11 @@ func (e *Engine) repairBootstrapReasoningContent(ctx context.Context, dbMsgs, me
} }
for i := range overlap { for i := range overlap {
if !messageMatchesIgnoringReasoningAndModelName(dbMsgs[i], messages[i]) { if !messagesMatch(dbMsgs[i], messages[i], messageMatchOptions{
IgnoreReasoningContent: true,
IgnoreModelName: true,
IgnoreCreatedAt: true,
}) {
return false, nil return false, nil
} }
if dbMsgs[i].ReasoningContent == messages[i].ReasoningContent { if dbMsgs[i].ReasoningContent == messages[i].ReasoningContent {
@@ -629,7 +640,11 @@ func (e *Engine) repairBootstrapModelName(ctx context.Context, dbMsgs, messages
} }
for i := range overlap { for i := range overlap {
if !messageMatchesIgnoringReasoningAndModelName(dbMsgs[i], messages[i]) { if !messagesMatch(dbMsgs[i], messages[i], messageMatchOptions{
IgnoreReasoningContent: true,
IgnoreModelName: true,
IgnoreCreatedAt: true,
}) {
return false, nil return false, nil
} }
if dbMsgs[i].ModelName == messages[i].ModelName { if dbMsgs[i].ModelName == messages[i].ModelName {
@@ -666,6 +681,64 @@ func (e *Engine) repairBootstrapModelName(ctx context.Context, dbMsgs, messages
return true, nil return true, nil
} }
func (e *Engine) repairBootstrapCreatedAt(ctx context.Context, dbMsgs, messages []Message) (bool, error) {
if len(dbMsgs) == 0 || len(messages) == 0 {
return false, nil
}
overlap := min(len(messages), len(dbMsgs))
var updates []struct {
index int
messageID int64
createdAt time.Time
}
for i := range overlap {
if !messagesMatch(dbMsgs[i], messages[i], messageMatchOptions{
IgnoreReasoningContent: true,
IgnoreModelName: true,
IgnoreCreatedAt: true,
}) {
return false, nil
}
wantCreatedAt := normalizeMessageCreatedAt(messages[i].CreatedAt)
if wantCreatedAt.IsZero() {
return false, nil
}
if dbMsgs[i].CreatedAt.Equal(wantCreatedAt) {
continue
}
updates = append(updates, struct {
index int
messageID int64
createdAt time.Time
}{
index: i,
messageID: dbMsgs[i].ID,
createdAt: wantCreatedAt,
})
}
if len(updates) == 0 {
return false, nil
}
for _, update := range updates {
if err := e.store.UpdateMessageCreatedAt(ctx, update.messageID, update.createdAt); err != nil {
return false, err
}
dbMsgs[update.index].CreatedAt = update.createdAt
}
logger.InfoCF("seahorse", "bootstrap: repaired message created_at", map[string]any{
"messages": len(updates),
})
return true, nil
}
// truncate shortens a string for logging. // truncate shortens a string for logging.
func truncate(s string, maxLen int) string { func truncate(s string, maxLen int) string {
if len(s) <= maxLen { if len(s) <= maxLen {
@@ -674,29 +747,28 @@ func truncate(s string, maxLen int) string {
return s[:maxLen] + "..." return s[:maxLen] + "..."
} }
// messageMatches compares two messages using role + reasoning_content and then type messageMatchOptions struct {
// either content or parts. TokenCount is NOT compared because it may be IgnoreReasoningContent bool
// re-estimated differently during bootstrap (e.g., via tokenizer.EstimateMessageTokens). IgnoreModelName bool
// For messages with Parts (tool_use, tool_result), compare Parts instead of Content IgnoreCreatedAt bool
// because structured messages are matched by their parts payload.
func messageMatches(a, b Message) bool {
if a.Role != b.Role || a.ReasoningContent != b.ReasoningContent || a.ModelName != b.ModelName {
return false
}
return messageMatchesIgnoringReasoning(a, b)
} }
func messageMatchesIgnoringReasoning(a, b Message) bool { // messagesMatch compares two messages by role and payload, plus the optional
if a.ModelName != b.ModelName { // metadata fields used by bootstrap repair. TokenCount is intentionally ignored
return false // because bootstrap may re-estimate it differently.
} func messagesMatch(a, b Message, opts messageMatchOptions) bool {
return messageMatchesIgnoringReasoningAndModelName(a, b)
}
func messageMatchesIgnoringReasoningAndModelName(a, b Message) bool {
if a.Role != b.Role { if a.Role != b.Role {
return false return false
} }
if !opts.IgnoreReasoningContent && a.ReasoningContent != b.ReasoningContent {
return false
}
if !opts.IgnoreModelName && a.ModelName != b.ModelName {
return false
}
if !opts.IgnoreCreatedAt && !messageCreatedAtMatches(a.CreatedAt, b.CreatedAt) {
return false
}
// If either message has Parts, compare Parts // If either message has Parts, compare Parts
if len(a.Parts) > 0 || len(b.Parts) > 0 { if len(a.Parts) > 0 || len(b.Parts) > 0 {
return partsMatch(a.Parts, b.Parts) return partsMatch(a.Parts, b.Parts)
@@ -705,6 +777,18 @@ func messageMatchesIgnoringReasoningAndModelName(a, b Message) bool {
return a.Content == b.Content return a.Content == b.Content
} }
// messageCreatedAtMatches treats missing timestamps as compatible so bootstrap
// can preserve legacy histories while still enforcing exact equality once both
// sides carry canonical created_at values.
func messageCreatedAtMatches(a, b time.Time) bool {
na := normalizeMessageCreatedAt(a)
nb := normalizeMessageCreatedAt(b)
if na.IsZero() || nb.IsZero() {
return true
}
return na.Equal(nb)
}
// partsMatch compares two slices of MessagePart for equality. // partsMatch compares two slices of MessagePart for equality.
func partsMatch(a, b []MessagePart) bool { func partsMatch(a, b []MessagePart) bool {
if len(a) != len(b) { if len(a) != len(b) {
+81 -5
View File
@@ -57,8 +57,8 @@ func prepareBootstrapRepairConversation(
} }
return conv, []Message{ return conv, []Message{
{Role: "user", Content: "hello", TokenCount: 3}, {Role: "user", Content: "hello", TokenCount: 3, CreatedAt: userMsg.CreatedAt},
{Role: "assistant", Content: "world", TokenCount: 3}, {Role: "assistant", Content: "world", TokenCount: 3, CreatedAt: assistantMsg.CreatedAt},
} }
} }
@@ -464,13 +464,19 @@ func TestBootstrapRepairsReasoningContentAndModelNameTogether(t *testing.T) {
} }
err = eng.Bootstrap(ctx, sessionKey, []Message{ err = eng.Bootstrap(ctx, sessionKey, []Message{
{Role: "user", Content: "hello", TokenCount: 3}, {
Role: "user",
Content: "hello",
TokenCount: 3,
CreatedAt: time.Date(2026, 3, 4, 5, 6, 7, 0, time.UTC),
},
{ {
Role: "assistant", Role: "assistant",
Content: "world", Content: "world",
ModelName: "gpt-5.4", ModelName: "gpt-5.4",
ReasoningContent: "let me think this through", ReasoningContent: "let me think this through",
TokenCount: 3, TokenCount: 3,
CreatedAt: time.Date(2026, 3, 4, 5, 6, 8, 0, time.UTC),
}, },
}) })
if err != nil { if err != nil {
@@ -515,6 +521,7 @@ func TestBootstrapRepairsIncorrectNonEmptyModelName(t *testing.T) {
"wrong-model", "wrong-model",
"", "",
3, 3,
time.Time{},
) )
if err != nil { if err != nil {
t.Fatalf("AddMessageWithReasoning assistant: %v", err) t.Fatalf("AddMessageWithReasoning assistant: %v", err)
@@ -545,6 +552,64 @@ func TestBootstrapRepairsIncorrectNonEmptyModelName(t *testing.T) {
} }
} }
func TestBootstrapRepairsCreatedAt(t *testing.T) {
eng := newTestEngine(t)
ctx := context.Background()
sessionKey := "agent:repair-created-at"
conv, msgs := prepareBootstrapRepairConversation(t, eng, ctx, sessionKey)
wantCreatedAt := time.Date(2026, 3, 4, 5, 6, 7, 0, time.UTC)
msgs[1].CreatedAt = wantCreatedAt
err := eng.Bootstrap(ctx, sessionKey, msgs)
if err != nil {
t.Fatalf("Bootstrap: %v", err)
}
stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
if err != nil {
t.Fatalf("GetMessages: %v", err)
}
if len(stored) != 2 {
t.Fatalf("stored messages = %d, want 2", len(stored))
}
if !stored[1].CreatedAt.Equal(wantCreatedAt) {
t.Fatalf("stored[1].CreatedAt = %v, want %v", stored[1].CreatedAt, wantCreatedAt)
}
}
func TestEngineIngestPreservesCreatedAt(t *testing.T) {
eng := newTestEngine(t)
ctx := context.Background()
wantCreatedAt := time.Date(2026, 4, 5, 6, 7, 8, 0, time.UTC)
msgs := []Message{
{
Role: "assistant",
Content: "world",
TokenCount: 4,
CreatedAt: wantCreatedAt,
},
}
_, err := eng.Ingest(ctx, "agent:created-at", msgs)
if err != nil {
t.Fatalf("Ingest: %v", err)
}
conv, _ := eng.store.GetOrCreateConversation(ctx, "agent:created-at")
stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
if err != nil {
t.Fatalf("GetMessages: %v", err)
}
if len(stored) != 1 {
t.Fatalf("stored messages = %d, want 1", len(stored))
}
if !stored[0].CreatedAt.Equal(wantCreatedAt) {
t.Fatalf("stored[0].CreatedAt = %v, want %v", stored[0].CreatedAt, wantCreatedAt)
}
}
func TestEngineIngestWithPartsPreservesReasoningContent(t *testing.T) { func TestEngineIngestWithPartsPreservesReasoningContent(t *testing.T) {
eng := newTestEngine(t) eng := newTestEngine(t)
ctx := context.Background() ctx := context.Background()
@@ -864,8 +929,19 @@ func TestBootstrapRepairsMissingReasoningContentWithoutDroppingSummaries(t *test
} }
err = eng.Bootstrap(ctx, sessionKey, []Message{ err = eng.Bootstrap(ctx, sessionKey, []Message{
{Role: "user", Content: "hello", TokenCount: 3}, {
{Role: "assistant", Content: "world", ReasoningContent: "let me think this through", TokenCount: 3}, Role: "user",
Content: "hello",
TokenCount: 3,
CreatedAt: time.Date(2026, 3, 4, 5, 6, 7, 0, time.UTC),
},
{
Role: "assistant",
Content: "world",
ReasoningContent: "let me think this through",
TokenCount: 3,
CreatedAt: time.Date(2026, 3, 4, 5, 6, 8, 0, time.UTC),
},
}) })
if err != nil { if err != nil {
t.Fatalf("Bootstrap: %v", err) t.Fatalf("Bootstrap: %v", err)
+75 -17
View File
@@ -8,6 +8,8 @@ import (
"time" "time"
) )
const sqliteTimeLayout = "2006-01-02 15:04:05"
// Store provides SQLite storage for seahorse. // Store provides SQLite storage for seahorse.
type Store struct { type Store struct {
db *sql.DB db *sql.DB
@@ -75,8 +77,8 @@ func (s *Store) GetConversationBySessionKey(ctx context.Context, sessionKey stri
if err != nil { if err != nil {
return nil, fmt.Errorf("get conversation by session key: %w", err) return nil, fmt.Errorf("get conversation by session key: %w", err)
} }
conv.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) conv.CreatedAt = parseSQLiteTime(createdAt)
conv.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) conv.UpdatedAt = parseSQLiteTime(updatedAt)
return &conv, nil return &conv, nil
} }
@@ -153,8 +155,8 @@ func (s *Store) getMessageTimeRange(ctx context.Context, convID int64) (time.Tim
if err != nil || minTime == "" { if err != nil || minTime == "" {
return time.Time{}, time.Time{}, err return time.Time{}, time.Time{}, err
} }
oldest, _ := time.Parse("2006-01-02 15:04:05", minTime) oldest := parseSQLiteTime(minTime)
newest, _ := time.Parse("2006-01-02 15:04:05", maxTime) newest := parseSQLiteTime(maxTime)
return oldest, newest, nil return oldest, newest, nil
} }
@@ -162,7 +164,7 @@ func (s *Store) getMessageTimeRange(ctx context.Context, convID int64) (time.Tim
// AddMessage appends a message to a conversation. // AddMessage appends a message to a conversation.
func (s *Store) AddMessage(ctx context.Context, convID int64, role, content string, tokenCount int) (*Message, error) { func (s *Store) AddMessage(ctx context.Context, convID int64, role, content string, tokenCount int) (*Message, error) {
return s.AddMessageWithReasoning(ctx, convID, role, content, "", "", tokenCount) return s.AddMessageWithReasoning(ctx, convID, role, content, "", "", tokenCount, time.Time{})
} }
// AddMessageWithReasoning appends a message with reasoning content to a conversation. // AddMessageWithReasoning appends a message with reasoning content to a conversation.
@@ -171,16 +173,22 @@ func (s *Store) AddMessageWithReasoning(
convID int64, convID int64,
role, content, modelName, reasoningContent string, role, content, modelName, reasoningContent string,
tokenCount int, tokenCount int,
createdAt time.Time,
) (*Message, error) { ) (*Message, error) {
storedCreatedAt := normalizeMessageCreatedAt(createdAt)
if storedCreatedAt.IsZero() {
storedCreatedAt = normalizeMessageCreatedAt(time.Now())
}
result, err := s.db.ExecContext( result, err := s.db.ExecContext(
ctx, ctx,
"INSERT INTO messages (conversation_id, role, content, model_name, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?, ?)", "INSERT INTO messages (conversation_id, role, content, model_name, reasoning_content, token_count, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
convID, convID,
role, role,
content, content,
modelName, modelName,
reasoningContent, reasoningContent,
tokenCount, tokenCount,
formatSQLiteTime(storedCreatedAt),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("add message: %w", err) return nil, fmt.Errorf("add message: %w", err)
@@ -194,6 +202,7 @@ func (s *Store) AddMessageWithReasoning(
ModelName: modelName, ModelName: modelName,
ReasoningContent: reasoningContent, ReasoningContent: reasoningContent,
TokenCount: tokenCount, TokenCount: tokenCount,
CreatedAt: storedCreatedAt,
}, nil }, nil
} }
@@ -231,7 +240,7 @@ func (s *Store) AddMessageWithParts(
parts []MessagePart, parts []MessagePart,
tokenCount int, tokenCount int,
) (*Message, error) { ) (*Message, error) {
return s.AddMessageWithPartsAndReasoning(ctx, convID, role, parts, "", "", tokenCount) return s.AddMessageWithPartsAndReasoning(ctx, convID, role, parts, "", "", tokenCount, time.Time{})
} }
// AddMessageWithPartsAndReasoning adds a message with structured parts and reasoning content. // AddMessageWithPartsAndReasoning adds a message with structured parts and reasoning content.
@@ -243,6 +252,7 @@ func (s *Store) AddMessageWithPartsAndReasoning(
modelName string, modelName string,
reasoningContent string, reasoningContent string,
tokenCount int, tokenCount int,
createdAt time.Time,
) (*Message, error) { ) (*Message, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
@@ -250,18 +260,24 @@ func (s *Store) AddMessageWithPartsAndReasoning(
} }
defer tx.Rollback() defer tx.Rollback()
storedCreatedAt := normalizeMessageCreatedAt(createdAt)
if storedCreatedAt.IsZero() {
storedCreatedAt = normalizeMessageCreatedAt(time.Now())
}
// Derive readable content from Parts for FTS5 indexing and summary formatting // Derive readable content from Parts for FTS5 indexing and summary formatting
readableContent := partsToReadableContent(parts) readableContent := partsToReadableContent(parts)
result, err := tx.ExecContext( result, err := tx.ExecContext(
ctx, ctx,
"INSERT INTO messages (conversation_id, role, content, model_name, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?, ?)", "INSERT INTO messages (conversation_id, role, content, model_name, reasoning_content, token_count, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
convID, convID,
role, role,
readableContent, readableContent,
modelName, modelName,
reasoningContent, reasoningContent,
tokenCount, tokenCount,
formatSQLiteTime(storedCreatedAt),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("add message: %w", err) return nil, fmt.Errorf("add message: %w", err)
@@ -299,6 +315,7 @@ func (s *Store) AddMessageWithPartsAndReasoning(
ModelName: modelName, ModelName: modelName,
ReasoningContent: reasoningContent, ReasoningContent: reasoningContent,
TokenCount: tokenCount, TokenCount: tokenCount,
CreatedAt: storedCreatedAt,
Parts: make([]MessagePart, len(parts)), Parts: make([]MessagePart, len(parts)),
} }
for i, p := range parts { for i, p := range parts {
@@ -344,7 +361,7 @@ func (s *Store) GetMessages(ctx context.Context, convID int64, limit int, before
); err != nil { ); err != nil {
return nil, err return nil, err
} }
msg.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) msg.CreatedAt = parseSQLiteTime(createdAt)
msgs = append(msgs, msg) msgs = append(msgs, msg)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
@@ -387,7 +404,7 @@ func (s *Store) GetMessageByID(ctx context.Context, messageID int64) (*Message,
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) msg.CreatedAt = parseSQLiteTime(createdAt)
msg.Parts, _ = s.loadMessageParts(ctx, msg.ID) msg.Parts, _ = s.loadMessageParts(ctx, msg.ID)
return &msg, nil return &msg, nil
} }
@@ -435,6 +452,32 @@ func (s *Store) UpdateMessageModelName(ctx context.Context, messageID int64, mod
return nil return nil
} }
func (s *Store) UpdateMessageCreatedAt(ctx context.Context, messageID int64, createdAt time.Time) error {
storedCreatedAt := normalizeMessageCreatedAt(createdAt)
if storedCreatedAt.IsZero() {
return fmt.Errorf("message %d created_at cannot be zero", messageID)
}
result, err := s.db.ExecContext(
ctx,
"UPDATE messages SET created_at = ? WHERE message_id = ?",
formatSQLiteTime(storedCreatedAt),
messageID,
)
if err != nil {
return fmt.Errorf("update message created_at: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("update message created_at rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("message %d not found", messageID)
}
return nil
}
func (s *Store) loadMessageParts(ctx context.Context, msgID int64) ([]MessagePart, error) { func (s *Store) loadMessageParts(ctx context.Context, msgID int64) ([]MessagePart, error) {
rows, err := s.db.QueryContext(ctx, rows, err := s.db.QueryContext(ctx,
`SELECT part_id, message_id, type, text, name, arguments, tool_call_id, media_uri, mime_type `SELECT part_id, message_id, type, text, name, arguments, tool_call_id, media_uri, mime_type
@@ -648,7 +691,7 @@ func (s *Store) GetSummarySourceMessages(ctx context.Context, summaryID string)
); err != nil { ); err != nil {
return nil, err return nil, err
} }
msg.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) msg.CreatedAt = parseSQLiteTime(createdAt)
msgs = append(msgs, msg) msgs = append(msgs, msg)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
@@ -714,8 +757,7 @@ func (s *Store) GetContextItems(ctx context.Context, convID int64) ([]ContextIte
item.MessageID = messageID.Int64 item.MessageID = messageID.Int64
} }
if createdAt.Valid { if createdAt.Valid {
t, _ := time.Parse("2006-01-02 15:04:05", createdAt.String) item.CreatedAt = parseSQLiteTime(createdAt.String)
item.CreatedAt = t
} }
items = append(items, item) items = append(items, item)
} }
@@ -1449,7 +1491,7 @@ func (s *Store) scanSearchResults(rows *sql.Rows, withRank bool) ([]SearchResult
} }
} }
r.Kind = SummaryKind(kind) r.Kind = SummaryKind(kind)
r.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) r.CreatedAt = parseSQLiteTime(createdAt)
results = append(results, r) results = append(results, r)
} }
return results, nil return results, nil
@@ -1573,7 +1615,7 @@ func (s *Store) scanMessageSearchResults(rows *sql.Rows, withRank bool) ([]Searc
} }
} }
r.Snippet = content r.Snippet = content
r.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) r.CreatedAt = parseSQLiteTime(createdAt)
results = append(results, r) results = append(results, r)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
@@ -1606,7 +1648,7 @@ func (s *Store) scanSummary(ctx context.Context, where string, args ...any) (*Su
return nil, err return nil, err
} }
sum.Kind = SummaryKind(kind) sum.Kind = SummaryKind(kind)
sum.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) sum.CreatedAt = parseSQLiteTime(createdAt)
if earliestAt.Valid { if earliestAt.Valid {
t, _ := time.Parse(time.RFC3339, earliestAt.String) t, _ := time.Parse(time.RFC3339, earliestAt.String)
sum.EarliestAt = &t sum.EarliestAt = &t
@@ -1633,7 +1675,7 @@ func (s *Store) scanSummaries(rows *sql.Rows) ([]Summary, error) {
return nil, err return nil, err
} }
sum.Kind = SummaryKind(kind) sum.Kind = SummaryKind(kind)
sum.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) sum.CreatedAt = parseSQLiteTime(createdAt)
if earliestAt.Valid { if earliestAt.Valid {
t, _ := time.Parse(time.RFC3339, earliestAt.String) t, _ := time.Parse(time.RFC3339, earliestAt.String)
sum.EarliestAt = &t sum.EarliestAt = &t
@@ -1659,6 +1701,22 @@ func isUniqueViolation(err error) bool {
contains(err.Error(), "constraint failed")) contains(err.Error(), "constraint failed"))
} }
func normalizeMessageCreatedAt(createdAt time.Time) time.Time {
if createdAt.IsZero() {
return time.Time{}
}
return createdAt.UTC().Truncate(time.Second)
}
func formatSQLiteTime(t time.Time) string {
return normalizeMessageCreatedAt(t).Format(sqliteTimeLayout)
}
func parseSQLiteTime(raw string) time.Time {
parsed, _ := time.Parse(sqliteTimeLayout, raw)
return parsed
}
func contains(s, sub string) bool { func contains(s, sub string) bool {
return len(s) >= len(sub) && searchSubstring(s, sub) return len(s) >= len(sub) && searchSubstring(s, sub)
} }
+14
View File
@@ -213,6 +213,7 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
"gpt-5.4-mini", "gpt-5.4-mini",
"let me think", "let me think",
5, 5,
time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC),
) )
if err != nil { if err != nil {
t.Fatalf("AddMessageWithReasoning: %v", err) t.Fatalf("AddMessageWithReasoning: %v", err)
@@ -223,6 +224,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
if msg.ModelName != "gpt-5.4-mini" { if msg.ModelName != "gpt-5.4-mini" {
t.Fatalf("ModelName = %q, want %q", msg.ModelName, "gpt-5.4-mini") t.Fatalf("ModelName = %q, want %q", msg.ModelName, "gpt-5.4-mini")
} }
if !msg.CreatedAt.Equal(time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC)) {
t.Fatalf("CreatedAt = %v, want 2026-01-02 03:04:05 UTC", msg.CreatedAt)
}
msgs, err := s.GetMessages(ctx, conv.ConversationID, 10, 0) msgs, err := s.GetMessages(ctx, conv.ConversationID, 10, 0)
if err != nil { if err != nil {
@@ -237,6 +241,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
if msgs[0].ModelName != "gpt-5.4-mini" { if msgs[0].ModelName != "gpt-5.4-mini" {
t.Errorf("ModelName = %q, want %q", msgs[0].ModelName, "gpt-5.4-mini") t.Errorf("ModelName = %q, want %q", msgs[0].ModelName, "gpt-5.4-mini")
} }
if !msgs[0].CreatedAt.Equal(time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC)) {
t.Errorf("CreatedAt = %v, want 2026-01-02 03:04:05 UTC", msgs[0].CreatedAt)
}
found, err := s.GetMessageByID(ctx, msg.ID) found, err := s.GetMessageByID(ctx, msg.ID)
if err != nil { if err != nil {
@@ -248,6 +255,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
if found.ModelName != "gpt-5.4-mini" { if found.ModelName != "gpt-5.4-mini" {
t.Errorf("GetMessageByID ModelName = %q, want %q", found.ModelName, "gpt-5.4-mini") t.Errorf("GetMessageByID ModelName = %q, want %q", found.ModelName, "gpt-5.4-mini")
} }
if !found.CreatedAt.Equal(time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC)) {
t.Errorf("GetMessageByID CreatedAt = %v, want 2026-01-02 03:04:05 UTC", found.CreatedAt)
}
} }
func TestStoreAddMessageWithParts(t *testing.T) { func TestStoreAddMessageWithParts(t *testing.T) {
@@ -301,6 +311,7 @@ func TestStoreAddMessageWithPartsAndReasoningContent(t *testing.T) {
"gpt-5.4", "gpt-5.4",
"need to inspect the file first", "need to inspect the file first",
10, 10,
time.Date(2026, 2, 3, 4, 5, 6, 0, time.UTC),
) )
if err != nil { if err != nil {
t.Fatalf("AddMessageWithPartsAndReasoning: %v", err) t.Fatalf("AddMessageWithPartsAndReasoning: %v", err)
@@ -323,6 +334,9 @@ func TestStoreAddMessageWithPartsAndReasoningContent(t *testing.T) {
if msgs[0].ModelName != "gpt-5.4" { if msgs[0].ModelName != "gpt-5.4" {
t.Errorf("ModelName = %q, want %q", msgs[0].ModelName, "gpt-5.4") t.Errorf("ModelName = %q, want %q", msgs[0].ModelName, "gpt-5.4")
} }
if !msgs[0].CreatedAt.Equal(time.Date(2026, 2, 3, 4, 5, 6, 0, time.UTC)) {
t.Errorf("CreatedAt = %v, want 2026-02-03 04:05:06 UTC", msgs[0].CreatedAt)
}
} }
func TestStoreGetMessageCount(t *testing.T) { func TestStoreGetMessageCount(t *testing.T) {
+19 -10
View File
@@ -60,6 +60,21 @@ func (sm *SessionManager) GetOrCreate(key string) *Session {
return session return session
} }
func ensureMessageCreatedAt(msg *providers.Message, fallback time.Time) {
if msg.CreatedAt != nil && !msg.CreatedAt.IsZero() {
return
}
ts := fallback
msg.CreatedAt = &ts
}
func normalizeHistoryCreatedAt(history []providers.Message) {
now := time.Now()
for i := range history {
ensureMessageCreatedAt(&history[i], now)
}
}
func (sm *SessionManager) AddMessage(sessionKey, role, content string) { func (sm *SessionManager) AddMessage(sessionKey, role, content string) {
sm.AddFullMessage(sessionKey, providers.Message{ sm.AddFullMessage(sessionKey, providers.Message{
Role: role, Role: role,
@@ -88,9 +103,7 @@ func (sm *SessionManager) AddFullMessage(sessionKey string, msg providers.Messag
} }
now := time.Now() now := time.Now()
if msg.CreatedAt == nil { ensureMessageCreatedAt(&msg, now)
msg.CreatedAt = &now
}
session.Messages = append(session.Messages, msg) session.Messages = append(session.Messages, msg)
session.Updated = now session.Updated = now
@@ -280,6 +293,7 @@ func (sm *SessionManager) loadSessions() error {
continue continue
} }
session.Messages = messageutil.FilterInvalidHistoryMessages(session.Messages) session.Messages = messageutil.FilterInvalidHistoryMessages(session.Messages)
normalizeHistoryCreatedAt(session.Messages)
sm.sessions[session.Key] = &session sm.sessions[session.Key] = &session
} }
@@ -305,13 +319,8 @@ func (sm *SessionManager) SetHistory(key string, history []providers.Message) {
// from the caller's slice. // from the caller's slice.
msgs := make([]providers.Message, len(history)) msgs := make([]providers.Message, len(history))
copy(msgs, history) copy(msgs, history)
now := time.Now() normalizeHistoryCreatedAt(msgs)
for i := range msgs {
if msgs[i].CreatedAt == nil {
msgs[i].CreatedAt = &now
}
}
session.Messages = msgs session.Messages = msgs
session.Updated = now session.Updated = time.Now()
} }
} }
+29
View File
@@ -83,3 +83,32 @@ func TestSave_RejectsPathTraversal(t *testing.T) {
t.Errorf("expected foo_bar.json in storage (sanitized from foo/bar)") t.Errorf("expected foo_bar.json in storage (sanitized from foo/bar)")
} }
} }
func TestLoadSessions_NormalizesMissingCreatedAt(t *testing.T) {
tmpDir := t.TempDir()
sessionPath := filepath.Join(tmpDir, "telegram_legacy.json")
legacy := `{
"key": "telegram:legacy",
"messages": [
{
"role": "user",
"content": "hello"
}
],
"created": "2026-01-01T00:00:00Z",
"updated": "2026-01-01T00:00:00Z"
}`
if err := os.WriteFile(sessionPath, []byte(legacy), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
sm := NewSessionManager(tmpDir)
history := sm.GetHistory("telegram:legacy")
if len(history) != 1 {
t.Fatalf("history = %d, want 1", len(history))
}
if history[0].CreatedAt == nil || history[0].CreatedAt.IsZero() {
t.Fatalf("history[0].CreatedAt = %v, want non-zero timestamp", history[0].CreatedAt)
}
}
@@ -13,6 +13,10 @@ import rehypeSanitize from "rehype-sanitize"
import remarkGfm from "remark-gfm" import remarkGfm from "remark-gfm"
import type { SkillDetailResponse, SkillSupportItem } from "@/api/skills" import type { SkillDetailResponse, SkillSupportItem } from "@/api/skills"
import {
MarkdownCodeBlock,
MessageCodeBlock,
} from "@/components/chat/message-code-block"
import { import {
Sheet, Sheet,
SheetContent, SheetContent,
@@ -176,6 +180,9 @@ export function DetailSheet({
<ReactMarkdown <ReactMarkdown
remarkPlugins={[remarkGfm]} remarkPlugins={[remarkGfm]}
rehypePlugins={[rehypeRaw, rehypeSanitize, rehypeHighlight]} rehypePlugins={[rehypeRaw, rehypeSanitize, rehypeHighlight]}
components={{
pre: MarkdownCodeBlock,
}}
> >
{selectedSkillDetail.content} {selectedSkillDetail.content}
</ReactMarkdown> </ReactMarkdown>
@@ -183,11 +190,12 @@ export function DetailSheet({
) : null} ) : null}
{detailView === "raw" ? ( {detailView === "raw" ? (
<div className="border-border/50 overflow-x-auto rounded-xl border bg-zinc-950 p-5 shadow-sm"> <MessageCodeBlock
<pre className="font-mono text-[13px] leading-relaxed break-words whitespace-pre-wrap text-zinc-100/90"> code={selectedSkillDetail.content}
<code>{selectedSkillDetail.content}</code> label={t("pages.agent.skills.detail_tabs.raw")}
</pre> className="my-0"
</div> bodyClassName="text-[13px] leading-relaxed"
/>
) : null} ) : null}
{detailView === "meta" ? ( {detailView === "meta" ? (
@@ -1,4 +1,5 @@
import type { ChannelConfig } from "@/api/channels" import type { ChannelConfig } from "@/api/channels"
import { MessageCodeBlock } from "@/components/chat/message-code-block"
import { getSecretInputPlaceholder } from "@/components/channels/channel-config-fields" import { getSecretInputPlaceholder } from "@/components/channels/channel-config-fields"
import { Field, KeyInput } from "@/components/shared-form" import { Field, KeyInput } from "@/components/shared-form"
import { import {
@@ -180,9 +181,12 @@ export function MqttForm({
{t("channels.mqtt.uplink")} {t("channels.mqtt.uplink")}
</p> </p>
<CodeLine>{`${topicBase}/request`}</CodeLine> <CodeLine>{`${topicBase}/request`}</CodeLine>
<pre className="bg-muted text-foreground rounded px-3 py-2 font-mono text-xs leading-relaxed"> <MessageCodeBlock
{`{\n "text": "your message"\n}`} code={`{\n "text": "your message"\n}`}
</pre> language="json"
className="my-0"
bodyClassName="px-3 py-2 text-xs leading-relaxed"
/>
<div className="text-muted-foreground space-y-1 text-xs"> <div className="text-muted-foreground space-y-1 text-xs">
<p> <p>
<span className="text-foreground font-medium"> <span className="text-foreground font-medium">
@@ -199,9 +203,12 @@ export function MqttForm({
{t("channels.mqtt.downlink")} {t("channels.mqtt.downlink")}
</p> </p>
<CodeLine>{`${topicBase}/response`}</CodeLine> <CodeLine>{`${topicBase}/response`}</CodeLine>
<pre className="bg-muted text-foreground rounded px-3 py-2 font-mono text-xs leading-relaxed"> <MessageCodeBlock
{`{\n "text": "agent response"\n}`} code={`{\n "text": "agent response"\n}`}
</pre> language="json"
className="my-0"
bodyClassName="px-3 py-2 text-xs leading-relaxed"
/>
<div className="text-muted-foreground space-y-1 text-xs"> <div className="text-muted-foreground space-y-1 text-xs">
<p> <p>
<span className="text-foreground font-medium"> <span className="text-foreground font-medium">
@@ -197,7 +197,6 @@ export function AssistantMessage({
label={toolName || t("chat.toolCallArgumentsLabel")} label={toolName || t("chat.toolCallArgumentsLabel")}
className="my-0 shadow-none" className="my-0 shadow-none"
bodyClassName="px-3 py-2 text-[12px] leading-relaxed" bodyClassName="px-3 py-2 text-[12px] leading-relaxed"
wrapLongLines
/> />
)} )}
</div> </div>
@@ -3,17 +3,30 @@ import {
IconChevronDown, IconChevronDown,
IconCopy, IconCopy,
} from "@tabler/icons-react" } from "@tabler/icons-react"
import { useAtom } from "jotai"
import hljs from "highlight.js/lib/core" import hljs from "highlight.js/lib/core"
import json from "highlight.js/lib/languages/json" import json from "highlight.js/lib/languages/json"
import { type ComponentProps, type ReactNode, useState } from "react" import {
type ComponentProps,
type CSSProperties,
type ReactNode,
useState,
} from "react"
import { useTranslation } from "react-i18next" import { useTranslation } from "react-i18next"
import { useCopyToClipboard } from "@/hooks/use-copy-to-clipboard" import { useCopyToClipboard } from "@/hooks/use-copy-to-clipboard"
import { cn } from "@/lib/utils" import { cn } from "@/lib/utils"
import { codeBlockWrapAtom } from "@/store/code-block"
import { import {
extractCodeBlockFromPreNode, extractCodeBlockFromPreNode,
extractCodeBlockRenderState,
type MarkdownNode, type MarkdownNode,
splitCodeIntoLines,
splitHighlightedHtmlIntoLines,
splitRenderedCodeContentIntoLines,
trimTrailingEmptyRenderedCodeLine,
trimTrailingEmptyStringLine,
} from "./message-code-block.utils" } from "./message-code-block.utils"
import { Button } from "@/components/ui/button" import { Button } from "@/components/ui/button"
@@ -27,10 +40,10 @@ interface MessageCodeBlockProps {
code: string code: string
language?: string | null language?: string | null
label?: string label?: string
children?: ReactNode
className?: string className?: string
bodyClassName?: string bodyClassName?: string
wrapLongLines?: boolean children?: ReactNode
trimTrailingEmptyLine?: boolean
} }
interface MarkdownCodeBlockProps extends ComponentProps<"pre"> { interface MarkdownCodeBlockProps extends ComponentProps<"pre"> {
@@ -53,13 +66,14 @@ export function MessageCodeBlock({
code, code,
language = null, language = null,
label, label,
children,
className, className,
bodyClassName, bodyClassName,
wrapLongLines = false, children,
trimTrailingEmptyLine = false,
}: MessageCodeBlockProps) { }: MessageCodeBlockProps) {
const { t } = useTranslation() const { t } = useTranslation()
const { copy, isCopied } = useCopyToClipboard() const { copy, isCopied } = useCopyToClipboard()
const [wrapLongLines, setWrapLongLines] = useAtom(codeBlockWrapAtom)
const [isExpanded, setIsExpanded] = useState(true) const [isExpanded, setIsExpanded] = useState(true)
const blockLabel = const blockLabel =
label ?? label ??
@@ -68,7 +82,31 @@ export function MessageCodeBlock({
: t("chat.codeLabel").toLocaleLowerCase()) : t("chat.codeLabel").toLocaleLowerCase())
const copyLabel = isCopied ? t("chat.copiedLabel") : t("chat.copyCode") const copyLabel = isCopied ? t("chat.copiedLabel") : t("chat.copyCode")
const expandLabel = isExpanded ? t("chat.collapseCode") : t("chat.expandCode") const expandLabel = isExpanded ? t("chat.collapseCode") : t("chat.expandCode")
const wrapLabel = wrapLongLines
? t("chat.disableCodeWrap")
: t("chat.enableCodeWrap")
const renderedCodeState = children
? extractCodeBlockRenderState(children)
: {
renderedContent: null,
className: undefined,
}
const highlightedHtml = !children ? getHighlightedHtml(code, language) : null const highlightedHtml = !children ? getHighlightedHtml(code, language) : null
const highlightedLines = highlightedHtml
? splitHighlightedHtmlIntoLines(highlightedHtml)
: null
const codeLines = children
? (trimTrailingEmptyLine
? trimTrailingEmptyRenderedCodeLine(
splitRenderedCodeContentIntoLines(renderedCodeState.renderedContent),
)
: splitRenderedCodeContentIntoLines(renderedCodeState.renderedContent))
: (trimTrailingEmptyLine
? trimTrailingEmptyStringLine(
highlightedLines ?? splitCodeIntoLines(code),
)
: (highlightedLines ?? splitCodeIntoLines(code)))
const lineNumberWidth = `${String(codeLines.length).length + 1}ch`
return ( return (
<div <div
@@ -102,6 +140,18 @@ export function MessageCodeBlock({
)} )}
<span className="hidden sm:inline">{copyLabel}</span> <span className="hidden sm:inline">{copyLabel}</span>
</Button> </Button>
<Button
type="button"
variant="ghost"
size="xs"
className="h-7 px-2 text-[11px] text-zinc-600 hover:bg-zinc-300/70 hover:text-zinc-900 dark:text-zinc-400 dark:hover:bg-zinc-800 dark:hover:text-zinc-100"
onClick={() => setWrapLongLines((current) => !current)}
aria-pressed={wrapLongLines}
aria-label={wrapLabel}
title={wrapLabel}
>
{wrapLabel}
</Button>
<Button <Button
type="button" type="button"
variant="ghost" variant="ghost"
@@ -123,23 +173,56 @@ export function MessageCodeBlock({
{isExpanded && ( {isExpanded && (
<pre <pre
className={cn( className={cn(
"m-0 overflow-x-auto bg-transparent px-4 py-3 font-mono text-[13px] leading-6 [&_code]:block [&_code]:bg-transparent [&_code]:p-0 [&_code]:text-inherit", "m-0 overflow-x-auto bg-transparent px-4 py-3 font-mono text-[13px] leading-6",
wrapLongLines ? "break-words whitespace-pre-wrap" : "whitespace-pre",
bodyClassName, bodyClassName,
)} )}
> >
{children ?? ( <code
highlightedHtml ? ( className={cn(
<code "block bg-transparent p-0 text-inherit",
className={cn("hljs", language && `language-${language}`)} children
dangerouslySetInnerHTML={{ __html: highlightedHtml }} ? renderedCodeState.className
/> : cn(highlightedHtml && "hljs", language && `language-${language}`),
) : ( )}
<code className={language ? `language-${language}` : undefined}> >
{code} {codeLines.map((line, index) => (
</code> <span
) key={`${index}-${line.length}`}
)} className="grid grid-cols-[var(--code-line-number-width)_minmax(0,1fr)] items-start gap-x-3"
style={
{
"--code-line-number-width": lineNumberWidth,
} as CSSProperties
}
>
<span className="sticky left-0 z-1 select-none bg-[#f6f8fa] text-right text-zinc-500/80 dark:bg-[#0d1117] dark:text-zinc-500">
{index + 1}
</span>
{!children && highlightedLines ? (
<span
className={cn(
"min-w-0",
wrapLongLines
? "break-words whitespace-pre-wrap"
: "whitespace-pre",
)}
dangerouslySetInnerHTML={{ __html: line }}
/>
) : (
<span
className={cn(
"min-w-0",
wrapLongLines
? "break-words whitespace-pre-wrap"
: "whitespace-pre",
)}
>
{line}
</span>
)}
</span>
))}
</code>
</pre> </pre>
)} )}
</div> </div>
@@ -158,6 +241,7 @@ export function MarkdownCodeBlock({
code={code} code={code}
language={language} language={language}
bodyClassName={className} bodyClassName={className}
trimTrailingEmptyLine
> >
{children} {children}
</MessageCodeBlock> </MessageCodeBlock>
@@ -1,3 +1,11 @@
import {
Children,
cloneElement,
Fragment,
isValidElement,
type ReactNode,
} from "react"
export interface MarkdownNode { export interface MarkdownNode {
type?: string type?: string
value?: string value?: string
@@ -6,7 +14,7 @@ export interface MarkdownNode {
children?: MarkdownNode[] children?: MarkdownNode[]
} }
function toClassNameTokens(className: unknown): string[] { export function toClassNameTokens(className: unknown): string[] {
if (typeof className === "string") { if (typeof className === "string") {
return className.split(/\s+/).filter(Boolean) return className.split(/\s+/).filter(Boolean)
} }
@@ -72,6 +80,10 @@ export function extractCodeBlockLanguage(className: unknown): string | null {
return languageToken ? languageToken.slice("language-".length) : null return languageToken ? languageToken.slice("language-".length) : null
} }
export function stripSingleTrailingLineBreak(value: string): string {
return value.replace(/\r?\n$/, "")
}
export function extractCodeBlockFromPreNode(node: MarkdownNode | undefined): { export function extractCodeBlockFromPreNode(node: MarkdownNode | undefined): {
code: string code: string
language: string | null language: string | null
@@ -79,7 +91,248 @@ export function extractCodeBlockFromPreNode(node: MarkdownNode | undefined): {
const codeNode = findFirstDescendantByTagName(node, "code") const codeNode = findFirstDescendantByTagName(node, "code")
return { return {
code: extractTextFromMarkdownNode(codeNode ?? node), code: stripSingleTrailingLineBreak(extractTextFromMarkdownNode(codeNode ?? node)),
language: extractCodeBlockLanguage(codeNode?.properties?.className), language: extractCodeBlockLanguage(codeNode?.properties?.className),
} }
} }
export function extractCodeBlockRenderState(children: ReactNode): {
renderedContent: ReactNode
className: string | undefined
} {
const childNodes = Children.toArray(children)
const codeChild = childNodes.find(
(child) =>
isValidElement<{ children?: ReactNode; className?: unknown }>(child) &&
typeof child.type === "string" &&
child.type === "code",
)
if (
isValidElement<{ children?: ReactNode; className?: unknown }>(codeChild)
) {
const classNameTokens = toClassNameTokens(codeChild.props.className)
return {
renderedContent: codeChild.props.children,
className:
classNameTokens.length > 0 ? classNameTokens.join(" ") : undefined,
}
}
return {
renderedContent: children,
className: undefined,
}
}
function mergeNodeLineGroups(
currentLines: Node[][],
nextLines: Node[][],
): Node[][] {
if (nextLines.length === 0) {
return currentLines
}
const mergedLines = currentLines.map((line) => [...line])
mergedLines[mergedLines.length - 1].push(...nextLines[0])
for (const line of nextLines.slice(1)) {
mergedLines.push([...line])
}
return mergedLines
}
function splitDomNodeIntoLines(node: Node, ownerDocument: Document): Node[][] {
if (node.nodeType === Node.TEXT_NODE) {
return (node.textContent ?? "").split("\n").map((line) =>
line.length > 0 ? [ownerDocument.createTextNode(line)] : [],
)
}
if (node.nodeType !== Node.ELEMENT_NODE) {
return [[]]
}
const element = node as Element
if (element.tagName.toLowerCase() === "br") {
return [
[],
[],
]
}
const childLines = splitHighlightedHtmlIntoNodeLines(
Array.from(element.childNodes),
ownerDocument,
)
return childLines.map((lineChildren) => {
const clonedElement = element.cloneNode(false)
for (const child of lineChildren) {
clonedElement.appendChild(child)
}
return [clonedElement]
})
}
function splitHighlightedHtmlIntoNodeLines(
nodes: Node[],
ownerDocument: Document,
): Node[][] {
let lines: Node[][] = [[]]
for (const node of nodes) {
lines = mergeNodeLineGroups(
lines,
splitDomNodeIntoLines(node, ownerDocument),
)
}
return lines
}
export function splitCodeIntoLines(code: string): string[] {
return code.split("\n")
}
export function splitHighlightedHtmlIntoLines(highlightedHtml: string): string[] {
if (typeof document === "undefined") {
return splitCodeIntoLines(highlightedHtml)
}
const container = document.createElement("div")
container.innerHTML = highlightedHtml
return splitHighlightedHtmlIntoNodeLines(
Array.from(container.childNodes),
document,
).map((lineNodes) => {
const lineContainer = document.createElement("div")
for (const node of lineNodes) {
lineContainer.appendChild(node)
}
return lineContainer.innerHTML
})
}
export function trimTrailingEmptyStringLine(lines: string[]): string[] {
if (lines.length > 1 && lines[lines.length - 1] === "") {
return lines.slice(0, -1)
}
return lines
}
function isEmptyRenderedCodeNode(node: ReactNode): boolean {
if (node === null || node === undefined || typeof node === "boolean") {
return true
}
if (typeof node === "string" || typeof node === "number") {
return String(node).length === 0
}
if (Array.isArray(node)) {
return node.every(isEmptyRenderedCodeNode)
}
if (!isValidElement<{ children?: ReactNode }>(node)) {
return false
}
return Children.toArray(node.props.children).every(isEmptyRenderedCodeNode)
}
export function trimTrailingEmptyRenderedCodeLine(
lines: ReactNode[][],
): ReactNode[][] {
if (
lines.length > 1 &&
lines[lines.length - 1].every(isEmptyRenderedCodeNode)
) {
return lines.slice(0, -1)
}
return lines
}
function mergeReactLineGroups(
currentLines: ReactNode[][],
nextLines: ReactNode[][],
): ReactNode[][] {
if (nextLines.length === 0) {
return currentLines
}
const mergedLines = currentLines.map((line) => [...line])
mergedLines[mergedLines.length - 1].push(...nextLines[0])
for (const line of nextLines.slice(1)) {
mergedLines.push([...line])
}
return mergedLines
}
function splitTextNodeIntoLines(value: string | number): ReactNode[][] {
return String(value).split("\n").map((line) => (line.length > 0 ? [line] : []))
}
function splitReactNodeIntoLines(node: ReactNode): ReactNode[][] {
if (node === null || node === undefined || typeof node === "boolean") {
return [[]]
}
if (typeof node === "string" || typeof node === "number") {
return splitTextNodeIntoLines(node)
}
if (Array.isArray(node)) {
return splitRenderedCodeContentIntoLines(node)
}
if (!isValidElement<{ children?: ReactNode }>(node)) {
return [[node]]
}
if (node.type === Fragment) {
return splitRenderedCodeContentIntoLines(Children.toArray(node.props.children))
}
if (typeof node.type === "string" && node.type === "br") {
return [
[],
[],
]
}
const childLines = splitRenderedCodeContentIntoLines(
Children.toArray(node.props.children),
)
return childLines.map((lineChildren, lineIndex) => [
cloneElement(
node,
{
key: `${node.key ?? "code-line"}-${lineIndex}`,
},
...lineChildren,
),
])
}
export function splitRenderedCodeContentIntoLines(
content: ReactNode,
): ReactNode[][] {
const contentNodes = Array.isArray(content) ? content : [content]
let lines: ReactNode[][] = [[]]
for (const node of contentNodes) {
lines = mergeReactLineGroups(lines, splitReactNodeIntoLines(node))
}
return lines
}
+2
View File
@@ -76,6 +76,8 @@
"copyMessage": "Copy message", "copyMessage": "Copy message",
"copyCode": "Copy code", "copyCode": "Copy code",
"copiedLabel": "Copied", "copiedLabel": "Copied",
"enableCodeWrap": "Wrap lines",
"disableCodeWrap": "Disable wrap",
"expandCode": "Expand code", "expandCode": "Expand code",
"collapseCode": "Collapse code", "collapseCode": "Collapse code",
"history": "History", "history": "History",
+2
View File
@@ -76,6 +76,8 @@
"copyMessage": "Copiar mensagem", "copyMessage": "Copiar mensagem",
"copyCode": "Copiar código", "copyCode": "Copiar código",
"copiedLabel": "Copiado", "copiedLabel": "Copiado",
"enableCodeWrap": "Quebrar linhas",
"disableCodeWrap": "Desativar quebra",
"expandCode": "Expandir código", "expandCode": "Expandir código",
"collapseCode": "Recolher código", "collapseCode": "Recolher código",
"history": "Histórico", "history": "Histórico",
+2
View File
@@ -76,6 +76,8 @@
"copyMessage": "复制消息", "copyMessage": "复制消息",
"copyCode": "复制代码", "copyCode": "复制代码",
"copiedLabel": "已复制", "copiedLabel": "已复制",
"enableCodeWrap": "开启换行",
"disableCodeWrap": "关闭换行",
"expandCode": "展开代码", "expandCode": "展开代码",
"collapseCode": "折叠代码", "collapseCode": "折叠代码",
"history": "历史记录", "history": "历史记录",
+11
View File
@@ -0,0 +1,11 @@
import { atomWithStorage } from "jotai/utils"
export const CODE_BLOCK_WRAP_STORAGE_KEY = "picoclaw:code-block-wrap"
export const DEFAULT_CODE_BLOCK_WRAP = false
export const codeBlockWrapAtom = atomWithStorage<boolean>(
CODE_BLOCK_WRAP_STORAGE_KEY,
DEFAULT_CODE_BLOCK_WRAP,
undefined,
{ getOnInit: true },
)
+1
View File
@@ -1,3 +1,4 @@
export * from "./gateway" export * from "./gateway"
export * from "./chat" export * from "./chat"
export * from "./code-block"
export * from "./tour" export * from "./tour"