diff --git a/activate.go b/activate.go index 78e586bb..5451d796 100644 --- a/activate.go +++ b/activate.go @@ -661,22 +661,24 @@ func (m *activateOneContainerStateMachine) tryWithUserAuthKeyslots(ctx context.C break } - cred, err := authRequestor.RequestUserCredential(ctx, name, m.container.Path(), authType) + cred, credAuthType, err := authRequestor.RequestUserCredential(ctx, name, m.container.Path(), authType) if err != nil { return fmt.Errorf("cannot request user credential: %w", err) } + credAuthType &= authType + // We have a user credential. - // 1) Try it against every keyslot with a passphrase. - // 2) See if it decodes as a PIN and try it against every keyslot with a passphrase. - // 3) See if it decodes as a recovery key, and try it against every recovery keyslot. + // 1) If it's a passphrase, try it against every keyslot with a passphrase. + // 2) If it's a PIN, try it against every keyslot with a passphrase. + // 3) If it's a recovery key, try it against every recovery keyslot. var ( unlockKey DiskUnlockKey primaryKey PrimaryKey ) - if passphraseTries > 0 { + if credAuthType&UserAuthTypePassphrase > 0 { passphraseTries -= 1 if uk, pk, success := m.tryPassphraseKeyslotsHelper(ctx, passphraseSlotRecords, cred); success { unlockKey = uk @@ -684,7 +686,7 @@ func (m *activateOneContainerStateMachine) tryWithUserAuthKeyslots(ctx context.C } } - if m.status == activationIncomplete && pinTries > 0 { + if m.status == activationIncomplete && credAuthType&UserAuthTypePIN > 0 { pin, err := ParsePIN(cred) switch { case err != nil && authType == UserAuthTypePIN: @@ -707,7 +709,7 @@ func (m *activateOneContainerStateMachine) tryWithUserAuthKeyslots(ctx context.C } } - if m.status == activationIncomplete && recoveryKeyTries > 0 { + if m.status == activationIncomplete && credAuthType&UserAuthTypeRecoveryKey > 0 { recoveryKey, err := ParseRecoveryKey(cred) switch { case err != nil && authType == UserAuthTypeRecoveryKey: diff --git a/activate_test.go b/activate_test.go index 5a6c2704..66d8e2ef 100644 --- a/activate_test.go +++ b/activate_test.go @@ -4843,6 +4843,62 @@ Error with keyslot "default": cannot recover keys from keyslot: user authorizati c.Check(err, Equals, ErrCannotActivate) } +func (s *activateSuite) TestActivateContainerAuthModePassphraseAuthRequestorOnlyReturnsRecoveryKey(c *C) { + // Test a simple case with 2 keyslots with passphrase auth and + // a recovery keyslot. Unlocking happens with a recovery keyslot + // after initially entering what looks like a correct passphrase + // but the AuthRequestor indicated it was only a recovery key. + primaryKey := testutil.DecodeHexString(c, "ed988fada3dbf68e13862cfc52b6d6205c862dd0941e643a81dcab106a79ce6a") + kd1, unlockKey1 := s.makeKeyDataBlobWithPassphrase(c, primaryKey, testutil.DecodeHexString(c, "4d8b57f05f0e70a73768c1d9f1078b8e9b0e9c399f555342e1ac4e675fea122e"), "run+recover", "secret") + kd2, unlockKey2 := s.makeKeyDataBlobWithPassphrase(c, primaryKey, testutil.DecodeHexString(c, "d72501b0b558c3119e036d5585629a026e82c05b6a4f19511daa3f12cc37902f"), "recover", "foo") + + recoveryKey := testutil.DecodeHexString(c, "9124e9a56e40c65424c5f652127f8d18") + + authRequestor := &mockAuthRequestor{ + responses: []any{ + mockAuthRequestorResponse{response: "secret", authTypes: UserAuthTypeRecoveryKey}, + mockAuthRequestorResponse{response: makeRecoveryKey(c, recoveryKey), authTypes: UserAuthTypeRecoveryKey}, + }, + } + + err := s.testActivateContextActivateContainer(c, &testActivateContextActivateContainerParams{ + contextOpts: []ActivateContextOption{ + WithAuthRequestor(authRequestor), + WithPassphraseTries(3), + WithRecoveryKeyTries(3), + }, + authRequestor: authRequestor, + container: newMockStorageContainer( + withStorageContainerPath("/dev/sda1"), + withStorageContainerCredentialName("sda1"), + withStorageContainerKeyslot("default", unlockKey1, KeyslotTypePlatform, 0, kd1), + withStorageContainerKeyslot("default-fallback", unlockKey2, KeyslotTypePlatform, 0, kd2), + withStorageContainerKeyslot("default-recovery", recoveryKey, KeyslotTypeRecovery, 0, nil), + ), + opts: []ActivateOption{ + WithAuthRequestorUserVisibleName("data"), + }, + expectedAuthRequestName: "data", + expectedAuthRequestPath: "/dev/sda1", + expectedAuthRequestTypes: []UserAuthType{ + UserAuthTypePassphrase | UserAuthTypeRecoveryKey, + UserAuthTypePassphrase | UserAuthTypeRecoveryKey, + }, + expectedActivateConfig: map[any]any{ + AuthRequestorKey: authRequestor, + PassphraseTriesKey: uint(3), + RecoveryKeyTriesKey: uint(3), + AuthRequestorUserVisibleNameKey: "data", + }, + expectedUnlockKey: recoveryKey, + expectedState: &ContainerActivateState{ + Status: ActivationSucceededWithRecoveryKey, + Keyslot: "default-recovery", + }, + }) + c.Check(err, IsNil) +} + func (s *activateSuite) TestActivateContainerAuthModePIN(c *C) { // Test a simple case with 2 keyslots with PIN auth. primaryKey := testutil.DecodeHexString(c, "ed988fada3dbf68e13862cfc52b6d6205c862dd0941e643a81dcab106a79ce6a") @@ -5536,6 +5592,62 @@ Error with keyslot "default": cannot recover keys from keyslot: user authorizati c.Check(err, Equals, ErrCannotActivate) } +func (s *activateSuite) TestActivateContainerAuthModePINAuthRequestorOnlyReturnsRecoveryKey(c *C) { + // Test a simple case with 2 keyslots with PIN auth and + // a recovery keyslot. Unlocking happens with a recovery keyslot + // after initially entering what looks like a correct PIN + // but the AuthRequestor indicated it was only a recovery key. + primaryKey := testutil.DecodeHexString(c, "ed988fada3dbf68e13862cfc52b6d6205c862dd0941e643a81dcab106a79ce6a") + kd1, unlockKey1 := s.makeKeyDataBlobWithPIN(c, primaryKey, testutil.DecodeHexString(c, "4d8b57f05f0e70a73768c1d9f1078b8e9b0e9c399f555342e1ac4e675fea122e"), "run+recover", makePIN(c, "1234")) + kd2, unlockKey2 := s.makeKeyDataBlobWithPIN(c, primaryKey, testutil.DecodeHexString(c, "d72501b0b558c3119e036d5585629a026e82c05b6a4f19511daa3f12cc37902f"), "recover", makePIN(c, "5678")) + + recoveryKey := testutil.DecodeHexString(c, "9124e9a56e40c65424c5f652127f8d18") + + authRequestor := &mockAuthRequestor{ + responses: []any{ + mockAuthRequestorResponse{response: "1234", authTypes: UserAuthTypeRecoveryKey}, + mockAuthRequestorResponse{response: makeRecoveryKey(c, recoveryKey), authTypes: UserAuthTypeRecoveryKey}, + }, + } + + err := s.testActivateContextActivateContainer(c, &testActivateContextActivateContainerParams{ + contextOpts: []ActivateContextOption{ + WithAuthRequestor(authRequestor), + WithPINTries(3), + WithRecoveryKeyTries(3), + }, + authRequestor: authRequestor, + container: newMockStorageContainer( + withStorageContainerPath("/dev/sda1"), + withStorageContainerCredentialName("sda1"), + withStorageContainerKeyslot("default", unlockKey1, KeyslotTypePlatform, 0, kd1), + withStorageContainerKeyslot("default-fallback", unlockKey2, KeyslotTypePlatform, 0, kd2), + withStorageContainerKeyslot("default-recovery", recoveryKey, KeyslotTypeRecovery, 0, nil), + ), + opts: []ActivateOption{ + WithAuthRequestorUserVisibleName("data"), + }, + expectedAuthRequestName: "data", + expectedAuthRequestPath: "/dev/sda1", + expectedAuthRequestTypes: []UserAuthType{ + UserAuthTypePIN | UserAuthTypeRecoveryKey, + UserAuthTypePIN | UserAuthTypeRecoveryKey, + }, + expectedActivateConfig: map[any]any{ + AuthRequestorKey: authRequestor, + PinTriesKey: uint(3), + RecoveryKeyTriesKey: uint(3), + AuthRequestorUserVisibleNameKey: "data", + }, + expectedUnlockKey: recoveryKey, + expectedState: &ContainerActivateState{ + Status: ActivationSucceededWithRecoveryKey, + Keyslot: "default-recovery", + }, + }) + c.Check(err, IsNil) +} + func (s *activateSuite) TestDeactivateContainer(c *C) { state := &ActivateState{ Activations: map[string]*ContainerActivateState{ diff --git a/auth_requestor.go b/auth_requestor.go index cb91ca4f..b7d242c7 100644 --- a/auth_requestor.go +++ b/auth_requestor.go @@ -46,5 +46,7 @@ type AuthRequestor interface { // and can be supplied via the ActivateContext API using the // WithAuthRequestorUserVisibleName option. The authTypes argument is used // to indicate what types of credential are being requested. - RequestUserCredential(ctx context.Context, name, path string, authTypes UserAuthType) (string, error) + // The implementation returns the requested credential and its type, which + // may be a subset of the requested credential types. + RequestUserCredential(ctx context.Context, name, path string, authTypes UserAuthType) (string, UserAuthType, error) } diff --git a/auth_requestor_plymouth.go b/auth_requestor_plymouth.go index 187f4dd8..ac86bbb1 100644 --- a/auth_requestor_plymouth.go +++ b/auth_requestor_plymouth.go @@ -45,10 +45,10 @@ type plymouthAuthRequestor struct { stringer PlymouthAuthRequestorStringer } -func (r *plymouthAuthRequestor) RequestUserCredential(ctx context.Context, name, path string, authTypes UserAuthType) (string, error) { +func (r *plymouthAuthRequestor) RequestUserCredential(ctx context.Context, name, path string, authTypes UserAuthType) (string, UserAuthType, error) { fmtString, err := r.stringer.RequestUserCredentialFormatString(authTypes) if err != nil { - return "", fmt.Errorf("cannot request format string for requested auth types: %w", err) + return "", 0, fmt.Errorf("cannot request format string for requested auth types: %w", err) } msg := fmt.Sprintf(fmtString, name, path) @@ -59,15 +59,15 @@ func (r *plymouthAuthRequestor) RequestUserCredential(ctx context.Context, name, cmd.Stdout = out cmd.Stdin = os.Stdin if err := cmd.Run(); err != nil { - return "", fmt.Errorf("cannot execute plymouth ask-for-password: %w", err) + return "", 0, fmt.Errorf("cannot execute plymouth ask-for-password: %w", err) } result, err := io.ReadAll(out) if err != nil { // The only error returned from bytes.Buffer.Read should be io.EOF, // which io.ReadAll filters out. - return "", fmt.Errorf("unexpected error: %w", err) + return "", 0, fmt.Errorf("unexpected error: %w", err) } - return string(result), nil + return string(result), authTypes, nil } // NewPlymouthAuthRequestor creates an implementation of AuthRequestor that diff --git a/auth_requestor_plymouth_test.go b/auth_requestor_plymouth_test.go index 436fc927..21b3a138 100644 --- a/auth_requestor_plymouth_test.go +++ b/auth_requestor_plymouth_test.go @@ -102,9 +102,10 @@ func (s *authRequestorPlymouthSuite) testRequestUserCredential(c *C, params *tes requestor, err := NewPlymouthAuthRequestor(new(mockPlymouthAuthRequestorStringer)) c.Assert(err, IsNil) - passphrase, err := requestor.RequestUserCredential(params.ctx, params.name, params.path, params.authTypes) + passphrase, passphraseType, err := requestor.RequestUserCredential(params.ctx, params.name, params.path, params.authTypes) c.Check(err, IsNil) c.Check(passphrase, Equals, params.passphrase) + c.Check(passphraseType, Equals, params.authTypes) c.Check(s.mockPlymouth.Calls(), HasLen, 1) c.Check(s.mockPlymouth.Calls()[0], DeepEquals, []string{"plymouth", "ask-for-password", "--prompt", params.expectedMsg}) @@ -231,7 +232,7 @@ func (s *authRequestorPlymouthSuite) TestRequestUserCredentialObtainFormatString }) c.Assert(err, IsNil) - _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase) + _, _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase) c.Check(err, ErrorMatches, `cannot request format string for requested auth types: some error`) } @@ -239,7 +240,7 @@ func (s *authRequestorPlymouthSuite) TestRequestUserCredentialFailure(c *C) { requestor, err := NewPlymouthAuthRequestor(new(mockPlymouthAuthRequestorStringer)) c.Assert(err, IsNil) - _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase) + _, _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase) c.Check(err, ErrorMatches, "cannot execute plymouth ask-for-password: exit status 1") } @@ -252,7 +253,7 @@ func (s *authRequestorPlymouthSuite) TestRequestUserCredentialCanceledContext(c ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err = requestor.RequestUserCredential(ctx, "data", "/dev/sda1", UserAuthTypePassphrase) + _, _, err = requestor.RequestUserCredential(ctx, "data", "/dev/sda1", UserAuthTypePassphrase) c.Check(err, ErrorMatches, "cannot execute plymouth ask-for-password: context canceled") c.Check(errors.Is(err, context.Canceled), testutil.IsTrue) } diff --git a/auth_requestor_systemd.go b/auth_requestor_systemd.go index 8078262b..a58c7d33 100644 --- a/auth_requestor_systemd.go +++ b/auth_requestor_systemd.go @@ -34,10 +34,10 @@ type systemdAuthRequestor struct { formatStringFn func(UserAuthType) (string, error) } -func (r *systemdAuthRequestor) RequestUserCredential(ctx context.Context, name, path string, authTypes UserAuthType) (string, error) { +func (r *systemdAuthRequestor) RequestUserCredential(ctx context.Context, name, path string, authTypes UserAuthType) (string, UserAuthType, error) { fmtString, err := r.formatStringFn(authTypes) if err != nil { - return "", fmt.Errorf("cannot request format string for requested auth types: %w", err) + return "", 0, fmt.Errorf("cannot request format string for requested auth types: %w", err) } msg := fmt.Sprintf(fmtString, name, path) @@ -50,14 +50,14 @@ func (r *systemdAuthRequestor) RequestUserCredential(ctx context.Context, name, cmd.Stdout = out cmd.Stdin = os.Stdin if err := cmd.Run(); err != nil { - return "", fmt.Errorf("cannot execute systemd-ask-password: %w", err) + return "", 0, fmt.Errorf("cannot execute systemd-ask-password: %w", err) } result, err := out.ReadString('\n') if err != nil { // The only error returned from bytes.Buffer.ReadString is io.EOF. - return "", errors.New("systemd-ask-password output is missing terminating newline") + return "", 0, errors.New("systemd-ask-password output is missing terminating newline") } - return strings.TrimRight(result, "\n"), nil + return strings.TrimRight(result, "\n"), authTypes, nil } // NewSystemdAuthRequestor creates an implementation of AuthRequestor that diff --git a/auth_requestor_systemd_test.go b/auth_requestor_systemd_test.go index 7ab5149d..4ecea549 100644 --- a/auth_requestor_systemd_test.go +++ b/auth_requestor_systemd_test.go @@ -93,9 +93,10 @@ func (s *authRequestorSystemdSuite) testRequestUserCredential(c *C, params *test }) c.Assert(err, IsNil) - passphrase, err := requestor.RequestUserCredential(params.ctx, params.name, params.path, params.authTypes) + passphrase, passphraseType, err := requestor.RequestUserCredential(params.ctx, params.name, params.path, params.authTypes) c.Check(err, IsNil) c.Check(passphrase, Equals, params.passphrase) + c.Check(passphraseType, Equals, params.authTypes) c.Check(s.mockSdAskPassword.Calls(), HasLen, 1) c.Check(s.mockSdAskPassword.Calls()[0], DeepEquals, []string{"systemd-ask-password", "--icon", "drive-harddisk", @@ -223,7 +224,7 @@ func (s *authRequestorSystemdSuite) TestRequestUserCredentialObtainFormatStringE }) c.Assert(err, IsNil) - _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase) + _, _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase) c.Check(err, ErrorMatches, `cannot request format string for requested auth types: some error`) } @@ -235,7 +236,7 @@ func (s *authRequestorSystemdSuite) TestRequestUserCredentialInvalidResponse(c * }) c.Assert(err, IsNil) - _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase) + _, _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase) c.Check(err, ErrorMatches, "systemd-ask-password output is missing terminating newline") } @@ -245,7 +246,7 @@ func (s *authRequestorSystemdSuite) TestRequestUserCredentialFailure(c *C) { }) c.Assert(err, IsNil) - _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase) + _, _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase) c.Check(err, ErrorMatches, "cannot execute systemd-ask-password: exit status 1") } @@ -260,7 +261,7 @@ func (s *authRequestorSystemdSuite) TestRequestUserCredentialCanceledContext(c * ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err = requestor.RequestUserCredential(ctx, "data", "/dev/sda1", UserAuthTypePassphrase) + _, _, err = requestor.RequestUserCredential(ctx, "data", "/dev/sda1", UserAuthTypePassphrase) c.Check(err, ErrorMatches, "cannot execute systemd-ask-password: context canceled") c.Check(errors.Is(err, context.Canceled), testutil.IsTrue) } diff --git a/crypt.go b/crypt.go index 9d0ec7c3..1cdbce1a 100755 --- a/crypt.go +++ b/crypt.go @@ -276,7 +276,7 @@ func (s *activateWithKeyDataState) run() (success bool, err error) { // a maximum of 2 keys with passphrases enabled (Ubuntu Core based desktop on // a UEFI+TPM platform with run+recovery and recovery-only protectors for // ubuntu-data). - passphrase, err := s.authRequestor.RequestUserCredential(context.Background(), s.volumeName, s.sourceDevicePath, UserAuthTypePassphrase) + passphrase, _, err := s.authRequestor.RequestUserCredential(context.Background(), s.volumeName, s.sourceDevicePath, UserAuthTypePassphrase) if err != nil { passphraseErr = xerrors.Errorf("cannot obtain passphrase: %w", err) continue @@ -329,7 +329,7 @@ func activateWithRecoveryKey(volumeName, sourceDevicePath string, authRequestor for ; tries > 0; tries-- { lastErr = nil - keyString, err := authRequestor.RequestUserCredential(context.Background(), volumeName, sourceDevicePath, UserAuthTypeRecoveryKey) + keyString, _, err := authRequestor.RequestUserCredential(context.Background(), volumeName, sourceDevicePath, UserAuthTypeRecoveryKey) if err != nil { lastErr = xerrors.Errorf("cannot obtain recovery key: %w", err) continue diff --git a/crypt_test.go b/crypt_test.go index 83d0137e..ad3f08e1 100644 --- a/crypt_test.go +++ b/crypt_test.go @@ -64,8 +64,13 @@ func (ctb *cryptTestBase) newRecoveryKey() RecoveryKey { return key } +type mockAuthRequestorResponse struct { + response any + authTypes UserAuthType +} + type mockAuthRequestor struct { - responses []interface{} + responses []any requests []struct { name string path string @@ -73,7 +78,7 @@ type mockAuthRequestor struct { } } -func (r *mockAuthRequestor) RequestUserCredential(ctx context.Context, name, path string, authTypes UserAuthType) (string, error) { +func (r *mockAuthRequestor) RequestUserCredential(ctx context.Context, name, path string, authTypes UserAuthType) (string, UserAuthType, error) { r.requests = append(r.requests, struct { name string path string @@ -85,18 +90,24 @@ func (r *mockAuthRequestor) RequestUserCredential(ctx context.Context, name, pat }) if len(r.responses) == 0 { - return "", errors.New("no response") + return "", 0, errors.New("no response") } response := r.responses[0] r.responses = r.responses[1:] + switch rsp := response.(type) { + case mockAuthRequestorResponse: + response = rsp.response + authTypes = rsp.authTypes + } + switch rsp := response.(type) { case string: - return rsp, nil + return rsp, authTypes, nil case RecoveryKey: - return rsp.String(), nil + return rsp.String(), authTypes, nil case error: - return "", rsp + return "", 0, rsp default: panic("invalid type") }