feat(model): llm rate limiting (#2198)

* feat(model): rate limiting

* fix(agent): preserve per-model identity in rate limiting and fallback

* fix test
This commit is contained in:
Mauro
2026-04-02 13:26:26 +02:00
committed by GitHub
parent dad5dcc30f
commit b114dcaeb1
9 changed files with 821 additions and 78 deletions
+52
View File
@@ -165,6 +165,58 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
}
}
func TestNewAgentInstance_PreservesDistinctLimiterIdentityForSharedResolvedModel(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "glm-4.7",
ModelFallbacks: []string{"glm-4.7__key_1"},
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "glm-4.7",
Model: "zhipu/glm-4.7",
RPM: 1,
},
{
ModelName: "glm-4.7__key_1",
Model: "zhipu/glm-4.7",
RPM: 3,
},
},
}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
if len(agent.Candidates) != 2 {
t.Fatalf("len(Candidates) = %d, want 2", len(agent.Candidates))
}
first := agent.Candidates[0]
second := agent.Candidates[1]
if first.Provider != "zhipu" || first.Model != "glm-4.7" {
t.Fatalf("first candidate = %s/%s, want zhipu/glm-4.7", first.Provider, first.Model)
}
if second.Provider != "zhipu" || second.Model != "glm-4.7" {
t.Fatalf("second candidate = %s/%s, want zhipu/glm-4.7", second.Provider, second.Model)
}
if first.IdentityKey != "model_name:glm-4.7" {
t.Fatalf("first identity key = %q, want %q", first.IdentityKey, "model_name:glm-4.7")
}
if second.IdentityKey != "model_name:glm-4.7__key_1" {
t.Fatalf("second identity key = %q, want %q", second.IdentityKey, "model_name:glm-4.7__key_1")
}
if first.RPM != 1 {
t.Fatalf("first RPM = %d, want 1", first.RPM)
}
if second.RPM != 3 {
t.Fatalf("second RPM = %d, want 3", second.RPM)
}
}
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
workspace := t.TempDir()
mediaDir := media.TempDir()
+21 -5
View File
@@ -119,9 +119,18 @@ func NewAgentLoop(
) *AgentLoop {
registry := NewAgentRegistry(cfg, provider)
// Set up shared fallback chain
// Set up shared fallback chain with rate limiting.
cooldown := providers.NewCooldownTracker()
fallbackChain := providers.NewFallbackChain(cooldown)
rl := providers.NewRateLimiterRegistry()
// Register rate limiters for all agents' candidates so that RPM limits
// configured in ModelConfig are enforced before each LLM call.
for _, agentID := range registry.ListAgentIDs() {
if agent, ok := registry.GetAgent(agentID); ok {
rl.RegisterCandidates(agent.Candidates)
rl.RegisterCandidates(agent.LightCandidates)
}
}
fallbackChain := providers.NewFallbackChain(cooldown, rl)
// Create state manager using default agent's workspace for channel recording
defaultAgent := registry.GetDefaultAgent()
@@ -1032,8 +1041,15 @@ func (al *AgentLoop) ReloadProviderAndConfig(
al.cfg = cfg
al.registry = registry
// Also update fallback chain with new config
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker())
// Also update fallback chain with new config; rebuild rate limiter registry.
newRL := providers.NewRateLimiterRegistry()
for _, agentID := range registry.ListAgentIDs() {
if agent, ok := registry.GetAgent(agentID); ok {
newRL.RegisterCandidates(agent.Candidates)
newRL.RegisterCandidates(agent.LightCandidates)
}
}
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), newRL)
al.mu.Unlock()
@@ -3229,7 +3245,7 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
}
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, value, agent.Fallbacks)
if len(nextCandidates) == 0 {
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
}
+116 -43
View File
@@ -8,44 +8,102 @@ import (
"github.com/sipeed/picoclaw/pkg/providers"
)
func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) {
ensureProtocol := func(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
if strings.Contains(model, "/") {
return model
}
return "openai/" + model
func ensureProtocolModel(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
if strings.Contains(model, "/") {
return model
}
return "openai/" + model
}
func modelConfigIdentityKey(mc *config.ModelConfig) string {
if mc == nil {
return ""
}
if name := strings.TrimSpace(mc.ModelName); name != "" {
return "model_name:" + name
}
return ""
}
func candidateFromModelConfig(
defaultProvider string,
mc *config.ModelConfig,
) (providers.FallbackCandidate, bool) {
if mc == nil {
return providers.FallbackCandidate{}, false
}
return func(raw string) (string, bool) {
raw = strings.TrimSpace(raw)
if raw == "" || cfg == nil {
return "", false
}
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
return ensureProtocol(mc.Model), true
}
for i := range cfg.ModelList {
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
if fullModel == "" {
continue
}
if fullModel == raw {
return ensureProtocol(fullModel), true
}
_, modelID := providers.ExtractProtocol(fullModel)
if modelID == raw {
return ensureProtocol(fullModel), true
}
}
return "", false
ref := providers.ParseModelRef(ensureProtocolModel(mc.Model), defaultProvider)
if ref == nil {
return providers.FallbackCandidate{}, false
}
return providers.FallbackCandidate{
Provider: ref.Provider,
Model: ref.Model,
RPM: mc.RPM,
IdentityKey: modelConfigIdentityKey(mc),
}, true
}
func lookupModelConfigByRef(cfg *config.Config, raw string) *config.ModelConfig {
raw = strings.TrimSpace(raw)
if raw == "" || cfg == nil {
return nil
}
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
return mc
}
for i := range cfg.ModelList {
mc := cfg.ModelList[i]
if mc == nil {
continue
}
fullModel := strings.TrimSpace(mc.Model)
if fullModel == "" {
continue
}
if fullModel == raw {
return mc
}
_, modelID := providers.ExtractProtocol(fullModel)
if modelID == raw {
return mc
}
}
return nil
}
func resolveModelCandidate(
cfg *config.Config,
defaultProvider string,
raw string,
) (providers.FallbackCandidate, bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return providers.FallbackCandidate{}, false
}
if mc := lookupModelConfigByRef(cfg, raw); mc != nil {
return candidateFromModelConfig(defaultProvider, mc)
}
ref := providers.ParseModelRef(raw, defaultProvider)
if ref == nil {
return providers.FallbackCandidate{}, false
}
return providers.FallbackCandidate{
Provider: ref.Provider,
Model: ref.Model,
}, true
}
func resolveModelCandidates(
@@ -54,14 +112,29 @@ func resolveModelCandidates(
primary string,
fallbacks []string,
) []providers.FallbackCandidate {
return providers.ResolveCandidatesWithLookup(
providers.ModelConfig{
Primary: primary,
Fallbacks: fallbacks,
},
defaultProvider,
buildModelListResolver(cfg),
)
seen := make(map[string]bool)
candidates := make([]providers.FallbackCandidate, 0, 1+len(fallbacks))
addCandidate := func(raw string) {
candidate, ok := resolveModelCandidate(cfg, defaultProvider, raw)
if !ok {
return
}
key := candidate.StableKey()
if seen[key] {
return
}
seen[key] = true
candidates = append(candidates, candidate)
}
addCandidate(primary)
for _, fallback := range fallbacks {
addCandidate(fallback)
}
return candidates
}
func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string {