Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

daemon: add ability to use syscall based reboot #250

Merged
merged 11 commits into from
Aug 11, 2023
Merged
48 changes: 41 additions & 7 deletions internals/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ func (d *Daemon) HandleRestart(t restart.RestartType) {
case restart.RestartSystem:
// try to schedule a fallback slow reboot already here
// in case we get stuck shutting down
if err := reboot(rebootWaitTimeout); err != nil {
if err := rebootHandler(rebootWaitTimeout); err != nil {
logger.Noticef("%s", err)
}

Expand Down Expand Up @@ -700,11 +700,11 @@ func (d *Daemon) doReboot(sigCh chan<- os.Signal, waitTimeout time.Duration) err
}
// ask for shutdown and wait for it to happen.
// if we exit, pebble will be restarted by systemd
if err := reboot(rebootDelay); err != nil {
if err := rebootHandler(rebootDelay); err != nil {
return err
}
// wait for reboot to happen
logger.Noticef("Waiting for system reboot")
logger.Noticef("Waiting for system reboot...")
if sigCh != nil {
signal.Stop(sigCh)
if len(sigCh) > 0 {
Expand All @@ -717,21 +717,55 @@ func (d *Daemon) doReboot(sigCh chan<- os.Signal, waitTimeout time.Duration) err
return fmt.Errorf("expected reboot did not happen")
}

var shutdownMsg = "reboot scheduled to update the system"
const rebootMsg = "reboot scheduled to update the system"

func rebootImpl(rebootDelay time.Duration) error {
var rebootHandler = commandReboot

// SetSyscallReboot replaces the default command-based reboot
// with a direct Linux kernel syscall based implementation.
func SetSyscallReboot() {
flotter marked this conversation as resolved.
Show resolved Hide resolved
rebootHandler = syscallReboot
}

// commandReboot assumes a userspace shutdown command exists.
func commandReboot(rebootDelay time.Duration) error {
if rebootDelay < 0 {
rebootDelay = 0
}
mins := int64(rebootDelay / time.Minute)
cmd := exec.Command("shutdown", "-r", fmt.Sprintf("+%d", mins), shutdownMsg)
cmd := exec.Command("shutdown", "-r", fmt.Sprintf("+%d", mins), rebootMsg)
if out, err := cmd.CombinedOutput(); err != nil {
return osutil.OutputErr(out, err)
}
return nil
}

var reboot = rebootImpl
var (
syncSyscall = syscall.Sync
rebootSyscall = syscall.Reboot
flotter marked this conversation as resolved.
Show resolved Hide resolved
)

// syscallReboot performs a delayed async reboot using direct Linux
// kernel syscalls.
//
// Note: Reboot message not currently supported.
func syscallReboot(rebootDelay time.Duration) error {
flotter marked this conversation as resolved.
Show resolved Hide resolved
if rebootDelay < 0 {
rebootDelay = 0
}
// This has to be non-blocking, and scheduled for a future
// point in time to mimic shutdown.
time.AfterFunc(rebootDelay, func() {
// As per the requirements of the reboot syscall, we
// have to first call sync.
syncSyscall()
err := rebootSyscall(syscall.LINUX_REBOOT_CMD_RESTART)
if err != nil {
logger.Noticef("reboot syscall failed : %v", err)
flotter marked this conversation as resolved.
Show resolved Hide resolved
}
})
return nil
}

func (d *Daemon) Dying() <-chan struct{} {
return d.tomb.Dying()
Expand Down
144 changes: 140 additions & 4 deletions internals/daemon/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ import (
"time"

"github.com/gorilla/mux"

"gopkg.in/check.v1"

// XXX Delete import above and make this file like the other ones.
. "gopkg.in/check.v1"

"github.com/canonical/pebble/internals/logger"
"github.com/canonical/pebble/internals/osutil"
"github.com/canonical/pebble/internals/overlord/patch"
"github.com/canonical/pebble/internals/overlord/restart"
Expand Down Expand Up @@ -665,15 +665,15 @@ func (s *daemonSuite) TestRestartSystemWiring(c *check.C) {
oldRebootNoticeWait := rebootNoticeWait
oldRebootWaitTimeout := rebootWaitTimeout
defer func() {
reboot = rebootImpl
rebootHandler = commandReboot
rebootNoticeWait = oldRebootNoticeWait
rebootWaitTimeout = oldRebootWaitTimeout
}()
rebootWaitTimeout = 100 * time.Millisecond
rebootNoticeWait = 150 * time.Millisecond

var delays []time.Duration
reboot = func(d time.Duration) error {
rebootHandler = func(d time.Duration) error {
delays = append(delays, d)
return nil
}
Expand Down Expand Up @@ -740,7 +740,7 @@ func (s *daemonSuite) TestRebootHelper(c *check.C) {
}

for _, t := range tests {
err := reboot(t.delay)
err := rebootHandler(t.delay)
c.Assert(err, check.IsNil)
c.Check(cmd.Calls(), check.DeepEquals, [][]string{
{"shutdown", "-r", t.delayArg, "reboot scheduled to update the system"},
Expand Down Expand Up @@ -1148,3 +1148,139 @@ services:
c.Assert(tasks, HasLen, 1)
c.Check(tasks[0].Kind(), Equals, "stop")
}

func mockSyncSyscall(f func()) (restore func()) {
flotter marked this conversation as resolved.
Show resolved Hide resolved
old := syncSyscall
syncSyscall = f
return func() {
syncSyscall = old
}
}

func mockRebootSyscall(f func(cmd int) error) (restore func()) {
old := rebootSyscall
rebootSyscall = f
return func() {
rebootSyscall = old
}
}

func (s *daemonSuite) TestSyscallPosRebootDelay(c *C) {
flotter marked this conversation as resolved.
Show resolved Hide resolved
wait := make(chan int)
defer mockSyncSyscall(func() {})()
defer mockRebootSyscall(func(cmd int) error {
if cmd == syscall.LINUX_REBOOT_CMD_RESTART {
wait <- 1
}
return nil
})()

period := time.Millisecond * 25
flotter marked this conversation as resolved.
Show resolved Hide resolved
timeout := time.Second * 10
flotter marked this conversation as resolved.
Show resolved Hide resolved
syscallReboot(period)
start := time.Now()
select {
case <-wait:
case <-time.After(timeout): // exit test if we fail and get stuck
c.Fail()
}
elapse := time.Now().Sub(start)
flotter marked this conversation as resolved.
Show resolved Hide resolved
c.Assert(elapse >= period, Equals, true)
}

func (s *daemonSuite) TestSyscallNegRebootDelay(c *C) {
wait := make(chan int)
defer mockSyncSyscall(func() {})()
defer mockRebootSyscall(func(cmd int) error {
if cmd == syscall.LINUX_REBOOT_CMD_RESTART {
wait <- 1
}
return nil
})()

// Negative periods will be zeroed, so do not fear the huge negative.
// We do supply a rather big value here because this test is
// effectively a race, but given the huge timeout, it is not going
// to be a problem (c).
period := time.Second * 10
syscallReboot(-period)
start := time.Now()
select {
case <-wait:
case <-time.After(period): // exit test if we fail and get stuck
c.Fail()
}
elapse := time.Now().Sub(start)
c.Assert(elapse < period, Equals, true)
}

func (s *daemonSuite) TestSetSyscall(c *C) {
wait := make(chan int)
defer mockSyncSyscall(func() {})()
defer mockRebootSyscall(func(cmd int) error {
if cmd == syscall.LINUX_REBOOT_CMD_RESTART {
wait <- 1
}
return nil
})()

// We know the default is commandReboot otherwise the unit tests
// above will fail. We need to check the switch works.
SetSyscallReboot()
defer func() {
rebootHandler = commandReboot
}()

err := rebootHandler(0)
c.Assert(err, IsNil)
// This would block forever if the switch did not work.
timeout := time.Second * 10
select {
case <-wait:
case <-time.After(timeout): // exit test if we fail and get stuck
c.Fail()
}
}

type fakeLogger struct {
s string
c chan int
flotter marked this conversation as resolved.
Show resolved Hide resolved
}

func (f *fakeLogger) Notice(msg string) {
f.s = msg
f.c <- 1
}

func (f *fakeLogger) Debug(msg string) {}

func (s *daemonSuite) TestSyscallRebootError(c *C) {
defer mockSyncSyscall(func() {})()
defer mockRebootSyscall(func(cmd int) error {
return fmt.Errorf("-EPERM")
})()

// We know the default is commandReboot otherwise the unit tests
// above will fail. We need to check the switch works.
SetSyscallReboot()
defer func() {
rebootHandler = commandReboot
}()
complete := make(chan int)
l := fakeLogger{c: complete}
old := logger.SetLogger(&l)
defer func() {
logger.SetLogger(old)
}()
flotter marked this conversation as resolved.
Show resolved Hide resolved

err := rebootHandler(0)
c.Assert(err, IsNil)
// This would block forever if the switch did not work.
timeout := time.Second * 10
select {
case <-complete:
case <-time.After(timeout): // exit test if we fail and get stuck
c.Fail()
}
c.Assert(l.s, Matches, "*-EPERM")
}