mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #2514 from lc6464/fix/issue-2488-host-binding
feat(launcher): add host overrides for launcher and gateway
This commit is contained in:
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
@@ -119,6 +120,7 @@ var (
|
||||
gatewayRestartGracePeriod = 5 * time.Second
|
||||
gatewayRestartForceKillWindow = 3 * time.Second
|
||||
gatewayRestartPollInterval = 100 * time.Millisecond
|
||||
gatewayExecCommand = exec.Command
|
||||
)
|
||||
|
||||
var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
|
||||
@@ -262,7 +264,7 @@ func (h *Handler) getGatewayHealthForPidData(
|
||||
host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
}
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
host = netbind.ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
|
||||
url := "http://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/health"
|
||||
@@ -723,7 +725,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
execPath := utils.FindPicoclawBinary()
|
||||
logger.InfoC("gateway", fmt.Sprintf("Starting gateway process (%s)", execPath))
|
||||
|
||||
cmd = exec.Command(execPath, h.gatewayCommandArgs()...)
|
||||
cmd = gatewayExecCommand(execPath, h.gatewayCommandArgs()...)
|
||||
cmd.Env = os.Environ()
|
||||
// Forward the launcher's config path via the environment variable that
|
||||
// GetConfigPath() already reads, so the gateway sub-process uses the same
|
||||
@@ -731,8 +733,9 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
if h.configPath != "" {
|
||||
cmd.Env = append(cmd.Env, config.EnvConfig+"="+h.configPath)
|
||||
}
|
||||
if host := h.gatewayHostOverride(); host != "" {
|
||||
cmd.Env = append(cmd.Env, config.EnvGatewayHost+"="+host)
|
||||
gatewayHostOverride := h.gatewayHostOverride()
|
||||
if gatewayHostOverride != "" {
|
||||
cmd.Env = append(cmd.Env, config.EnvGatewayHost+"="+gatewayHostOverride)
|
||||
}
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
|
||||
@@ -8,9 +8,15 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
)
|
||||
|
||||
func (h *Handler) effectiveLauncherPublic() bool {
|
||||
if h.serverHostExplicit {
|
||||
// -host takes precedence over -public and launcher-config public setting.
|
||||
return false
|
||||
}
|
||||
|
||||
if h.serverPublicExplicit {
|
||||
return h.serverPublic
|
||||
}
|
||||
@@ -24,8 +30,11 @@ func (h *Handler) effectiveLauncherPublic() bool {
|
||||
}
|
||||
|
||||
func (h *Handler) gatewayHostOverride() string {
|
||||
if h.serverHostExplicit {
|
||||
return strings.TrimSpace(h.serverHostInput)
|
||||
}
|
||||
if h.effectiveLauncherPublic() {
|
||||
return "0.0.0.0"
|
||||
return "*"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -41,10 +50,11 @@ func (h *Handler) effectiveGatewayBindHost(cfg *config.Config) string {
|
||||
}
|
||||
|
||||
func gatewayProbeHost(bindHost string) string {
|
||||
if bindHost == "" || bindHost == "0.0.0.0" {
|
||||
return "127.0.0.1"
|
||||
plan, err := netbind.BuildPlan(bindHost, netbind.DefaultLoopback)
|
||||
if err != nil || strings.TrimSpace(plan.ProbeHost) == "" {
|
||||
return netbind.ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
return bindHost
|
||||
return plan.ProbeHost
|
||||
}
|
||||
|
||||
func (h *Handler) gatewayProxyURL() *url.URL {
|
||||
@@ -72,7 +82,7 @@ func requestHostName(r *http.Request) string {
|
||||
if strings.TrimSpace(r.Host) != "" {
|
||||
return r.Host
|
||||
}
|
||||
return "127.0.0.1"
|
||||
return netbind.ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
|
||||
func requestWSScheme(r *http.Request) string {
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
)
|
||||
|
||||
@@ -26,8 +28,8 @@ func TestGatewayHostOverrideUsesExplicitRuntimePublic(t *testing.T) {
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, true, true, nil)
|
||||
|
||||
if got := h.gatewayHostOverride(); got != "0.0.0.0" {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want %q", got, "0.0.0.0")
|
||||
if got := h.gatewayHostOverride(); got != "*" {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want %q", got, "*")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,8 +66,36 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) {
|
||||
if got := gatewayProbeHost("0.0.0.0"); got != "127.0.0.1" {
|
||||
t.Fatalf("gatewayProbeHost() = %q, want %q", got, "127.0.0.1")
|
||||
want := "127.0.0.1"
|
||||
if got := gatewayProbeHost("0.0.0.0"); got != want {
|
||||
t.Fatalf("gatewayProbeHost() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProbeHostUsesPreferredLoopbackForEmptyBind(t *testing.T) {
|
||||
want := netbind.ResolveAdaptiveLoopbackHost()
|
||||
if got := gatewayProbeHost(""); got != want {
|
||||
t.Fatalf("gatewayProbeHost(empty) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProbeHostUsesPreferredLoopbackForLocalhostBind(t *testing.T) {
|
||||
want := netbind.ResolveAdaptiveLoopbackHost()
|
||||
if got := gatewayProbeHost("localhost"); got != want {
|
||||
t.Fatalf("gatewayProbeHost(localhost) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProbeHostUsesLoopbackForIPv6WildcardBind(t *testing.T) {
|
||||
want := "::1"
|
||||
if got := gatewayProbeHost("::"); got != want {
|
||||
t.Fatalf("gatewayProbeHost(::) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProbeHostUsesFirstConcreteHostForMultiHostBind(t *testing.T) {
|
||||
if got := gatewayProbeHost("127.0.0.1,::1"); got != "127.0.0.1" {
|
||||
t.Fatalf("gatewayProbeHost(multi) = %q, want %q", got, "127.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,8 +167,9 @@ func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) {
|
||||
_ = statusCode
|
||||
_ = err
|
||||
|
||||
if requestedURL != "http://127.0.0.1:18791/health" {
|
||||
t.Fatalf("health url = %q, want %q", requestedURL, "http://127.0.0.1:18791/health")
|
||||
want := "http://" + net.JoinHostPort(netbind.ResolveAdaptiveLoopbackHost(), "18791") + "/health"
|
||||
if requestedURL != want {
|
||||
t.Fatalf("health url = %q, want %q", requestedURL, want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,3 +271,43 @@ func TestBuildWsURLUsesRequestHostNotGatewayBindLoopback(t *testing.T) {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://localhost:18800/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayHostOverrideWithExplicitHostAndAlignedGatewayHost(t *testing.T) {
|
||||
h := NewHandler(filepath.Join(t.TempDir(), "config.json"))
|
||||
h.SetServerOptions(18800, false, false, nil)
|
||||
h.SetServerBindHost("0.0.0.0", true)
|
||||
|
||||
if got := h.gatewayHostOverride(); got != "0.0.0.0" {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want %q", got, "0.0.0.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayHostOverrideWithExplicitHostAndLocalhostGatewayHost(t *testing.T) {
|
||||
h := NewHandler(filepath.Join(t.TempDir(), "config.json"))
|
||||
h.SetServerOptions(18800, false, false, nil)
|
||||
h.SetServerBindHost("::", true)
|
||||
|
||||
if got := h.gatewayHostOverride(); got != "::" {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want %q", got, "::")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayHostOverrideWithExplicitMultiHost(t *testing.T) {
|
||||
h := NewHandler(filepath.Join(t.TempDir(), "config.json"))
|
||||
h.SetServerOptions(18800, false, false, nil)
|
||||
h.SetServerBindHost("127.0.0.1,::1", true)
|
||||
|
||||
if got := h.gatewayHostOverride(); got != "127.0.0.1,::1" {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want %q", got, "127.0.0.1,::1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayHostExplicitIgnoresPublicFlag(t *testing.T) {
|
||||
h := NewHandler(filepath.Join(t.TempDir(), "config.json"))
|
||||
h.SetServerOptions(18800, true, true, nil)
|
||||
h.SetServerBindHost("127.0.0.1", true)
|
||||
|
||||
if got := h.effectiveLauncherPublic(); got {
|
||||
t.Fatalf("effectiveLauncherPublic() = %t, want false when explicit host is set", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,6 +97,7 @@ func resetGatewayTestState(t *testing.T) {
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
originalProcessMatcher := gatewayProcessMatcher
|
||||
originalExecCommand := gatewayExecCommand
|
||||
originalRestartGracePeriod := gatewayRestartGracePeriod
|
||||
originalRestartForceKillWindow := gatewayRestartForceKillWindow
|
||||
originalRestartPollInterval := gatewayRestartPollInterval
|
||||
@@ -104,6 +105,7 @@ func resetGatewayTestState(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
gatewayProcessMatcher = originalProcessMatcher
|
||||
gatewayExecCommand = originalExecCommand
|
||||
gatewayRestartGracePeriod = originalRestartGracePeriod
|
||||
gatewayRestartForceKillWindow = originalRestartForceKillWindow
|
||||
gatewayRestartPollInterval = originalRestartPollInterval
|
||||
@@ -119,6 +121,159 @@ func resetGatewayTestState(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
type gatewayStartEnvSnapshot struct {
|
||||
GatewayHost string `json:"gateway_host"`
|
||||
GatewayHostSet bool `json:"gateway_host_set"`
|
||||
ConfigPath string `json:"config_path"`
|
||||
}
|
||||
|
||||
func TestGatewayStartHelperProcess(t *testing.T) {
|
||||
var envPath string
|
||||
for i, arg := range os.Args {
|
||||
if arg == "--" && i+2 < len(os.Args) && os.Args[i+1] == "gateway-env-helper" {
|
||||
envPath = os.Args[i+2]
|
||||
break
|
||||
}
|
||||
}
|
||||
if envPath == "" {
|
||||
t.Skip("helper process")
|
||||
}
|
||||
|
||||
host, ok := os.LookupEnv(config.EnvGatewayHost)
|
||||
raw, err := json.Marshal(gatewayStartEnvSnapshot{
|
||||
GatewayHost: host,
|
||||
GatewayHostSet: ok,
|
||||
ConfigPath: os.Getenv(config.EnvConfig),
|
||||
})
|
||||
if err != nil {
|
||||
_, _ = io.WriteString(os.Stderr, err.Error())
|
||||
os.Exit(2)
|
||||
}
|
||||
if err := os.WriteFile(envPath, raw, 0o600); err != nil {
|
||||
_, _ = io.WriteString(os.Stderr, err.Error())
|
||||
os.Exit(2)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func unsetGatewayStartEnvForTest(t *testing.T, key string) {
|
||||
t.Helper()
|
||||
|
||||
prev, hadPrev := os.LookupEnv(key)
|
||||
if err := os.Unsetenv(key); err != nil {
|
||||
t.Fatalf("Unsetenv(%q) error = %v", key, err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if hadPrev {
|
||||
_ = os.Setenv(key, prev)
|
||||
return
|
||||
}
|
||||
_ = os.Unsetenv(key)
|
||||
})
|
||||
}
|
||||
|
||||
func newGatewayStartTestHandler(t *testing.T) *Handler {
|
||||
t.Helper()
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, false, false, nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func startGatewayAndCaptureEnv(t *testing.T, h *Handler) gatewayStartEnvSnapshot {
|
||||
t.Helper()
|
||||
|
||||
unsetGatewayStartEnvForTest(t, config.EnvGatewayHost)
|
||||
|
||||
envPath := filepath.Join(t.TempDir(), "gateway-child-env.json")
|
||||
gatewayExecCommand = func(_ string, _ ...string) *exec.Cmd {
|
||||
return exec.Command(
|
||||
os.Args[0],
|
||||
"-test.run=TestGatewayStartHelperProcess",
|
||||
"--",
|
||||
"gateway-env-helper",
|
||||
envPath,
|
||||
)
|
||||
}
|
||||
|
||||
pid, err := h.startGatewayLocked("starting", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("startGatewayLocked() error = %v", err)
|
||||
}
|
||||
if pid <= 0 {
|
||||
t.Fatalf("startGatewayLocked() pid = %d, want > 0", pid)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for {
|
||||
raw, err := os.ReadFile(envPath)
|
||||
if err == nil {
|
||||
var snapshot gatewayStartEnvSnapshot
|
||||
err = json.Unmarshal(raw, &snapshot)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal(child env) error = %v", err)
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Fatalf("ReadFile(%q) error = %v", envPath, err)
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("timed out waiting for gateway child env snapshot %q", envPath)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGatewayLocked_ForwardsLauncherHostOverrideToGatewayEnv(t *testing.T) {
|
||||
h := newGatewayStartTestHandler(t)
|
||||
h.SetServerBindHost("127.0.0.1,::1", true)
|
||||
|
||||
snapshot := startGatewayAndCaptureEnv(t, h)
|
||||
if !snapshot.GatewayHostSet {
|
||||
t.Fatal("gateway host env was not set")
|
||||
}
|
||||
if snapshot.GatewayHost != "127.0.0.1,::1" {
|
||||
t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "127.0.0.1,::1")
|
||||
}
|
||||
if snapshot.ConfigPath != h.configPath {
|
||||
t.Fatalf("config env = %q, want %q", snapshot.ConfigPath, h.configPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGatewayLocked_ForwardsLauncherHostFromEnvironmentToGatewayEnv(t *testing.T) {
|
||||
h := newGatewayStartTestHandler(t)
|
||||
h.SetServerBindHost("::", true)
|
||||
|
||||
snapshot := startGatewayAndCaptureEnv(t, h)
|
||||
if !snapshot.GatewayHostSet {
|
||||
t.Fatal("gateway host env was not set")
|
||||
}
|
||||
if snapshot.GatewayHost != "::" {
|
||||
t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "::")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGatewayLocked_ForwardsWildcardHostForPublicLauncher(t *testing.T) {
|
||||
h := newGatewayStartTestHandler(t)
|
||||
h.SetServerOptions(18800, true, true, nil)
|
||||
|
||||
snapshot := startGatewayAndCaptureEnv(t, h)
|
||||
if !snapshot.GatewayHostSet {
|
||||
t.Fatal("gateway host env was not set")
|
||||
}
|
||||
if snapshot.GatewayHost != "*" {
|
||||
t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "*")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_NoDefaultModel(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
@@ -13,6 +14,8 @@ type Handler struct {
|
||||
serverPort int
|
||||
serverPublic bool
|
||||
serverPublicExplicit bool
|
||||
serverHostInput string
|
||||
serverHostExplicit bool
|
||||
serverCIDRs []string
|
||||
debug bool
|
||||
oauthMu sync.Mutex
|
||||
@@ -41,9 +44,21 @@ func (h *Handler) SetServerOptions(port int, public bool, publicExplicit bool, a
|
||||
h.serverPort = port
|
||||
h.serverPublic = public
|
||||
h.serverPublicExplicit = publicExplicit
|
||||
h.serverHostInput = ""
|
||||
h.serverHostExplicit = false
|
||||
h.serverCIDRs = append([]string(nil), allowedCIDRs...)
|
||||
}
|
||||
|
||||
// SetServerBindHost stores the launcher's effective bind host.
|
||||
// When explicit is true, hostInput is the normalized -host / PICOCLAW_LAUNCHER_HOST value.
|
||||
func (h *Handler) SetServerBindHost(hostInput string, explicit bool) {
|
||||
h.serverHostInput = strings.TrimSpace(hostInput)
|
||||
if !explicit {
|
||||
h.serverHostInput = ""
|
||||
}
|
||||
h.serverHostExplicit = explicit
|
||||
}
|
||||
|
||||
func (h *Handler) SetDebug(debug bool) {
|
||||
h.debug = debug
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user