mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Add azure entra id support for azure openai provider
This commit is contained in:
@@ -413,12 +413,14 @@ PicoClaw supports 30+ LLM providers through the `model_list` configuration. Use
|
||||
| [Ollama](https://ollama.com/) | `ollama/` | Not needed | Local models, self-hosted |
|
||||
| [vLLM](https://docs.vllm.ai/) | `vllm/` | Not needed | Local deployment, OpenAI-compatible |
|
||||
| [LiteLLM](https://docs.litellm.ai/) | `litellm/` | Varies | Proxy for 100+ providers |
|
||||
| [Azure OpenAI](https://portal.azure.com/) | `azure/` | Required | Enterprise Azure deployment |
|
||||
| [Azure OpenAI](https://portal.azure.com/) | `azure/` | API key or Entra ID** | Enterprise Azure deployment |
|
||||
| [GitHub Copilot](https://github.com/features/copilot) | `github-copilot/` | OAuth | Device code login |
|
||||
| [Antigravity](https://console.cloud.google.com/) | `antigravity/` | OAuth | Google Cloud AI |
|
||||
| [AWS Bedrock](https://console.aws.amazon.com/bedrock)* | `bedrock/` | AWS credentials | Claude, Llama, Mistral on AWS |
|
||||
|
||||
> \* AWS Bedrock requires build tag: `go build -tags bedrock`. Set `api_base` to a region name (e.g., `us-east-1`) for automatic endpoint resolution across all AWS partitions (aws, aws-cn, aws-us-gov). When using a full endpoint URL instead, you must also configure `AWS_REGION` via environment variable or AWS config/profile.
|
||||
>
|
||||
> \*\* Azure OpenAI uses `api_key` when set. If `api_key` is omitted, the provider falls back to Microsoft Entra ID via `DefaultAzureCredential` (env vars, workload identity, managed identity, Azure CLI, etc.). The Entra ID path requires build tag: `go build -tags azidentity`.
|
||||
|
||||
<details>
|
||||
<summary><b>Local deployment (Ollama, vLLM, etc.)</b></summary>
|
||||
|
||||
@@ -4,6 +4,8 @@ go 1.25.10
|
||||
|
||||
require (
|
||||
fyne.io/systray v1.12.1
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1
|
||||
github.com/SevereCloud/vksdk/v3 v3.3.1
|
||||
github.com/adhocore/gronx v1.20.0
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0
|
||||
@@ -55,6 +57,8 @@ require (
|
||||
require (
|
||||
aead.dev/minisign v0.2.0 // indirect
|
||||
filippo.io/edwards25519 v1.2.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect
|
||||
@@ -82,7 +86,9 @@ require (
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/godbus/dbus/v5 v5.1.0 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
@@ -91,6 +97,7 @@ require (
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 // indirect
|
||||
github.com/pion/randutil v0.1.0 // indirect
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/segmentio/asm v1.1.3 // indirect
|
||||
|
||||
@@ -5,6 +5,18 @@ filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
|
||||
filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
|
||||
fyne.io/systray v1.12.1 h1:ygBD6aZXwiOmZoY5N+ukbH9pih0Kq6fYgVeMYbr5skQ=
|
||||
fyne.io/systray v1.12.1/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 h1:jHb/wfvRikGdxMXYV3QG/SzUOPYN9KEUUuC0Yd0/vC0=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1/go.mod h1:pzBXCYn05zvYIrwLgtK8Ap8QcjRg+0i76tMQdWN6wOk=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0/go.mod h1:7dCRMLwisfRH3dBupKeNCioWYUZ4SS09Z14H+7i8ZoY=
|
||||
github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM=
|
||||
github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/SevereCloud/vksdk/v3 v3.3.1 h1:O86zsp5LQnHE+O5acvuXM/s6S1LyxzVTkF6+Lup0Jyg=
|
||||
@@ -164,6 +176,8 @@ github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyf
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=
|
||||
github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.6 h1:2jupLlAwFm95+YDR+NwD2MEfFO9d4z4Prjl1XXDjuao=
|
||||
@@ -179,6 +193,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.7.5 h1:dimv+ZAGia01f4xCDGvCiBHKWMf4K1AB7fGsM+lv5Jw=
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.7.5/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
|
||||
github.com/line/line-bot-sdk-go/v8 v8.20.0 h1:Jv22DV3JuQ5qZvniqUbg504bJrVzffXs2CMpyoiuIZU=
|
||||
@@ -225,6 +241,8 @@ github.com/pion/rtp v1.10.2 h1:l+f6tTDcAH6xwepaAoW791ddhuYsJlqRATOzirO04Mo=
|
||||
github.com/pion/rtp v1.10.2/go.mod h1:Au8fc6cEByy8RLTwKTQTEeQqDB/SJDxwL4mZuxYA5Pk=
|
||||
github.com/pion/webrtc/v3 v3.3.6 h1:7XAh4RPtlY1Vul6/GmZrv7z+NnxKA6If0KStXBI2ZLE=
|
||||
github.com/pion/webrtc/v3 v3.3.6/go.mod h1:zyN7th4mZpV27eXybfR/cnUf3J2DRy8zw/mdjD9JTNM=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
@@ -384,6 +402,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
//go:build azidentity
|
||||
|
||||
// Package azure: Entra ID (DefaultAzureCredential) auth adapter.
|
||||
// Built only when -tags azidentity is supplied; otherwise identity_stub.go
|
||||
// satisfies the same exported API with a friendly error.
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
)
|
||||
|
||||
// azureOpenAIScope is the OAuth scope for Azure OpenAI (Cognitive Services).
|
||||
// Service-wide scope, so it covers all regions including sovereign clouds.
|
||||
const azureOpenAIScope = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
// NewProviderWithIdentity creates an Azure OpenAI provider authenticated via
|
||||
// the DefaultAzureCredential chain (env vars, workload identity, managed
|
||||
// identity, Azure CLI, ...). Construction itself only fails if the credential
|
||||
// chain cannot be built; misconfigured environments surface their error on
|
||||
// the first Chat call when GetToken is invoked.
|
||||
func NewProviderWithIdentity(apiBase, proxy, userAgent string, opts ...Option) (*Provider, error) {
|
||||
cred, err := azidentity.NewDefaultAzureCredential(nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating azure default credential: %w", err)
|
||||
}
|
||||
|
||||
ts := func(ctx context.Context) (string, error) {
|
||||
tok, err := cred.GetToken(ctx, policy.TokenRequestOptions{
|
||||
Scopes: []string{azureOpenAIScope},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("acquiring azure access token: %w", err)
|
||||
}
|
||||
return tok.Token, nil
|
||||
}
|
||||
|
||||
return NewProviderWithTokenSource(apiBase, proxy, userAgent, ts, opts...), nil
|
||||
}
|
||||
|
||||
// NewProviderWithIdentityAndTimeout mirrors NewProviderWithTimeout for the
|
||||
// identity auth path.
|
||||
func NewProviderWithIdentityAndTimeout(
|
||||
apiBase, proxy, userAgent string,
|
||||
requestTimeoutSeconds int,
|
||||
) (*Provider, error) {
|
||||
return NewProviderWithIdentity(
|
||||
apiBase, proxy, userAgent,
|
||||
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
//go:build !azidentity
|
||||
|
||||
// Package azure: stub for the Entra ID auth path when built without
|
||||
// the azidentity tag. Mirrors the exported surface of identity.go so
|
||||
// callers compile cleanly in the default build.
|
||||
package azure
|
||||
|
||||
import "fmt"
|
||||
|
||||
const azidentityBuildHint = "azure identity auth not available: build with -tags azidentity to enable Entra ID auth, or set api_key"
|
||||
|
||||
// NewProviderWithIdentity returns an error in the default build.
|
||||
func NewProviderWithIdentity(apiBase, proxy, userAgent string, opts ...Option) (*Provider, error) {
|
||||
return nil, fmt.Errorf("%s", azidentityBuildHint)
|
||||
}
|
||||
|
||||
// NewProviderWithIdentityAndTimeout returns an error in the default build.
|
||||
func NewProviderWithIdentityAndTimeout(
|
||||
apiBase, proxy, userAgent string,
|
||||
requestTimeoutSeconds int,
|
||||
) (*Provider, error) {
|
||||
return nil, fmt.Errorf("%s", azidentityBuildHint)
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
//go:build azidentity
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewProviderWithIdentity_Construction(t *testing.T) {
|
||||
// DefaultAzureCredential construction itself does not require any env vars;
|
||||
// failures surface only on the first GetToken call. Verify we get a
|
||||
// non-nil provider back with a token source wired in.
|
||||
p, err := NewProviderWithIdentity("https://example.openai.azure.com", "", "ua-test")
|
||||
if err != nil {
|
||||
t.Fatalf("NewProviderWithIdentity() error = %v", err)
|
||||
}
|
||||
if p == nil {
|
||||
t.Fatal("NewProviderWithIdentity() returned nil provider")
|
||||
}
|
||||
if p.tokenSource == nil {
|
||||
t.Fatal("provider.tokenSource should be set")
|
||||
}
|
||||
if p.apiKey != "" {
|
||||
t.Errorf("provider.apiKey = %q, want empty", p.apiKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProviderWithIdentityAndTimeout_Construction(t *testing.T) {
|
||||
p, err := NewProviderWithIdentityAndTimeout("https://example.openai.azure.com", "", "ua-test", 30)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProviderWithIdentityAndTimeout() error = %v", err)
|
||||
}
|
||||
if p == nil {
|
||||
t.Fatal("returned nil provider")
|
||||
}
|
||||
if p.httpClient.Timeout.Seconds() != 30 {
|
||||
t.Errorf("timeout = %v, want 30s", p.httpClient.Timeout)
|
||||
}
|
||||
}
|
||||
@@ -33,10 +33,11 @@ const (
|
||||
// It handles Azure-specific authentication (Bearer token), URL construction
|
||||
// (Responses API), and request/response formatting.
|
||||
type Provider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
userAgent string
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
userAgent string
|
||||
tokenSource func(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
// Option configures the Azure Provider.
|
||||
@@ -58,6 +59,14 @@ func WithUserAgent(userAgent string) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithTokenSource sets a callback that returns a bearer token per request.
|
||||
// When set, it takes precedence over the static api key.
|
||||
func WithTokenSource(ts func(ctx context.Context) (string, error)) Option {
|
||||
return func(p *Provider) {
|
||||
p.tokenSource = ts
|
||||
}
|
||||
}
|
||||
|
||||
// NewProvider creates a new Azure OpenAI provider.
|
||||
func NewProvider(apiKey, apiBase, proxy, userAgent string, opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
@@ -84,6 +93,30 @@ func NewProviderWithTimeout(apiKey, apiBase, proxy, userAgent string, requestTim
|
||||
)
|
||||
}
|
||||
|
||||
// NewProviderWithTokenSource creates a new Azure OpenAI provider that obtains its
|
||||
// bearer token from the supplied callback on every request. Used for Entra ID auth
|
||||
// where tokens are short-lived and refreshed by the underlying credential.
|
||||
func NewProviderWithTokenSource(
|
||||
apiBase, proxy, userAgent string,
|
||||
tokenSource func(ctx context.Context) (string, error),
|
||||
opts ...Option,
|
||||
) *Provider {
|
||||
p := &Provider{
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
userAgent: userAgent,
|
||||
httpClient: common.NewHTTPClient(proxy),
|
||||
tokenSource: tokenSource,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(p)
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// Chat sends a request to the Azure OpenAI Responses API endpoint.
|
||||
// The model parameter is passed in the request body.
|
||||
func (p *Provider) Chat(
|
||||
@@ -147,7 +180,14 @@ func (p *Provider) Chat(
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.apiKey != "" {
|
||||
switch {
|
||||
case p.tokenSource != nil:
|
||||
tok, tokErr := p.tokenSource(ctx)
|
||||
if tokErr != nil {
|
||||
return nil, fmt.Errorf("acquiring azure identity token: %w", tokErr)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tok)
|
||||
case p.apiKey != "":
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
if p.userAgent != "" {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -415,3 +417,68 @@ func TestProviderChat_AzureNoNativeWebSearch(t *testing.T) {
|
||||
t.Errorf("tool type = %v, want %q", tool["type"], "function")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureTokenSourceHeader(t *testing.T) {
|
||||
var capturedAuth string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ts := func(ctx context.Context) (string, error) {
|
||||
return "fake-entra-token", nil
|
||||
}
|
||||
p := NewProviderWithTokenSource(server.URL, "", "", ts)
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if capturedAuth != "Bearer fake-entra-token" {
|
||||
t.Errorf("Authorization header = %q, want %q", capturedAuth, "Bearer fake-entra-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureTokenSourceError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
wantErr := errors.New("creds gone")
|
||||
ts := func(ctx context.Context) (string, error) {
|
||||
return "", wantErr
|
||||
}
|
||||
p := NewProviderWithTokenSource(server.URL, "", "", ts)
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error from token source")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "creds gone") {
|
||||
t.Errorf("error %q should wrap original token source error", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureTokenSourcePrecedence(t *testing.T) {
|
||||
var capturedAuth string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ts := func(ctx context.Context) (string, error) {
|
||||
return "from-token-source", nil
|
||||
}
|
||||
// Provider with both an api_key AND a token source: token source must win.
|
||||
p := NewProvider("static-api-key", server.URL, "", "", WithTokenSource(ts))
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if capturedAuth != "Bearer from-token-source" {
|
||||
t.Errorf("Authorization header = %q, want token-source value", capturedAuth)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,23 +137,32 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "azure":
|
||||
// Azure OpenAI uses deployment-based URLs, api-key header auth,
|
||||
// and always sends max_completion_tokens.
|
||||
if cfg.APIKey() == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for azure protocol")
|
||||
}
|
||||
// Azure OpenAI uses deployment-based URLs. Auth is Bearer token via api_key
|
||||
// when set; otherwise falls back to Entra ID (DefaultAzureCredential).
|
||||
if cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf(
|
||||
"api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
|
||||
)
|
||||
}
|
||||
return finalizeProviderFromConfig(azure.NewProviderWithTimeout(
|
||||
cfg.APIKey(),
|
||||
if cfg.APIKey() != "" {
|
||||
return finalizeProviderFromConfig(azure.NewProviderWithTimeout(
|
||||
cfg.APIKey(),
|
||||
cfg.APIBase,
|
||||
cfg.Proxy,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, cfg)
|
||||
}
|
||||
provider, err := azure.NewProviderWithIdentityAndTimeout(
|
||||
cfg.APIBase,
|
||||
cfg.Proxy,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, cfg)
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "bedrock":
|
||||
// AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
//go:build azidentity
|
||||
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// With the azidentity build tag, an azure config with no api_key must succeed
|
||||
// (falls back to DefaultAzureCredential). Construction does not require any
|
||||
// real Azure environment — token acquisition happens on first Chat.
|
||||
func TestCreateProviderFromConfig_AzureIdentityFallback(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "my-gpt5-deployment" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "my-gpt5-deployment")
|
||||
}
|
||||
}
|
||||
@@ -870,8 +870,11 @@ func TestCreateProviderFromConfig_AzureMissingAPIKey(t *testing.T) {
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for missing API key")
|
||||
// Without api_key the factory falls back to identity auth, which in the
|
||||
// default build is stubbed out and surfaces a build-tag error. With the
|
||||
// azidentity tag, the call succeeds and is covered by a separate test.
|
||||
if err != nil && !strings.Contains(err.Error(), "azidentity") {
|
||||
t.Fatalf("CreateProviderFromConfig() unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user