Merge branch 'main' into fix/gemini-mcp-schema-sanitization

This commit is contained in:
Mauro
2026-04-27 21:14:25 +02:00
committed by GitHub
24 changed files with 1380 additions and 101 deletions
+178
View File
@@ -2,6 +2,7 @@ package api
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
@@ -10,7 +11,9 @@ import (
"net/http"
"os"
"os/exec"
"reflect"
"runtime"
"sort"
"strconv"
"strings"
"sync"
@@ -431,6 +434,10 @@ func computeConfigSignature(cfg *config.Config) string {
}
if cfg.Tools.Web.Enabled {
toolSignatures = append(toolSignatures, "web")
webConfig, err := json.Marshal(canonicalizeSignatureValue(reflect.ValueOf(cfg.Tools.Web)))
if err == nil {
parts = append(parts, "webcfg:"+string(webConfig))
}
}
if cfg.Tools.WebFetch.Enabled {
toolSignatures = append(toolSignatures, "web_fetch")
@@ -474,9 +481,175 @@ func computeConfigSignature(cfg *config.Config) string {
if len(toolSignatures) > 0 {
parts = append(parts, "tools:"+strings.Join(toolSignatures, ","))
}
channelSignatures := computeChannelSignatures(cfg.Channels)
if len(channelSignatures) > 0 {
parts = append(parts, "channels:"+strings.Join(channelSignatures, ","))
}
return strings.Join(parts, ";")
}
func computeChannelSignatures(channels config.ChannelsConfig) []string {
if len(channels) == 0 {
return nil
}
keys := make([]string, 0, len(channels))
for name := range channels {
keys = append(keys, name)
}
sort.Strings(keys)
signatures := make([]string, 0, len(keys))
for _, name := range keys {
channel := channels[name]
if channel == nil {
signatures = append(signatures, name+":<nil>")
continue
}
payload := struct {
Enabled bool `json:"enabled"`
Type string `json:"type"`
AllowFrom config.FlexibleStringSlice `json:"allow_from,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id,omitempty"`
GroupTrigger config.GroupTriggerConfig `json:"group_trigger,omitempty"`
Typing config.TypingConfig `json:"typing,omitempty"`
Placeholder config.PlaceholderConfig `json:"placeholder,omitempty"`
Settings json.RawMessage `json:"settings,omitempty"`
}{
Enabled: channel.Enabled,
Type: channel.Type,
AllowFrom: channel.AllowFrom,
ReasoningChannelID: channel.ReasoningChannelID,
GroupTrigger: channel.GroupTrigger,
Typing: channel.Typing,
Placeholder: channel.Placeholder,
Settings: normalizeChannelSettings(channel),
}
encoded, err := json.Marshal(payload)
if err != nil {
signatures = append(signatures, name+":<invalid>")
continue
}
signatures = append(signatures, name+":"+string(encoded))
}
return signatures
}
func normalizeChannelSettings(channel *config.Channel) json.RawMessage {
if channel == nil {
return nil
}
decoded, err := channel.GetDecoded()
if err == nil && decoded != nil {
normalized, err := json.Marshal(canonicalizeSignatureValue(reflect.ValueOf(decoded)))
if err == nil {
return normalized
}
}
return normalizeRawJSON(channel.Settings)
}
func normalizeRawJSON(raw config.RawNode) json.RawMessage {
if len(raw) == 0 {
return nil
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return bytes.TrimSpace(raw)
}
normalized, err := json.Marshal(value)
if err != nil {
return bytes.TrimSpace(raw)
}
return normalized
}
func canonicalizeSignatureValue(value reflect.Value) any {
if !value.IsValid() {
return nil
}
if value.CanInterface() {
switch typed := value.Interface().(type) {
case config.SecureString:
return typed.String()
case *config.SecureString:
if typed == nil {
return ""
}
return typed.String()
case config.SecureStrings:
return typed.Values()
case *config.SecureStrings:
if typed == nil {
return nil
}
return typed.Values()
}
}
switch value.Kind() {
case reflect.Interface, reflect.Pointer:
if value.IsNil() {
return nil
}
return canonicalizeSignatureValue(value.Elem())
case reflect.Struct:
result := make(map[string]any)
valueType := value.Type()
for i := 0; i < value.NumField(); i++ {
field := valueType.Field(i)
if field.PkgPath != "" {
continue
}
tag := field.Tag.Get("json")
name := field.Name
if tag != "" {
if comma := strings.Index(tag, ","); comma >= 0 {
tag = tag[:comma]
}
if tag == "-" {
continue
}
if tag != "" {
name = tag
}
}
result[name] = canonicalizeSignatureValue(value.Field(i))
}
return result
case reflect.Slice, reflect.Array:
length := value.Len()
result := make([]any, 0, length)
for i := 0; i < length; i++ {
result = append(result, canonicalizeSignatureValue(value.Index(i)))
}
return result
case reflect.Map:
if value.Type().Key().Kind() != reflect.String {
return value.Interface()
}
result := make(map[string]any, value.Len())
iter := value.MapRange()
for iter.Next() {
result[iter.Key().String()] = canonicalizeSignatureValue(iter.Value())
}
return result
default:
if value.CanInterface() {
return value.Interface()
}
return nil
}
}
func gatewayRestartRequiredBySignature(bootSignature, currentSignature, gatewayStatus string) bool {
if gatewayStatus != "running" {
return false
@@ -742,6 +915,11 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
// Already holding gateway.mu from caller.
if changed {
refreshPicoTokensLocked(h.configPath)
cfg, err = config.LoadConfig(h.configPath)
if err != nil {
return 0, fmt.Errorf("failed to reload config after ensuring pico channel: %w", err)
}
defaultModelName = strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
}
if err := cmd.Start(); err != nil {
+185
View File
@@ -286,6 +286,61 @@ func TestStartGatewayLocked_ForwardsWildcardHostForPublicLauncher(t *testing.T)
}
}
func TestStartGatewayLocked_UsesReloadedConfigForBootSignature(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("sleep command differs on Windows")
}
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
delete(cfg.Channels, "pico")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
h.SetServerOptions(18800, false, false, nil)
gatewayExecCommand = func(_ string, _ ...string) *exec.Cmd {
return exec.Command("sleep", "30")
}
originalSignature := computeConfigSignature(cfg)
pid, err := h.startGatewayLocked("starting", 0)
if err != nil {
t.Fatalf("startGatewayLocked() error = %v", err)
}
if pid <= 0 {
t.Fatalf("startGatewayLocked() pid = %d, want > 0", pid)
}
gateway.mu.Lock()
cmd := gateway.cmd
bootSignature := gateway.bootConfigSignature
gateway.mu.Unlock()
t.Cleanup(func() {
if cmd != nil && cmd.Process != nil {
_ = cmd.Process.Kill()
}
if cmd != nil {
_ = cmd.Wait()
}
})
updatedCfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
expectedSignature := computeConfigSignature(updatedCfg)
if expectedSignature == originalSignature {
t.Fatal("expected EnsurePicoChannel() to change the config signature during gateway start")
}
if bootSignature != expectedSignature {
t.Fatalf("bootConfigSignature = %q, want %q", bootSignature, expectedSignature)
}
}
func TestGatewayStartReady_NoDefaultModel(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
@@ -1108,6 +1163,136 @@ func TestGatewayStatusRequiresRestartAfterToolChange(t *testing.T) {
}
}
func TestGatewayStatusRequiresRestartAfterChannelChange(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].SetAPIKey("test-key")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
process, err := os.FindProcess(os.Getpid())
if err != nil {
t.Fatalf("FindProcess() error = %v", err)
}
bootSignature := computeConfigSignature(cfg)
gateway.mu.Lock()
gateway.cmd = &exec.Cmd{Process: process}
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
gateway.bootConfigSignature = bootSignature
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
updatedCfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
telegram := updatedCfg.Channels.Get("telegram")
if telegram == nil {
t.Fatalf("expected default telegram channel config")
}
telegram.Enabled = !telegram.Enabled
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_status"]; got != "running" {
t.Fatalf("gateway_status = %#v, want %q", got, "running")
}
if got := body["gateway_restart_required"]; got != true {
t.Fatalf("gateway_restart_required = %#v, want true", got)
}
}
func TestGatewayStatusRequiresRestartAfterWebSearchConfigChange(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].SetAPIKey("test-key")
cfg.Tools.Web.Enabled = true
cfg.Tools.Web.Provider = "sogou"
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
process, err := os.FindProcess(os.Getpid())
if err != nil {
t.Fatalf("FindProcess() error = %v", err)
}
bootSignature := computeConfigSignature(cfg)
gateway.mu.Lock()
gateway.cmd = &exec.Cmd{Process: process}
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
gateway.bootConfigSignature = bootSignature
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
updatedCfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
updatedCfg.Tools.Web.Provider = "duckduckgo"
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_status"]; got != "running" {
t.Fatalf("gateway_status = %#v, want %q", got, "running")
}
if got := body["gateway_restart_required"]; got != true {
t.Fatalf("gateway_restart_required = %#v, want true", got)
}
}
func TestGatewayStatusNoRestartRequiredForNonSensitiveChanges(t *testing.T) {
resetGatewayTestState(t)