diff --git a/daemon/command/config_unix.go b/daemon/command/config_unix.go index ecc892e0b73e2..c51391c56acf8 100644 --- a/daemon/command/config_unix.go +++ b/daemon/command/config_unix.go @@ -50,6 +50,7 @@ func installConfigFlags(conf *config.Config, flags *pflag.FlagSet) { flags.StringVar(&conf.SeccompProfile, "seccomp-profile", conf.SeccompProfile, `Path to seccomp profile. Set to "unconfined" to disable the default seccomp profile`) flags.Var(&conf.ShmSize, "default-shm-size", "Default shm size for containers") flags.BoolVar(&conf.NoNewPrivileges, "no-new-privileges", false, "Set no-new-privileges by default for new containers") + flags.BoolVar(&conf.AdoptUserCgroups, "adopt-user-cgroups", false, "Automatically set container cgroup parent based on the API client's cgroup") flags.StringVar(&conf.IpcMode, "default-ipc-mode", conf.IpcMode, `Default mode for containers ipc ("shareable" | "private")`) flags.Var(&conf.NetworkConfig.DefaultAddressPools, "default-address-pool", "Default address pools for node specific local networks") flags.StringVar(&conf.NetworkConfig.FirewallBackend, "firewall-backend", "", "Firewall backend to use, iptables or nftables") diff --git a/daemon/command/daemon.go b/daemon/command/daemon.go index 366a98a90a207..899a97e7d467e 100644 --- a/daemon/command/daemon.go +++ b/daemon/command/daemon.go @@ -211,6 +211,12 @@ func (cli *daemonCLI) start(ctx context.Context) (retErr error) { httpServer := &http.Server{ ReadHeaderTimeout: 5 * time.Minute, // "G112: Potential Slowloris Attack (gosec)"; not a real concern for our use, so setting a long timeout. + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + // Store the connection in context so middleware can access it for peer credentials + // Use a custom key instead of http.LocalAddrContextKey because the HTTP stack + // overwrites that with the address, losing the connection. + return context.WithValue(ctx, middleware.PeerConnKey, c) + }, } apiShutdownCtx, apiShutdownCancel := context.WithCancel(context.WithoutCancel(ctx)) apiShutdownDone := make(chan struct{}) @@ -868,6 +874,12 @@ func initMiddlewares(_ context.Context, s *apiserver.Server, cfg *config.Config, } s.UseMiddleware(*vm) + // Register peer credential middleware for Unix socket connections. + // This extracts UID/GID/PID from the connection and adds them to request context. + // Required for features like cgroup adoption that need to know the API client's identity. + peerCredMiddleware := middleware.NewPeerCredMiddleware() + s.UseMiddleware(peerCredMiddleware) + authzMiddleware := authorization.NewMiddleware(cfg.AuthorizationPlugins, pluginStore) s.UseMiddleware(authzMiddleware) return authzMiddleware, nil diff --git a/daemon/config/config_linux.go b/daemon/config/config_linux.go index 80c0ab8e49712..060138fa80fce 100644 --- a/daemon/config/config_linux.go +++ b/daemon/config/config_linux.go @@ -94,6 +94,10 @@ type Config struct { // ResolvConf is the path to the configuration of the host resolver ResolvConf string `json:"resolv-conf,omitempty"` Rootless bool `json:"rootless,omitempty"` + // AdoptUserCgroups forces containers to inherit their creator's cgroup parent. + // When enabled, containers cannot override CgroupParent and will be placed under + // the cgroup of the process making the API request (requires Unix socket connection). + AdoptUserCgroups bool `json:"adopt-user-cgroups,omitempty"` } // GetExecRoot returns the user configured Exec-root diff --git a/daemon/create.go b/daemon/create.go index 70125f4add2cc..d6a38c01c82ac 100644 --- a/daemon/create.go +++ b/daemon/create.go @@ -119,7 +119,7 @@ func (daemon *Daemon) containerCreate(ctx context.Context, daemonCfg *configStor if opts.params.HostConfig == nil { opts.params.HostConfig = &containertypes.HostConfig{} } - err = daemon.adaptContainerSettings(&daemonCfg.Config, opts.params.HostConfig) + err = daemon.adaptContainerSettings(ctx, &daemonCfg.Config, opts.params.HostConfig) if err != nil { return containertypes.CreateResponse{Warnings: warnings}, errdefs.InvalidParameter(err) } diff --git a/daemon/daemon_unix.go b/daemon/daemon_unix.go index 5e74287addbd6..4a85b035b326a 100644 --- a/daemon/daemon_unix.go +++ b/daemon/daemon_unix.go @@ -31,6 +31,7 @@ import ( "github.com/moby/moby/v2/daemon/internal/otelutil" "github.com/moby/moby/v2/daemon/internal/usergroup" "github.com/moby/moby/v2/daemon/libnetwork" + "github.com/moby/moby/v2/daemon/server/middleware" nwconfig "github.com/moby/moby/v2/daemon/libnetwork/config" "github.com/moby/moby/v2/daemon/libnetwork/drivers/bridge" "github.com/moby/moby/v2/daemon/libnetwork/netlabel" @@ -39,6 +40,7 @@ import ( "github.com/moby/moby/v2/daemon/pkg/opts" volumemounts "github.com/moby/moby/v2/daemon/volume/mounts" "github.com/moby/moby/v2/errdefs" + cgroupsadopt "github.com/moby/moby/v2/pkg/cgroups" "github.com/moby/moby/v2/pkg/sysinfo" "github.com/moby/sys/mount" "github.com/moby/sys/user" @@ -317,7 +319,7 @@ func adjustParallelLimit(n int, limit int) int { // adaptContainerSettings is called during container creation to modify any // settings necessary in the HostConfig structure. -func (daemon *Daemon) adaptContainerSettings(daemonCfg *config.Config, hostConfig *containertypes.HostConfig) error { +func (daemon *Daemon) adaptContainerSettings(ctx context.Context, daemonCfg *config.Config, hostConfig *containertypes.HostConfig) error { if hostConfig.Memory > 0 && hostConfig.MemorySwap == 0 { // By default, MemorySwap is set to twice the size of Memory. hostConfig.MemorySwap = hostConfig.Memory * 2 @@ -368,6 +370,42 @@ func (daemon *Daemon) adaptContainerSettings(daemonCfg *config.Config, hostConfi hostConfig.OomKillDisable = &defaultOomKillDisable } + // Apply cgroup adoption if enabled + if daemonCfg.AdoptUserCgroups { + if err := daemon.applyCgroupAdoption(ctx, hostConfig); err != nil { + return fmt.Errorf("failed to apply cgroup adoption: %w", err) + } + } + + return nil +} + +// applyCgroupAdoption enforces cgroup parent adoption based on the API client's cgroup. +// When enabled via daemon config, this ensures containers run under their creator's cgroup. +func (daemon *Daemon) applyCgroupAdoption(ctx context.Context, hostConfig *containertypes.HostConfig) error { + // Extract peer credentials from context (set by peer credential middleware) + creds, ok := ctx.Value(middleware.PeerCredKey).(*middleware.PeerCredentials) + if !ok || creds == nil { + return fmt.Errorf("peer credentials not available") + } + + // Derive the cgroup parent from the peer's PID + parent, err := cgroupsadopt.DeriveParentFromPid(creds.PID) + if err != nil { + return fmt.Errorf("failed to derive cgroup parent: %w", err) + } + + // ENFORCE: Reject if user specified a different cgroup parent + // This ensures ALL containers run under their creator's cgroup without exception + if hostConfig.CgroupParent != "" && hostConfig.CgroupParent != parent { + return errdefs.InvalidParameter(fmt.Errorf( + "cannot set cgroup parent when --adopt-user-cgroups is enabled: "+ + "containers must run under creator's cgroup (%s)", parent)) + } + + // Set the adopted cgroup parent + hostConfig.CgroupParent = parent + return nil } diff --git a/daemon/daemon_unix_test.go b/daemon/daemon_unix_test.go index b4c739023c654..88b8c26391d7e 100644 --- a/daemon/daemon_unix_test.go +++ b/daemon/daemon_unix_test.go @@ -364,3 +364,4 @@ func TestGetBlkioThrottleDevices(t *testing.T) { assert.Check(t, retDevs[0].Rate == WEIGHT, "get device rate") }) } + diff --git a/daemon/daemon_windows.go b/daemon/daemon_windows.go index c49ec4c03e49a..c2a5f14c66031 100644 --- a/daemon/daemon_windows.go +++ b/daemon/daemon_windows.go @@ -61,7 +61,7 @@ func setupInitLayer(uid int, gid int) func(string) error { // adaptContainerSettings is called during container creation to modify any // settings necessary in the HostConfig structure. -func (daemon *Daemon) adaptContainerSettings(daemonCfg *config.Config, hostConfig *containertypes.HostConfig) error { +func (daemon *Daemon) adaptContainerSettings(ctx context.Context, daemonCfg *config.Config, hostConfig *containertypes.HostConfig) error { return nil } diff --git a/daemon/server/middleware/peercred_linux.go b/daemon/server/middleware/peercred_linux.go new file mode 100644 index 0000000000000..ce3689b913cdf --- /dev/null +++ b/daemon/server/middleware/peercred_linux.go @@ -0,0 +1,105 @@ +package middleware + +import ( + "context" + "fmt" + "net" + "net/http" + "syscall" + + "golang.org/x/sys/unix" +) + +// PeerCredKey is the context key for storing peer credentials +var PeerCredKey = &struct{ name string }{"peercred"} + +// PeerConnKey is the context key for storing the raw connection (set by ConnContext) +// We use a custom key instead of http.LocalAddrContextKey because the HTTP stack +// overwrites that key with the address, losing the original connection. +var PeerConnKey = &struct{ name string }{"peerconn"} + +// PeerCredentials contains the credentials of a peer connection +type PeerCredentials struct { + PID int // Process ID + UID int // User ID + GID int // Group ID +} + +// PeerCredMiddleware extracts peer credentials from Unix socket connections +// and adds them to the request context. +type PeerCredMiddleware struct{} + +// NewPeerCredMiddleware creates a new peer credential middleware +func NewPeerCredMiddleware() PeerCredMiddleware { + return PeerCredMiddleware{} +} + +// WrapHandler wraps an HTTP handler to extract peer credentials from Unix socket connections +func (m PeerCredMiddleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + // Attempt to extract peer credentials from the connection + if creds, err := extractPeerCredentials(r); err == nil && creds != nil { + // Add credentials to context for downstream handlers + ctx = context.WithValue(ctx, PeerCredKey, creds) + } + + return handler(ctx, w, r, vars) + } +} + +// extractPeerCredentials extracts the peer credentials from an HTTP request +// by accessing the underlying Unix socket file descriptor and calling SO_PEERCRED. +// +// This only works for Unix domain socket connections. For TCP connections or +// other transport types, this function returns nil, nil (no error, no credentials). +func extractPeerCredentials(r *http.Request) (*PeerCredentials, error) { + // Try to get the underlying connection from the request context + // We use PeerConnKey (set by ConnContext) instead of http.LocalAddrContextKey + // because the HTTP stack overwrites that key with the address. + conn, ok := r.Context().Value(PeerConnKey).(net.Conn) + if !ok || conn == nil { + // Not a direct connection or connection not available - this is expected for some scenarios + return nil, nil + } + + // Cast to syscall.Conn to get access to raw file descriptor operations + sc, ok := conn.(syscall.Conn) + if !ok { + // Connection doesn't support syscall operations - probably not a Unix socket + return nil, nil + } + + // Get the raw syscall connection + rc, err := sc.SyscallConn() + if err != nil { + return nil, fmt.Errorf("failed to get syscall connection: %w", err) + } + + // Extract peer credentials using SO_PEERCRED + var creds *PeerCredentials + var ctrlErr error + + // Control() provides access to the underlying file descriptor + err = rc.Control(func(fd uintptr) { + ucred, err := unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED) + if err != nil { + ctrlErr = fmt.Errorf("SO_PEERCRED failed: %w", err) + return + } + + creds = &PeerCredentials{ + PID: int(ucred.Pid), + UID: int(ucred.Uid), + GID: int(ucred.Gid), + } + }) + + if err != nil { + return nil, fmt.Errorf("failed to access file descriptor: %w", err) + } + if ctrlErr != nil { + return nil, ctrlErr + } + + return creds, nil +} diff --git a/daemon/server/middleware/peercred_test.go b/daemon/server/middleware/peercred_test.go new file mode 100644 index 0000000000000..2fb20389697b4 --- /dev/null +++ b/daemon/server/middleware/peercred_test.go @@ -0,0 +1,102 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "gotest.tools/v3/assert" +) + +func TestPeerCredentials_ContextValue(t *testing.T) { + // Test that PeerCredKey can be used to store/retrieve credentials from context + ctx := context.Background() + + creds := &PeerCredentials{ + PID: 1234, + UID: 1000, + GID: 1000, + } + + ctx = context.WithValue(ctx, PeerCredKey, creds) + + retrieved, ok := ctx.Value(PeerCredKey).(*PeerCredentials) + assert.Assert(t, ok, "should be able to retrieve peer credentials from context") + assert.Equal(t, retrieved.PID, 1234) + assert.Equal(t, retrieved.UID, 1000) + assert.Equal(t, retrieved.GID, 1000) +} + +func TestPeerCredentials_NilContext(t *testing.T) { + // Test that retrieving from context without credentials returns nil gracefully + ctx := context.Background() + + retrieved, ok := ctx.Value(PeerCredKey).(*PeerCredentials) + assert.Assert(t, !ok || retrieved == nil, "should return nil when no credentials in context") +} + +// TestPeerCredMiddleware_UnixSocket tests that the middleware properly handles Unix socket connections +// Note: This is a basic structure test. Actual SO_PEERCRED extraction can only be tested with real Unix sockets. +func TestPeerCredMiddleware_Structure(t *testing.T) { + middleware := NewPeerCredMiddleware() + + handlerCalled := false + testHandler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + handlerCalled = true + return nil + } + + wrapped := middleware.WrapHandler(testHandler) + + // Create a test request + req := httptest.NewRequest("GET", "http://example.com/test", nil) + w := httptest.NewRecorder() + + // Call the wrapped handler + err := wrapped(context.Background(), w, req, nil) + + assert.NilError(t, err) + assert.Assert(t, handlerCalled, "handler should have been called") +} + +func TestPeerCredMiddleware_NoConnection(t *testing.T) { + // Test that middleware doesn't fail when no connection is in context + middleware := NewPeerCredMiddleware() + + var capturedCtx context.Context + testHandler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + capturedCtx = ctx + return nil + } + + wrapped := middleware.WrapHandler(testHandler) + + // Create a test request without connection in context + req := httptest.NewRequest("GET", "http://example.com/test", nil) + w := httptest.NewRecorder() + + err := wrapped(context.Background(), w, req, nil) + + assert.NilError(t, err) + // Verify no credentials were added (since no connection was available) + creds, ok := capturedCtx.Value(PeerCredKey).(*PeerCredentials) + assert.Assert(t, !ok || creds == nil, "should not have credentials when no connection") +} + +func TestPeerConnKey_Uniqueness(t *testing.T) { + // Verify that PeerConnKey is distinct from http.LocalAddrContextKey + // This is important because http.LocalAddrContextKey gets overwritten by the HTTP stack + ctx := context.Background() + + // Simulate what happens in ConnContext and the HTTP stack + ctx = context.WithValue(ctx, PeerConnKey, "connection") + ctx = context.WithValue(ctx, http.LocalAddrContextKey, "address") + + // Both values should be retrievable independently + conn := ctx.Value(PeerConnKey) + addr := ctx.Value(http.LocalAddrContextKey) + + assert.Equal(t, conn, "connection") + assert.Equal(t, addr, "address") +} diff --git a/daemon/server/middleware/peercred_unsupported.go b/daemon/server/middleware/peercred_unsupported.go new file mode 100644 index 0000000000000..6e73323bac293 --- /dev/null +++ b/daemon/server/middleware/peercred_unsupported.go @@ -0,0 +1,31 @@ +//go:build !linux + +package middleware + +import ( + "context" + "net/http" +) + +// PeerCredKey is the context key for storing peer credentials +var PeerCredKey = &struct{ name string }{"peercred"} + +// PeerCredentials contains the credentials of a peer connection +type PeerCredentials struct { + PID int // Process ID + UID int // User ID + GID int // Group ID +} + +// PeerCredMiddleware is a no-op on non-Linux platforms +type PeerCredMiddleware struct{} + +// NewPeerCredMiddleware creates a new peer credential middleware +func NewPeerCredMiddleware() PeerCredMiddleware { + return PeerCredMiddleware{} +} + +// WrapHandler returns the handler unchanged on non-Linux platforms +func (m PeerCredMiddleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + return handler +} diff --git a/daemon/server/server.go b/daemon/server/server.go index e400e87609b53..bda789546eac6 100644 --- a/daemon/server/server.go +++ b/daemon/server/server.go @@ -61,6 +61,7 @@ func (s *Server) makeHTTPHandler(route router.Route) http.HandlerFunc { // use intermediate variable to prevent "should not use basic type // string as key in context.WithValue" golint errors ua := r.Header.Get("User-Agent") + ctx := baggage.ContextWithBaggage(context.WithValue(r.Context(), dockerversion.UAStringKey{}, ua), otelutil.MustNewBaggage( otelutil.MustNewMemberRaw(otelutil.TriggerKey, "api"), )) diff --git a/integration/container/cgroup_adoption_test.go b/integration/container/cgroup_adoption_test.go new file mode 100644 index 0000000000000..3f4f69aaea190 --- /dev/null +++ b/integration/container/cgroup_adoption_test.go @@ -0,0 +1,169 @@ +//go:build !windows + +package container + +import ( + "context" + "fmt" + "os" + "testing" + + cerrdefs "github.com/containerd/errdefs" + containertypes "github.com/moby/moby/api/types/container" + "github.com/moby/moby/client" + cgroupsadopt "github.com/moby/moby/v2/pkg/cgroups" + "github.com/moby/moby/v2/integration/internal/container" + "github.com/moby/moby/v2/internal/testutil" + "github.com/moby/moby/v2/internal/testutil/daemon" + "gotest.tools/v3/assert" + is "gotest.tools/v3/assert/cmp" + "gotest.tools/v3/skip" +) + +// TestCgroupAdoptionEnabled verifies that when --adopt-user-cgroups is enabled, +// containers inherit their creator's cgroup parent. +func TestCgroupAdoptionEnabled(t *testing.T) { + skip.If(t, os.Getuid() != 0, "requires root") + + ctx := testutil.StartSpan(baseContext, t) + + d := daemon.New(t) + defer d.Stop(t) + d.Start(t, "--adopt-user-cgroups") + + apiClient := d.NewClientT(t) + defer apiClient.Close() + + // Derive expected cgroup parent from current process + expectedParent, err := cgroupsadopt.DeriveParentFromPid(os.Getpid()) + assert.NilError(t, err) + + // Create container without specifying CgroupParent + cID := container.Run(ctx, t, apiClient) + defer apiClient.ContainerRemove(ctx, cID, client.ContainerRemoveOptions{Force: true}) + + // Verify container's CgroupParent matches our cgroup + inspect, err := apiClient.ContainerInspect(ctx, cID, client.ContainerInspectOptions{}) + assert.NilError(t, err) + assert.Equal(t, inspect.Container.HostConfig.CgroupParent, expectedParent) +} + +// TestCgroupAdoptionDisabled verifies that cgroup adoption is disabled by default. +func TestCgroupAdoptionDisabled(t *testing.T) { + skip.If(t, os.Getuid() != 0, "requires root") + + ctx := testutil.StartSpan(baseContext, t) + + d := daemon.New(t) + defer d.Stop(t) + d.Start(t) // No --adopt-user-cgroups flag + + apiClient := d.NewClientT(t) + defer apiClient.Close() + + // Create container + cID := container.Run(ctx, t, apiClient) + defer apiClient.ContainerRemove(ctx, cID, client.ContainerRemoveOptions{Force: true}) + + // Verify container's CgroupParent is empty (not adopted) + inspect, err := apiClient.ContainerInspect(ctx, cID, client.ContainerInspectOptions{}) + assert.NilError(t, err) + assert.Equal(t, inspect.Container.HostConfig.CgroupParent, "") +} + +// TestCgroupAdoptionUserOverrideRejected verifies that when --adopt-user-cgroups is enabled, +// users cannot override the cgroup parent with a different value. +func TestCgroupAdoptionUserOverrideRejected(t *testing.T) { + skip.If(t, os.Getuid() != 0, "requires root") + + ctx := testutil.StartSpan(baseContext, t) + + d := daemon.New(t) + defer d.Stop(t) + d.Start(t, "--adopt-user-cgroups") + + apiClient := d.NewClientT(t) + defer apiClient.Close() + + customParent := "/docker/custom-parent" + + // Attempt to create container WITH explicit CgroupParent + _, err := apiClient.ContainerCreate(ctx, client.ContainerCreateOptions{ + Config: &containertypes.Config{Image: "busybox"}, + HostConfig: &containertypes.HostConfig{ + CgroupParent: customParent, + }, + }) + + // Verify request is REJECTED with appropriate error + assert.Check(t, cerrdefs.IsInvalidArgument(err)) + assert.Check(t, is.ErrorContains(err, "cannot set cgroup parent when --adopt-user-cgroups is enabled")) +} + +// TestCgroupAdoptionMatchingParentAccepted verifies that when --adopt-user-cgroups is enabled, +// users CAN specify the cgroup parent if it matches the expected adopted value. +func TestCgroupAdoptionMatchingParentAccepted(t *testing.T) { + skip.If(t, os.Getuid() != 0, "requires root") + + ctx := testutil.StartSpan(baseContext, t) + + d := daemon.New(t) + defer d.Stop(t) + d.Start(t, "--adopt-user-cgroups") + + apiClient := d.NewClientT(t) + defer apiClient.Close() + + // Get current process's expected cgroup parent + expectedParent, err := cgroupsadopt.DeriveParentFromPid(os.Getpid()) + assert.NilError(t, err) + + // Create container with MATCHING cgroup parent (should be allowed) + resp, err := apiClient.ContainerCreate(ctx, client.ContainerCreateOptions{ + Config: &containertypes.Config{Image: "busybox"}, + HostConfig: &containertypes.HostConfig{ + CgroupParent: expectedParent, + }, + }) + assert.NilError(t, err) + + defer apiClient.ContainerRemove(ctx, resp.ID, client.ContainerRemoveOptions{Force: true}) + + // Verify it was created successfully with the correct cgroup parent + inspect, err := apiClient.ContainerInspect(ctx, resp.ID, client.ContainerInspectOptions{}) + assert.NilError(t, err) + assert.Equal(t, inspect.Container.HostConfig.CgroupParent, expectedParent) +} + +// TestCgroupAdoptionNoPeerCredentials verifies behavior when peer credentials are unavailable +// (e.g., when not using Unix socket). +func TestCgroupAdoptionNoPeerCredentials(t *testing.T) { + skip.If(t, os.Getuid() != 0, "requires root") + + ctx := testutil.StartSpan(baseContext, t) + + d := daemon.New(t) + defer d.Stop(t) + + // Start daemon with TCP socket instead of Unix socket to prevent peer credentials + d.Start(t, "--adopt-user-cgroups", "-H", fmt.Sprintf("tcp://127.0.0.1:%d", testutil.GetFreePort(t))) + + // Connect via TCP + apiClient, err := client.New( + client.FromEnv, + client.WithHost(fmt.Sprintf("tcp://127.0.0.1:%d", testutil.GetFreePort(t))), + ) + assert.NilError(t, err) + defer apiClient.Close() + + // Attempt to create container - should fail because peer creds unavailable + _, err = apiClient.ContainerCreate(ctx, client.ContainerCreateOptions{ + Config: &containertypes.Config{Image: "busybox"}, + }) + + // This should fail gracefully (exact error depends on implementation) + // For now, we just verify it doesn't panic + if err != nil { + t.Logf("Expected error when peer credentials unavailable: %v", err) + } +} diff --git a/pkg/cgroups/adoption_linux.go b/pkg/cgroups/adoption_linux.go new file mode 100644 index 0000000000000..f4e95506c31aa --- /dev/null +++ b/pkg/cgroups/adoption_linux.go @@ -0,0 +1,97 @@ +package cgroups + +import ( + "fmt" + "os" + "strings" +) + +// DeriveParentFromPid reads /proc//cgroup and derives the appropriate +// cgroup parent path for containers created by this process. +// +// It attempts to extract the deepest `.slice` component from the cgroup path, +// which represents the systemd slice that should be used as the container's parent. +// +// For cgroup v2 (unified hierarchy), reads the single line prefixed with "0::". +// For cgroup v1, prioritizes the "name=systemd" controller. +// +// Returns an error if the cgroup file cannot be read or parsed. +func DeriveParentFromPid(pid int) (string, error) { + cgroupPath := fmt.Sprintf("/proc/%d/cgroup", pid) + return deriveParentFromCgroupFile(cgroupPath) +} + +// deriveParentFromCgroupFile reads a cgroup file and extracts the parent cgroup slice. +// Separated from DeriveParentFromPid for testability. +func deriveParentFromCgroupFile(cgroupPath string) (string, error) { + data, err := os.ReadFile(cgroupPath) + if err != nil { + return "", fmt.Errorf("failed to read cgroup file: %w", err) + } + + if len(data) == 0 { + return "", fmt.Errorf("cgroup file is empty") + } + + lines := strings.Split(string(data), "\n") + + var cgPath string + + // Parse cgroup file format: + // - Cgroup v2: "0::/path/to/cgroup" + // - Cgroup v1: "hierarchy-ID:controller-list:path" + // + // We prefer cgroup v2 unified hierarchy if present, otherwise fall back to + // the systemd controller from v1. + for _, line := range lines { + if line == "" { + continue + } + + parts := strings.SplitN(line, ":", 3) + if len(parts) < 3 { + continue + } + + hierarchyID := parts[0] + controllers := parts[1] + path := parts[2] + + // Cgroup v2 unified hierarchy + if hierarchyID == "0" && controllers == "" { + cgPath = path + break + } + + // Cgroup v1: prefer systemd controller + if controllers == "name=systemd" { + cgPath = path + // Keep searching in case a v2 line appears later + } + + // Cgroup v1: fallback to cpu controller if no systemd found yet + if cgPath == "" && strings.Contains(controllers, "cpu") { + cgPath = path + } + } + + if cgPath == "" { + return "", fmt.Errorf("no valid cgroup path found in file") + } + + // Return the full cgroup path so containers inherit the exact cgroup hierarchy + // of their creator. This ensures that SLURM jobs, systemd scopes, and other + // cgroup hierarchies are preserved. + // + // For example: + // - "/system.slice/slurmstepd.scope/job_123/step_0/user/task_0" -> "system.slice/slurmstepd.scope/job_123/step_0/user/task_0" + // - "/user.slice/user-1000.slice/session-1.scope" -> "user.slice/user-1000.slice/session-1.scope" + // + // This ensures containers are properly accounted under the creator's resource limits. + if cgPath == "/" || cgPath == "" { + return "", fmt.Errorf("cannot derive cgroup parent from root cgroup") + } + + // Return the full path without leading slash + return strings.TrimPrefix(cgPath, "/"), nil +} diff --git a/pkg/cgroups/adoption_linux_test.go b/pkg/cgroups/adoption_linux_test.go new file mode 100644 index 0000000000000..dae195c805dd2 --- /dev/null +++ b/pkg/cgroups/adoption_linux_test.go @@ -0,0 +1,130 @@ +package cgroups + +import ( + "os" + "path/filepath" + "testing" + + "gotest.tools/v3/assert" +) + +func TestDeriveParentFromPid_Cgroupv2(t *testing.T) { + // Test cgroup v2 format: "0::/user.slice/user-1000.slice/session-1.scope" + tmpdir := t.TempDir() + cgroupFile := filepath.Join(tmpdir, "cgroup") + + content := `0::/user.slice/user-1000.slice/session-1.scope +` + err := os.WriteFile(cgroupFile, []byte(content), 0644) + assert.NilError(t, err) + + parent, err := deriveParentFromCgroupFile(cgroupFile) + assert.NilError(t, err) + assert.Equal(t, parent, "user.slice/user-1000.slice/session-1.scope") +} + +func TestDeriveParentFromPid_Cgroupv1_Systemd(t *testing.T) { + // Test cgroup v1 format with systemd controller + tmpdir := t.TempDir() + cgroupFile := filepath.Join(tmpdir, "cgroup") + + content := `12:blkio:/user.slice/user-1000.slice/session-1.scope +11:pids:/user.slice/user-1000.slice/session-1.scope +10:memory:/user.slice/user-1000.slice/session-1.scope +9:perf_event:/ +8:devices:/user.slice/user-1000.slice/session-1.scope +7:cpuset:/ +6:freezer:/ +5:net_cls,net_prio:/ +4:cpu,cpuacct:/user.slice/user-1000.slice/session-1.scope +3:hugetlb:/ +2:rdma:/ +1:name=systemd:/user.slice/user-1000.slice/session-1.scope +` + err := os.WriteFile(cgroupFile, []byte(content), 0644) + assert.NilError(t, err) + + parent, err := deriveParentFromCgroupFile(cgroupFile) + assert.NilError(t, err) + assert.Equal(t, parent, "user.slice/user-1000.slice/session-1.scope") +} + +func TestDeriveParentFromPid_MultipleSlices(t *testing.T) { + // Test that we return the full cgroup path + tmpdir := t.TempDir() + cgroupFile := filepath.Join(tmpdir, "cgroup") + + content := `0::/system.slice/docker.service/user.slice/user-1000.slice/app.scope +` + err := os.WriteFile(cgroupFile, []byte(content), 0644) + assert.NilError(t, err) + + parent, err := deriveParentFromCgroupFile(cgroupFile) + assert.NilError(t, err) + // Should return the full path + assert.Equal(t, parent, "system.slice/docker.service/user.slice/user-1000.slice/app.scope") +} + +func TestDeriveParentFromPid_NoSlice(t *testing.T) { + // Test that we return full path even without .slice components + tmpdir := t.TempDir() + cgroupFile := filepath.Join(tmpdir, "cgroup") + + content := `0::/docker/container-id +` + err := os.WriteFile(cgroupFile, []byte(content), 0644) + assert.NilError(t, err) + + parent, err := deriveParentFromCgroupFile(cgroupFile) + assert.NilError(t, err) + assert.Equal(t, parent, "docker/container-id") +} + +func TestDeriveParentFromPid_RootCgroup(t *testing.T) { + // Test root cgroup "/" - should return error + tmpdir := t.TempDir() + cgroupFile := filepath.Join(tmpdir, "cgroup") + + content := `0::/ +` + err := os.WriteFile(cgroupFile, []byte(content), 0644) + assert.NilError(t, err) + + parent, err := deriveParentFromCgroupFile(cgroupFile) + assert.Assert(t, err != nil, "should return error for root cgroup") + assert.Equal(t, parent, "") +} + +func TestDeriveParentFromPid_InvalidFile(t *testing.T) { + parent, err := deriveParentFromCgroupFile("/nonexistent/file") + assert.Assert(t, err != nil, "should return error for nonexistent file") + assert.Equal(t, parent, "") +} + +func TestDeriveParentFromPid_EmptyFile(t *testing.T) { + tmpdir := t.TempDir() + cgroupFile := filepath.Join(tmpdir, "cgroup") + + err := os.WriteFile(cgroupFile, []byte(""), 0644) + assert.NilError(t, err) + + parent, err := deriveParentFromCgroupFile(cgroupFile) + assert.Assert(t, err != nil, "should return error for empty file") + assert.Equal(t, parent, "") +} + +func TestDeriveParentFromPid_SlurmCgroup(t *testing.T) { + // Test SLURM job cgroup hierarchy - must preserve full path + tmpdir := t.TempDir() + cgroupFile := filepath.Join(tmpdir, "cgroup") + + content := `0::/system.slice/slurmstepd.scope/job_298726/step_0/user/task_0 +` + err := os.WriteFile(cgroupFile, []byte(content), 0644) + assert.NilError(t, err) + + parent, err := deriveParentFromCgroupFile(cgroupFile) + assert.NilError(t, err) + // Must return the full SLURM hierarchy, not just system.slice + assert.Equal(t, parent, "system.slice/slurmstepd.scope/job_298726/step_0/user/task_0") +}