Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

daemon: add ability to use syscall based reboot #250

Merged
merged 11 commits into from
Aug 11, 2023
Merged
71 changes: 64 additions & 7 deletions internals/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ func (d *Daemon) HandleRestart(t restart.RestartType) {
case restart.RestartSystem:
// try to schedule a fallback slow reboot already here
// in case we get stuck shutting down
if err := reboot(rebootWaitTimeout); err != nil {
if err := rebootHandler(rebootWaitTimeout); err != nil {
logger.Noticef("%s", err)
}

Expand Down Expand Up @@ -693,11 +693,11 @@ func (d *Daemon) doReboot(sigCh chan<- os.Signal, waitTimeout time.Duration) err
}
// ask for shutdown and wait for it to happen.
// if we exit, pebble will be restarted by systemd
if err := reboot(rebootDelay); err != nil {
if err := rebootHandler(rebootDelay); err != nil {
return err
}
// wait for reboot to happen
logger.Noticef("Waiting for system reboot")
logger.Noticef("Waiting for system reboot...")
if sigCh != nil {
signal.Stop(sigCh)
if len(sigCh) > 0 {
Expand All @@ -710,21 +710,78 @@ func (d *Daemon) doReboot(sigCh chan<- os.Signal, waitTimeout time.Duration) err
return fmt.Errorf("expected reboot did not happen")
}

var shutdownMsg = "reboot scheduled to update the system"
const rebootMsg = "reboot scheduled to update the system"

func rebootImpl(rebootDelay time.Duration) error {
var rebootHandler = systemdModeReboot

type RebootMode int

const (
// Reboot uses systemd
SystemdMode RebootMode = iota + 1
// Reboot uses direct kernel syscalls
SyscallMode
)

// SetRebootMode configures how the system issues a reboot. The default
// reboot handler mode is SystemdMode, which relies on systemd
// (or similar) provided functionality to reboot.
func SetRebootMode(mode RebootMode) {
switch mode {
case SystemdMode:
rebootHandler = systemdModeReboot
case SyscallMode:
rebootHandler = syscallModeReboot
default:
panic(fmt.Sprintf("unsupported reboot mode %v", mode))
}
}

// systemdModeReboot assumes a userspace shutdown command exists.
func systemdModeReboot(rebootDelay time.Duration) error {
if rebootDelay < 0 {
rebootDelay = 0
}
mins := int64(rebootDelay / time.Minute)
cmd := exec.Command("shutdown", "-r", fmt.Sprintf("+%d", mins), shutdownMsg)
cmd := exec.Command("shutdown", "-r", fmt.Sprintf("+%d", mins), rebootMsg)
if out, err := cmd.CombinedOutput(); err != nil {
return osutil.OutputErr(out, err)
}
return nil
}

var reboot = rebootImpl
var (
syscallSync = syscall.Sync
syscallReboot = syscall.Reboot
)

// syscallModeReboot performs a non-blocking delayed reboot using direct Linux
// kernel syscalls. If the delay is negative or zero, the reboot is issued
// immediately.
//
// Note: Reboot message not currently supported.
func syscallModeReboot(rebootDelay time.Duration) error {
safeReboot := func() {
// As per the requirements of the reboot syscall, we
// have to first call sync.
syscallSync()
err := syscallReboot(syscall.LINUX_REBOOT_CMD_RESTART)
if err != nil {
logger.Noticef("Failed on reboot syscall: %v", err)
}
}

if rebootDelay <= 0 {
// Synchronous reboot right now.
safeReboot()
} else {
// Asynchronous non-blocking reboot scheduled
time.AfterFunc(rebootDelay, func() {
safeReboot()
})
}
return nil
}

func (d *Daemon) Dying() <-chan struct{} {
return d.tomb.Dying()
Expand Down
133 changes: 130 additions & 3 deletions internals/daemon/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/gorilla/mux"
. "gopkg.in/check.v1"

"github.com/canonical/pebble/internals/logger"
"github.com/canonical/pebble/internals/osutil"
"github.com/canonical/pebble/internals/overlord/patch"
"github.com/canonical/pebble/internals/overlord/restart"
Expand Down Expand Up @@ -693,15 +694,15 @@ func (s *daemonSuite) TestRestartSystemWiring(c *C) {
oldRebootNoticeWait := rebootNoticeWait
oldRebootWaitTimeout := rebootWaitTimeout
defer func() {
reboot = rebootImpl
rebootHandler = systemdModeReboot
rebootNoticeWait = oldRebootNoticeWait
rebootWaitTimeout = oldRebootWaitTimeout
}()
rebootWaitTimeout = 100 * time.Millisecond
rebootNoticeWait = 150 * time.Millisecond

var delays []time.Duration
reboot = func(d time.Duration) error {
rebootHandler = func(d time.Duration) error {
delays = append(delays, d)
return nil
}
Expand Down Expand Up @@ -768,7 +769,7 @@ func (s *daemonSuite) TestRebootHelper(c *C) {
}

for _, t := range tests {
err := reboot(t.delay)
err := rebootHandler(t.delay)
c.Assert(err, IsNil)
c.Check(cmd.Calls(), DeepEquals, [][]string{
{"shutdown", "-r", t.delayArg, "reboot scheduled to update the system"},
Expand Down Expand Up @@ -1176,3 +1177,129 @@ services:
c.Assert(tasks, HasLen, 1)
c.Check(tasks[0].Kind(), Equals, "stop")
}

type rebootSuite struct{}

var _ = Suite(&rebootSuite{})

func (s *rebootSuite) TestSyscallPosRebootDelay(c *C) {
wait := make(chan struct{})
defer FakeSyscallSync(func() {})()
defer FakeSyscallReboot(func(cmd int) error {
if cmd == syscall.LINUX_REBOOT_CMD_RESTART {
close(wait)
}
return nil
})()

period := 25 * time.Millisecond
syscallModeReboot(period)
start := time.Now()
select {
case <-wait:
case <-time.After(10 * time.Second):
c.Fatal("syscall did not take place and we timed out")
}
elapsed := time.Now().Sub(start)
c.Assert(elapsed >= period, Equals, true)
}

func (s *rebootSuite) TestSyscallNegRebootDelay(c *C) {
wait := make(chan struct{})
defer FakeSyscallSync(func() {})()
defer FakeSyscallReboot(func(cmd int) error {
if cmd == syscall.LINUX_REBOOT_CMD_RESTART {
close(wait)
}
return nil
})()

// Negative periods will be zeroed, so do not fear the huge negative.
// We do supply a rather big value here because this test is
// effectively a race, but given the huge timeout, it is not going
// to be a problem (c).
period := 10 * time.Second
go func() {
// We need a different thread for the unbuffered wait.
syscallModeReboot(-period)
}()
start := time.Now()
select {
case <-wait:
case <-time.After(10 * time.Second):
c.Fatal("syscall did not take place and we timed out")
}
elapsed := time.Now().Sub(start)
c.Assert(elapsed < period, Equals, true)
}

func (s *rebootSuite) TestSetSyscall(c *C) {
wait := make(chan struct{})
defer FakeSyscallSync(func() {})()
defer FakeSyscallReboot(func(cmd int) error {
if cmd == syscall.LINUX_REBOOT_CMD_RESTART {
close(wait)
}
return nil
})()

// We know the default is systemdReboot otherwise the unit tests
// above will fail. We need to check the switch works.
SetRebootMode(SyscallMode)
defer SetRebootMode(SystemdMode)

err := make(chan error)
go func() {
// We need a different thread for the unbuffered wait.
err <- rebootHandler(0)
}()

select {
case <-wait:
case <-time.After(10 * time.Second):
c.Fatal("syscall did not take place and we timed out")
}
c.Assert(<-err, IsNil)
}

type fakeLogger struct {
msg string
noticeCh chan int
}

func (f *fakeLogger) Notice(msg string) {
f.msg = msg
f.noticeCh <- 1
}

func (f *fakeLogger) Debug(msg string) {}

func (s *rebootSuite) TestSyscallRebootError(c *C) {
defer FakeSyscallSync(func() {})()
defer FakeSyscallReboot(func(cmd int) error {
return fmt.Errorf("-EPERM")
})()

// We know the default is systemdReboot otherwise the unit tests
// above will fail. We need to check the switch works.
SetRebootMode(SyscallMode)
defer SetRebootMode(SystemdMode)

complete := make(chan int)
l := fakeLogger{noticeCh: complete}
old := logger.SetLogger(&l)
defer logger.SetLogger(old)

err := make(chan error)
go func() {
// We need a different thread for the unbuffered wait.
err <- rebootHandler(0)
}()
select {
case <-complete:
case <-time.After(10 * time.Second):
c.Fatal("syscall did not take place and we timed out")
}
c.Assert(l.msg, Matches, "*-EPERM")
c.Assert(<-err, IsNil)
}
16 changes: 16 additions & 0 deletions internals/daemon/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,19 @@ func FakeGetChecks(f func(o *overlord.Overlord) ([]*checkstate.CheckInfo, error)
getChecks = old
}
}

func FakeSyscallSync(f func()) (restore func()) {
old := syscallSync
syscallSync = f
return func() {
syscallSync = old
}
}

func FakeSyscallReboot(f func(cmd int) error) (restore func()) {
old := syscallReboot
syscallReboot = f
return func() {
syscallReboot = old
}
}
Loading