diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 28ef76ad3..30d61aec3 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -212,6 +212,9 @@ func gatewayCmd() { fmt.Println("\nShutting down...") cancel() + if cp, ok := provider.(providers.SessionProvider); ok { + cp.Close() + } healthServer.Stop(context.Background()) deviceService.Stop() heartbeatService.Stop() diff --git a/pkg/providers/github_copilot_provider.go b/pkg/providers/github_copilot_provider.go index 6124881f7..8131b76fc 100644 --- a/pkg/providers/github_copilot_provider.go +++ b/pkg/providers/github_copilot_provider.go @@ -4,60 +4,75 @@ import ( "context" "encoding/json" "fmt" + "sync" copilot "github.com/github/copilot-sdk/go" ) type GitHubCopilotProvider struct { uri string - connectMode string // `stdio` or `grpc`` + connectMode string // "stdio" or "grpc" + client *copilot.Client session *copilot.Session + + mu sync.Mutex } func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*GitHubCopilotProvider, error) { - var session *copilot.Session if connectMode == "" { connectMode = "grpc" } - switch connectMode { + switch connectMode { case "stdio": - // todo + // TODO: + return nil, fmt.Errorf("stdio mode not implemented") case "grpc": client := copilot.NewClient(&copilot.ClientOptions{ CLIUrl: uri, }) if err := client.Start(context.Background()); err != nil { - return nil, fmt.Errorf( - "Can't connect to Github Copilot, https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server for details", - ) + return nil, fmt.Errorf("can't connect to Github Copilot: %w; `https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server` for details", err) } - defer client.Stop() - session, _ = client.CreateSession(context.Background(), &copilot.SessionConfig{ + + session, err := client.CreateSession(context.Background(), &copilot.SessionConfig{ Model: model, Hooks: &copilot.SessionHooks{}, }) + if err != nil { + client.Stop() + return nil, fmt.Errorf("create session failed: %w", err) + } + + return &GitHubCopilotProvider{ + uri: uri, + connectMode: connectMode, + client: client, + session: session, + }, nil + default: + return nil, fmt.Errorf("unknown connect mode: %s", connectMode) } - - return &GitHubCopilotProvider{ - uri: uri, - connectMode: connectMode, - session: session, - }, nil } -// Chat sends a chat request to GitHub Copilot -func (p *GitHubCopilotProvider) Chat( - ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any, -) (*LLMResponse, error) { +func (p *GitHubCopilotProvider) Close() { + p.mu.Lock() + defer p.mu.Unlock() + if p.client != nil { + p.client.Stop() + p.client = nil + p.session = nil + } +} + +func (p *GitHubCopilotProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { type tempMessage struct { Role string `json:"role"` Content string `json:"content"` } out := make([]tempMessage, 0, len(messages)) - for _, msg := range messages { out = append(out, tempMessage{ Role: msg.Role, @@ -65,18 +80,31 @@ func (p *GitHubCopilotProvider) Chat( }) } - fullcontent, _ := json.Marshal(out) + fullcontent, err := json.Marshal(out) + if err != nil { + return nil, fmt.Errorf("marshal messages: %w", err) + } + p.mu.Lock() + defer p.mu.Unlock() - content, _ := p.session.Send(ctx, copilot.MessageOptions{ + resp, err := p.session.SendAndWait(ctx, copilot.MessageOptions{ Prompt: string(fullcontent), }) + if err != nil { + return nil, err + } + + var content string + if resp != nil && resp.Data.Content != nil { + content = *resp.Data.Content + } return &LLMResponse{ FinishReason: "stop", Content: content, }, nil } - func (p *GitHubCopilotProvider) GetDefaultModel() string { + return "gpt-4.1" } diff --git a/pkg/providers/types.go b/pkg/providers/types.go index f711e7803..40ff6f7c8 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -30,6 +30,11 @@ type LLMProvider interface { GetDefaultModel() string } +type SessionProvider interface { + LLMProvider + Close() +} + // FailoverReason classifies why an LLM request failed for fallback decisions. type FailoverReason string