From e70824074dc2c73a463c85ea3b5bd24f66f25624 Mon Sep 17 00:00:00 2001 From: zhongjiawei Date: Fri, 23 Feb 2024 14:58:08 +0800 Subject: [PATCH] containerd:disable Transparent HugePage for shim process if SHIM_DISABLE_THP is set --- runtime/v1/shim/client/client.go | 13 +++++ sys/reaper/reaper_unix.go | 14 +++++ sys/thp.go | 34 ++++++++++++ sys/thp_amd64.go | 3 ++ sys/thp_arm64.go | 3 ++ sys/thp_riscv64.go | 3 ++ sys/thp_loong64.go | 3 ++ .../github.com/containerd/go-runc/monitor.go | 54 ++++++++++++------- 8 files changed, 107 insertions(+), 20 deletions(-) create mode 100644 sys/thp.go create mode 100644 sys/thp_amd64.go create mode 100644 sys/thp_arm64.go create mode 100644 sys/thp_riscv64.go create mode 100644 sys/thp_loong64.go diff --git a/runtime/v1/shim/client/client.go b/runtime/v1/shim/client/client.go index 965a5cf..af4917c 100644 --- a/runtime/v1/shim/client/client.go +++ b/runtime/v1/shim/client/client.go @@ -104,9 +104,22 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa if err != nil { return nil, nil, err } + // Set THP disabled for shim process + if err := sys.SetTHPState(1, false); err != nil { + if err := sys.SetTHPState(0, true); err != nil { + return nil, nil, err + } + return nil, nil, err + } if err := cmd.Start(); err != nil { + if err := sys.SetTHPState(0, true); err != nil { + return nil, nil, err + } return nil, nil, fmt.Errorf("failed to start shim: %w", err) } + if err := sys.SetTHPState(0, true); err != nil { + return nil, nil, err + } defer func() { if err != nil { cmd.Process.Kill() diff --git a/sys/reaper/reaper_unix.go b/sys/reaper/reaper_unix.go index 61c2e8a..2181432 100644 --- a/sys/reaper/reaper_unix.go +++ b/sys/reaper/reaper_unix.go @@ -26,6 +26,7 @@ import ( "syscall" "time" + "github.com/containerd/containerd/sys" runc "github.com/containerd/go-runc" "github.com/sirupsen/logrus" exec "golang.org/x/sys/execabs" @@ -94,9 +95,22 @@ type Monitor struct { // Start starts the command a registers the process with the reaper func (m *Monitor) Start(c *exec.Cmd) (chan runc.Exit, error) { + // Set THP enabled for subprocess. + if err := sys.SetTHPState(0, false); err != nil { + if err := sys.SetTHPState(1, true); err != nil { + return nil, err + } + return nil, err + } ec := m.Subscribe() if err := c.Start(); err != nil { m.Unsubscribe(ec) + if err := sys.SetTHPState(1, true); err != nil { + return nil, err + } + return nil, err + } + if err := sys.SetTHPState(1, true); err != nil { return nil, err } return ec, nil diff --git a/sys/thp.go b/sys/thp.go new file mode 100644 index 0000000..25c97a6 --- /dev/null +++ b/sys/thp.go @@ -0,0 +1,34 @@ +package sys + +import ( + "os" + "runtime" + "syscall" + + "github.com/sirupsen/logrus" +) + +const ( + PR_SET_THP_DISABLE = 41 +) + +func SetTHPState(flag int, resume bool) error { + logrus.Debug("start to set THP") + if os.Getenv("SHIM_DISABLE_THP") != "1" { + logrus.Debug("skip set THP") + return nil + } + + if resume { + defer runtime.UnlockOSThread() + } else { + runtime.LockOSThread() + } + + _, _, errno := syscall.RawSyscall6(uintptr(PRCTL_SYSCALL), uintptr(PR_SET_THP_DISABLE), uintptr(flag), 0, 0, 0, 0) + if errno != 0 { + logrus.Errorf("disable THP failed: %v", errno) + return errno + } + return nil +} diff --git a/sys/thp_amd64.go b/sys/thp_amd64.go new file mode 100644 index 0000000..e1e977e --- /dev/null +++ b/sys/thp_amd64.go @@ -0,0 +1,3 @@ +package sys + +const PRCTL_SYSCALL = 157 diff --git a/sys/thp_arm64.go b/sys/thp_arm64.go new file mode 100644 index 0000000..a6db8d6 --- /dev/null +++ b/sys/thp_arm64.go @@ -0,0 +1,3 @@ +package sys + +const PRCTL_SYSCALL = 167 diff --git a/sys/thp_riscv64.go b/sys/thp_riscv64.go new file mode 100644 index 0000000..a6db8d6 --- /dev/null +++ b/sys/thp_riscv64.go @@ -0,0 +1,3 @@ +package sys + +const PRCTL_SYSCALL = 167 diff --git a/sys/thp_loong64.go b/sys/thp_loong64.go new file mode 100644 index 0000000..a6db8d6 --- /dev/null +++ b/sys/thp_loong64.go @@ -0,0 +1,3 @@ +package sys + +const PRCTL_SYSCALL = 167 diff --git a/vendor/github.com/containerd/go-runc/monitor.go b/vendor/github.com/containerd/go-runc/monitor.go index 73c8ac1..c7b4451 100644 --- a/vendor/github.com/containerd/go-runc/monitor.go +++ b/vendor/github.com/containerd/go-runc/monitor.go @@ -25,6 +25,7 @@ import ( "syscall" "time" + "github.com/containerd/containerd/sys" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -54,7 +55,20 @@ type defaultMonitor struct { } func (m *defaultMonitor) Start(c *exec.Cmd) (chan Exit, error) { + // Set THP enabled for runc process. + if err := sys.SetTHPState(0, false); err != nil { + if err := sys.SetTHPState(1, true); err != nil { + return nil, err + } + return nil, err + } if err := c.Start(); err != nil { + if err := sys.SetTHPState(1, true); err != nil { + return nil, err + } + return nil, err + } + if err := sys.SetTHPState(1, true); err != nil { return nil, err } ec := make(chan Exit, 1) @@ -84,27 +98,27 @@ func (m *defaultMonitor) Wait(c *exec.Cmd, ec chan Exit) (int, error) { } func (m *defaultMonitor) WaitTimeout(c *exec.Cmd, ec chan Exit, sec int64) (int, error) { - select { - case <-time.After(time.Duration(sec) * time.Second): - if SameProcess(c, c.Process.Pid) { - logrus.Devour(syscall.Kill(c.Process.Pid, syscall.SIGKILL)) - } - return 0, errors.Errorf("timeout %ds for cmd(pid=%d): %s, %s", sec, c.Process.Pid, c.Path, c.Args) - case e := <-ec: - return e.Status, nil - } + select { + case <-time.After(time.Duration(sec) * time.Second): + if SameProcess(c, c.Process.Pid) { + logrus.Devour(syscall.Kill(c.Process.Pid, syscall.SIGKILL)) + } + return 0, errors.Errorf("timeout %ds for cmd(pid=%d): %s, %s", sec, c.Process.Pid, c.Path, c.Args) + case e := <-ec: + return e.Status, nil + } } func SameProcess(cmd *exec.Cmd, pid int) bool { - bytes, err := ioutil.ReadFile(filepath.Join("/proc", strconv.Itoa(pid), "cmdline")) - if err != nil { - return false - } - for i := range bytes { - if bytes[i] == 0 { - bytes[i] = 32 - } - } - cmdline := string(bytes) - return strings.EqualFold(cmdline, strings.Join(cmd.Args, " ")+" ") + bytes, err := ioutil.ReadFile(filepath.Join("/proc", strconv.Itoa(pid), "cmdline")) + if err != nil { + return false + } + for i := range bytes { + if bytes[i] == 0 { + bytes[i] = 32 + } + } + cmdline := string(bytes) + return strings.EqualFold(cmdline, strings.Join(cmd.Args, " ")+" ") } -- 2.33.0