diff --git a/internals/daemon/daemon.go b/internals/daemon/daemon.go index 99aa9a61b..7cee9f79f 100644 --- a/internals/daemon/daemon.go +++ b/internals/daemon/daemon.go @@ -717,10 +717,9 @@ func (d *Daemon) doReboot(sigCh chan<- os.Signal, waitTimeout time.Duration) err return fmt.Errorf("expected reboot did not happen") } -var ( - rebootMsg = "reboot scheduled to update the system" - rebootHandler = commandReboot -) +const rebootMsg = "reboot scheduled to update the system" + +var rebootHandler = commandReboot // SetSyscallReboot replaces the default command-based reboot // with a direct Linux kernel syscall based implementation. @@ -741,23 +740,48 @@ func commandReboot(rebootDelay time.Duration) error { return nil } -var shutdownSyscall = func() { - // As per the requirements of the reboot syscall, we have to - // first call sync. - unix.Sync() - // This syscall can fail (EINVAL) if invalid arguments are - // supplied, which should not be the case, but let's panic if - // that were ever to be true just to be sure we catch it. - err := unix.Reboot(unix.LINUX_REBOOT_CMD_RESTART) - if err != nil { - panic("internal error: reboot syscall failed") +var ( + // checkCapSysBoot returns nil if the system has the correct + // permissions to issue a reboot request to the Linux kernel + checkCapSysBoot = func() error { + var caps unix.CapUserData + // We deliberately use v1 caps here due to: + // https://github.com/golang/go/issues/44312 + hdr := unix.CapUserHeader{Version: unix.LINUX_CAPABILITY_VERSION_1} + err := unix.Capget(&hdr, &caps) + if err == nil { + if (int32(caps.Effective) & (1 << unix.CAP_SYS_BOOT)) == 0 { + err = fmt.Errorf("no capability to reboot") + } + } + return err } -} + + shutdownSyscall = func() { + // As per the requirements of the reboot syscall, we have to + // first call sync. + unix.Sync() + // This syscall can fail (EINVAL/EPERM) if invalid arguments are + // supplied or CAP_SYS_BOOT capability is missing. We cover the + // latter case in a separate capability check, so this will not + // happen here, but let's panic if we see something to make sure + // we catch anything unexpected. + err := unix.Reboot(unix.LINUX_REBOOT_CMD_RESTART) + if err != nil { + panic("internal error: reboot syscall failed") + } + } +) // syscallReboot performs a reboot using direct Linux kernel syscalls. // // Note: Reboot message not currently supported. func syscallReboot(rebootDelay time.Duration) error { + err := checkCapSysBoot() + if err != nil { + return err + } + if rebootDelay < 0 { rebootDelay = 0 } diff --git a/internals/daemon/daemon_test.go b/internals/daemon/daemon_test.go index f678b535b..b2bb4293b 100644 --- a/internals/daemon/daemon_test.go +++ b/internals/daemon/daemon_test.go @@ -1149,34 +1149,97 @@ services: c.Check(tasks[0].Kind(), Equals, "stop") } -func (s *daemonSuite) TestSyscallRebootDelay(c *C) { - waitState := 0 +func mockShutdownSyscall(f func()) (restore func()) { old := shutdownSyscall - shutdownSyscall = func() { - waitState = 1 - } - defer func() { + shutdownSyscall = f + return func() { shutdownSyscall = old - }() - syscallReboot(time.Millisecond * 25) - c.Assert(waitState, Equals, 0) - time.Sleep(time.Millisecond * 50) - c.Assert(waitState, Equals, 1) + } } -func (s *daemonSuite) TestSetSyscall(c *C) { - check := 0 - old := shutdownSyscall - shutdownSyscall = func() { - check = 1 +func mockCheckCapSysBoot(f func() error) (restore func()) { + old := checkCapSysBoot + checkCapSysBoot = f + return func() { + checkCapSysBoot = old } +} + +func (s *daemonSuite) TestSyscallPosRebootDelay(c *C) { + wait := make(chan int) + defer mockCheckCapSysBoot(func() error { + return nil + })() + defer mockShutdownSyscall(func() { + wait <- 1 + })() + + period := time.Millisecond * 25 + syscallReboot(period) + start := time.Now() + <-wait + elapse := time.Now().Sub(start) + c.Assert(elapse >= period, Equals, true) +} + +func (s *daemonSuite) TestSyscallNegRebootDelay(c *C) { + wait := make(chan int) + defer mockCheckCapSysBoot(func() error { + return nil + })() + defer mockShutdownSyscall(func() { + wait <- 1 + })() + + // Negative periods will be zeroed, so do not fear the huge negative. + // We do supply a rather big value here because this test is + // effectively a race, but given the huge timeout, it is not going + // to be a problem (c). + period := time.Second * 10 + syscallReboot(-period) + start := time.Now() + <-wait + elapse := time.Now().Sub(start) + c.Assert(elapse < period, Equals, true) +} + +func (s *daemonSuite) TestSetSyscall(c *C) { + wait := make(chan int) + defer mockCheckCapSysBoot(func() error { + return nil + })() + defer mockShutdownSyscall(func() { + wait <- 1 + })() + + // We know the default is commandReboot otherwise the unit tests + // above will fail. We need to check the switch works. + SetSyscallReboot() defer func() { - shutdownSyscall = old + rebootHandler = commandReboot }() + + err := rebootHandler(0) + c.Assert(err, IsNil) + // This would block forever if the switch did not work. + <-wait +} + +func (s *daemonSuite) TestCapSysBootFail(c *C) { + defer mockCheckCapSysBoot(func() error { + return fmt.Errorf("no reboot cap") + })() + defer mockShutdownSyscall(func() { + panic("this should not happen") + })() + // We know the default is commandReboot otherwise the unit tests // above will fail. We need to check the switch works. SetSyscallReboot() - rebootHandler(0) - time.Sleep(time.Millisecond * 50) - c.Assert(check, Equals, 1) + defer func() { + rebootHandler = commandReboot + }() + + err := rebootHandler(0) + c.Assert(err, ErrorMatches, "no reboot cap") }