mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(web,api): fetch models and saved catalog support (#2832)
* feat(web,api): add fetch models and saved catalog support Split from PR #2752 (part 2 of 3). Backend: - /api/models/catalog endpoint for browsing remote model catalogs - /api/models/fetch endpoint for fetching available models from providers - Credential reuse with provider/API base matching for security - Default API base resolution for providers without explicit base Frontend: - FetchModelsDialog for importing models from remote providers - CatalogDialog for browsing and importing from model catalogs - Static import for FetchModelsDialog (replaces dynamic import from PR1) - Dynamic import retained for TestModelDialog (PR3 territory) * fix(web,api): support bare-array responses in fetchOpenAICompatibleModels * fix(web,api): tighten maskAPIKeyValue to match maskAPIKey policy For 9-12 character keys, maskAPIKeyValue exposed first 4 + last 4 chars (only 1 char masked for a 9-char key). Now uses the same policy as maskAPIKey: first 3 + last 2 for 9-12 chars, first 3 + last 4 for longer keys. Adds tests covering all key length boundaries.
This commit is contained in:
@@ -0,0 +1,169 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
)
|
||||
|
||||
// CatalogModel represents a single model entry in a saved catalog.
|
||||
type CatalogModel struct {
|
||||
ID string `json:"id"`
|
||||
OwnedBy string `json:"owned_by,omitempty"`
|
||||
Extra map[string]any `json:"extra,omitempty"`
|
||||
}
|
||||
|
||||
// CatalogEntry is a saved list of upstream models fetched for a specific provider+key combination.
|
||||
type CatalogEntry struct {
|
||||
ID string `json:"id"`
|
||||
Provider string `json:"provider"`
|
||||
APIBase string `json:"api_base"`
|
||||
APIKeyMask string `json:"api_key_mask"`
|
||||
Models []CatalogModel `json:"models"`
|
||||
FetchedAt string `json:"fetched_at"`
|
||||
}
|
||||
|
||||
// CatalogStore holds all saved model catalogs.
|
||||
type CatalogStore struct {
|
||||
Entries map[string]*CatalogEntry `json:"entries"`
|
||||
}
|
||||
|
||||
func catalogFilePath() string {
|
||||
return filepath.Join(config.GetHome(), "model_catalogs.json")
|
||||
}
|
||||
|
||||
// generateCatalogKey creates a deterministic key for a provider+base+key combination.
|
||||
func generateCatalogKey(provider, apiBase, apiKey string) string {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
apiBase = strings.TrimRight(strings.TrimSpace(apiBase), "/")
|
||||
hash := sha256.Sum256([]byte(apiKey))
|
||||
return fmt.Sprintf("%s|%s|%x", provider, apiBase, hash[:6])
|
||||
}
|
||||
|
||||
// maskAPIKeyValue masks an API key for display.
|
||||
// Keys longer than 12 chars show prefix + last 4 chars: "sk-****abcd".
|
||||
// Keys 9-12 chars show prefix + last 2 chars: "sk-****cd".
|
||||
// Shorter keys are fully masked as "****".
|
||||
// Empty keys return empty string.
|
||||
// Ensure at least 40% of the key will not be displayed.
|
||||
func maskAPIKeyValue(key string) string {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
if len(key) <= 8 {
|
||||
return "****"
|
||||
}
|
||||
if len(key) <= 12 {
|
||||
return key[:3] + "****" + key[len(key)-2:]
|
||||
}
|
||||
return key[:3] + "****" + key[len(key)-4:]
|
||||
}
|
||||
|
||||
func loadCatalogs() (*CatalogStore, error) {
|
||||
path := catalogFilePath()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &CatalogStore{Entries: make(map[string]*CatalogEntry)}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var store CatalogStore
|
||||
if err := json.Unmarshal(data, &store); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if store.Entries == nil {
|
||||
store.Entries = make(map[string]*CatalogEntry)
|
||||
}
|
||||
return &store, nil
|
||||
}
|
||||
|
||||
func saveCatalogs(store *CatalogStore) error {
|
||||
path := catalogFilePath()
|
||||
data, err := json.MarshalIndent(store, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteFileAtomic(path, data, 0o600)
|
||||
}
|
||||
|
||||
// SaveCatalog persists a fetched model list for a given provider+key combination.
|
||||
// If a catalog with the same key already exists, it is updated.
|
||||
func SaveCatalog(provider, apiBase, apiKey string, models []CatalogModel) error {
|
||||
store, err := loadCatalogs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := generateCatalogKey(provider, apiBase, apiKey)
|
||||
store.Entries[key] = &CatalogEntry{
|
||||
ID: key,
|
||||
Provider: strings.ToLower(strings.TrimSpace(provider)),
|
||||
APIBase: strings.TrimRight(strings.TrimSpace(apiBase), "/"),
|
||||
APIKeyMask: maskAPIKeyValue(apiKey),
|
||||
Models: models,
|
||||
FetchedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
return saveCatalogs(store)
|
||||
}
|
||||
|
||||
// handleListCatalogs returns all saved model catalogs.
|
||||
//
|
||||
// GET /api/models/catalog
|
||||
func (h *Handler) handleListCatalogs(w http.ResponseWriter, r *http.Request) {
|
||||
store, err := loadCatalogs()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load catalogs: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
entries := make([]*CatalogEntry, 0, len(store.Entries))
|
||||
for _, e := range store.Entries {
|
||||
entries = append(entries, e)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"entries": entries,
|
||||
"total": len(entries),
|
||||
})
|
||||
}
|
||||
|
||||
// handleDeleteCatalog deletes a saved model catalog by ID.
|
||||
//
|
||||
// DELETE /api/models/catalog/{id}
|
||||
func (h *Handler) handleDeleteCatalog(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
http.Error(w, "id is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
store, err := loadCatalogs()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load catalogs: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := store.Entries[id]; !ok {
|
||||
http.Error(w, "catalog not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
delete(store.Entries, id)
|
||||
if err := saveCatalogs(store); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save catalogs: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMaskAPIKeyValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty key",
|
||||
key: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
key: " ",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "short key fully masked",
|
||||
key: "abcd",
|
||||
want: "****",
|
||||
},
|
||||
{
|
||||
name: "length 8 boundary fully masked",
|
||||
key: "12345678",
|
||||
want: "****",
|
||||
},
|
||||
{
|
||||
name: "length 9 boundary shows last 2",
|
||||
key: "123456789",
|
||||
want: "123****89",
|
||||
},
|
||||
{
|
||||
name: "length 10 shows last 2",
|
||||
key: "1234567890",
|
||||
want: "123****90",
|
||||
},
|
||||
{
|
||||
name: "length 12 boundary shows last 2",
|
||||
key: "abcdefghijkl",
|
||||
want: "abc****kl",
|
||||
},
|
||||
{
|
||||
name: "length 13 boundary shows last 4",
|
||||
key: "abcdefghijklm",
|
||||
want: "abc****jklm",
|
||||
},
|
||||
{
|
||||
name: "typical api key",
|
||||
key: "sk-1234567890abcd",
|
||||
want: "sk-****abcd",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := maskAPIKeyValue(tc.key)
|
||||
if got != tc.want {
|
||||
t.Fatalf("maskAPIKeyValue(%q) = %q, want %q", tc.key, got, tc.want)
|
||||
}
|
||||
|
||||
if tc.key != "" {
|
||||
displayed := strings.Replace(got, "****", "", 1)
|
||||
if len(strings.TrimSpace(tc.key)) <= 8 {
|
||||
if displayed != "" {
|
||||
t.Fatalf("maskAPIKeyValue(%q) displayed part = %q, want empty", tc.key, displayed)
|
||||
}
|
||||
} else {
|
||||
if len(displayed)*10 > len(strings.TrimSpace(tc.key))*6 {
|
||||
t.Fatalf(
|
||||
"maskAPIKeyValue(%q) displayed length = %d, want at most 60%% of %d",
|
||||
tc.key,
|
||||
len(displayed),
|
||||
len(strings.TrimSpace(tc.key)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio/asr"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -15,6 +17,18 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// fetchableProviders lists providers that support OpenAI-compatible /models listing.
|
||||
var fetchableProviders = map[string]bool{
|
||||
"openai": true, "deepseek": true, "openrouter": true,
|
||||
"qwen-portal": true, "qwen-intl": true, "moonshot": true,
|
||||
"volcengine": true, "zhipu": true, "groq": true,
|
||||
"mistral": true, "nvidia": true, "cerebras": true,
|
||||
"venice": true, "shengsuanyun": true, "vivgrid": true,
|
||||
"minimax": true, "longcat": true, "modelscope": true,
|
||||
"mimo": true, "avian": true, "zai": true, "novita": true,
|
||||
"litellm": true, "vllm": true, "lmstudio": true, "ollama": true,
|
||||
}
|
||||
|
||||
// registerModelRoutes binds model list management endpoints to the ServeMux.
|
||||
func (h *Handler) registerModelRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/models", h.handleListModels)
|
||||
@@ -22,6 +36,9 @@ func (h *Handler) registerModelRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("POST /api/models/default", h.handleSetDefaultModel)
|
||||
mux.HandleFunc("PUT /api/models/{index}", h.handleUpdateModel)
|
||||
mux.HandleFunc("DELETE /api/models/{index}", h.handleDeleteModel)
|
||||
mux.HandleFunc("POST /api/models/fetch", h.handleFetchModels)
|
||||
mux.HandleFunc("GET /api/models/catalog", h.handleListCatalogs)
|
||||
mux.HandleFunc("DELETE /api/models/catalog/{id}", h.handleDeleteCatalog)
|
||||
}
|
||||
|
||||
// modelResponse is the JSON structure returned for each model in the list.
|
||||
@@ -614,3 +631,199 @@ func maskAPIKey(key string) string {
|
||||
// Show first 3 chars and last 4 chars
|
||||
return key[:3] + "****" + key[len(key)-4:]
|
||||
}
|
||||
|
||||
// handleFetchModels fetches available models from an upstream provider.
|
||||
//
|
||||
// POST /api/models/fetch
|
||||
func (h *Handler) handleFetchModels(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var req struct {
|
||||
Provider string `json:"provider"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIBase string `json:"api_base"`
|
||||
}
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Provider == "" {
|
||||
http.Error(w, "provider is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !fetchableProviders[strings.ToLower(req.Provider)] {
|
||||
http.Error(w, fmt.Sprintf("provider %q does not support model listing", req.Provider), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
apiBase := strings.TrimSpace(req.APIBase)
|
||||
if apiBase == "" {
|
||||
apiBase = providers.DefaultAPIBaseForProtocol(req.Provider)
|
||||
}
|
||||
if apiBase == "" {
|
||||
http.Error(w, fmt.Sprintf("No default API base for provider %q", req.Provider), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
models, err := fetchUpstreamModels(ctx, req.Provider, apiBase, req.APIKey)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to fetch models: %v", err), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// Auto-save fetched models to catalog
|
||||
catalogModels := make([]CatalogModel, len(models))
|
||||
for i, m := range models {
|
||||
catalogModels[i] = CatalogModel{ID: m.ID, OwnedBy: m.OwnedBy}
|
||||
}
|
||||
if saveErr := SaveCatalog(req.Provider, apiBase, req.APIKey, catalogModels); saveErr != nil {
|
||||
// Log but don't fail the request — saving catalog is non-critical
|
||||
logger.Warnf("Failed to save model catalog: %v", saveErr)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"models": models,
|
||||
"total": len(models),
|
||||
})
|
||||
}
|
||||
|
||||
type upstreamModel struct {
|
||||
ID string `json:"id"`
|
||||
OwnedBy string `json:"owned_by,omitempty"`
|
||||
}
|
||||
|
||||
func fetchUpstreamModels(ctx context.Context, provider, apiBase, apiKey string) ([]upstreamModel, error) {
|
||||
apiBase = strings.TrimRight(strings.TrimSpace(apiBase), "/")
|
||||
|
||||
var fetchURL string
|
||||
switch strings.ToLower(provider) {
|
||||
case "ollama":
|
||||
// Strip /v1 suffix if present to get the Ollama root
|
||||
root := apiBase
|
||||
if strings.HasSuffix(root, "/v1") {
|
||||
root = root[:len(root)-3]
|
||||
}
|
||||
root = strings.TrimRight(root, "/")
|
||||
fetchURL = root + "/api/tags"
|
||||
return fetchOllamaModels(ctx, fetchURL)
|
||||
default:
|
||||
// OpenAI-compatible: /v1/models
|
||||
fetchURL = apiBase + "/models"
|
||||
return fetchOpenAICompatibleModels(ctx, fetchURL, apiKey)
|
||||
}
|
||||
}
|
||||
|
||||
func fetchOpenAICompatibleModels(ctx context.Context, fetchURL, apiKey string) ([]upstreamModel, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fetchURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if apiKey = strings.TrimSpace(apiKey); apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("upstream returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
type modelItem struct {
|
||||
ID string `json:"id"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
// {"data": [...]} envelope. Distinguish "envelope shape with empty list"
|
||||
// from "object without a data key" via Data being non-nil after unmarshal:
|
||||
// json.Unmarshal sets Data to []modelItem{} for `{"data":[]}` but leaves
|
||||
// it as nil when "data" is absent or null.
|
||||
var envelope struct {
|
||||
Data []modelItem `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &envelope); err == nil && envelope.Data != nil {
|
||||
models := make([]upstreamModel, 0, len(envelope.Data))
|
||||
for _, m := range envelope.Data {
|
||||
if m.ID != "" {
|
||||
models = append(models, upstreamModel{ID: m.ID, OwnedBy: m.OwnedBy})
|
||||
}
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// Bare-array shape, including `[]`.
|
||||
var arr []modelItem
|
||||
if err := json.Unmarshal(body, &arr); err == nil {
|
||||
models := make([]upstreamModel, 0, len(arr))
|
||||
for _, m := range arr {
|
||||
if m.ID != "" {
|
||||
models = append(models, upstreamModel{ID: m.ID, OwnedBy: m.OwnedBy})
|
||||
}
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
preview := body
|
||||
if len(preview) > 256 {
|
||||
preview = preview[:256]
|
||||
}
|
||||
return nil, fmt.Errorf("decode response: unrecognized shape: %s", strings.TrimSpace(string(preview)))
|
||||
}
|
||||
|
||||
func fetchOllamaModels(ctx context.Context, fetchURL string) ([]upstreamModel, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fetchURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("ollama returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var parsed struct {
|
||||
Models []struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
} `json:"models"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
models := make([]upstreamModel, 0, len(parsed.Models))
|
||||
for _, m := range parsed.Models {
|
||||
id := m.Name
|
||||
if id == "" {
|
||||
id = m.Model
|
||||
}
|
||||
if id != "" {
|
||||
models = append(models, upstreamModel{ID: id})
|
||||
}
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -2219,3 +2220,174 @@ func TestMaskAPIKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchOpenAICompatibleModels_ResponseShapes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response string
|
||||
apiKey string
|
||||
wantLen int
|
||||
wantFirst struct {
|
||||
id, ownedBy string
|
||||
}
|
||||
wantSecond struct {
|
||||
id, ownedBy string
|
||||
}
|
||||
}{
|
||||
{
|
||||
name: "envelope shape",
|
||||
response: `{"data":[{"id":"gpt-4o","owned_by":"openai"},{"id":"gpt-4o-mini","owned_by":"openai"}]}`,
|
||||
apiKey: "test-key",
|
||||
wantLen: 2,
|
||||
wantFirst: struct {
|
||||
id, ownedBy string
|
||||
}{id: "gpt-4o", ownedBy: "openai"},
|
||||
wantSecond: struct {
|
||||
id, ownedBy string
|
||||
}{id: "gpt-4o-mini", ownedBy: "openai"},
|
||||
},
|
||||
{
|
||||
name: "bare array shape",
|
||||
response: `[{"id":"qwen-max","owned_by":"qwen"},{"id":"qwen-plus","owned_by":"qwen"}]`,
|
||||
apiKey: "",
|
||||
wantLen: 2,
|
||||
wantFirst: struct {
|
||||
id, ownedBy string
|
||||
}{id: "qwen-max", ownedBy: "qwen"},
|
||||
wantSecond: struct {
|
||||
id, ownedBy string
|
||||
}{id: "qwen-plus", ownedBy: "qwen"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, tt.response)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
models, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", tt.apiKey)
|
||||
if err != nil {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
if len(models) != tt.wantLen {
|
||||
t.Fatalf("len(models) = %d, want %d", len(models), tt.wantLen)
|
||||
}
|
||||
if models[0].ID != tt.wantFirst.id || models[0].OwnedBy != tt.wantFirst.ownedBy {
|
||||
t.Fatalf("models[0] = %+v, want {ID:%s OwnedBy:%s}", models[0], tt.wantFirst.id, tt.wantFirst.ownedBy)
|
||||
}
|
||||
if models[1].ID != tt.wantSecond.id || models[1].OwnedBy != tt.wantSecond.ownedBy {
|
||||
t.Fatalf("models[1] = %+v, want {ID:%s OwnedBy:%s}", models[1], tt.wantSecond.id, tt.wantSecond.ownedBy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchOpenAICompatibleModels_EmptyEnvelopeReturnsEmptySlice(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"data":[]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
models, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", "k")
|
||||
if err != nil {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
if len(models) != 0 {
|
||||
t.Fatalf("len(models) = %d, want 0", len(models))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchOpenAICompatibleModels_EmptyBareArrayReturnsEmptySlice(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `[]`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
models, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", "k")
|
||||
if err != nil {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
if len(models) != 0 {
|
||||
t.Fatalf("len(models) = %d, want 0", len(models))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchOpenAICompatibleModels_UnrecognizedShape(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"models":[],"error":"unsupported"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
_, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", "k")
|
||||
if err == nil {
|
||||
t.Fatal("error = nil, want unrecognized shape error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unrecognized shape") {
|
||||
t.Fatalf("error = %q, want it to contain 'unrecognized shape'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchOpenAICompatibleModels_FiltersEmptyIDs(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"data":[`+
|
||||
`{"id":"gpt-4o","owned_by":"openai"},`+
|
||||
`{"id":"","owned_by":"openai"},`+
|
||||
`{"id":"gpt-4o-mini"}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
models, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", "k")
|
||||
if err != nil {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
if len(models) != 2 {
|
||||
t.Fatalf("len(models) = %d, want 2 (empty IDs should be filtered)", len(models))
|
||||
}
|
||||
if models[0].ID != "gpt-4o" {
|
||||
t.Fatalf("models[0].ID = %q, want %q", models[0].ID, "gpt-4o")
|
||||
}
|
||||
if models[1].ID != "gpt-4o-mini" {
|
||||
t.Fatalf("models[1].ID = %q, want %q", models[1].ID, "gpt-4o-mini")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchOpenAICompatibleModels_SetsAuthorizationHeader(t *testing.T) {
|
||||
var gotAuth string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"data":[{"id":"m1"}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
if _, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", "my-secret-key"); err != nil {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
if gotAuth != "Bearer my-secret-key" {
|
||||
t.Fatalf("Authorization = %q, want %q", gotAuth, "Bearer my-secret-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchOpenAICompatibleModels_NoAuthHeaderWhenKeyEmpty(t *testing.T) {
|
||||
var gotAuth string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `[{"id":"m1"}]`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
if _, err := fetchOpenAICompatibleModels(t.Context(), srv.URL+"/models", ""); err != nil {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
if gotAuth != "" {
|
||||
t.Fatalf("Authorization = %q, want empty", gotAuth)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user