fix(api): enhance model availability probing with backoff and caching mechanisms (#2231)

* fix(api): enhance model availability probing with backoff and caching mechanisms

* fix(lint): resolve gci and predeclared issues in model probe

* fix(api): address copilot review feedback on probe cache key and test stability

* fix(api): reduce probe cache key fragmentation
This commit is contained in:
LC
2026-04-01 14:15:28 +08:00
committed by GitHub
parent 0f395ce110
commit f327859cce
3 changed files with 612 additions and 8 deletions
+301 -8
View File
@@ -1,19 +1,36 @@
package api
import (
"context"
"encoding/json"
"fmt"
"hash/fnv"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/sync/singleflight"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
const modelProbeTimeout = 800 * time.Millisecond
const (
modelProbeTimeout = 800 * time.Millisecond
modelProbeSuccessBaseInterval = 2 * time.Second
modelProbeSuccessMaxInterval = 60 * time.Second
modelProbeFailureBaseInterval = 1 * time.Second
modelProbeFailureMaxInterval = 30 * time.Second
modelProbeBackoffMaxShift = 8
modelProbeCacheMaxEntries = 1024
modelProbeCacheEntryTTL = 30 * time.Minute
modelProbeCacheTrimToEntries = modelProbeCacheMaxEntries * 8 / 10
modelProbeTTLGCInterval = 1 * time.Minute
)
const (
modelStatusAvailable = "available"
@@ -30,8 +47,41 @@ var (
probeTCPServiceFunc = probeTCPService
probeOllamaModelFunc = probeOllamaModel
probeOpenAICompatibleModelFunc = probeOpenAICompatibleModel
modelProbeNowFunc = time.Now
modelProbeState = newModelProbeCacheState()
)
type modelProbeCacheState struct {
mu sync.RWMutex
cache map[string]*modelProbeCacheEntry
group singleflight.Group
nextTTLGCAt time.Time
}
type modelProbeCacheEntry struct {
lastResult bool
hasResult bool
successStreak int
failureStreak int
nextProbeAt time.Time
updatedAt time.Time
}
func newModelProbeCacheState() *modelProbeCacheState {
return &modelProbeCacheState{cache: map[string]*modelProbeCacheEntry{}}
}
func resetModelProbeCache() {
modelProbeState.resetForTest()
}
func (s *modelProbeCacheState) resetForTest() {
s.mu.Lock()
defer s.mu.Unlock()
s.cache = map[string]*modelProbeCacheEntry{}
s.nextTTLGCAt = time.Time{}
}
func hasModelConfiguration(m *config.ModelConfig) bool {
authMethod := strings.ToLower(strings.TrimSpace(m.AuthMethod))
apiKey := strings.TrimSpace(m.APIKey())
@@ -93,6 +143,34 @@ func requiresRuntimeProbe(m *config.ModelConfig) bool {
}
func probeLocalModelAvailability(m *config.ModelConfig) bool {
cacheKey := modelProbeCacheKey(m)
return modelProbeState.probe(cacheKey, func() bool {
return runLocalModelProbe(m)
})
}
func (s *modelProbeCacheState) probe(cacheKey string, probeFunc func() bool) bool {
now := modelProbeNowFunc()
if cachedResult, ok := s.getCachedResult(cacheKey, now); ok {
return cachedResult
}
v, _, _ := s.group.Do(cacheKey, func() (any, error) {
now = modelProbeNowFunc()
if cachedResult, ok := s.getCachedResult(cacheKey, now); ok {
return cachedResult, nil
}
result := probeFunc()
s.setCachedResult(cacheKey, result, now)
return result, nil
})
result, _ := v.(bool)
return result
}
func runLocalModelProbe(m *config.ModelConfig) bool {
apiBase := modelProbeAPIBase(m)
protocol, modelID := splitModel(m.Model)
switch protocol {
@@ -112,6 +190,195 @@ func probeLocalModelAvailability(m *config.ModelConfig) bool {
}
}
func modelProbeCacheKey(m *config.ModelConfig) string {
protocol, modelID := splitModel(m.Model)
apiBaseRaw := modelProbeAPIBase(m)
apiBase := strings.ToLower(strings.TrimRight(strings.TrimSpace(apiBaseRaw), "/"))
apiKeyFingerprint := modelProbeAPIKeyFingerprint(m.APIKey())
var b strings.Builder
b.Grow(len(protocol) + len(modelID) + len(apiBase) + len(apiKeyFingerprint) + 8)
b.WriteString(protocol)
b.WriteByte('|')
b.WriteString(modelID)
b.WriteByte('|')
b.WriteString(apiBase)
b.WriteByte('|')
b.WriteString(apiKeyFingerprint)
return b.String()
}
func modelProbeAPIKeyFingerprint(raw string) string {
apiKey := strings.TrimSpace(raw)
if apiKey == "" {
return "none"
}
h := fnv.New64a()
_, _ = h.Write([]byte(apiKey))
return strconv.FormatUint(h.Sum64(), 36)
}
func (s *modelProbeCacheState) getCachedResult(cacheKey string, now time.Time) (bool, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
entry, ok := s.cache[cacheKey]
if !ok || !entry.hasResult {
return false, false
}
if now.Before(entry.nextProbeAt) {
return entry.lastResult, true
}
return false, false
}
func (s *modelProbeCacheState) setCachedResult(cacheKey string, result bool, now time.Time) {
s.mu.Lock()
entry, ok := s.cache[cacheKey]
if !ok {
entry = &modelProbeCacheEntry{}
s.cache[cacheKey] = entry
}
entry.lastResult = result
entry.hasResult = true
entry.updatedAt = now
var delay time.Duration
if result {
entry.successStreak++
entry.failureStreak = 0
delay = modelProbeBackoffDelay(
modelProbeSuccessBaseInterval,
modelProbeSuccessMaxInterval,
entry.successStreak,
)
} else {
entry.failureStreak++
entry.successStreak = 0
delay = modelProbeBackoffDelay(
modelProbeFailureBaseInterval,
modelProbeFailureMaxInterval,
entry.failureStreak,
)
}
entry.nextProbeAt = now.Add(delay)
shouldRunTTLGC := modelProbeCacheEntryTTL > 0 && (s.nextTTLGCAt.IsZero() || !now.Before(s.nextTTLGCAt))
if shouldRunTTLGC {
s.nextTTLGCAt = now.Add(modelProbeTTLGCInterval)
}
shouldRunSizeGC := len(s.cache) > modelProbeCacheMaxEntries
s.mu.Unlock()
if shouldRunTTLGC || shouldRunSizeGC {
s.gc(now, shouldRunTTLGC)
}
}
func (s *modelProbeCacheState) gc(now time.Time, runTTL bool) {
type evictionCandidate struct {
key string
updatedAt time.Time
}
var expireBefore time.Time
if runTTL && modelProbeCacheEntryTTL > 0 {
expireBefore = now.Add(-modelProbeCacheEntryTTL)
}
s.mu.RLock()
cacheLen := len(s.cache)
if cacheLen == 0 {
s.mu.RUnlock()
return
}
expiredKeys := make([]string, 0)
if !expireBefore.IsZero() {
expiredKeys = make([]string, 0, min(cacheLen/8+1, 64))
for key, entry := range s.cache {
if entry.updatedAt.Before(expireBefore) {
expiredKeys = append(expiredKeys, key)
}
}
}
effectiveLen := cacheLen - len(expiredKeys)
removeCount := max(effectiveLen-modelProbeCacheTrimToEntries, 0)
candidates := make([]evictionCandidate, 0)
if removeCount > 0 {
candidates = make([]evictionCandidate, 0, effectiveLen)
for key, entry := range s.cache {
if !expireBefore.IsZero() && entry.updatedAt.Before(expireBefore) {
continue
}
candidates = append(candidates, evictionCandidate{key: key, updatedAt: entry.updatedAt})
}
}
s.mu.RUnlock()
if len(expiredKeys) == 0 && len(candidates) == 0 {
return
}
toEvict := map[string]time.Time{}
for i := 0; i < removeCount && len(candidates) > 0; i++ {
oldest := 0
for j := 1; j < len(candidates); j++ {
if candidates[j].updatedAt.Before(candidates[oldest].updatedAt) {
oldest = j
}
}
victim := candidates[oldest]
toEvict[victim.key] = victim.updatedAt
candidates[oldest] = candidates[len(candidates)-1]
candidates = candidates[:len(candidates)-1]
}
s.mu.Lock()
defer s.mu.Unlock()
if !expireBefore.IsZero() {
for _, key := range expiredKeys {
entry, ok := s.cache[key]
if ok && entry.updatedAt.Before(expireBefore) {
delete(s.cache, key)
}
}
}
for key, victimUpdatedAt := range toEvict {
entry, ok := s.cache[key]
if ok && !entry.updatedAt.After(victimUpdatedAt) {
delete(s.cache, key)
}
}
}
func modelProbeBackoffDelay(base, maxDelay time.Duration, streak int) time.Duration {
if streak <= 0 {
streak = 1
}
shift := min(streak-1, modelProbeBackoffMaxShift)
delay := base * time.Duration(1<<shift)
if maxDelay > 0 && (delay > maxDelay || delay < 0) {
return maxDelay
}
if delay <= 0 {
return base
}
return delay
}
func modelProbeAPIBase(m *config.ModelConfig) string {
if apiBase := strings.TrimSpace(m.APIBase); apiBase != "" {
return normalizeModelProbeAPIBase(apiBase)
@@ -207,7 +474,11 @@ func probeTCPService(raw string) bool {
return false
}
conn, err := net.DialTimeout("tcp", hostPort, modelProbeTimeout)
ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout)
defer cancel()
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", hostPort)
if err != nil {
return false
}
@@ -262,7 +533,10 @@ func probeOpenAICompatibleModel(apiBase, modelID, apiKey string) bool {
}
func getJSON(rawURL string, out any, apiKey string) error {
req, err := http.NewRequest(http.MethodGet, rawURL, nil)
ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return err
}
@@ -270,7 +544,7 @@ func getJSON(rawURL string, out any, apiKey string) error {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
client := &http.Client{Timeout: modelProbeTimeout}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
@@ -336,10 +610,29 @@ func ollamaModelMatches(candidate, want string) bool {
if candidate == "" || want == "" {
return false
}
if strings.EqualFold(candidate, want) {
return true
candidateBase, candidateTag := splitOllamaModel(candidate)
wantBase, wantTag := splitOllamaModel(want)
if candidateBase == "" || wantBase == "" {
return false
}
base, _, _ := strings.Cut(candidate, ":")
return strings.EqualFold(base, want)
if candidateTag == "" {
candidateTag = "latest"
}
if wantTag == "" {
wantTag = "latest"
}
return strings.EqualFold(candidateBase, wantBase) && strings.EqualFold(candidateTag, wantTag)
}
func splitOllamaModel(raw string) (base, tag string) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", ""
}
base, tag, _ = strings.Cut(raw, ":")
return strings.TrimSpace(base), strings.TrimSpace(tag)
}
+307
View File
@@ -3,7 +3,10 @@ package api
import (
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -85,3 +88,307 @@ func TestProbeLocalModelAvailability_LMStudioUsesOpenAICompatibleProbe(t *testin
t.Fatal("probeOpenAICompatibleModelFunc was not called for lmstudio")
}
}
func TestModelProbeCacheKey_DifferentAPIKeysProduceDifferentKeys(t *testing.T) {
base := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
AuthMethod: "local",
ConnectMode: "",
}
m1 := *base
m1.SetAPIKey("key-a")
m2 := *base
m2.SetAPIKey("key-b")
k1 := modelProbeCacheKey(&m1)
k2 := modelProbeCacheKey(&m2)
if k1 == k2 {
t.Fatal("modelProbeCacheKey() should differ when api key changes")
}
}
func TestModelProbeCacheKey_NormalizesTrailingSlashInAPIBase(t *testing.T) {
m1 := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
}
m2 := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1/",
}
k1 := modelProbeCacheKey(m1)
k2 := modelProbeCacheKey(m2)
if k1 != k2 {
t.Fatalf("modelProbeCacheKey() mismatch for equivalent api_base values: %q vs %q", k1, k2)
}
}
func TestModelProbeCacheKey_IgnoresDisplayAndConnectionFields(t *testing.T) {
base := &config.ModelConfig{
ModelName: "vllm-one",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
AuthMethod: "none",
ConnectMode: "http",
}
changed := &config.ModelConfig{
ModelName: "vllm-two",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
AuthMethod: "token",
ConnectMode: "ws",
}
k1 := modelProbeCacheKey(base)
k2 := modelProbeCacheKey(changed)
if k1 != k2 {
t.Fatalf("modelProbeCacheKey() should ignore non-probe fields, got %q vs %q", k1, k2)
}
}
func TestProbeLocalModelAvailability_SuccessBackoff(t *testing.T) {
resetModelProbeHooks(t)
now := time.Unix(1700000000, 0)
modelProbeNowFunc = func() time.Time { return now }
calls := 0
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
calls++
return true
}
model := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
}
if !probeLocalModelAvailability(model) {
t.Fatal("first probe result = false, want true")
}
if calls != 1 {
t.Fatalf("probe calls after first probe = %d, want 1", calls)
}
if !probeLocalModelAvailability(model) {
t.Fatal("cached probe result = false, want true")
}
if calls != 1 {
t.Fatalf("probe calls after immediate re-check = %d, want 1", calls)
}
now = now.Add(modelProbeSuccessBaseInterval)
if !probeLocalModelAvailability(model) {
t.Fatal("second probe result = false, want true")
}
if calls != 2 {
t.Fatalf("probe calls after success backoff window = %d, want 2", calls)
}
now = now.Add(modelProbeSuccessBaseInterval)
if !probeLocalModelAvailability(model) {
t.Fatal("cached result after doubled backoff = false, want true")
}
if calls != 2 {
t.Fatalf("probe calls before doubled backoff expires = %d, want 2", calls)
}
now = now.Add(modelProbeSuccessBaseInterval)
if !probeLocalModelAvailability(model) {
t.Fatal("third probe result = false, want true")
}
if calls != 3 {
t.Fatalf("probe calls after doubled backoff expires = %d, want 3", calls)
}
}
func TestProbeLocalModelAvailability_FailureBackoff(t *testing.T) {
resetModelProbeHooks(t)
now := time.Unix(1700000100, 0)
modelProbeNowFunc = func() time.Time { return now }
calls := 0
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
calls++
return false
}
model := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
}
if probeLocalModelAvailability(model) {
t.Fatal("first probe result = true, want false")
}
if calls != 1 {
t.Fatalf("probe calls after first failure = %d, want 1", calls)
}
if probeLocalModelAvailability(model) {
t.Fatal("cached failed probe result = true, want false")
}
if calls != 1 {
t.Fatalf("probe calls after immediate failed re-check = %d, want 1", calls)
}
now = now.Add(modelProbeFailureBaseInterval)
if probeLocalModelAvailability(model) {
t.Fatal("second failed probe result = true, want false")
}
if calls != 2 {
t.Fatalf("probe calls after failure backoff window = %d, want 2", calls)
}
now = now.Add(modelProbeFailureBaseInterval)
if probeLocalModelAvailability(model) {
t.Fatal("cached failure after doubled backoff = true, want false")
}
if calls != 2 {
t.Fatalf("probe calls before doubled failure backoff expires = %d, want 2", calls)
}
now = now.Add(modelProbeFailureBaseInterval)
if probeLocalModelAvailability(model) {
t.Fatal("third failed probe result = true, want false")
}
if calls != 3 {
t.Fatalf("probe calls after doubled failure backoff expires = %d, want 3", calls)
}
}
func TestProbeLocalModelAvailability_ResultFlipResetsBackoff(t *testing.T) {
resetModelProbeHooks(t)
now := time.Unix(1700000200, 0)
modelProbeNowFunc = func() time.Time { return now }
results := []bool{true, false, false}
index := 0
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
if index >= len(results) {
return false
}
result := results[index]
index++
return result
}
model := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
}
if !probeLocalModelAvailability(model) {
t.Fatal("first probe result = false, want true")
}
now = now.Add(modelProbeSuccessBaseInterval)
if probeLocalModelAvailability(model) {
t.Fatal("second probe result = true, want false")
}
now = now.Add(modelProbeFailureBaseInterval)
if probeLocalModelAvailability(model) {
t.Fatal("third probe result = true, want false")
}
if index != 3 {
t.Fatalf("probe invocations = %d, want 3", index)
}
}
func TestProbeLocalModelAvailability_DeduplicatesInflightProbe(t *testing.T) {
resetModelProbeHooks(t)
now := time.Unix(1700000300, 0)
modelProbeNowFunc = func() time.Time { return now }
var calls int32
probeStarted := make(chan struct{})
releaseProbe := make(chan struct{})
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
if atomic.AddInt32(&calls, 1) == 1 {
close(probeStarted)
}
<-releaseProbe
return true
}
model := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
}
const workers = 8
var wg sync.WaitGroup
results := make(chan bool, workers)
workerStarted := make(chan struct{}, workers)
for range workers {
wg.Add(1)
go func() {
defer wg.Done()
workerStarted <- struct{}{}
results <- probeLocalModelAvailability(model)
}()
}
for range workers {
<-workerStarted
}
select {
case <-probeStarted:
case <-time.After(200 * time.Millisecond):
t.Fatal("probe did not start in time")
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("concurrent probe calls = %d, want 1", got)
}
close(releaseProbe)
wg.Wait()
close(results)
for result := range results {
if !result {
t.Fatal("deduplicated probe result = false, want true")
}
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("final probe calls = %d, want 1", got)
}
}
func TestOllamaModelMatches_WithTagRequiresExactTag(t *testing.T) {
if ollamaModelMatches("llama3:8b", "llama3:7b") {
t.Fatal("ollamaModelMatches() = true, want false for mismatched tags")
}
if !ollamaModelMatches("llama3:7b", "llama3:7b") {
t.Fatal("ollamaModelMatches() = false, want true for exact tagged match")
}
if ollamaModelMatches("llama3:8b", "llama3") {
t.Fatal("ollamaModelMatches() = true, want false when request omits tag (defaults to latest)")
}
if !ollamaModelMatches("llama3:latest", "llama3") {
t.Fatal("ollamaModelMatches() = false, want true when request omits tag and candidate is latest")
}
if !ollamaModelMatches("llama3", "llama3") {
t.Fatal("ollamaModelMatches() = false, want true when both candidate and request omit tag (latest)")
}
}
+4
View File
@@ -20,10 +20,14 @@ func resetModelProbeHooks(t *testing.T) {
origTCPProbe := probeTCPServiceFunc
origOllamaProbe := probeOllamaModelFunc
origOpenAIProbe := probeOpenAICompatibleModelFunc
origNow := modelProbeNowFunc
resetModelProbeCache()
t.Cleanup(func() {
probeTCPServiceFunc = origTCPProbe
probeOllamaModelFunc = origOllamaProbe
probeOpenAICompatibleModelFunc = origOpenAIProbe
modelProbeNowFunc = origNow
resetModelProbeCache()
})
}