diff --git a/internal/procutil/process_linux.go b/internal/procutil/process_linux.go index 15b46d4..8a011b0 100644 --- a/internal/procutil/process_linux.go +++ b/internal/procutil/process_linux.go @@ -9,18 +9,33 @@ import ( "fmt" "os" "path/filepath" + "strings" ) -// IsExpectedProcess checks if the process at pid is running the expected binary. -// On Linux, reads /proc//exe to verify the binary path. Returns false if -// the process does not exist or is running a different binary, preventing -// signals to recycled PIDs. +// IsExpectedProcess checks if the process at pid is running the expected +// binary. On Linux, reads /proc//exe to verify the binary path. Returns +// false if the process does not exist or is running a different binary, +// preventing signals to recycled PIDs. +// +// When expectedBinary is an absolute path, the comparison is against the +// full resolved exe path — this is the strong guarantee. When +// expectedBinary is just a name, the comparison falls back to the base +// name; two unrelated binaries with the same base name on the same host +// would collide under the fallback, so callers should pass absolute paths +// when possible. +// +// The "(deleted)" suffix that the kernel appends when the underlying +// binary has been unlinked post-exec is stripped so that processes still +// running from a since-removed extract directory are correctly identified. func IsExpectedProcess(pid int, expectedBinary string) bool { exePath, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", pid)) if err != nil { return false // Process gone or no permission } - // Compare base names: the state may store just the binary name while - // /proc/pid/exe returns the full resolved path. + exePath = strings.TrimSuffix(exePath, " (deleted)") + + if filepath.IsAbs(expectedBinary) { + return filepath.Clean(exePath) == filepath.Clean(expectedBinary) + } return filepath.Base(exePath) == filepath.Base(expectedBinary) } diff --git a/internal/procutil/process_linux_test.go b/internal/procutil/process_linux_test.go index e88d01f..1ba6a9a 100644 --- a/internal/procutil/process_linux_test.go +++ b/internal/procutil/process_linux_test.go @@ -23,7 +23,9 @@ func TestIsExpectedProcess_Self(t *testing.T) { } } -func TestIsExpectedProcess_SelfBaseName(t *testing.T) { +func TestIsExpectedProcess_SelfBaseNameFallback(t *testing.T) { + // When expectedBinary is NOT absolute (just a bare name), the fallback + // base-name comparison applies. pid := os.Getpid() selfExe, err := os.Executable() @@ -31,8 +33,24 @@ func TestIsExpectedProcess_SelfBaseName(t *testing.T) { t.Fatalf("failed to get self executable: %v", err) } baseName := filepath.Base(selfExe) - if !IsExpectedProcess(pid, "/some/other/path/"+baseName) { - t.Errorf("IsExpectedProcess with different dir but same base name should return true") + if !IsExpectedProcess(pid, baseName) { + t.Errorf("IsExpectedProcess with bare base name should match self") + } +} + +func TestIsExpectedProcess_AbsolutePathMismatch(t *testing.T) { + // When expectedBinary IS absolute, a different directory with the same + // base name must NOT match. This is the strengthened guarantee against + // unrelated binaries with colliding base names. + pid := os.Getpid() + + selfExe, err := os.Executable() + if err != nil { + t.Fatalf("failed to get self executable: %v", err) + } + baseName := filepath.Base(selfExe) + if IsExpectedProcess(pid, "/some/other/path/"+baseName) { + t.Errorf("absolute path mismatch must not match even with same base name") } } diff --git a/microvm.go b/microvm.go index cefa654..b8e6410 100644 --- a/microvm.go +++ b/microvm.go @@ -98,11 +98,10 @@ func Run(ctx context.Context, imageRef string, opts ...Option) (*VM, error) { slog.Warn("egress policy overrides firewall default action to Deny") } cfg.firewallDefaultAction = firewall.Deny - if cfg.netProvider == nil { - cfg.netProvider = hosted.NewProvider() - } } + wireDefaultProvider(cfg) + // 1. Preflight checks. { ctx, span := tracer.Start(ctx, "microvm.Preflight") @@ -312,6 +311,7 @@ func Run(ctx context.Context, imageRef string, opts ...Option) (*VM, error) { ls.State.Name = cfg.name if pid, pidErr := pidFromID(handle.ID()); pidErr == nil { ls.State.PID = pid + ls.State.PIDStartTime = time.Now().UTC() } else { slog.Warn("could not persist VM PID", "id", handle.ID(), "error", pidErr) } @@ -354,6 +354,21 @@ const ( staleTermPoll = 250 * time.Millisecond ) +// wireDefaultProvider auto-creates a hosted network provider when any +// firewall configuration (egress policy, static rules, or a non-Allow +// default action) is set but no provider was supplied explicitly. The +// default runner-side networking path does not enforce firewall rules, +// so without this the caller's deny-default would silently degrade to +// allow-all. No-op when a provider is already set. +func wireDefaultProvider(cfg *config) { + firewallConfigured := cfg.egressPolicy != nil || + len(cfg.firewallRules) > 0 || + cfg.firewallDefaultAction != firewall.Allow + if firewallConfigured && cfg.netProvider == nil { + cfg.netProvider = hosted.NewProvider() + } +} + func cleanDataDir(cfg *config) error { if cfg.dataDir == "" { return nil @@ -409,6 +424,13 @@ func terminateStaleRunner(cfg *config) { slog.Debug("stale runner already dead", "pid", st.PID) return } + if !cfg.processIsExpected(st.PID) { + // PID has been recycled onto an unrelated binary since we wrote + // the state file. Signalling it would kill the wrong process + // group (or fail silently if we lack permission). Bail out. + slog.Warn("stale PID does not match expected runner binary, skipping termination", "pid", st.PID) + return + } // Use negative PID to signal the entire process group (PGID == PID // because the runner starts with Setsid: true). This ensures any diff --git a/microvm_test.go b/microvm_test.go index 04c39ff..4dc682e 100644 --- a/microvm_test.go +++ b/microvm_test.go @@ -21,11 +21,20 @@ import ( "github.com/stacklok/go-microvm/hypervisor" "github.com/stacklok/go-microvm/image" "github.com/stacklok/go-microvm/internal/testutil" + propnet "github.com/stacklok/go-microvm/net" "github.com/stacklok/go-microvm/net/firewall" "github.com/stacklok/go-microvm/preflight" "github.com/stacklok/go-microvm/state" ) +// sentinelProvider is a minimal net.Provider used by tests to assert that +// a caller-supplied provider survives auto-wiring without being replaced. +type sentinelProvider struct{} + +func (*sentinelProvider) Start(_ context.Context, _ propnet.Config) error { return nil } +func (*sentinelProvider) SocketPath() string { return "" } +func (*sentinelProvider) Stop() {} + // --- Pure function tests --- func TestBuildInitConfig_NilOCIConfig(t *testing.T) { @@ -660,6 +669,53 @@ func TestBuildNetConfig_Empty(t *testing.T) { // --- Egress validation tests --- +func TestWireDefaultProvider(t *testing.T) { + t.Parallel() + + t.Run("no firewall config leaves provider nil", func(t *testing.T) { + t.Parallel() + cfg := defaultConfig() + wireDefaultProvider(cfg) + assert.Nil(t, cfg.netProvider) + }) + + t.Run("egress policy auto-wires provider", func(t *testing.T) { + t.Parallel() + cfg := defaultConfig() + cfg.egressPolicy = &EgressPolicy{} + wireDefaultProvider(cfg) + assert.NotNil(t, cfg.netProvider) + }) + + t.Run("firewall rules alone auto-wire provider", func(t *testing.T) { + t.Parallel() + cfg := defaultConfig() + cfg.firewallRules = []firewall.Rule{{Direction: firewall.Egress, Action: firewall.Allow}} + wireDefaultProvider(cfg) + assert.NotNil(t, cfg.netProvider, + "firewall-only config must auto-wire a provider; otherwise rules go unenforced") + }) + + t.Run("deny default alone auto-wires provider", func(t *testing.T) { + t.Parallel() + cfg := defaultConfig() + cfg.firewallDefaultAction = firewall.Deny + wireDefaultProvider(cfg) + assert.NotNil(t, cfg.netProvider, + "deny-default config must auto-wire a provider to actually deny") + }) + + t.Run("explicit provider is not overwritten", func(t *testing.T) { + t.Parallel() + existing := &sentinelProvider{} + cfg := defaultConfig() + cfg.netProvider = existing + cfg.firewallDefaultAction = firewall.Deny + wireDefaultProvider(cfg) + assert.Same(t, existing, cfg.netProvider) + }) +} + func TestRun_EgressPolicy_EmptyHosts_DenyAll(t *testing.T) { t.Parallel() @@ -862,6 +918,7 @@ func TestTerminateStaleRunner_AliveProcess_GracefulExit(t *testing.T) { // (after SIGTERM + first poll). return aliveCount <= 1 } + cfg.processIsExpected = func(_ int) bool { return true } terminateStaleRunner(cfg) @@ -899,6 +956,7 @@ func TestTerminateStaleRunner_AliveProcess_RequiresKill(t *testing.T) { } // Process never exits on its own. cfg.processAlive = func(_ int) bool { return true } + cfg.processIsExpected = func(_ int) bool { return true } terminateStaleRunner(cfg) @@ -941,6 +999,7 @@ func TestTerminateStaleRunner_SendsToProcessGroup(t *testing.T) { aliveCount++ return aliveCount <= 1 } + cfg.processIsExpected = func(_ int) bool { return true } terminateStaleRunner(cfg) @@ -950,6 +1009,38 @@ func TestTerminateStaleRunner_SendsToProcessGroup(t *testing.T) { assert.Equal(t, -55555, receivedPIDs[0], "killProcess should receive negative PID for process group") } +func TestTerminateStaleRunner_RecycledPID_Skipped(t *testing.T) { + t.Parallel() + + // The state file points at a live PID, but processIsExpected says the + // binary at that PID is not the runner (as if the kernel had recycled + // the PID onto an unrelated process since state was written). The + // function must refuse to signal it. + dataDir := t.TempDir() + + mgr := state.NewManager(dataDir) + ls, err := mgr.LoadAndLock(context.Background()) + require.NoError(t, err) + ls.State.Active = true + ls.State.PID = 77777 + require.NoError(t, ls.Save()) + ls.Release() + + cfg := defaultConfig() + cfg.dataDir = dataDir + + var killCalled bool + cfg.killProcess = func(_ int, _ syscall.Signal) error { + killCalled = true + return nil + } + cfg.processAlive = func(_ int) bool { return true } + cfg.processIsExpected = func(_ int) bool { return false } + + terminateStaleRunner(cfg) + assert.False(t, killCalled, "must not signal a recycled PID belonging to an unrelated binary") +} + func TestTerminateStaleRunner_PID1_Skipped(t *testing.T) { t.Parallel() @@ -973,6 +1064,7 @@ func TestTerminateStaleRunner_PID1_Skipped(t *testing.T) { return nil } cfg.processAlive = func(_ int) bool { return true } + cfg.processIsExpected = func(_ int) bool { return true } terminateStaleRunner(cfg) assert.False(t, killCalled, "should not attempt to kill PID 1") diff --git a/options.go b/options.go index 252be1c..b178e88 100644 --- a/options.go +++ b/options.go @@ -11,6 +11,7 @@ import ( "github.com/stacklok/go-microvm/hypervisor" "github.com/stacklok/go-microvm/image" + "github.com/stacklok/go-microvm/internal/procutil" "github.com/stacklok/go-microvm/net" "github.com/stacklok/go-microvm/net/firewall" "github.com/stacklok/go-microvm/preflight" @@ -103,6 +104,7 @@ type config struct { stat func(string) (os.FileInfo, error) killProcess func(pid int, sig syscall.Signal) error processAlive func(pid int) bool + processIsExpected func(pid int) bool } func defaultConfig() *config { @@ -126,9 +128,17 @@ func defaultConfig() *config { } return proc.Signal(syscall.Signal(0)) == nil }, + processIsExpected: func(pid int) bool { + return procutil.IsExpectedProcess(pid, runnerBinaryName) + }, } } +// runnerBinaryName is the base name of the runner executable — used by +// the default processIsExpected check to distinguish the go-microvm +// runner from an unrelated process that happens to be at the same PID. +const runnerBinaryName = "go-microvm-runner" + func defaultDataDir() string { if dir := os.Getenv("GO_MICROVM_DATA_DIR"); dir != "" { return dir diff --git a/runner/process_linux_test.go b/runner/process_linux_test.go index bcee405..85a2fbb 100644 --- a/runner/process_linux_test.go +++ b/runner/process_linux_test.go @@ -28,8 +28,9 @@ func TestIsExpectedProcess_Self(t *testing.T) { } } -func TestIsExpectedProcess_SelfBaseName(t *testing.T) { - // Should match by base name even if full paths differ. +func TestIsExpectedProcess_BaseNameFallbackForRelative(t *testing.T) { + // A relative/bare binary name still matches by base name — that is the + // documented fallback. Absolute mismatches must no longer pass. pid := os.Getpid() selfExe, err := os.Executable() @@ -37,8 +38,21 @@ func TestIsExpectedProcess_SelfBaseName(t *testing.T) { t.Fatalf("failed to get self executable: %v", err) } baseName := selfExe[len(selfExe)-len("runner.test"):] // last component - if !isExpectedProcess(pid, "/some/other/path/"+baseName) { - t.Errorf("isExpectedProcess with different dir but same base name should return true") + if !isExpectedProcess(pid, baseName) { + t.Errorf("isExpectedProcess with bare base name should match self") + } +} + +func TestIsExpectedProcess_AbsolutePathMismatchFails(t *testing.T) { + pid := os.Getpid() + + selfExe, err := os.Executable() + if err != nil { + t.Fatalf("failed to get self executable: %v", err) + } + baseName := selfExe[len(selfExe)-len("runner.test"):] + if isExpectedProcess(pid, "/some/other/path/"+baseName) { + t.Errorf("absolute path with different dir must not match even when base name matches") } } diff --git a/state/state.go b/state/state.go index 3553595..4f57de9 100644 --- a/state/state.go +++ b/state/state.go @@ -61,6 +61,13 @@ type State struct { // PID is the process ID of the VM runner, or 0 if not running. PID int `json:"pid,omitempty"` + // PIDStartTime records wall-clock time when PID was recorded. Used + // to disambiguate a recycled PID from the original runner in contexts + // where /proc/PID/exe comparison is unavailable (e.g. macOS) or as + // belt-and-suspenders alongside the exe-path check on Linux. + // Zero time on state files written before this field was introduced. + PIDStartTime time.Time `json:"pid_start_time,omitempty"` + // CreatedAt is the time the VM state was first created. CreatedAt time.Time `json:"created_at"` } diff --git a/state/state_test.go b/state/state_test.go index 4561607..58ece10 100644 --- a/state/state_test.go +++ b/state/state_test.go @@ -58,6 +58,51 @@ func TestManager_SaveLoad_RoundTrip(t *testing.T) { assert.Equal(t, stateVersion, loaded.Version) } +func TestManager_PIDStartTime_RoundTrip(t *testing.T) { + t.Parallel() + + dataDir := t.TempDir() + mgr := NewManager(dataDir) + + ls, err := mgr.LoadAndLock(context.Background()) + require.NoError(t, err) + + want := time.Date(2026, 4, 17, 10, 0, 0, 0, time.UTC) + ls.State.PID = 42 + ls.State.PIDStartTime = want + + require.NoError(t, ls.Save()) + ls.Release() + + loaded, err := mgr.Load() + require.NoError(t, err) + assert.Equal(t, 42, loaded.PID) + assert.True(t, loaded.PIDStartTime.Equal(want), + "PIDStartTime round-trip lost fidelity: got %v, want %v", + loaded.PIDStartTime, want) +} + +func TestManager_Load_MissingPIDStartTime_IsZero(t *testing.T) { + t.Parallel() + + // Legacy state files written before the PIDStartTime field was + // introduced must still load cleanly; the field should come back as + // the zero time. + dataDir := t.TempDir() + mgr := NewManager(dataDir) + + legacy := []byte(`{"version":1,"name":"legacy","pid":100,"created_at":"2025-01-01T00:00:00Z"}`) + require.NoError(t, os.MkdirAll(dataDir, 0o700)) + require.NoError(t, os.WriteFile(filepath.Join(dataDir, stateFileName), legacy, 0o600)) + + loaded, err := mgr.Load() + require.NoError(t, err) + assert.Equal(t, 100, loaded.PID) + assert.True(t, loaded.PIDStartTime.IsZero(), + "missing PIDStartTime should unmarshal to zero time; got %v", + loaded.PIDStartTime) +} + func TestManager_LoadAndLock_SaveUnderLock(t *testing.T) { t.Parallel()