mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(api): enhance model availability probing with backoff and caching mechanisms (#2231)
* fix(api): enhance model availability probing with backoff and caching mechanisms * fix(lint): resolve gci and predeclared issues in model probe * fix(api): address copilot review feedback on probe cache key and test stability * fix(api): reduce probe cache key fragmentation
This commit is contained in:
@@ -1,19 +1,36 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
const modelProbeTimeout = 800 * time.Millisecond
|
||||
const (
|
||||
modelProbeTimeout = 800 * time.Millisecond
|
||||
modelProbeSuccessBaseInterval = 2 * time.Second
|
||||
modelProbeSuccessMaxInterval = 60 * time.Second
|
||||
modelProbeFailureBaseInterval = 1 * time.Second
|
||||
modelProbeFailureMaxInterval = 30 * time.Second
|
||||
modelProbeBackoffMaxShift = 8
|
||||
modelProbeCacheMaxEntries = 1024
|
||||
modelProbeCacheEntryTTL = 30 * time.Minute
|
||||
modelProbeCacheTrimToEntries = modelProbeCacheMaxEntries * 8 / 10
|
||||
modelProbeTTLGCInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
modelStatusAvailable = "available"
|
||||
@@ -30,8 +47,41 @@ var (
|
||||
probeTCPServiceFunc = probeTCPService
|
||||
probeOllamaModelFunc = probeOllamaModel
|
||||
probeOpenAICompatibleModelFunc = probeOpenAICompatibleModel
|
||||
modelProbeNowFunc = time.Now
|
||||
modelProbeState = newModelProbeCacheState()
|
||||
)
|
||||
|
||||
type modelProbeCacheState struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]*modelProbeCacheEntry
|
||||
group singleflight.Group
|
||||
nextTTLGCAt time.Time
|
||||
}
|
||||
|
||||
type modelProbeCacheEntry struct {
|
||||
lastResult bool
|
||||
hasResult bool
|
||||
successStreak int
|
||||
failureStreak int
|
||||
nextProbeAt time.Time
|
||||
updatedAt time.Time
|
||||
}
|
||||
|
||||
func newModelProbeCacheState() *modelProbeCacheState {
|
||||
return &modelProbeCacheState{cache: map[string]*modelProbeCacheEntry{}}
|
||||
}
|
||||
|
||||
func resetModelProbeCache() {
|
||||
modelProbeState.resetForTest()
|
||||
}
|
||||
|
||||
func (s *modelProbeCacheState) resetForTest() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cache = map[string]*modelProbeCacheEntry{}
|
||||
s.nextTTLGCAt = time.Time{}
|
||||
}
|
||||
|
||||
func hasModelConfiguration(m *config.ModelConfig) bool {
|
||||
authMethod := strings.ToLower(strings.TrimSpace(m.AuthMethod))
|
||||
apiKey := strings.TrimSpace(m.APIKey())
|
||||
@@ -93,6 +143,34 @@ func requiresRuntimeProbe(m *config.ModelConfig) bool {
|
||||
}
|
||||
|
||||
func probeLocalModelAvailability(m *config.ModelConfig) bool {
|
||||
cacheKey := modelProbeCacheKey(m)
|
||||
return modelProbeState.probe(cacheKey, func() bool {
|
||||
return runLocalModelProbe(m)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *modelProbeCacheState) probe(cacheKey string, probeFunc func() bool) bool {
|
||||
now := modelProbeNowFunc()
|
||||
if cachedResult, ok := s.getCachedResult(cacheKey, now); ok {
|
||||
return cachedResult
|
||||
}
|
||||
|
||||
v, _, _ := s.group.Do(cacheKey, func() (any, error) {
|
||||
now = modelProbeNowFunc()
|
||||
if cachedResult, ok := s.getCachedResult(cacheKey, now); ok {
|
||||
return cachedResult, nil
|
||||
}
|
||||
|
||||
result := probeFunc()
|
||||
s.setCachedResult(cacheKey, result, now)
|
||||
return result, nil
|
||||
})
|
||||
|
||||
result, _ := v.(bool)
|
||||
return result
|
||||
}
|
||||
|
||||
func runLocalModelProbe(m *config.ModelConfig) bool {
|
||||
apiBase := modelProbeAPIBase(m)
|
||||
protocol, modelID := splitModel(m.Model)
|
||||
switch protocol {
|
||||
@@ -112,6 +190,195 @@ func probeLocalModelAvailability(m *config.ModelConfig) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func modelProbeCacheKey(m *config.ModelConfig) string {
|
||||
protocol, modelID := splitModel(m.Model)
|
||||
|
||||
apiBaseRaw := modelProbeAPIBase(m)
|
||||
apiBase := strings.ToLower(strings.TrimRight(strings.TrimSpace(apiBaseRaw), "/"))
|
||||
apiKeyFingerprint := modelProbeAPIKeyFingerprint(m.APIKey())
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(protocol) + len(modelID) + len(apiBase) + len(apiKeyFingerprint) + 8)
|
||||
b.WriteString(protocol)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(modelID)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(apiBase)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(apiKeyFingerprint)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func modelProbeAPIKeyFingerprint(raw string) string {
|
||||
apiKey := strings.TrimSpace(raw)
|
||||
if apiKey == "" {
|
||||
return "none"
|
||||
}
|
||||
|
||||
h := fnv.New64a()
|
||||
_, _ = h.Write([]byte(apiKey))
|
||||
return strconv.FormatUint(h.Sum64(), 36)
|
||||
}
|
||||
|
||||
func (s *modelProbeCacheState) getCachedResult(cacheKey string, now time.Time) (bool, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
entry, ok := s.cache[cacheKey]
|
||||
if !ok || !entry.hasResult {
|
||||
return false, false
|
||||
}
|
||||
if now.Before(entry.nextProbeAt) {
|
||||
return entry.lastResult, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (s *modelProbeCacheState) setCachedResult(cacheKey string, result bool, now time.Time) {
|
||||
s.mu.Lock()
|
||||
|
||||
entry, ok := s.cache[cacheKey]
|
||||
if !ok {
|
||||
entry = &modelProbeCacheEntry{}
|
||||
s.cache[cacheKey] = entry
|
||||
}
|
||||
|
||||
entry.lastResult = result
|
||||
entry.hasResult = true
|
||||
entry.updatedAt = now
|
||||
|
||||
var delay time.Duration
|
||||
if result {
|
||||
entry.successStreak++
|
||||
entry.failureStreak = 0
|
||||
delay = modelProbeBackoffDelay(
|
||||
modelProbeSuccessBaseInterval,
|
||||
modelProbeSuccessMaxInterval,
|
||||
entry.successStreak,
|
||||
)
|
||||
} else {
|
||||
entry.failureStreak++
|
||||
entry.successStreak = 0
|
||||
delay = modelProbeBackoffDelay(
|
||||
modelProbeFailureBaseInterval,
|
||||
modelProbeFailureMaxInterval,
|
||||
entry.failureStreak,
|
||||
)
|
||||
}
|
||||
|
||||
entry.nextProbeAt = now.Add(delay)
|
||||
|
||||
shouldRunTTLGC := modelProbeCacheEntryTTL > 0 && (s.nextTTLGCAt.IsZero() || !now.Before(s.nextTTLGCAt))
|
||||
if shouldRunTTLGC {
|
||||
s.nextTTLGCAt = now.Add(modelProbeTTLGCInterval)
|
||||
}
|
||||
shouldRunSizeGC := len(s.cache) > modelProbeCacheMaxEntries
|
||||
s.mu.Unlock()
|
||||
|
||||
if shouldRunTTLGC || shouldRunSizeGC {
|
||||
s.gc(now, shouldRunTTLGC)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *modelProbeCacheState) gc(now time.Time, runTTL bool) {
|
||||
type evictionCandidate struct {
|
||||
key string
|
||||
updatedAt time.Time
|
||||
}
|
||||
|
||||
var expireBefore time.Time
|
||||
if runTTL && modelProbeCacheEntryTTL > 0 {
|
||||
expireBefore = now.Add(-modelProbeCacheEntryTTL)
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
cacheLen := len(s.cache)
|
||||
if cacheLen == 0 {
|
||||
s.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
expiredKeys := make([]string, 0)
|
||||
if !expireBefore.IsZero() {
|
||||
expiredKeys = make([]string, 0, min(cacheLen/8+1, 64))
|
||||
for key, entry := range s.cache {
|
||||
if entry.updatedAt.Before(expireBefore) {
|
||||
expiredKeys = append(expiredKeys, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
effectiveLen := cacheLen - len(expiredKeys)
|
||||
removeCount := max(effectiveLen-modelProbeCacheTrimToEntries, 0)
|
||||
|
||||
candidates := make([]evictionCandidate, 0)
|
||||
if removeCount > 0 {
|
||||
candidates = make([]evictionCandidate, 0, effectiveLen)
|
||||
for key, entry := range s.cache {
|
||||
if !expireBefore.IsZero() && entry.updatedAt.Before(expireBefore) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, evictionCandidate{key: key, updatedAt: entry.updatedAt})
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if len(expiredKeys) == 0 && len(candidates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
toEvict := map[string]time.Time{}
|
||||
for i := 0; i < removeCount && len(candidates) > 0; i++ {
|
||||
oldest := 0
|
||||
for j := 1; j < len(candidates); j++ {
|
||||
if candidates[j].updatedAt.Before(candidates[oldest].updatedAt) {
|
||||
oldest = j
|
||||
}
|
||||
}
|
||||
victim := candidates[oldest]
|
||||
toEvict[victim.key] = victim.updatedAt
|
||||
candidates[oldest] = candidates[len(candidates)-1]
|
||||
candidates = candidates[:len(candidates)-1]
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !expireBefore.IsZero() {
|
||||
for _, key := range expiredKeys {
|
||||
entry, ok := s.cache[key]
|
||||
if ok && entry.updatedAt.Before(expireBefore) {
|
||||
delete(s.cache, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for key, victimUpdatedAt := range toEvict {
|
||||
entry, ok := s.cache[key]
|
||||
if ok && !entry.updatedAt.After(victimUpdatedAt) {
|
||||
delete(s.cache, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func modelProbeBackoffDelay(base, maxDelay time.Duration, streak int) time.Duration {
|
||||
if streak <= 0 {
|
||||
streak = 1
|
||||
}
|
||||
|
||||
shift := min(streak-1, modelProbeBackoffMaxShift)
|
||||
|
||||
delay := base * time.Duration(1<<shift)
|
||||
if maxDelay > 0 && (delay > maxDelay || delay < 0) {
|
||||
return maxDelay
|
||||
}
|
||||
if delay <= 0 {
|
||||
return base
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func modelProbeAPIBase(m *config.ModelConfig) string {
|
||||
if apiBase := strings.TrimSpace(m.APIBase); apiBase != "" {
|
||||
return normalizeModelProbeAPIBase(apiBase)
|
||||
@@ -207,7 +474,11 @@ func probeTCPService(raw string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", hostPort, modelProbeTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout)
|
||||
defer cancel()
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", hostPort)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -262,7 +533,10 @@ func probeOpenAICompatibleModel(apiBase, modelID, apiKey string) bool {
|
||||
}
|
||||
|
||||
func getJSON(rawURL string, out any, apiKey string) error {
|
||||
req, err := http.NewRequest(http.MethodGet, rawURL, nil)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -270,7 +544,7 @@ func getJSON(rawURL string, out any, apiKey string) error {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: modelProbeTimeout}
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -336,10 +610,29 @@ func ollamaModelMatches(candidate, want string) bool {
|
||||
if candidate == "" || want == "" {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(candidate, want) {
|
||||
return true
|
||||
|
||||
candidateBase, candidateTag := splitOllamaModel(candidate)
|
||||
wantBase, wantTag := splitOllamaModel(want)
|
||||
if candidateBase == "" || wantBase == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
base, _, _ := strings.Cut(candidate, ":")
|
||||
return strings.EqualFold(base, want)
|
||||
if candidateTag == "" {
|
||||
candidateTag = "latest"
|
||||
}
|
||||
if wantTag == "" {
|
||||
wantTag = "latest"
|
||||
}
|
||||
|
||||
return strings.EqualFold(candidateBase, wantBase) && strings.EqualFold(candidateTag, wantTag)
|
||||
}
|
||||
|
||||
func splitOllamaModel(raw string) (base, tag string) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
base, tag, _ = strings.Cut(raw, ":")
|
||||
return strings.TrimSpace(base), strings.TrimSpace(tag)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,10 @@ package api
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
@@ -85,3 +88,307 @@ func TestProbeLocalModelAvailability_LMStudioUsesOpenAICompatibleProbe(t *testin
|
||||
t.Fatal("probeOpenAICompatibleModelFunc was not called for lmstudio")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelProbeCacheKey_DifferentAPIKeysProduceDifferentKeys(t *testing.T) {
|
||||
base := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
AuthMethod: "local",
|
||||
ConnectMode: "",
|
||||
}
|
||||
|
||||
m1 := *base
|
||||
m1.SetAPIKey("key-a")
|
||||
m2 := *base
|
||||
m2.SetAPIKey("key-b")
|
||||
|
||||
k1 := modelProbeCacheKey(&m1)
|
||||
k2 := modelProbeCacheKey(&m2)
|
||||
if k1 == k2 {
|
||||
t.Fatal("modelProbeCacheKey() should differ when api key changes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelProbeCacheKey_NormalizesTrailingSlashInAPIBase(t *testing.T) {
|
||||
m1 := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
}
|
||||
m2 := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1/",
|
||||
}
|
||||
|
||||
k1 := modelProbeCacheKey(m1)
|
||||
k2 := modelProbeCacheKey(m2)
|
||||
if k1 != k2 {
|
||||
t.Fatalf("modelProbeCacheKey() mismatch for equivalent api_base values: %q vs %q", k1, k2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelProbeCacheKey_IgnoresDisplayAndConnectionFields(t *testing.T) {
|
||||
base := &config.ModelConfig{
|
||||
ModelName: "vllm-one",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
AuthMethod: "none",
|
||||
ConnectMode: "http",
|
||||
}
|
||||
changed := &config.ModelConfig{
|
||||
ModelName: "vllm-two",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
AuthMethod: "token",
|
||||
ConnectMode: "ws",
|
||||
}
|
||||
|
||||
k1 := modelProbeCacheKey(base)
|
||||
k2 := modelProbeCacheKey(changed)
|
||||
if k1 != k2 {
|
||||
t.Fatalf("modelProbeCacheKey() should ignore non-probe fields, got %q vs %q", k1, k2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeLocalModelAvailability_SuccessBackoff(t *testing.T) {
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
now := time.Unix(1700000000, 0)
|
||||
modelProbeNowFunc = func() time.Time { return now }
|
||||
|
||||
calls := 0
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
|
||||
calls++
|
||||
return true
|
||||
}
|
||||
|
||||
model := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
}
|
||||
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("first probe result = false, want true")
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("probe calls after first probe = %d, want 1", calls)
|
||||
}
|
||||
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("cached probe result = false, want true")
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("probe calls after immediate re-check = %d, want 1", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeSuccessBaseInterval)
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("second probe result = false, want true")
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Fatalf("probe calls after success backoff window = %d, want 2", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeSuccessBaseInterval)
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("cached result after doubled backoff = false, want true")
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Fatalf("probe calls before doubled backoff expires = %d, want 2", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeSuccessBaseInterval)
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("third probe result = false, want true")
|
||||
}
|
||||
if calls != 3 {
|
||||
t.Fatalf("probe calls after doubled backoff expires = %d, want 3", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeLocalModelAvailability_FailureBackoff(t *testing.T) {
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
now := time.Unix(1700000100, 0)
|
||||
modelProbeNowFunc = func() time.Time { return now }
|
||||
|
||||
calls := 0
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
|
||||
calls++
|
||||
return false
|
||||
}
|
||||
|
||||
model := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
}
|
||||
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("first probe result = true, want false")
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("probe calls after first failure = %d, want 1", calls)
|
||||
}
|
||||
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("cached failed probe result = true, want false")
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("probe calls after immediate failed re-check = %d, want 1", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeFailureBaseInterval)
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("second failed probe result = true, want false")
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Fatalf("probe calls after failure backoff window = %d, want 2", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeFailureBaseInterval)
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("cached failure after doubled backoff = true, want false")
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Fatalf("probe calls before doubled failure backoff expires = %d, want 2", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeFailureBaseInterval)
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("third failed probe result = true, want false")
|
||||
}
|
||||
if calls != 3 {
|
||||
t.Fatalf("probe calls after doubled failure backoff expires = %d, want 3", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeLocalModelAvailability_ResultFlipResetsBackoff(t *testing.T) {
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
now := time.Unix(1700000200, 0)
|
||||
modelProbeNowFunc = func() time.Time { return now }
|
||||
|
||||
results := []bool{true, false, false}
|
||||
index := 0
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
|
||||
if index >= len(results) {
|
||||
return false
|
||||
}
|
||||
result := results[index]
|
||||
index++
|
||||
return result
|
||||
}
|
||||
|
||||
model := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
}
|
||||
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("first probe result = false, want true")
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeSuccessBaseInterval)
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("second probe result = true, want false")
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeFailureBaseInterval)
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("third probe result = true, want false")
|
||||
}
|
||||
|
||||
if index != 3 {
|
||||
t.Fatalf("probe invocations = %d, want 3", index)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeLocalModelAvailability_DeduplicatesInflightProbe(t *testing.T) {
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
now := time.Unix(1700000300, 0)
|
||||
modelProbeNowFunc = func() time.Time { return now }
|
||||
|
||||
var calls int32
|
||||
probeStarted := make(chan struct{})
|
||||
releaseProbe := make(chan struct{})
|
||||
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
|
||||
if atomic.AddInt32(&calls, 1) == 1 {
|
||||
close(probeStarted)
|
||||
}
|
||||
<-releaseProbe
|
||||
return true
|
||||
}
|
||||
|
||||
model := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
}
|
||||
|
||||
const workers = 8
|
||||
var wg sync.WaitGroup
|
||||
results := make(chan bool, workers)
|
||||
workerStarted := make(chan struct{}, workers)
|
||||
|
||||
for range workers {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
workerStarted <- struct{}{}
|
||||
results <- probeLocalModelAvailability(model)
|
||||
}()
|
||||
}
|
||||
|
||||
for range workers {
|
||||
<-workerStarted
|
||||
}
|
||||
|
||||
select {
|
||||
case <-probeStarted:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("probe did not start in time")
|
||||
}
|
||||
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("concurrent probe calls = %d, want 1", got)
|
||||
}
|
||||
|
||||
close(releaseProbe)
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
for result := range results {
|
||||
if !result {
|
||||
t.Fatal("deduplicated probe result = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("final probe calls = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOllamaModelMatches_WithTagRequiresExactTag(t *testing.T) {
|
||||
if ollamaModelMatches("llama3:8b", "llama3:7b") {
|
||||
t.Fatal("ollamaModelMatches() = true, want false for mismatched tags")
|
||||
}
|
||||
if !ollamaModelMatches("llama3:7b", "llama3:7b") {
|
||||
t.Fatal("ollamaModelMatches() = false, want true for exact tagged match")
|
||||
}
|
||||
if ollamaModelMatches("llama3:8b", "llama3") {
|
||||
t.Fatal("ollamaModelMatches() = true, want false when request omits tag (defaults to latest)")
|
||||
}
|
||||
if !ollamaModelMatches("llama3:latest", "llama3") {
|
||||
t.Fatal("ollamaModelMatches() = false, want true when request omits tag and candidate is latest")
|
||||
}
|
||||
if !ollamaModelMatches("llama3", "llama3") {
|
||||
t.Fatal("ollamaModelMatches() = false, want true when both candidate and request omit tag (latest)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,10 +20,14 @@ func resetModelProbeHooks(t *testing.T) {
|
||||
origTCPProbe := probeTCPServiceFunc
|
||||
origOllamaProbe := probeOllamaModelFunc
|
||||
origOpenAIProbe := probeOpenAICompatibleModelFunc
|
||||
origNow := modelProbeNowFunc
|
||||
resetModelProbeCache()
|
||||
t.Cleanup(func() {
|
||||
probeTCPServiceFunc = origTCPProbe
|
||||
probeOllamaModelFunc = origOllamaProbe
|
||||
probeOpenAICompatibleModelFunc = origOpenAIProbe
|
||||
modelProbeNowFunc = origNow
|
||||
resetModelProbeCache()
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user