diff --git a/web/backend/api/model_status.go b/web/backend/api/model_status.go index 160c4d257..98bd501f5 100644 --- a/web/backend/api/model_status.go +++ b/web/backend/api/model_status.go @@ -1,19 +1,36 @@ package api import ( + "context" "encoding/json" "fmt" + "hash/fnv" "net" "net/http" "net/url" + "strconv" "strings" + "sync" "time" + "golang.org/x/sync/singleflight" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" ) -const modelProbeTimeout = 800 * time.Millisecond +const ( + modelProbeTimeout = 800 * time.Millisecond + modelProbeSuccessBaseInterval = 2 * time.Second + modelProbeSuccessMaxInterval = 60 * time.Second + modelProbeFailureBaseInterval = 1 * time.Second + modelProbeFailureMaxInterval = 30 * time.Second + modelProbeBackoffMaxShift = 8 + modelProbeCacheMaxEntries = 1024 + modelProbeCacheEntryTTL = 30 * time.Minute + modelProbeCacheTrimToEntries = modelProbeCacheMaxEntries * 8 / 10 + modelProbeTTLGCInterval = 1 * time.Minute +) const ( modelStatusAvailable = "available" @@ -30,8 +47,41 @@ var ( probeTCPServiceFunc = probeTCPService probeOllamaModelFunc = probeOllamaModel probeOpenAICompatibleModelFunc = probeOpenAICompatibleModel + modelProbeNowFunc = time.Now + modelProbeState = newModelProbeCacheState() ) +type modelProbeCacheState struct { + mu sync.RWMutex + cache map[string]*modelProbeCacheEntry + group singleflight.Group + nextTTLGCAt time.Time +} + +type modelProbeCacheEntry struct { + lastResult bool + hasResult bool + successStreak int + failureStreak int + nextProbeAt time.Time + updatedAt time.Time +} + +func newModelProbeCacheState() *modelProbeCacheState { + return &modelProbeCacheState{cache: map[string]*modelProbeCacheEntry{}} +} + +func resetModelProbeCache() { + modelProbeState.resetForTest() +} + +func (s *modelProbeCacheState) resetForTest() { + s.mu.Lock() + defer s.mu.Unlock() + s.cache = map[string]*modelProbeCacheEntry{} + s.nextTTLGCAt = time.Time{} +} + func hasModelConfiguration(m *config.ModelConfig) bool { authMethod := strings.ToLower(strings.TrimSpace(m.AuthMethod)) apiKey := strings.TrimSpace(m.APIKey()) @@ -93,6 +143,34 @@ func requiresRuntimeProbe(m *config.ModelConfig) bool { } func probeLocalModelAvailability(m *config.ModelConfig) bool { + cacheKey := modelProbeCacheKey(m) + return modelProbeState.probe(cacheKey, func() bool { + return runLocalModelProbe(m) + }) +} + +func (s *modelProbeCacheState) probe(cacheKey string, probeFunc func() bool) bool { + now := modelProbeNowFunc() + if cachedResult, ok := s.getCachedResult(cacheKey, now); ok { + return cachedResult + } + + v, _, _ := s.group.Do(cacheKey, func() (any, error) { + now = modelProbeNowFunc() + if cachedResult, ok := s.getCachedResult(cacheKey, now); ok { + return cachedResult, nil + } + + result := probeFunc() + s.setCachedResult(cacheKey, result, now) + return result, nil + }) + + result, _ := v.(bool) + return result +} + +func runLocalModelProbe(m *config.ModelConfig) bool { apiBase := modelProbeAPIBase(m) protocol, modelID := splitModel(m.Model) switch protocol { @@ -112,6 +190,195 @@ func probeLocalModelAvailability(m *config.ModelConfig) bool { } } +func modelProbeCacheKey(m *config.ModelConfig) string { + protocol, modelID := splitModel(m.Model) + + apiBaseRaw := modelProbeAPIBase(m) + apiBase := strings.ToLower(strings.TrimRight(strings.TrimSpace(apiBaseRaw), "/")) + apiKeyFingerprint := modelProbeAPIKeyFingerprint(m.APIKey()) + + var b strings.Builder + b.Grow(len(protocol) + len(modelID) + len(apiBase) + len(apiKeyFingerprint) + 8) + b.WriteString(protocol) + b.WriteByte('|') + b.WriteString(modelID) + b.WriteByte('|') + b.WriteString(apiBase) + b.WriteByte('|') + b.WriteString(apiKeyFingerprint) + + return b.String() +} + +func modelProbeAPIKeyFingerprint(raw string) string { + apiKey := strings.TrimSpace(raw) + if apiKey == "" { + return "none" + } + + h := fnv.New64a() + _, _ = h.Write([]byte(apiKey)) + return strconv.FormatUint(h.Sum64(), 36) +} + +func (s *modelProbeCacheState) getCachedResult(cacheKey string, now time.Time) (bool, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + entry, ok := s.cache[cacheKey] + if !ok || !entry.hasResult { + return false, false + } + if now.Before(entry.nextProbeAt) { + return entry.lastResult, true + } + return false, false +} + +func (s *modelProbeCacheState) setCachedResult(cacheKey string, result bool, now time.Time) { + s.mu.Lock() + + entry, ok := s.cache[cacheKey] + if !ok { + entry = &modelProbeCacheEntry{} + s.cache[cacheKey] = entry + } + + entry.lastResult = result + entry.hasResult = true + entry.updatedAt = now + + var delay time.Duration + if result { + entry.successStreak++ + entry.failureStreak = 0 + delay = modelProbeBackoffDelay( + modelProbeSuccessBaseInterval, + modelProbeSuccessMaxInterval, + entry.successStreak, + ) + } else { + entry.failureStreak++ + entry.successStreak = 0 + delay = modelProbeBackoffDelay( + modelProbeFailureBaseInterval, + modelProbeFailureMaxInterval, + entry.failureStreak, + ) + } + + entry.nextProbeAt = now.Add(delay) + + shouldRunTTLGC := modelProbeCacheEntryTTL > 0 && (s.nextTTLGCAt.IsZero() || !now.Before(s.nextTTLGCAt)) + if shouldRunTTLGC { + s.nextTTLGCAt = now.Add(modelProbeTTLGCInterval) + } + shouldRunSizeGC := len(s.cache) > modelProbeCacheMaxEntries + s.mu.Unlock() + + if shouldRunTTLGC || shouldRunSizeGC { + s.gc(now, shouldRunTTLGC) + } +} + +func (s *modelProbeCacheState) gc(now time.Time, runTTL bool) { + type evictionCandidate struct { + key string + updatedAt time.Time + } + + var expireBefore time.Time + if runTTL && modelProbeCacheEntryTTL > 0 { + expireBefore = now.Add(-modelProbeCacheEntryTTL) + } + + s.mu.RLock() + cacheLen := len(s.cache) + if cacheLen == 0 { + s.mu.RUnlock() + return + } + + expiredKeys := make([]string, 0) + if !expireBefore.IsZero() { + expiredKeys = make([]string, 0, min(cacheLen/8+1, 64)) + for key, entry := range s.cache { + if entry.updatedAt.Before(expireBefore) { + expiredKeys = append(expiredKeys, key) + } + } + } + + effectiveLen := cacheLen - len(expiredKeys) + removeCount := max(effectiveLen-modelProbeCacheTrimToEntries, 0) + + candidates := make([]evictionCandidate, 0) + if removeCount > 0 { + candidates = make([]evictionCandidate, 0, effectiveLen) + for key, entry := range s.cache { + if !expireBefore.IsZero() && entry.updatedAt.Before(expireBefore) { + continue + } + candidates = append(candidates, evictionCandidate{key: key, updatedAt: entry.updatedAt}) + } + } + s.mu.RUnlock() + + if len(expiredKeys) == 0 && len(candidates) == 0 { + return + } + + toEvict := map[string]time.Time{} + for i := 0; i < removeCount && len(candidates) > 0; i++ { + oldest := 0 + for j := 1; j < len(candidates); j++ { + if candidates[j].updatedAt.Before(candidates[oldest].updatedAt) { + oldest = j + } + } + victim := candidates[oldest] + toEvict[victim.key] = victim.updatedAt + candidates[oldest] = candidates[len(candidates)-1] + candidates = candidates[:len(candidates)-1] + } + + s.mu.Lock() + defer s.mu.Unlock() + + if !expireBefore.IsZero() { + for _, key := range expiredKeys { + entry, ok := s.cache[key] + if ok && entry.updatedAt.Before(expireBefore) { + delete(s.cache, key) + } + } + } + + for key, victimUpdatedAt := range toEvict { + entry, ok := s.cache[key] + if ok && !entry.updatedAt.After(victimUpdatedAt) { + delete(s.cache, key) + } + } +} + +func modelProbeBackoffDelay(base, maxDelay time.Duration, streak int) time.Duration { + if streak <= 0 { + streak = 1 + } + + shift := min(streak-1, modelProbeBackoffMaxShift) + + delay := base * time.Duration(1< 0 && (delay > maxDelay || delay < 0) { + return maxDelay + } + if delay <= 0 { + return base + } + return delay +} + func modelProbeAPIBase(m *config.ModelConfig) string { if apiBase := strings.TrimSpace(m.APIBase); apiBase != "" { return normalizeModelProbeAPIBase(apiBase) @@ -207,7 +474,11 @@ func probeTCPService(raw string) bool { return false } - conn, err := net.DialTimeout("tcp", hostPort, modelProbeTimeout) + ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout) + defer cancel() + + dialer := &net.Dialer{} + conn, err := dialer.DialContext(ctx, "tcp", hostPort) if err != nil { return false } @@ -262,7 +533,10 @@ func probeOpenAICompatibleModel(apiBase, modelID, apiKey string) bool { } func getJSON(rawURL string, out any, apiKey string) error { - req, err := http.NewRequest(http.MethodGet, rawURL, nil) + ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) if err != nil { return err } @@ -270,7 +544,7 @@ func getJSON(rawURL string, out any, apiKey string) error { req.Header.Set("Authorization", "Bearer "+apiKey) } - client := &http.Client{Timeout: modelProbeTimeout} + client := &http.Client{} resp, err := client.Do(req) if err != nil { return err @@ -336,10 +610,29 @@ func ollamaModelMatches(candidate, want string) bool { if candidate == "" || want == "" { return false } - if strings.EqualFold(candidate, want) { - return true + + candidateBase, candidateTag := splitOllamaModel(candidate) + wantBase, wantTag := splitOllamaModel(want) + if candidateBase == "" || wantBase == "" { + return false } - base, _, _ := strings.Cut(candidate, ":") - return strings.EqualFold(base, want) + if candidateTag == "" { + candidateTag = "latest" + } + if wantTag == "" { + wantTag = "latest" + } + + return strings.EqualFold(candidateBase, wantBase) && strings.EqualFold(candidateTag, wantTag) +} + +func splitOllamaModel(raw string) (base, tag string) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", "" + } + + base, tag, _ = strings.Cut(raw, ":") + return strings.TrimSpace(base), strings.TrimSpace(tag) } diff --git a/web/backend/api/model_status_test.go b/web/backend/api/model_status_test.go index bfeadf1fe..d5463a856 100644 --- a/web/backend/api/model_status_test.go +++ b/web/backend/api/model_status_test.go @@ -3,7 +3,10 @@ package api import ( "net/http" "net/http/httptest" + "sync" + "sync/atomic" "testing" + "time" "github.com/sipeed/picoclaw/pkg/config" ) @@ -85,3 +88,307 @@ func TestProbeLocalModelAvailability_LMStudioUsesOpenAICompatibleProbe(t *testin t.Fatal("probeOpenAICompatibleModelFunc was not called for lmstudio") } } + +func TestModelProbeCacheKey_DifferentAPIKeysProduceDifferentKeys(t *testing.T) { + base := &config.ModelConfig{ + ModelName: "local-vllm", + Model: "vllm/custom-model", + APIBase: "http://127.0.0.1:8000/v1", + AuthMethod: "local", + ConnectMode: "", + } + + m1 := *base + m1.SetAPIKey("key-a") + m2 := *base + m2.SetAPIKey("key-b") + + k1 := modelProbeCacheKey(&m1) + k2 := modelProbeCacheKey(&m2) + if k1 == k2 { + t.Fatal("modelProbeCacheKey() should differ when api key changes") + } +} + +func TestModelProbeCacheKey_NormalizesTrailingSlashInAPIBase(t *testing.T) { + m1 := &config.ModelConfig{ + ModelName: "local-vllm", + Model: "vllm/custom-model", + APIBase: "http://127.0.0.1:8000/v1", + } + m2 := &config.ModelConfig{ + ModelName: "local-vllm", + Model: "vllm/custom-model", + APIBase: "http://127.0.0.1:8000/v1/", + } + + k1 := modelProbeCacheKey(m1) + k2 := modelProbeCacheKey(m2) + if k1 != k2 { + t.Fatalf("modelProbeCacheKey() mismatch for equivalent api_base values: %q vs %q", k1, k2) + } +} + +func TestModelProbeCacheKey_IgnoresDisplayAndConnectionFields(t *testing.T) { + base := &config.ModelConfig{ + ModelName: "vllm-one", + Model: "vllm/custom-model", + APIBase: "http://127.0.0.1:8000/v1", + AuthMethod: "none", + ConnectMode: "http", + } + changed := &config.ModelConfig{ + ModelName: "vllm-two", + Model: "vllm/custom-model", + APIBase: "http://127.0.0.1:8000/v1", + AuthMethod: "token", + ConnectMode: "ws", + } + + k1 := modelProbeCacheKey(base) + k2 := modelProbeCacheKey(changed) + if k1 != k2 { + t.Fatalf("modelProbeCacheKey() should ignore non-probe fields, got %q vs %q", k1, k2) + } +} + +func TestProbeLocalModelAvailability_SuccessBackoff(t *testing.T) { + resetModelProbeHooks(t) + + now := time.Unix(1700000000, 0) + modelProbeNowFunc = func() time.Time { return now } + + calls := 0 + probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool { + calls++ + return true + } + + model := &config.ModelConfig{ + ModelName: "local-vllm", + Model: "vllm/custom-model", + APIBase: "http://127.0.0.1:8000/v1", + } + + if !probeLocalModelAvailability(model) { + t.Fatal("first probe result = false, want true") + } + if calls != 1 { + t.Fatalf("probe calls after first probe = %d, want 1", calls) + } + + if !probeLocalModelAvailability(model) { + t.Fatal("cached probe result = false, want true") + } + if calls != 1 { + t.Fatalf("probe calls after immediate re-check = %d, want 1", calls) + } + + now = now.Add(modelProbeSuccessBaseInterval) + if !probeLocalModelAvailability(model) { + t.Fatal("second probe result = false, want true") + } + if calls != 2 { + t.Fatalf("probe calls after success backoff window = %d, want 2", calls) + } + + now = now.Add(modelProbeSuccessBaseInterval) + if !probeLocalModelAvailability(model) { + t.Fatal("cached result after doubled backoff = false, want true") + } + if calls != 2 { + t.Fatalf("probe calls before doubled backoff expires = %d, want 2", calls) + } + + now = now.Add(modelProbeSuccessBaseInterval) + if !probeLocalModelAvailability(model) { + t.Fatal("third probe result = false, want true") + } + if calls != 3 { + t.Fatalf("probe calls after doubled backoff expires = %d, want 3", calls) + } +} + +func TestProbeLocalModelAvailability_FailureBackoff(t *testing.T) { + resetModelProbeHooks(t) + + now := time.Unix(1700000100, 0) + modelProbeNowFunc = func() time.Time { return now } + + calls := 0 + probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool { + calls++ + return false + } + + model := &config.ModelConfig{ + ModelName: "local-vllm", + Model: "vllm/custom-model", + APIBase: "http://127.0.0.1:8000/v1", + } + + if probeLocalModelAvailability(model) { + t.Fatal("first probe result = true, want false") + } + if calls != 1 { + t.Fatalf("probe calls after first failure = %d, want 1", calls) + } + + if probeLocalModelAvailability(model) { + t.Fatal("cached failed probe result = true, want false") + } + if calls != 1 { + t.Fatalf("probe calls after immediate failed re-check = %d, want 1", calls) + } + + now = now.Add(modelProbeFailureBaseInterval) + if probeLocalModelAvailability(model) { + t.Fatal("second failed probe result = true, want false") + } + if calls != 2 { + t.Fatalf("probe calls after failure backoff window = %d, want 2", calls) + } + + now = now.Add(modelProbeFailureBaseInterval) + if probeLocalModelAvailability(model) { + t.Fatal("cached failure after doubled backoff = true, want false") + } + if calls != 2 { + t.Fatalf("probe calls before doubled failure backoff expires = %d, want 2", calls) + } + + now = now.Add(modelProbeFailureBaseInterval) + if probeLocalModelAvailability(model) { + t.Fatal("third failed probe result = true, want false") + } + if calls != 3 { + t.Fatalf("probe calls after doubled failure backoff expires = %d, want 3", calls) + } +} + +func TestProbeLocalModelAvailability_ResultFlipResetsBackoff(t *testing.T) { + resetModelProbeHooks(t) + + now := time.Unix(1700000200, 0) + modelProbeNowFunc = func() time.Time { return now } + + results := []bool{true, false, false} + index := 0 + probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool { + if index >= len(results) { + return false + } + result := results[index] + index++ + return result + } + + model := &config.ModelConfig{ + ModelName: "local-vllm", + Model: "vllm/custom-model", + APIBase: "http://127.0.0.1:8000/v1", + } + + if !probeLocalModelAvailability(model) { + t.Fatal("first probe result = false, want true") + } + + now = now.Add(modelProbeSuccessBaseInterval) + if probeLocalModelAvailability(model) { + t.Fatal("second probe result = true, want false") + } + + now = now.Add(modelProbeFailureBaseInterval) + if probeLocalModelAvailability(model) { + t.Fatal("third probe result = true, want false") + } + + if index != 3 { + t.Fatalf("probe invocations = %d, want 3", index) + } +} + +func TestProbeLocalModelAvailability_DeduplicatesInflightProbe(t *testing.T) { + resetModelProbeHooks(t) + + now := time.Unix(1700000300, 0) + modelProbeNowFunc = func() time.Time { return now } + + var calls int32 + probeStarted := make(chan struct{}) + releaseProbe := make(chan struct{}) + + probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool { + if atomic.AddInt32(&calls, 1) == 1 { + close(probeStarted) + } + <-releaseProbe + return true + } + + model := &config.ModelConfig{ + ModelName: "local-vllm", + Model: "vllm/custom-model", + APIBase: "http://127.0.0.1:8000/v1", + } + + const workers = 8 + var wg sync.WaitGroup + results := make(chan bool, workers) + workerStarted := make(chan struct{}, workers) + + for range workers { + wg.Add(1) + go func() { + defer wg.Done() + workerStarted <- struct{}{} + results <- probeLocalModelAvailability(model) + }() + } + + for range workers { + <-workerStarted + } + + select { + case <-probeStarted: + case <-time.After(200 * time.Millisecond): + t.Fatal("probe did not start in time") + } + + if got := atomic.LoadInt32(&calls); got != 1 { + t.Fatalf("concurrent probe calls = %d, want 1", got) + } + + close(releaseProbe) + wg.Wait() + close(results) + + for result := range results { + if !result { + t.Fatal("deduplicated probe result = false, want true") + } + } + + if got := atomic.LoadInt32(&calls); got != 1 { + t.Fatalf("final probe calls = %d, want 1", got) + } +} + +func TestOllamaModelMatches_WithTagRequiresExactTag(t *testing.T) { + if ollamaModelMatches("llama3:8b", "llama3:7b") { + t.Fatal("ollamaModelMatches() = true, want false for mismatched tags") + } + if !ollamaModelMatches("llama3:7b", "llama3:7b") { + t.Fatal("ollamaModelMatches() = false, want true for exact tagged match") + } + if ollamaModelMatches("llama3:8b", "llama3") { + t.Fatal("ollamaModelMatches() = true, want false when request omits tag (defaults to latest)") + } + if !ollamaModelMatches("llama3:latest", "llama3") { + t.Fatal("ollamaModelMatches() = false, want true when request omits tag and candidate is latest") + } + if !ollamaModelMatches("llama3", "llama3") { + t.Fatal("ollamaModelMatches() = false, want true when both candidate and request omit tag (latest)") + } +} diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go index e78de1606..e54d5b77c 100644 --- a/web/backend/api/models_test.go +++ b/web/backend/api/models_test.go @@ -20,10 +20,14 @@ func resetModelProbeHooks(t *testing.T) { origTCPProbe := probeTCPServiceFunc origOllamaProbe := probeOllamaModelFunc origOpenAIProbe := probeOpenAICompatibleModelFunc + origNow := modelProbeNowFunc + resetModelProbeCache() t.Cleanup(func() { probeTCPServiceFunc = origTCPProbe probeOllamaModelFunc = origOllamaProbe probeOpenAICompatibleModelFunc = origOpenAIProbe + modelProbeNowFunc = origNow + resetModelProbeCache() }) }