mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into fix/gemini-mcp-schema-sanitization
This commit is contained in:
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -10,7 +11,9 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -431,6 +434,10 @@ func computeConfigSignature(cfg *config.Config) string {
|
||||
}
|
||||
if cfg.Tools.Web.Enabled {
|
||||
toolSignatures = append(toolSignatures, "web")
|
||||
webConfig, err := json.Marshal(canonicalizeSignatureValue(reflect.ValueOf(cfg.Tools.Web)))
|
||||
if err == nil {
|
||||
parts = append(parts, "webcfg:"+string(webConfig))
|
||||
}
|
||||
}
|
||||
if cfg.Tools.WebFetch.Enabled {
|
||||
toolSignatures = append(toolSignatures, "web_fetch")
|
||||
@@ -474,9 +481,175 @@ func computeConfigSignature(cfg *config.Config) string {
|
||||
if len(toolSignatures) > 0 {
|
||||
parts = append(parts, "tools:"+strings.Join(toolSignatures, ","))
|
||||
}
|
||||
channelSignatures := computeChannelSignatures(cfg.Channels)
|
||||
if len(channelSignatures) > 0 {
|
||||
parts = append(parts, "channels:"+strings.Join(channelSignatures, ","))
|
||||
}
|
||||
return strings.Join(parts, ";")
|
||||
}
|
||||
|
||||
func computeChannelSignatures(channels config.ChannelsConfig) []string {
|
||||
if len(channels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(channels))
|
||||
for name := range channels {
|
||||
keys = append(keys, name)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
signatures := make([]string, 0, len(keys))
|
||||
for _, name := range keys {
|
||||
channel := channels[name]
|
||||
if channel == nil {
|
||||
signatures = append(signatures, name+":<nil>")
|
||||
continue
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type string `json:"type"`
|
||||
AllowFrom config.FlexibleStringSlice `json:"allow_from,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id,omitempty"`
|
||||
GroupTrigger config.GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
Typing config.TypingConfig `json:"typing,omitempty"`
|
||||
Placeholder config.PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
Settings json.RawMessage `json:"settings,omitempty"`
|
||||
}{
|
||||
Enabled: channel.Enabled,
|
||||
Type: channel.Type,
|
||||
AllowFrom: channel.AllowFrom,
|
||||
ReasoningChannelID: channel.ReasoningChannelID,
|
||||
GroupTrigger: channel.GroupTrigger,
|
||||
Typing: channel.Typing,
|
||||
Placeholder: channel.Placeholder,
|
||||
Settings: normalizeChannelSettings(channel),
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
signatures = append(signatures, name+":<invalid>")
|
||||
continue
|
||||
}
|
||||
signatures = append(signatures, name+":"+string(encoded))
|
||||
}
|
||||
|
||||
return signatures
|
||||
}
|
||||
|
||||
func normalizeChannelSettings(channel *config.Channel) json.RawMessage {
|
||||
if channel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
decoded, err := channel.GetDecoded()
|
||||
if err == nil && decoded != nil {
|
||||
normalized, err := json.Marshal(canonicalizeSignatureValue(reflect.ValueOf(decoded)))
|
||||
if err == nil {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
|
||||
return normalizeRawJSON(channel.Settings)
|
||||
}
|
||||
|
||||
func normalizeRawJSON(raw config.RawNode) json.RawMessage {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return bytes.TrimSpace(raw)
|
||||
}
|
||||
|
||||
normalized, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return bytes.TrimSpace(raw)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func canonicalizeSignatureValue(value reflect.Value) any {
|
||||
if !value.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if value.CanInterface() {
|
||||
switch typed := value.Interface().(type) {
|
||||
case config.SecureString:
|
||||
return typed.String()
|
||||
case *config.SecureString:
|
||||
if typed == nil {
|
||||
return ""
|
||||
}
|
||||
return typed.String()
|
||||
case config.SecureStrings:
|
||||
return typed.Values()
|
||||
case *config.SecureStrings:
|
||||
if typed == nil {
|
||||
return nil
|
||||
}
|
||||
return typed.Values()
|
||||
}
|
||||
}
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Interface, reflect.Pointer:
|
||||
if value.IsNil() {
|
||||
return nil
|
||||
}
|
||||
return canonicalizeSignatureValue(value.Elem())
|
||||
case reflect.Struct:
|
||||
result := make(map[string]any)
|
||||
valueType := value.Type()
|
||||
for i := 0; i < value.NumField(); i++ {
|
||||
field := valueType.Field(i)
|
||||
if field.PkgPath != "" {
|
||||
continue
|
||||
}
|
||||
tag := field.Tag.Get("json")
|
||||
name := field.Name
|
||||
if tag != "" {
|
||||
if comma := strings.Index(tag, ","); comma >= 0 {
|
||||
tag = tag[:comma]
|
||||
}
|
||||
if tag == "-" {
|
||||
continue
|
||||
}
|
||||
if tag != "" {
|
||||
name = tag
|
||||
}
|
||||
}
|
||||
result[name] = canonicalizeSignatureValue(value.Field(i))
|
||||
}
|
||||
return result
|
||||
case reflect.Slice, reflect.Array:
|
||||
length := value.Len()
|
||||
result := make([]any, 0, length)
|
||||
for i := 0; i < length; i++ {
|
||||
result = append(result, canonicalizeSignatureValue(value.Index(i)))
|
||||
}
|
||||
return result
|
||||
case reflect.Map:
|
||||
if value.Type().Key().Kind() != reflect.String {
|
||||
return value.Interface()
|
||||
}
|
||||
result := make(map[string]any, value.Len())
|
||||
iter := value.MapRange()
|
||||
for iter.Next() {
|
||||
result[iter.Key().String()] = canonicalizeSignatureValue(iter.Value())
|
||||
}
|
||||
return result
|
||||
default:
|
||||
if value.CanInterface() {
|
||||
return value.Interface()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func gatewayRestartRequiredBySignature(bootSignature, currentSignature, gatewayStatus string) bool {
|
||||
if gatewayStatus != "running" {
|
||||
return false
|
||||
@@ -742,6 +915,11 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
// Already holding gateway.mu from caller.
|
||||
if changed {
|
||||
refreshPicoTokensLocked(h.configPath)
|
||||
cfg, err = config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to reload config after ensuring pico channel: %w", err)
|
||||
}
|
||||
defaultModelName = strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
|
||||
@@ -286,6 +286,61 @@ func TestStartGatewayLocked_ForwardsWildcardHostForPublicLauncher(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGatewayLocked_UsesReloadedConfigForBootSignature(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("sleep command differs on Windows")
|
||||
}
|
||||
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
delete(cfg.Channels, "pico")
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, false, false, nil)
|
||||
gatewayExecCommand = func(_ string, _ ...string) *exec.Cmd {
|
||||
return exec.Command("sleep", "30")
|
||||
}
|
||||
|
||||
originalSignature := computeConfigSignature(cfg)
|
||||
pid, err := h.startGatewayLocked("starting", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("startGatewayLocked() error = %v", err)
|
||||
}
|
||||
if pid <= 0 {
|
||||
t.Fatalf("startGatewayLocked() pid = %d, want > 0", pid)
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
cmd := gateway.cmd
|
||||
bootSignature := gateway.bootConfigSignature
|
||||
gateway.mu.Unlock()
|
||||
t.Cleanup(func() {
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
if cmd != nil {
|
||||
_ = cmd.Wait()
|
||||
}
|
||||
})
|
||||
|
||||
updatedCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
expectedSignature := computeConfigSignature(updatedCfg)
|
||||
if expectedSignature == originalSignature {
|
||||
t.Fatal("expected EnsurePicoChannel() to change the config signature during gateway start")
|
||||
}
|
||||
if bootSignature != expectedSignature {
|
||||
t.Fatalf("bootConfigSignature = %q, want %q", bootSignature, expectedSignature)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_NoDefaultModel(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
@@ -1108,6 +1163,136 @@ func TestGatewayStatusRequiresRestartAfterToolChange(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusRequiresRestartAfterChannelChange(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].SetAPIKey("test-key")
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
process, err := os.FindProcess(os.Getpid())
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess() error = %v", err)
|
||||
}
|
||||
|
||||
bootSignature := computeConfigSignature(cfg)
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = &exec.Cmd{Process: process}
|
||||
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
|
||||
gateway.bootConfigSignature = bootSignature
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
updatedCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
telegram := updatedCfg.Channels.Get("telegram")
|
||||
if telegram == nil {
|
||||
t.Fatalf("expected default telegram channel config")
|
||||
}
|
||||
telegram.Enabled = !telegram.Enabled
|
||||
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
if got := body["gateway_restart_required"]; got != true {
|
||||
t.Fatalf("gateway_restart_required = %#v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusRequiresRestartAfterWebSearchConfigChange(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].SetAPIKey("test-key")
|
||||
cfg.Tools.Web.Enabled = true
|
||||
cfg.Tools.Web.Provider = "sogou"
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
process, err := os.FindProcess(os.Getpid())
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess() error = %v", err)
|
||||
}
|
||||
|
||||
bootSignature := computeConfigSignature(cfg)
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = &exec.Cmd{Process: process}
|
||||
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
|
||||
gateway.bootConfigSignature = bootSignature
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
updatedCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
updatedCfg.Tools.Web.Provider = "duckduckgo"
|
||||
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
if got := body["gateway_restart_required"]; got != true {
|
||||
t.Fatalf("gateway_restart_required = %#v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusNoRestartRequiredForNonSensitiveChanges(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user