From 4185b832a4f89e671e6ecf201d21b75d866a48e4 Mon Sep 17 00:00:00 2001 From: jingrui Date: Sat, 14 Nov 2020 15:55:30 +0800 Subject: [PATCH] use path based socket for shims Signed-off-by: jingrui --- cmd/containerd-shim/main_unix.go | 16 +++-- cmd/ctr/commands/shim/shim.go | 2 + runtime/v1/linux/bundle.go | 37 +++++++++- runtime/v1/shim/client/client.go | 118 ++++++++++++++++++++++++++++--- 4 files changed, 159 insertions(+), 14 deletions(-) diff --git a/cmd/containerd-shim/main_unix.go b/cmd/containerd-shim/main_unix.go index e9c14263b..3a5bb6170 100644 --- a/cmd/containerd-shim/main_unix.go +++ b/cmd/containerd-shim/main_unix.go @@ -66,7 +66,7 @@ var ( func init() { flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs") flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim") - flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve") + flag.StringVar(&socketFlag, "socket", "", "socket path to serve") flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd") flag.StringVar(&workdirFlag, "workdir", "", "path used to storge large temporary data") flag.StringVar(&runtimeRootFlag, "runtime-root", proc.RuncRoot, "root directory for the runtime") @@ -190,10 +190,18 @@ func serve(ctx context.Context, server *ttrpc.Server, path string) error { } path = "[inherited from parent]" } else { - if len(path) > 106 { - return errors.Errorf("%q: unix socket path too long (> 106)", path) + const ( + abstractSocketPrefix = "\x00" + socketPathLimit = 106 + ) + p := strings.TrimPrefix(path, "unix://") + if len(p) == len(path) { + p = abstractSocketPrefix + p } - l, err = net.Listen("unix", "\x00"+path) + if len(p) > socketPathLimit { + return errors.Errorf("%q: unix socket path too long (> %d)", p, socketPathLimit) + } + l, err = net.Listen("unix", p) } if err != nil { return err diff --git a/cmd/ctr/commands/shim/shim.go b/cmd/ctr/commands/shim/shim.go index ec08cc68b..8ef068292 100644 --- a/cmd/ctr/commands/shim/shim.go +++ b/cmd/ctr/commands/shim/shim.go @@ -23,6 +23,7 @@ import ( "fmt" "io/ioutil" "net" + "strings" "github.com/containerd/console" "github.com/containerd/containerd/cmd/ctr/commands" @@ -231,6 +232,7 @@ func getTaskService(context *cli.Context) (task.TaskService, error) { return nil, errors.New("socket path must be specified") } + bindSocket = strings.TrimPrefix(bindSocket, "unix://") conn, err := net.Dial("unix", "\x00"+bindSocket) if err != nil { return nil, err diff --git a/runtime/v1/linux/bundle.go b/runtime/v1/linux/bundle.go index ef4200b29..0442246f9 100644 --- a/runtime/v1/linux/bundle.go +++ b/runtime/v1/linux/bundle.go @@ -20,6 +20,7 @@ package linux import ( "context" + "fmt" "io/ioutil" "os" "path/filepath" @@ -117,7 +118,7 @@ func ShimLocal(c *Config, exchange *exchange.Exchange) ShimOpt { // ShimConnect is a ShimOpt for connecting to an existing remote shim func ShimConnect(c *Config, onClose func()) ShimOpt { return func(b *bundle, ns string, ropts *runctypes.RuncOptions) (shim.Config, client.Opt) { - return b.shimConfig(ns, c, ropts), client.WithConnect(b.shimAddress(ns), onClose) + return b.shimConfig(ns, c, ropts), client.WithConnect(b.decideShimAddress(ns), onClose) } } @@ -129,6 +130,11 @@ func (b *bundle) NewShimClient(ctx context.Context, namespace string, getClientO // Delete deletes the bundle from disk func (b *bundle) Delete() error { + address, _ := b.loadAddress() + if address != "" { + // we don't care about errors here + client.RemoveSocket(address) + } err := os.RemoveAll(b.path) if err == nil { return os.RemoveAll(b.workDir) @@ -141,10 +147,37 @@ func (b *bundle) Delete() error { return errors.Wrapf(err, "Failed to remove both bundle and workdir locations: %v", err2) } -func (b *bundle) shimAddress(namespace string) string { +func (b *bundle) legacyShimAddress(namespace string) string { return filepath.Join(string(filepath.Separator), "containerd-shim", namespace, b.id, "shim.sock") } +const socketRoot = "/run/containerd" + +func (b *bundle) shimAddress(namespace string) string { + return fmt.Sprintf("unix://%s", b.shimSock()) +} + +func (b *bundle) shimSock() string { + return filepath.Join(socketRoot, "s", b.id) +} + +func (b *bundle) loadAddress() (string, error) { + addressPath := filepath.Join(b.path, "address") + data, err := ioutil.ReadFile(addressPath) + if err != nil { + return "", err + } + return string(data), nil +} + +func (b *bundle) decideShimAddress(namespace string) string { + address, err := b.loadAddress() + if err != nil { + return b.legacyShimAddress(namespace) + } + return address +} + func (b *bundle) shimConfig(namespace string, c *Config, runcOptions *runctypes.RuncOptions) shim.Config { var ( criuPath string diff --git a/runtime/v1/shim/client/client.go b/runtime/v1/shim/client/client.go index a4669d33c..06453b35a 100644 --- a/runtime/v1/shim/client/client.go +++ b/runtime/v1/shim/client/client.go @@ -20,11 +20,14 @@ package client import ( "context" + "fmt" "io" "net" "os" "os/exec" + "path/filepath" "runtime" + "strconv" "strings" "sync" "syscall" @@ -55,9 +58,17 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa return func(ctx context.Context, config shim.Config) (_ shimapi.ShimService, _ io.Closer, err error) { socket, err := newSocket(address) if err != nil { - return nil, nil, err + if !eaddrinuse(err) { + return nil, nil, err + } + if err := RemoveSocket(address); err != nil { + return nil, nil, errors.Wrap(err, "remove already used socket") + } + if socket, err = newSocket(address); err != nil { + return nil, nil, err + } } - defer socket.Close() + f, err := socket.File() if err != nil { return nil, nil, errors.Wrapf(err, "failed to get fd for socket %s", address) @@ -102,12 +113,22 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa if stderrLog != nil { stderrLog.Close() } + socket.Close() + RemoveSocket(address) }() log.G(ctx).WithFields(logrus.Fields{ "pid": cmd.Process.Pid, "address": address, "debug": debug, }).Infof("shim %s started", binary) + + if err := writeFile(filepath.Join(config.Path, "address"), address); err != nil { + return nil, nil, err + } + if err := writeFile(filepath.Join(config.Path, "shim.pid"), strconv.Itoa(cmd.Process.Pid)); err != nil { + return nil, nil, err + } + // set shim in cgroup if it is provided if cgroup != "" { if err := setCgroup(cgroup, cmd); err != nil { @@ -170,25 +191,106 @@ func newCommand(binary, daemonAddress string, debug bool, config shim.Config, so return cmd, nil } +// writeFile writes a address file atomically +func writeFile(path, address string) error { + path, err := filepath.Abs(path) + if err != nil { + return err + } + tempPath := filepath.Join(filepath.Dir(path), fmt.Sprintf(".%s", filepath.Base(path))) + f, err := os.OpenFile(tempPath, os.O_RDWR|os.O_CREATE|os.O_EXCL|os.O_SYNC, 0666) + if err != nil { + return err + } + _, err = f.WriteString(address) + f.Close() + if err != nil { + return err + } + return os.Rename(tempPath, path) +} + +const ( + abstractSocketPrefix = "\x00" + socketPathLimit = 106 +) + +func eaddrinuse(err error) bool { + cause := errors.Cause(err) + netErr, ok := cause.(*net.OpError) + if !ok { + return false + } + if netErr.Op != "listen" { + return false + } + syscallErr, ok := netErr.Err.(*os.SyscallError) + if !ok { + return false + } + errno, ok := syscallErr.Err.(syscall.Errno) + if !ok { + return false + } + return errno == syscall.EADDRINUSE +} + +type socket string + +func (s socket) isAbstract() bool { + return !strings.HasPrefix(string(s), "unix://") +} + +func (s socket) path() string { + path := strings.TrimPrefix(string(s), "unix://") + // if there was no trim performed, we assume an abstract socket + if len(path) == len(s) { + path = abstractSocketPrefix + path + } + return path +} + func newSocket(address string) (*net.UnixListener, error) { - if len(address) > 106 { - return nil, errors.Errorf("%q: unix socket path too long (> 106)", address) + if len(address) > socketPathLimit { + return nil, errors.Errorf("%q: unix socket path too long (> %d)", address, socketPathLimit) + } + var ( + sock = socket(address) + path = sock.path() + ) + if !sock.isAbstract() { + if err := os.MkdirAll(filepath.Dir(path), 0600); err != nil { + return nil, errors.Wrapf(err, "%s", path) + } } - l, err := net.Listen("unix", "\x00"+address) + l, err := net.Listen("unix", path) if err != nil { - return nil, errors.Wrapf(err, "failed to listen to abstract unix socket %q", address) + return nil, errors.Wrapf(err, "failed to listen to unix socket %q (abstract: %t)", address, sock.isAbstract()) + } + if err := os.Chmod(path, 0600); err != nil { + l.Close() + return nil, err } return l.(*net.UnixListener), nil } +// RemoveSocket removes the socket at the specified address if +// it exists on the filesystem +func RemoveSocket(address string) error { + sock := socket(address) + if !sock.isAbstract() { + return os.Remove(sock.path()) + } + return nil +} + func connect(address string, d func(string, time.Duration) (net.Conn, error)) (net.Conn, error) { return d(address, 100*time.Second) } func annonDialer(address string, timeout time.Duration) (net.Conn, error) { - address = strings.TrimPrefix(address, "unix://") - return net.DialTimeout("unix", "\x00"+address, timeout) + return net.DialTimeout("unix", socket(address).path(), timeout) } // WithConnect connects to an existing shim -- 2.17.1