mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
359 lines
8.2 KiB
Go
359 lines
8.2 KiB
Go
// PicoClaw - Ultra-lightweight personal AI agent
|
|
// License: MIT
|
|
//
|
|
// Copyright (c) 2026 PicoClaw contributors
|
|
|
|
package config
|
|
|
|
import (
|
|
"encoding/json"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
)
|
|
|
|
func TestGetModelConfig_Found(t *testing.T) {
|
|
cfg := (&Config{
|
|
Version: CurrentVersion,
|
|
ModelList: []*ModelConfig{
|
|
{ModelName: "test-model", Model: "openai/gpt-4o"},
|
|
{ModelName: "other-model", Model: "anthropic/claude"},
|
|
},
|
|
}).WithSecurity(&SecurityConfig{ModelList: map[string]ModelSecurityEntry{
|
|
"test-model:0": {
|
|
APIKeys: []string{"key1"},
|
|
},
|
|
"other-model:0": {
|
|
APIKeys: []string{"key2"},
|
|
},
|
|
}})
|
|
|
|
result, err := cfg.GetModelConfig("test-model")
|
|
if err != nil {
|
|
t.Fatalf("GetModelConfig() error = %v", err)
|
|
}
|
|
if result.Model != "openai/gpt-4o" {
|
|
t.Errorf("Model = %q, want %q", result.Model, "openai/gpt-4o")
|
|
}
|
|
}
|
|
|
|
func TestGetModelConfig_NotFound(t *testing.T) {
|
|
cfg := (&Config{
|
|
ModelList: []*ModelConfig{
|
|
{ModelName: "test-model", Model: "openai/gpt-4o"},
|
|
},
|
|
}).WithSecurity(&SecurityConfig{
|
|
ModelList: map[string]ModelSecurityEntry{
|
|
"test-model:0": {
|
|
APIKeys: []string{"key1"},
|
|
},
|
|
},
|
|
})
|
|
|
|
_, err := cfg.GetModelConfig("nonexistent")
|
|
if err == nil {
|
|
t.Fatal("GetModelConfig() expected error for nonexistent model")
|
|
}
|
|
}
|
|
|
|
func TestGetModelConfig_EmptyList(t *testing.T) {
|
|
cfg := &Config{
|
|
ModelList: []*ModelConfig{},
|
|
}
|
|
|
|
_, err := cfg.GetModelConfig("any-model")
|
|
if err == nil {
|
|
t.Fatal("GetModelConfig() expected error for empty model list")
|
|
}
|
|
}
|
|
|
|
func TestGetModelConfig_RoundRobin(t *testing.T) {
|
|
cfg := (&Config{
|
|
ModelList: []*ModelConfig{
|
|
{ModelName: "lb-model", Model: "openai/gpt-4o-1"},
|
|
{ModelName: "lb-model", Model: "openai/gpt-4o-2"},
|
|
{ModelName: "lb-model", Model: "openai/gpt-4o-3"},
|
|
},
|
|
}).WithSecurity(&SecurityConfig{
|
|
ModelList: map[string]ModelSecurityEntry{
|
|
"lb-model:0": {
|
|
APIKeys: []string{"key1"},
|
|
},
|
|
"lb-model:1": {
|
|
APIKeys: []string{"key2"},
|
|
},
|
|
"lb-model:2": {
|
|
APIKeys: []string{"key3"},
|
|
},
|
|
},
|
|
})
|
|
|
|
// Test round-robin distribution
|
|
results := make(map[string]int)
|
|
for range 30 {
|
|
result, err := cfg.GetModelConfig("lb-model")
|
|
if err != nil {
|
|
t.Fatalf("GetModelConfig() error = %v", err)
|
|
}
|
|
results[result.Model]++
|
|
}
|
|
|
|
// Each model should appear roughly 10 times (30 calls / 3 models)
|
|
for model, count := range results {
|
|
if count < 5 || count > 15 {
|
|
t.Errorf("Model %s appeared %d times, expected ~10", model, count)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGetModelConfig_RoundRobinStartsFromFirstMatch(t *testing.T) {
|
|
rrCounter.Store(0)
|
|
|
|
cfg := &Config{
|
|
ModelList: []*ModelConfig{
|
|
{ModelName: "lb-model", Model: "openai/gpt-4o-1", apiKeys: []string{"key1"}},
|
|
{ModelName: "lb-model", Model: "openai/gpt-4o-2", apiKeys: []string{"key2"}},
|
|
{ModelName: "lb-model", Model: "openai/gpt-4o-3", apiKeys: []string{"key3"}},
|
|
},
|
|
}
|
|
|
|
wantOrder := []string{
|
|
"openai/gpt-4o-1",
|
|
"openai/gpt-4o-2",
|
|
"openai/gpt-4o-3",
|
|
"openai/gpt-4o-1",
|
|
"openai/gpt-4o-2",
|
|
}
|
|
|
|
for i, want := range wantOrder {
|
|
result, err := cfg.GetModelConfig("lb-model")
|
|
if err != nil {
|
|
t.Fatalf("GetModelConfig() call %d error = %v", i, err)
|
|
}
|
|
if result.Model != want {
|
|
t.Fatalf("GetModelConfig() call %d model = %q, want %q", i, result.Model, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGetModelConfig_Concurrent(t *testing.T) {
|
|
cfg := &Config{
|
|
ModelList: []*ModelConfig{
|
|
{ModelName: "concurrent-model", Model: "openai/gpt-4o-1", apiKeys: []string{"key1"}},
|
|
{ModelName: "concurrent-model", Model: "openai/gpt-4o-2", apiKeys: []string{"key2"}},
|
|
},
|
|
}
|
|
|
|
const goroutines = 100
|
|
const iterations = 10
|
|
|
|
var wg sync.WaitGroup
|
|
errors := make(chan error, goroutines*iterations)
|
|
|
|
for range goroutines {
|
|
wg.Go(func() {
|
|
for range iterations {
|
|
_, err := cfg.GetModelConfig("concurrent-model")
|
|
if err != nil {
|
|
errors <- err
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errors)
|
|
|
|
for err := range errors {
|
|
t.Errorf("Concurrent GetModelConfig() error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAgentDefaultsV0_JSON_BackwardCompat(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
json string
|
|
wantName string
|
|
}{
|
|
{
|
|
name: "new model_name field",
|
|
json: `{"model_name": "gpt4"}`,
|
|
wantName: "gpt4",
|
|
},
|
|
{
|
|
name: "old model field",
|
|
json: `{"model": "gpt4"}`,
|
|
wantName: "gpt4",
|
|
},
|
|
{
|
|
name: "both fields - model_name wins",
|
|
json: `{"model_name": "new", "model": "old"}`,
|
|
wantName: "new",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var defaults agentDefaultsV0
|
|
if err := json.Unmarshal([]byte(tt.json), &defaults); err != nil {
|
|
t.Fatalf("Unmarshal error: %v", err)
|
|
}
|
|
if got := defaults.GetModelName(); got != tt.wantName {
|
|
t.Errorf("GetModelName() = %q, want %q", got, tt.wantName)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestModelConfig_Validate(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config ModelConfig
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "valid config",
|
|
config: ModelConfig{
|
|
ModelName: "test",
|
|
Model: "openai/gpt-4o",
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "missing model_name",
|
|
config: ModelConfig{
|
|
Model: "openai/gpt-4o",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "missing model",
|
|
config: ModelConfig{
|
|
ModelName: "test",
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "empty config",
|
|
config: ModelConfig{},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := tt.config.Validate()
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConfig_ValidateModelList(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config *Config
|
|
wantErr bool
|
|
errMsg string // partial error message to check
|
|
}{
|
|
{
|
|
name: "valid list",
|
|
config: &Config{
|
|
ModelList: []*ModelConfig{
|
|
{ModelName: "test1", Model: "openai/gpt-4o"},
|
|
{ModelName: "test2", Model: "anthropic/claude"},
|
|
},
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid entry",
|
|
config: &Config{
|
|
ModelList: []*ModelConfig{
|
|
{ModelName: "test1", Model: "openai/gpt-4o"},
|
|
{ModelName: "", Model: "anthropic/claude"}, // missing model_name
|
|
},
|
|
},
|
|
wantErr: true,
|
|
errMsg: "model_name is required",
|
|
},
|
|
{
|
|
name: "empty list",
|
|
config: &Config{
|
|
ModelList: []*ModelConfig{},
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
// Load balancing: multiple entries with same model_name are allowed
|
|
name: "duplicate model_name for load balancing",
|
|
config: &Config{
|
|
ModelList: []*ModelConfig{},
|
|
},
|
|
wantErr: false, // Changed: duplicates are allowed for load balancing
|
|
},
|
|
{
|
|
// Load balancing: non-adjacent entries with same model_name are also allowed
|
|
name: "duplicate model_name non-adjacent for load balancing",
|
|
config: &Config{
|
|
ModelList: []*ModelConfig{
|
|
{ModelName: "model-a", Model: "openai/gpt-4o"},
|
|
{ModelName: "model-b", Model: "anthropic/claude"},
|
|
{ModelName: "model-a", Model: "openai/gpt-4-turbo"},
|
|
},
|
|
},
|
|
wantErr: false, // Changed: duplicates are allowed for load balancing
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := tt.config.ValidateModelList()
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("ValidateModelList() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
if err != nil && tt.errMsg != "" {
|
|
if !strings.Contains(err.Error(), tt.errMsg) {
|
|
t.Errorf("ValidateModelList() error = %v, want error containing %q", err, tt.errMsg)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestModelConfig_RequestTimeoutParsing(t *testing.T) {
|
|
jsonData := `{
|
|
"model_name": "slow-local",
|
|
"model": "openai/local-model",
|
|
"api_base": "http://localhost:11434/v1",
|
|
"request_timeout": 300
|
|
}`
|
|
|
|
var cfg ModelConfig
|
|
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
|
t.Fatalf("Unmarshal() error = %v", err)
|
|
}
|
|
|
|
if cfg.RequestTimeout != 300 {
|
|
t.Fatalf("RequestTimeout = %d, want 300", cfg.RequestTimeout)
|
|
}
|
|
}
|
|
|
|
func TestModelConfig_RequestTimeoutDefaultZeroValue(t *testing.T) {
|
|
jsonData := `{
|
|
"model_name": "default-timeout",
|
|
"model": "openai/gpt-4o",
|
|
"api_key": "test-key"
|
|
}`
|
|
|
|
var cfg ModelConfig
|
|
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
|
t.Fatalf("Unmarshal() error = %v", err)
|
|
}
|
|
|
|
if cfg.RequestTimeout != 0 {
|
|
t.Fatalf("RequestTimeout = %d, want 0", cfg.RequestTimeout)
|
|
}
|
|
}
|