diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index cf7f3563a..3010c1451 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -211,6 +211,9 @@ func gatewayCmd() { <-sigChan fmt.Println("\nShutting down...") + if cp, ok := provider.(providers.StatefulProvider); ok { + cp.Close() + } cancel() healthServer.Stop(context.Background()) deviceService.Stop() diff --git a/pkg/providers/github_copilot_provider.go b/pkg/providers/github_copilot_provider.go index 6124881f7..9210021e1 100644 --- a/pkg/providers/github_copilot_provider.go +++ b/pkg/providers/github_copilot_provider.go @@ -4,60 +4,84 @@ import ( "context" "encoding/json" "fmt" + "sync" copilot "github.com/github/copilot-sdk/go" ) type GitHubCopilotProvider struct { uri string - connectMode string // `stdio` or `grpc`` + connectMode string // "stdio" or "grpc" + client *copilot.Client session *copilot.Session + + mu sync.Mutex } func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*GitHubCopilotProvider, error) { - var session *copilot.Session if connectMode == "" { connectMode = "grpc" } - switch connectMode { + switch connectMode { case "stdio": - // todo + // TODO: + return nil, fmt.Errorf("stdio mode not implemented") case "grpc": client := copilot.NewClient(&copilot.ClientOptions{ CLIUrl: uri, }) if err := client.Start(context.Background()); err != nil { return nil, fmt.Errorf( - "Can't connect to Github Copilot, https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server for details", + "can't connect to Github Copilot: %w; `https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server` for details", + err, ) } - defer client.Stop() - session, _ = client.CreateSession(context.Background(), &copilot.SessionConfig{ + + session, err := client.CreateSession(context.Background(), &copilot.SessionConfig{ Model: model, Hooks: &copilot.SessionHooks{}, }) + if err != nil { + client.Stop() + return nil, fmt.Errorf("create session failed: %w", err) + } + + return &GitHubCopilotProvider{ + uri: uri, + connectMode: connectMode, + client: client, + session: session, + }, nil + default: + return nil, fmt.Errorf("unknown connect mode: %s", connectMode) + } +} + +func (p *GitHubCopilotProvider) Close() { + p.mu.Lock() + defer p.mu.Unlock() + if p.client != nil { + p.client.Stop() + p.client = nil + p.session = nil } - - return &GitHubCopilotProvider{ - uri: uri, - connectMode: connectMode, - session: session, - }, nil } -// Chat sends a chat request to GitHub Copilot func (p *GitHubCopilotProvider) Chat( - ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any, + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, ) (*LLMResponse, error) { type tempMessage struct { Role string `json:"role"` Content string `json:"content"` } out := make([]tempMessage, 0, len(messages)) - for _, msg := range messages { out = append(out, tempMessage{ Role: msg.Role, @@ -65,12 +89,30 @@ func (p *GitHubCopilotProvider) Chat( }) } - fullcontent, _ := json.Marshal(out) + fullcontent, err := json.Marshal(out) + if err != nil { + return nil, fmt.Errorf("marshal messages: %w", err) + } + p.mu.Lock() + session := p.session + p.mu.Unlock() - content, _ := p.session.Send(ctx, copilot.MessageOptions{ + if session == nil { + return nil, fmt.Errorf("provider closed") + } + + resp, err := session.SendAndWait(ctx, copilot.MessageOptions{ Prompt: string(fullcontent), }) + if resp == nil { + return nil, fmt.Errorf("empty response from copilot") + } + if resp.Data.Content == nil { + return nil, fmt.Errorf("no content in copilot response") + } + content := *resp.Data.Content + return &LLMResponse{ FinishReason: "stop", Content: content, diff --git a/pkg/providers/types.go b/pkg/providers/types.go index f711e7803..b2dda04a5 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -30,6 +30,11 @@ type LLMProvider interface { GetDefaultModel() string } +type StatefulProvider interface { + LLMProvider + Close() +} + // FailoverReason classifies why an LLM request failed for fallback decisions. type FailoverReason string