diff --git a/.github/ISSUE_TEMPLATE/BUG-REPORT.yml b/.github/ISSUE_TEMPLATE/BUG-REPORT.yml index 4048324eb8a7..b3db2a2d3de2 100644 --- a/.github/ISSUE_TEMPLATE/BUG-REPORT.yml +++ b/.github/ISSUE_TEMPLATE/BUG-REPORT.yml @@ -1,43 +1,48 @@ # AUTO-GENERATED, DO NOT EDIT! # Please edit the original at https://github.com/ory/meta/blob/master/templates/repository/common/.github/ISSUE_TEMPLATE/BUG-REPORT.yml -description: "Create a bug report" +description: 'Create a bug report' labels: - bug -name: "Bug Report" +name: 'Bug Report' body: - attributes: value: "Thank you for taking the time to fill out this bug report!\n" type: markdown - attributes: - label: "Preflight checklist" + label: 'Preflight checklist' options: - - label: "I could not find a solution in the existing issues, docs, nor - discussions." + - label: + 'I could not find a solution in the existing issues, docs, nor + discussions.' required: true - - label: "I agree to follow this project's [Code of + - label: + "I agree to follow this project's [Code of Conduct](https://github.com/ory/kratos/blob/master/CODE_OF_CONDUCT.md)." required: true - - label: "I have read and am following this repository's [Contribution + - label: + "I have read and am following this repository's [Contribution Guidelines](https://github.com/ory/kratos/blob/master/CONTRIBUTING.md)." required: true - - label: "I have joined the [Ory Community Slack](https://slack.ory.sh)." - - label: "I am signed up to the [Ory Security Patch - Newsletter](https://www.ory.sh/l/sign-up-newsletter)." + - label: + 'I have joined the [Ory Community Slack](https://slack.ory.sh).' + - label: + 'I am signed up to the [Ory Security Patch + Newsletter](https://www.ory.sh/l/sign-up-newsletter).' id: checklist type: checkboxes - attributes: description: - "Enter the slug or API URL of the affected Ory Network project. Leave - empty when you are self-hosting." - label: "Ory Network Project" - placeholder: "https://.projects.oryapis.com" + 'Enter the slug or API URL of the affected Ory Network project. Leave + empty when you are self-hosting.' + label: 'Ory Network Project' + placeholder: 'https://.projects.oryapis.com' id: ory-network-project type: input - attributes: - description: "A clear and concise description of what the bug is." - label: "Describe the bug" - placeholder: "Tell us what you see!" + description: 'A clear and concise description of what the bug is.' + label: 'Describe the bug' + placeholder: 'Tell us what you see!' id: describe-bug type: textarea validations: @@ -51,16 +56,17 @@ body: 1. Run `docker run ....` 2. Make API Request to with `curl ...` 3. Request fails with response: `{"some": "error"}` - label: "Reproducing the bug" + label: 'Reproducing the bug' id: reproduce-bug type: textarea validations: required: true - attributes: - description: "Please copy and paste any relevant log output. This will be + description: + 'Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. Please - redact any sensitive information" - label: "Relevant log output" + redact any sensitive information' + label: 'Relevant log output' render: shell placeholder: | log=error .... @@ -68,10 +74,10 @@ body: type: textarea - attributes: description: - "Please copy and paste any relevant configuration. This will be + 'Please copy and paste any relevant configuration. This will be automatically formatted into code, so no need for backticks. Please - redact any sensitive information!" - label: "Relevant configuration" + redact any sensitive information!' + label: 'Relevant configuration' render: yml placeholder: | server: @@ -80,14 +86,14 @@ body: id: config type: textarea - attributes: - description: "What version of our software are you running?" + description: 'What version of our software are you running?' label: Version id: version type: input validations: required: true - attributes: - label: "On which operating system are you observing this issue?" + label: 'On which operating system are you observing this issue?' options: - Ory Network - macOS @@ -98,19 +104,19 @@ body: id: operating-system type: dropdown - attributes: - label: "In which environment are you deploying?" + label: 'In which environment are you deploying?' options: - Ory Network - Docker - - "Docker Compose" - - "Kubernetes with Helm" + - 'Docker Compose' + - 'Kubernetes with Helm' - Kubernetes - Binary - Other id: deployment type: dropdown - attributes: - description: "Add any other context about the problem here." + description: 'Add any other context about the problem here.' label: Additional Context id: additional type: textarea diff --git a/.github/ISSUE_TEMPLATE/DESIGN-DOC.yml b/.github/ISSUE_TEMPLATE/DESIGN-DOC.yml index b5741119698b..a6a86f36dcc7 100644 --- a/.github/ISSUE_TEMPLATE/DESIGN-DOC.yml +++ b/.github/ISSUE_TEMPLATE/DESIGN-DOC.yml @@ -1,10 +1,11 @@ # AUTO-GENERATED, DO NOT EDIT! # Please edit the original at https://github.com/ory/meta/blob/master/templates/repository/common/.github/ISSUE_TEMPLATE/DESIGN-DOC.yml -description: "A design document is needed for non-trivial changes to the code base." +description: + 'A design document is needed for non-trivial changes to the code base.' labels: - rfc -name: "Design Document" +name: 'Design Document' body: - attributes: value: | @@ -20,34 +21,39 @@ body: after code reviews, and your pull requests will be merged faster. type: markdown - attributes: - label: "Preflight checklist" + label: 'Preflight checklist' options: - - label: "I could not find a solution in the existing issues, docs, nor - discussions." + - label: + 'I could not find a solution in the existing issues, docs, nor + discussions.' required: true - - label: "I agree to follow this project's [Code of + - label: + "I agree to follow this project's [Code of Conduct](https://github.com/ory/kratos/blob/master/CODE_OF_CONDUCT.md)." required: true - - label: "I have read and am following this repository's [Contribution + - label: + "I have read and am following this repository's [Contribution Guidelines](https://github.com/ory/kratos/blob/master/CONTRIBUTING.md)." required: true - - label: "I have joined the [Ory Community Slack](https://slack.ory.sh)." - - label: "I am signed up to the [Ory Security Patch - Newsletter](https://www.ory.sh/l/sign-up-newsletter)." + - label: + 'I have joined the [Ory Community Slack](https://slack.ory.sh).' + - label: + 'I am signed up to the [Ory Security Patch + Newsletter](https://www.ory.sh/l/sign-up-newsletter).' id: checklist type: checkboxes - attributes: description: - "Enter the slug or API URL of the affected Ory Network project. Leave - empty when you are self-hosting." - label: "Ory Network Project" - placeholder: "https://.projects.oryapis.com" + 'Enter the slug or API URL of the affected Ory Network project. Leave + empty when you are self-hosting.' + label: 'Ory Network Project' + placeholder: 'https://.projects.oryapis.com' id: ory-network-project type: input - attributes: description: | This section gives the reader a very rough overview of the landscape in which the new system is being built and what is actually being built. This isn’t a requirements doc. Keep it succinct! The goal is that readers are brought up to speed but some previous knowledge can be assumed and detailed info can be linked to. This section should be entirely focused on objective background facts. - label: "Context and scope" + label: 'Context and scope' id: scope type: textarea validations: @@ -56,7 +62,7 @@ body: - attributes: description: | A short list of bullet points of what the goals of the system are, and, sometimes more importantly, what non-goals are. Note, that non-goals aren’t negated goals like “The system shouldn’t crash”, but rather things that could reasonably be goals, but are explicitly chosen not to be goals. A good example would be “ACID compliance”; when designing a database, you’d certainly want to know whether that is a goal or non-goal. And if it is a non-goal you might still select a solution that provides it, if it doesn’t introduce trade-offs that prevent achieving the goals. - label: "Goals and non-goals" + label: 'Goals and non-goals' id: goals type: textarea validations: @@ -68,7 +74,7 @@ body: The design doc is the place to write down the trade-offs you made in designing your software. Focus on those trade-offs to produce a useful document with long-term value. That is, given the context (facts), goals and non-goals (requirements), the design doc is the place to suggest solutions and show why a particular solution best satisfies those goals. The point of writing a document over a more formal medium is to provide the flexibility to express the problem at hand in an appropriate manner. Because of this, there is no explicit guidance on how to actually describe the design. - label: "The design" + label: 'The design' id: design type: textarea validations: @@ -77,21 +83,21 @@ body: - attributes: description: | If the system under design exposes an API, then sketching out that API is usually a good idea. In most cases, however, one should withstand the temptation to copy-paste formal interface or data definitions into the doc as these are often verbose, contain unnecessary detail and quickly get out of date. Instead, focus on the parts that are relevant to the design and its trade-offs. - label: "APIs" + label: 'APIs' id: apis type: textarea - attributes: description: | Systems that store data should likely discuss how and in what rough form this happens. Similar to the advice on APIs, and for the same reasons, copy-pasting complete schema definitions should be avoided. Instead, focus on the parts that are relevant to the design and its trade-offs. - label: "Data storage" + label: 'Data storage' id: persistence type: textarea - attributes: description: | Design docs should rarely contain code, or pseudo-code except in situations where novel algorithms are described. As appropriate, link to prototypes that show the feasibility of the design. - label: "Code and pseudo-code" + label: 'Code and pseudo-code' id: pseudocode type: textarea @@ -104,7 +110,7 @@ body: On the other end are systems where the possible solutions are very well defined, but it isn't at all obvious how they could even be combined to achieve the goals. This may be a legacy system that is difficult to change and wasn't designed to do what you want it to do or a library design that needs to operate within the constraints of the host programming language. In this situation, you may be able to enumerate all the things you can do relatively easily, but you need to creatively put those things together to achieve the goals. There may be multiple solutions, and none of them are great, and hence such a document should focus on selecting the best way given all identified trade-offs. - label: "Degree of constraint" + label: 'Degree of constraint' id: constrait type: textarea diff --git a/.github/ISSUE_TEMPLATE/FEATURE-REQUEST.yml b/.github/ISSUE_TEMPLATE/FEATURE-REQUEST.yml index 7152fbdde4cf..7c023c2f48b3 100644 --- a/.github/ISSUE_TEMPLATE/FEATURE-REQUEST.yml +++ b/.github/ISSUE_TEMPLATE/FEATURE-REQUEST.yml @@ -1,10 +1,11 @@ # AUTO-GENERATED, DO NOT EDIT! # Please edit the original at https://github.com/ory/meta/blob/master/templates/repository/common/.github/ISSUE_TEMPLATE/FEATURE-REQUEST.yml -description: "Suggest an idea for this project without a plan for implementation" +description: + 'Suggest an idea for this project without a plan for implementation' labels: - feat -name: "Feature Request" +name: 'Feature Request' body: - attributes: value: | @@ -13,33 +14,39 @@ body: If you already have a plan to implement a feature or a change, please create a [design document](https://github.com/aeneasr/gh-template-test/issues/new?assignees=&labels=rfc&template=DESIGN-DOC.yml) instead if the change is non-trivial! type: markdown - attributes: - label: "Preflight checklist" + label: 'Preflight checklist' options: - - label: "I could not find a solution in the existing issues, docs, nor - discussions." + - label: + 'I could not find a solution in the existing issues, docs, nor + discussions.' required: true - - label: "I agree to follow this project's [Code of + - label: + "I agree to follow this project's [Code of Conduct](https://github.com/ory/kratos/blob/master/CODE_OF_CONDUCT.md)." required: true - - label: "I have read and am following this repository's [Contribution + - label: + "I have read and am following this repository's [Contribution Guidelines](https://github.com/ory/kratos/blob/master/CONTRIBUTING.md)." required: true - - label: "I have joined the [Ory Community Slack](https://slack.ory.sh)." - - label: "I am signed up to the [Ory Security Patch - Newsletter](https://www.ory.sh/l/sign-up-newsletter)." + - label: + 'I have joined the [Ory Community Slack](https://slack.ory.sh).' + - label: + 'I am signed up to the [Ory Security Patch + Newsletter](https://www.ory.sh/l/sign-up-newsletter).' id: checklist type: checkboxes - attributes: description: - "Enter the slug or API URL of the affected Ory Network project. Leave - empty when you are self-hosting." - label: "Ory Network Project" - placeholder: "https://.projects.oryapis.com" + 'Enter the slug or API URL of the affected Ory Network project. Leave + empty when you are self-hosting.' + label: 'Ory Network Project' + placeholder: 'https://.projects.oryapis.com' id: ory-network-project type: input - attributes: - description: "Is your feature request related to a problem? Please describe." - label: "Describe your problem" + description: + 'Is your feature request related to a problem? Please describe.' + label: 'Describe your problem' placeholder: "A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]" @@ -52,27 +59,28 @@ body: Describe the solution you'd like placeholder: | A clear and concise description of what you want to happen. - label: "Describe your ideal solution" + label: 'Describe your ideal solution' id: solution type: textarea validations: required: true - attributes: description: "Describe alternatives you've considered" - label: "Workarounds or alternatives" + label: 'Workarounds or alternatives' id: alternatives type: textarea validations: required: true - attributes: - description: "What version of our software are you running?" + description: 'What version of our software are you running?' label: Version id: version type: input validations: required: true - attributes: - description: "Add any other context or screenshots about the feature request here." + description: + 'Add any other context or screenshots about the feature request here.' label: Additional Context id: additional type: textarea diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index ef4c482ae405..abb0b696c9d9 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -5,8 +5,10 @@ blank_issues_enabled: false contact_links: - name: Ory Kratos Forum url: https://github.com/ory/kratos/discussions - about: Please ask and answer questions here, show your implementations and + about: + Please ask and answer questions here, show your implementations and discuss ideas. - name: Ory Chat url: https://www.ory.sh/chat - about: Hang out with other Ory community members to ask and answer questions. + about: + Hang out with other Ory community members to ask and answer questions. diff --git a/SECURITY.md b/SECURITY.md index 7a05c1cfc62e..026e3afb70f8 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,30 +1,53 @@ - - - -- [Security Policy](#security-policy) - - [Supported Versions](#supported-versions) - - [Reporting a Vulnerability](#reporting-a-vulnerability) - - - -# Security Policy - -## Supported Versions - -We release patches for security vulnerabilities. Which versions are eligible for -receiving such patches depends on the CVSS v3.0 Rating: - -| CVSS v3.0 | Supported Versions | -| --------- | ----------------------------------------- | -| 9.0-10.0 | Releases within the previous three months | -| 4.0-8.9 | Most recent release | +# Ory Security Policy + +## Overview + +This security policy outlines the security support commitments for different +types of Ory users. + +[Get in touch](https://www.ory.sh/contact/) to learn more about Ory's security +SLAs and process. + +## Apache 2.0 License Users + +- **Security SLA:** No security Service Level Agreement (SLA) is provided. +- **Release Schedule:** Releases are planned every 3 to 6 months. These releases + will contain all security fixes implemented up to that point. +- **Version Support:** Security patches are only provided for the current + release version. + +## Ory Enterprise License Customers + +- **Security SLA:** The following timelines apply for security vulnerabilities + based on their severity: + - Critical: Resolved within 14 days. + - High: Resolved within 30 days. + - Medium: Resolved within 90 days. + - Low: Resolved within 180 days. + - Informational: Addressed as needed. +- **Release Schedule:** Updates are provided as soon as vulnerabilities are + resolved, adhering to the above SLA. +- **Version Support:** Depending on the Ory Enterprise License agreement + multiple versions can be supported. + +## Ory Network Users + +- **Security SLA:** The following timelines apply for security vulnerabilities + based on their severity: + - Critical: Resolved within 14 days. + - High: Resolved within 30 days. + - Medium: Resolved within 90 days. + - Low: Resolved within 180 days. + - Informational: Addressed as needed. +- **Release Schedule:** Updates are automatically deployed to Ory Network as + soon as vulnerabilities are resolved, adhering to the above SLA. +- **Version Support:** Ory Network always runs the most current version. ## Reporting a Vulnerability -Please report (suspected) security vulnerabilities to -**[security@ory.sh](mailto:security@ory.sh)**. You will receive a response from -us within 48 hours. If the issue is confirmed, we will release a patch as soon -as possible depending on complexity but historically within a few days. +Please head over to our +[security policy](https://www.ory.sh/docs/ecosystem/security) to learn more +about reporting security vulnerabilities. diff --git a/cmd/clidoc/main.go b/cmd/clidoc/main.go index 7bc0eca6b58e..6cd127cd8a5a 100644 --- a/cmd/clidoc/main.go +++ b/cmd/clidoc/main.go @@ -176,9 +176,9 @@ func init() { "NewErrorValidationLoginLinkedCredentialsDoNotMatch": text.NewErrorValidationLoginLinkedCredentialsDoNotMatch(), "NewErrorValidationAddressUnknown": text.NewErrorValidationAddressUnknown(), "NewInfoSelfServiceLoginCodeMFA": text.NewInfoSelfServiceLoginCodeMFA(), - "NewInfoSelfServiceLoginCodeMFAHint": text.NewInfoSelfServiceLoginCodeMFAHint("{maskedIdentifier}"), "NewInfoLoginPassword": text.NewInfoLoginPassword(), "NewErrorValidationAccountNotFound": text.NewErrorValidationAccountNotFound(), + "NewInfoSelfServiceLoginAAL2CodeAddress": text.NewInfoSelfServiceLoginAAL2CodeAddress("{channel}", "{address}"), } } diff --git a/driver/config/config.go b/driver/config/config.go index 3978c24668ac..52762356fcc2 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -181,6 +181,7 @@ const ( ViperKeyLinkLifespan = "selfservice.methods.link.config.lifespan" ViperKeyLinkBaseURL = "selfservice.methods.link.config.base_url" ViperKeyCodeLifespan = "selfservice.methods.code.config.lifespan" + ViperKeyCodeConfigMissingCredentialFallbackEnabled = "selfservice.methods.code.config.missing_credential_fallback_enabled" ViperKeyPasswordHaveIBeenPwnedHost = "selfservice.methods.password.config.haveibeenpwned_host" ViperKeyPasswordHaveIBeenPwnedEnabled = "selfservice.methods.password.config.haveibeenpwned_enabled" ViperKeyPasswordMaxBreaches = "selfservice.methods.password.config.max_breaches" @@ -1330,6 +1331,10 @@ func (p *Config) SelfServiceCodeMethodLifespan(ctx context.Context) time.Duratio return p.GetProvider(ctx).DurationF(ViperKeyCodeLifespan, time.Hour) } +func (p *Config) SelfServiceCodeMethodMissingCredentialFallbackEnabled(ctx context.Context) bool { + return p.GetProvider(ctx).Bool(ViperKeyCodeConfigMissingCredentialFallbackEnabled) +} + func (p *Config) DatabaseCleanupSleepTables(ctx context.Context) time.Duration { return p.GetProvider(ctx).Duration(ViperKeyDatabaseCleanupSleepTables) } diff --git a/driver/config/testhelpers/config.go b/driver/config/testhelpers/config.go index 6d0e4ba0b910..bf68772d9a3f 100644 --- a/driver/config/testhelpers/config.go +++ b/driver/config/testhelpers/config.go @@ -33,7 +33,9 @@ func (t *TestConfigProvider) Config(ctx context.Context, config *configx.Provide if !ok { return config } + opts := make([]configx.OptionModifier, 0, len(values)) + opts = append(opts, configx.WithValues(config.All())) for _, v := range values { opts = append(opts, configx.WithValues(v)) } diff --git a/driver/registry_default.go b/driver/registry_default.go index c77ab5d783f2..8eef98ed28db 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -797,6 +797,10 @@ func (m *RegistryDefault) LoginCodePersister() code.LoginCodePersister { return m.Persister() } +func (m *RegistryDefault) TransactionalPersisterProvider() x.TransactionalPersister { + return m.Persister() +} + func (m *RegistryDefault) Persister() persistence.Persister { return m.persister } diff --git a/embedx/config.schema.json b/embedx/config.schema.json index a2802ed0a0b7..b7f0c468a963 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -1574,6 +1574,13 @@ "pattern": "^([0-9]+(ns|us|ms|s|m|h))+$", "default": "1h", "examples": ["1h", "1m", "1s"] + }, + "missing_credential_fallback_enabled": { + "type": "boolean", + "title": "Enable Code OTP as a Fallback", + "description": "Enabling this allows users to sign in with the code method, even if their identity schema or their credentials are not set up to use the code method. If enabled, a verified address (such as an email) will be used to send the code to the user. Use with caution and only if actually needed.", + + "default": false } } } diff --git a/go.mod b/go.mod index 53c7ff98683f..f3a839bacaff 100644 --- a/go.mod +++ b/go.mod @@ -110,6 +110,7 @@ require ( github.com/cortesi/moddwatch v0.1.0 // indirect github.com/cortesi/termlog v0.0.0-20210222042314-a1eec763abec // indirect github.com/rjeczalik/notify v0.9.3 // indirect + github.com/wI2L/jsondiff v0.6.0 // indirect golang.org/x/term v0.22.0 // indirect gopkg.in/alecthomas/kingpin.v2 v2.2.6 // indirect mvdan.cc/sh/v3 v3.6.0 // indirect @@ -257,7 +258,7 @@ require ( github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/term v0.5.0 // indirect - github.com/nyaruka/phonenumbers v1.3.6 // indirect + github.com/nyaruka/phonenumbers v1.3.6 github.com/ogier/pflag v0.0.1 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect diff --git a/go.sum b/go.sum index 7e9d0ef27bb4..44ff24e8a182 100644 --- a/go.sum +++ b/go.sum @@ -872,6 +872,8 @@ github.com/toqueteos/webbrowser v1.2.0 h1:tVP/gpK69Fx+qMJKsLE7TD8LuGWPnEV71wBN9r github.com/toqueteos/webbrowser v1.2.0/go.mod h1:XWoZq4cyp9WeUeak7w7LXRUQf1F1ATJMir8RTqb4ayM= github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= +github.com/wI2L/jsondiff v0.6.0 h1:zrsH3FbfVa3JO9llxrcDy/XLkYPLgoMX6Mz3T2PP2AI= +github.com/wI2L/jsondiff v0.6.0/go.mod h1:D6aQ5gKgPF9g17j+E9N7aasmU1O+XvfmWm1y8UMmNpw= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=0.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=0.json new file mode 100644 index 000000000000..9d4512a908be --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=0.json @@ -0,0 +1,6 @@ +{ + "type": "password", + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=1.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=1.json new file mode 100644 index 000000000000..9d4512a908be --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=1.json @@ -0,0 +1,6 @@ +{ + "type": "password", + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=10.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=10.json new file mode 100644 index 000000000000..b189c3022df0 --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=10.json @@ -0,0 +1,18 @@ +{ + "type": "code", + "config": { + "addresses": [ + { + "channel": "sms", + "address": "+4917667111638" + }, + { + "channel": "email", + "address": "foo@ory.sh" + } + ] + }, + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=11.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=11.json new file mode 100644 index 000000000000..b189c3022df0 --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=11.json @@ -0,0 +1,18 @@ +{ + "type": "code", + "config": { + "addresses": [ + { + "channel": "sms", + "address": "+4917667111638" + }, + { + "channel": "email", + "address": "foo@ory.sh" + } + ] + }, + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=2.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=2.json new file mode 100644 index 000000000000..5a19603d9a00 --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=2.json @@ -0,0 +1,6 @@ +{ + "type": "webauthn", + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=3.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=3.json new file mode 100644 index 000000000000..9d4512a908be --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=3.json @@ -0,0 +1,6 @@ +{ + "type": "password", + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=4.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=4.json new file mode 100644 index 000000000000..5a19603d9a00 --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=4.json @@ -0,0 +1,6 @@ +{ + "type": "webauthn", + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=5.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=5.json new file mode 100644 index 000000000000..5a19603d9a00 --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=5.json @@ -0,0 +1,6 @@ +{ + "type": "webauthn", + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=6.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=6.json new file mode 100644 index 000000000000..2663d5b1297a --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=6.json @@ -0,0 +1,14 @@ +{ + "type": "code", + "config": { + "addresses": [ + { + "channel": "email", + "address": "foo@ory.sh" + } + ] + }, + "version": 1, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=7.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=7.json new file mode 100644 index 000000000000..6b8c682b5077 --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=7.json @@ -0,0 +1,14 @@ +{ + "type": "code", + "config": { + "addresses": [ + { + "channel": "email", + "address": "foo@ory.sh" + } + ] + }, + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=8.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=8.json new file mode 100644 index 000000000000..6b8c682b5077 --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=8.json @@ -0,0 +1,14 @@ +{ + "type": "code", + "config": { + "addresses": [ + { + "channel": "email", + "address": "foo@ory.sh" + } + ] + }, + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestSchemaExtensionCredentials-case=9.json b/identity/.snapshots/TestSchemaExtensionCredentials-case=9.json new file mode 100644 index 000000000000..b189c3022df0 --- /dev/null +++ b/identity/.snapshots/TestSchemaExtensionCredentials-case=9.json @@ -0,0 +1,18 @@ +{ + "type": "code", + "config": { + "addresses": [ + { + "channel": "sms", + "address": "+4917667111638" + }, + { + "channel": "email", + "address": "foo@ory.sh" + } + ] + }, + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_correct_value.json b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_correct_value.json new file mode 100644 index 000000000000..3673c6943f03 --- /dev/null +++ b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_correct_value.json @@ -0,0 +1,30 @@ +{ + "id": "4d64fa08-20fc-450d-bebd-ebd7c7b6e249", + "credentials": { + "code": { + "type": "code", + "identifiers": [ + "hi@example.org" + ], + "config": { + "addresses": [ + { + "channel": "email", + "address": "hi@example.org" + } + ] + }, + "version": 1, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" + } + }, + "schema_id": "", + "schema_url": "", + "state": "", + "traits": null, + "metadata_public": null, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z", + "organization_id": null +} diff --git a/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_email_empty_space_value-with_one_identifier.json b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_email_empty_space_value-with_one_identifier.json new file mode 100644 index 000000000000..3673c6943f03 --- /dev/null +++ b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_email_empty_space_value-with_one_identifier.json @@ -0,0 +1,30 @@ +{ + "id": "4d64fa08-20fc-450d-bebd-ebd7c7b6e249", + "credentials": { + "code": { + "type": "code", + "identifiers": [ + "hi@example.org" + ], + "config": { + "addresses": [ + { + "channel": "email", + "address": "hi@example.org" + } + ] + }, + "version": 1, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" + } + }, + "schema_id": "", + "schema_url": "", + "state": "", + "traits": null, + "metadata_public": null, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z", + "organization_id": null +} diff --git a/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_email_empty_space_value-with_two_identifiers.json b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_email_empty_space_value-with_two_identifiers.json new file mode 100644 index 000000000000..9163f3c3886b --- /dev/null +++ b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_email_empty_space_value-with_two_identifiers.json @@ -0,0 +1,35 @@ +{ + "id": "4d64fa08-20fc-450d-bebd-ebd7c7b6e249", + "credentials": { + "code": { + "type": "code", + "identifiers": [ + "foo@example.org", + "bar@example.org" + ], + "config": { + "addresses": [ + { + "channel": "email", + "address": "foo@example.org" + }, + { + "channel": "email", + "address": "bar@example.org" + } + ] + }, + "version": 1, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" + } + }, + "schema_id": "", + "schema_url": "", + "state": "", + "traits": null, + "metadata_public": null, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z", + "organization_id": null +} diff --git a/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_empty_value.json b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_empty_value.json new file mode 100644 index 000000000000..3673c6943f03 --- /dev/null +++ b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_empty_value.json @@ -0,0 +1,30 @@ +{ + "id": "4d64fa08-20fc-450d-bebd-ebd7c7b6e249", + "credentials": { + "code": { + "type": "code", + "identifiers": [ + "hi@example.org" + ], + "config": { + "addresses": [ + { + "channel": "email", + "address": "hi@example.org" + } + ] + }, + "version": 1, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" + } + }, + "schema_id": "", + "schema_url": "", + "state": "", + "traits": null, + "metadata_public": null, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z", + "organization_id": null +} diff --git a/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_unknown_value.json b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_unknown_value.json new file mode 100644 index 000000000000..3673c6943f03 --- /dev/null +++ b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v0_with_unknown_value.json @@ -0,0 +1,30 @@ +{ + "id": "4d64fa08-20fc-450d-bebd-ebd7c7b6e249", + "credentials": { + "code": { + "type": "code", + "identifiers": [ + "hi@example.org" + ], + "config": { + "addresses": [ + { + "channel": "email", + "address": "hi@example.org" + } + ] + }, + "version": 1, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" + } + }, + "schema_id": "", + "schema_url": "", + "state": "", + "traits": null, + "metadata_public": null, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z", + "organization_id": null +} diff --git a/identity/.snapshots/TestUpgradeCredentials-type=code-from=v2_with_empty_value.json b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v2_with_empty_value.json new file mode 100644 index 000000000000..a52b01ffd526 --- /dev/null +++ b/identity/.snapshots/TestUpgradeCredentials-type=code-from=v2_with_empty_value.json @@ -0,0 +1,35 @@ +{ + "id": "4d64fa08-20fc-450d-bebd-ebd7c7b6e249", + "credentials": { + "code": { + "type": "code", + "identifiers": [ + "foo@example.org", + "+12341234" + ], + "config": { + "addresses": [ + { + "address": "foo@example.org", + "channel": "email" + }, + { + "address": "+12341234", + "channel": "sms" + } + ] + }, + "version": 1, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" + } + }, + "schema_id": "", + "schema_url": "", + "state": "", + "traits": null, + "metadata_public": null, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z", + "organization_id": null +} diff --git a/identity/.snapshots/TestUpgradeCredentials-type=webauthn-from=v1.json b/identity/.snapshots/TestUpgradeCredentials-type=webauthn-from=v1.json index 60038539418a..4931ed3f972d 100644 --- a/identity/.snapshots/TestUpgradeCredentials-type=webauthn-from=v1.json +++ b/identity/.snapshots/TestUpgradeCredentials-type=webauthn-from=v1.json @@ -3,7 +3,7 @@ "credentials": { "webauthn": { "type": "webauthn", - "identifiers": null, + "identifiers": [], "config": { "credentials": [ { diff --git a/identity/address.go b/identity/address.go index 2e9175642e84..ae7ab83e38e7 100644 --- a/identity/address.go +++ b/identity/address.go @@ -5,4 +5,5 @@ package identity const ( AddressTypeEmail = "email" + AddressTypeSMS = "sms" ) diff --git a/identity/credentials.go b/identity/credentials.go index 3cc910c5a74e..9fc2d93851bb 100644 --- a/identity/credentials.go +++ b/identity/credentials.go @@ -10,6 +10,7 @@ import ( "time" "github.com/gofrs/uuid" + "github.com/wI2L/jsondiff" "github.com/ory/kratos/ui/node" "github.com/ory/x/sqlxx" @@ -215,8 +216,8 @@ type ( // swagger:ignore ActiveCredentialsCounter interface { ID() CredentialsType - CountActiveFirstFactorCredentials(cc map[CredentialsType]Credentials) (int, error) - CountActiveMultiFactorCredentials(cc map[CredentialsType]Credentials) (int, error) + CountActiveFirstFactorCredentials(context.Context, map[CredentialsType]Credentials) (int, error) + CountActiveMultiFactorCredentials(context.Context, map[CredentialsType]Credentials) (int, error) } // swagger:ignore @@ -248,7 +249,13 @@ func CredentialsEqual(a, b map[CredentialsType]Credentials) bool { return false } - if string(expect.Config) != string(actual.Config) { + // Try to normalize configs (remove spaces etc). + patch, err := jsondiff.CompareJSON(actual.Config, expect.Config) + if err != nil { + return false + } + + if len(patch) > 0 { return false } diff --git a/identity/credentials_code.go b/identity/credentials_code.go index b7f828ae09a1..e3e5f174f84c 100644 --- a/identity/credentials_code.go +++ b/identity/credentials_code.go @@ -4,22 +4,63 @@ package identity import ( - "database/sql" + "encoding/json" + + "github.com/ory/herodot" + + "github.com/pkg/errors" + + "github.com/ory/x/stringsx" ) -type CodeAddressType = string +type CodeChannel string const ( - CodeAddressTypeEmail CodeAddressType = AddressTypeEmail + CodeChannelEmail CodeChannel = AddressTypeEmail + CodeChannelSMS CodeChannel = AddressTypeSMS ) +func NewCodeChannel(value string) (CodeChannel, error) { + switch f := stringsx.SwitchExact(value); { + case f.AddCase(string(CodeChannelEmail)): + return CodeChannelEmail, nil + case f.AddCase(string(CodeChannelSMS)): + return CodeChannelSMS, nil + default: + return "", errors.Wrap(ErrInvalidCodeAddressType, f.ToUnknownCaseErr().Error()) + } +} + // CredentialsCode represents a one time login/registration code // // swagger:model identityCredentialsCode type CredentialsCode struct { + Addresses []CredentialsCodeAddress `json:"addresses"` +} + +// swagger:model identityCredentialsCodeAddress +type CredentialsCodeAddress struct { // The type of the address for this code - AddressType CodeAddressType `json:"address_type"` + Channel CodeChannel `json:"channel"` + + // The address for this code + Address string `json:"address"` +} + +var ErrInvalidCodeAddressType = herodot.ErrInternalServerError.WithReasonf("The address type for sending OTP codes is not supported.") + +func (c *CredentialsCodeAddress) UnmarshalJSON(data []byte) (err error) { + type alias CredentialsCodeAddress + var ac alias + if err := json.Unmarshal(data, &ac); err != nil { + return err + } + + ac.Channel, err = NewCodeChannel(string(ac.Channel)) + if err != nil { + return err + } - // UsedAt indicates whether and when a recovery code was used. - UsedAt sql.NullTime `json:"used_at,omitempty"` + *c = CredentialsCodeAddress(ac) + return nil } diff --git a/identity/credentials_code_test.go b/identity/credentials_code_test.go new file mode 100644 index 000000000000..475885024faa --- /dev/null +++ b/identity/credentials_code_test.go @@ -0,0 +1,110 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package identity + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCredentialsCodeAddressUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + want CredentialsCodeAddress + wantErr bool + }{ + { + name: "valid email address", + input: `{"channel": "email", "address": "user@example.com"}`, + want: CredentialsCodeAddress{ + Channel: CodeChannelEmail, + Address: "user@example.com", + }, + wantErr: false, + }, + { + name: "valid SMS address", + input: `{"channel": "sms", "address": "+1234567890"}`, + want: CredentialsCodeAddress{ + Channel: CodeChannelSMS, + Address: "+1234567890", + }, + wantErr: false, + }, + { + name: "invalid address type", + input: `{"channel": "invalid", "address": "user@example.com"}`, + want: CredentialsCodeAddress{}, + wantErr: true, + }, + { + name: "missing channel field", + input: `{"address": "user@example.com"}`, + want: CredentialsCodeAddress{}, + wantErr: true, + }, + { + name: "invalid JSON structure", + input: `{"channel": "email", "address": "user@example.com"`, + want: CredentialsCodeAddress{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got CredentialsCodeAddress + err := json.Unmarshal([]byte(tt.input), &got) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestNewCodeAddressType(t *testing.T) { + tests := []struct { + name string + input string + want CodeChannel + wantErr bool + }{ + { + name: "valid email address type", + input: "email", + want: CodeChannelEmail, + wantErr: false, + }, + { + name: "valid SMS address type", + input: "sms", + want: CodeChannelSMS, + wantErr: false, + }, + { + name: "invalid address type", + input: "invalid", + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewCodeChannel(tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} diff --git a/identity/credentials_migrate.go b/identity/credentials_migrate.go index f856be135a29..dea1a8a44b8d 100644 --- a/identity/credentials_migrate.go +++ b/identity/credentials_migrate.go @@ -6,6 +6,7 @@ package identity import ( "encoding/json" "fmt" + "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -60,7 +61,61 @@ func UpgradeCredentials(i *Identity) error { if err := UpgradeWebAuthnCredentials(i, &c); err != nil { return errors.WithStack(err) } + if err := UpgradeCodeCredentials(&c); err != nil { + return errors.WithStack(err) + } i.Credentials[k] = c } return nil } + +func UpgradeCodeCredentials(c *Credentials) (err error) { + if c.Type != CredentialsTypeCodeAuth { + return nil + } + + version := c.Version + if version == 0 { + addressType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(c.Config, "address_type").String())) + + channel, err := NewCodeChannel(addressType) + if err != nil { + // We know that in some cases the address type can be empty. In this case, we default to email + // as sms is a new addition to the address_type introduced in this PR. + channel = CodeChannelEmail + } + + c.Config, err = sjson.DeleteBytes(c.Config, "used_at") + if err != nil { + return errors.WithStack(err) + } + + c.Config, err = sjson.DeleteBytes(c.Config, "address_type") + if err != nil { + return errors.WithStack(err) + } + + for _, id := range c.Identifiers { + if id == "" { + continue + } + + c.Config, err = sjson.SetBytes(c.Config, "addresses.-1", &CredentialsCodeAddress{ + Address: id, + Channel: channel, + }) + if err != nil { + return errors.WithStack(err) + } + } + + // This is needed because sjson adds spaces which can trip string comparisons. + c.Config, err = json.Marshal(json.RawMessage(c.Config)) + if err != nil { + return errors.WithStack(err) + } + + c.Version = 1 + } + return nil +} diff --git a/identity/credentials_migrate_test.go b/identity/credentials_migrate_test.go index 2916bc892bde..43291d63e771 100644 --- a/identity/credentials_migrate_test.go +++ b/identity/credentials_migrate_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -30,43 +31,63 @@ func TestUpgradeCredentials(t *testing.T) { snapshotx.SnapshotTExcept(t, &wc, nil) }) - identityID := uuid.FromStringOrNil("4d64fa08-20fc-450d-bebd-ebd7c7b6e249") + run := func(t *testing.T, identifiers []string, config string, version int, credentialsType CredentialsType, expectedVersion int) { + if identifiers == nil { + identifiers = []string{"hi@example.org"} + } + i := &Identity{ + ID: uuid.FromStringOrNil("4d64fa08-20fc-450d-bebd-ebd7c7b6e249"), + Credentials: map[CredentialsType]Credentials{ + credentialsType: { + Identifiers: identifiers, + Type: credentialsType, + Version: version, + Config: []byte(config), + }}, + } + + require.NoError(t, UpgradeCredentials(i)) + wc := WithCredentialsAndAdminMetadataInJSON(*i) + snapshotx.SnapshotT(t, &wc) + assert.Equal(t, expectedVersion, i.Credentials[credentialsType].Version) + } + + t.Run("type=code", func(t *testing.T) { + t.Run("from=v0 with email empty space value", func(t *testing.T) { + t.Run("with one identifier", func(t *testing.T) { + run(t, nil, `{"address_type": "email ", "used_at": {"Time": "0001-01-01T00:00:00Z", "Valid": false}}`, 0, CredentialsTypeCodeAuth, 1) + }) + + t.Run("with two identifiers", func(t *testing.T) { + run(t, []string{"foo@example.org", "bar@example.org"}, `{"address_type": "email ", "used_at": {"Time": "0001-01-01T00:00:00Z", "Valid": false}}`, 0, CredentialsTypeCodeAuth, 1) + }) + }) + + t.Run("from=v0 with empty value", func(t *testing.T) { + run(t, nil, `{"address_type": "", "used_at": {"Time": "0001-01-01T00:00:00Z", "Valid": false}}`, 0, CredentialsTypeCodeAuth, 1) + }) + + t.Run("from=v0 with correct value", func(t *testing.T) { + run(t, nil, `{"address_type": "email", "used_at": {"Time": "0001-01-01T00:00:00Z", "Valid": false}}`, 0, CredentialsTypeCodeAuth, 1) + }) + + t.Run("from=v0 with unknown value", func(t *testing.T) { + run(t, nil, `{"address_type": "other", "used_at": {"Time": "0001-01-01T00:00:00Z", "Valid": false}}`, 0, CredentialsTypeCodeAuth, 1) + }) + + t.Run("from=v2 with empty value", func(t *testing.T) { + run(t, []string{"foo@example.org", "+12341234"}, `{"addresses": [{"address":"foo@example.org","channel":"email"},{"address":"+12341234","channel":"sms"}]}`, 1, CredentialsTypeCodeAuth, 1) + }) + }) + t.Run("type=webauthn", func(t *testing.T) { t.Run("from=v0", func(t *testing.T) { - i := &Identity{ - ID: identityID, - Credentials: map[CredentialsType]Credentials{ - CredentialsTypeWebAuthn: { - Identifiers: []string{"4d64fa08-20fc-450d-bebd-ebd7c7b6e249"}, - Type: CredentialsTypeWebAuthn, - Version: 0, - Config: webAuthnV0, - }}, - } - - require.NoError(t, UpgradeCredentials(i)) - wc := WithCredentialsAndAdminMetadataInJSON(*i) - snapshotx.SnapshotTExcept(t, &wc, nil) - - assert.Equal(t, 1, i.Credentials[CredentialsTypeWebAuthn].Version) + run(t, []string{"4d64fa08-20fc-450d-bebd-ebd7c7b6e249"}, string(webAuthnV0), 0, CredentialsTypeWebAuthn, 1) }) t.Run("from=v1", func(t *testing.T) { - i := &Identity{ - ID: identityID, - Credentials: map[CredentialsType]Credentials{ - CredentialsTypeWebAuthn: { - Type: CredentialsTypeWebAuthn, - Version: 1, - Config: webAuthnV1, - }}, - } - - require.NoError(t, UpgradeCredentials(i)) - wc := WithCredentialsAndAdminMetadataInJSON(*i) - snapshotx.SnapshotTExcept(t, &wc, nil) - - assert.Equal(t, 1, i.Credentials[CredentialsTypeWebAuthn].Version) + + run(t, []string{}, string(webAuthnV1), 1, CredentialsTypeWebAuthn, 1) }) }) } diff --git a/identity/extension_credentials.go b/identity/extension_credentials.go index 3baa826b2e9c..faae191b89ca 100644 --- a/identity/extension_credentials.go +++ b/identity/extension_credentials.go @@ -4,16 +4,21 @@ package identity import ( + "encoding/json" "fmt" + "sort" "strings" "sync" + "github.com/pkg/errors" + "github.com/samber/lo" + + "github.com/ory/kratos/x" + "github.com/ory/jsonschema/v3" + "github.com/ory/kratos/schema" "github.com/ory/x/sqlxx" "github.com/ory/x/stringslice" - "github.com/ory/x/stringsx" - - "github.com/ory/kratos/schema" ) type SchemaExtensionCredentials struct { @@ -35,6 +40,7 @@ func (r *SchemaExtensionCredentials) setIdentifier(ct CredentialsType, value int Config: sqlxx.JSONRawMessage{}, } } + if r.v == nil { r.v = make(map[CredentialsType][]string) } @@ -57,22 +63,71 @@ func (r *SchemaExtensionCredentials) Run(ctx jsonschema.ValidationContext, s sch } if s.Credentials.Code.Identifier { - switch f := stringsx.SwitchExact(s.Credentials.Code.Via); { - case f.AddCase(AddressTypeEmail): - if !jsonschema.Formats["email"](value) { - return ctx.Error("format", "%q is not a valid %q", value, s.Credentials.Code.Via) + via, err := NewCodeChannel(s.Credentials.Code.Via) + if err != nil { + return ctx.Error("ory.sh~/kratos/credentials/code/via", "channel type %q must be one of %s", s.Credentials.Code.Via, strings.Join([]string{ + string(CodeChannelEmail), + string(CodeChannelSMS), + }, ", ")) + } + + cred := r.i.GetCredentialsOr(CredentialsTypeCodeAuth, &Credentials{ + Type: CredentialsTypeCodeAuth, + Identifiers: []string{}, + Config: sqlxx.JSONRawMessage("{}"), + Version: 1, + }) + + var conf CredentialsCode + if len(cred.Config) > 0 { + // Only decode the config if it is not empty. + if err := json.Unmarshal(cred.Config, &conf); err != nil { + return &jsonschema.ValidationError{Message: "unable to unmarshal identity credentials"} } + } - r.setIdentifier(CredentialsTypeCodeAuth, value) - // case f.AddCase(AddressTypePhone): - // if !jsonschema.Formats["tel"](value) { - // return ctx.Error("format", "%q is not a valid %q", value, s.Credentials.Code.Via) - // } + if conf.Addresses == nil { + conf.Addresses = []CredentialsCodeAddress{} + } - // r.setIdentifier(CredentialsTypeCodeAuth, value, CredentialsIdentifierAddressTypePhone) - default: - return ctx.Error("", "credentials.code.via has unknown value %q", s.Credentials.Code.Via) + value, err := x.NormalizeIdentifier(fmt.Sprintf("%s", value), string(via)) + if err != nil { + return &jsonschema.ValidationError{Message: err.Error()} } + + conf.Addresses = append(conf.Addresses, CredentialsCodeAddress{ + Channel: via, + Address: value, + }) + + conf.Addresses = lo.UniqBy(conf.Addresses, func(item CredentialsCodeAddress) string { + return fmt.Sprintf("%x:%s", item.Address, item.Channel) + }) + + sort.SliceStable(conf.Addresses, func(i, j int) bool { + if conf.Addresses[i].Address == conf.Addresses[j].Address { + return conf.Addresses[i].Channel < conf.Addresses[j].Channel + } + return conf.Addresses[i].Address < conf.Addresses[j].Address + }) + + if r.v == nil { + r.v = make(map[CredentialsType][]string) + } + + r.v[CredentialsTypeCodeAuth] = stringslice.Unique(append(r.v[CredentialsTypeCodeAuth], + lo.Map(conf.Addresses, func(item CredentialsCodeAddress, _ int) string { + return item.Address + })..., + )) + + cred.Identifiers = r.v[CredentialsTypeCodeAuth] + cred.Config, err = json.Marshal(conf) + if err != nil { + return errors.WithStack(err) + } + + r.i.SetCredentials(CredentialsTypeCodeAuth, *cred) } return nil diff --git a/identity/extension_credentials_test.go b/identity/extension_credentials_test.go index 95cd9d000c6a..fd44bde801c8 100644 --- a/identity/extension_credentials_test.go +++ b/identity/extension_credentials_test.go @@ -9,6 +9,8 @@ import ( "fmt" "testing" + "github.com/ory/x/snapshotx" + "github.com/ory/jsonschema/v3" _ "github.com/ory/jsonschema/v3/fileloader" @@ -87,6 +89,42 @@ func TestSchemaExtensionCredentials(t *testing.T) { }, ct: identity.CredentialsTypeCodeAuth, }, + { + doc: `{"email":"FOO@ory.sh"}`, + schema: "file://./stub/extension/credentials/code.schema.json", + expect: []string{"foo@ory.sh"}, + existing: &identity.Credentials{ + Identifiers: []string{"not-foo@ory.sh", "foo@ory.sh"}, + }, + ct: identity.CredentialsTypeCodeAuth, + }, + { + doc: `{"email":"FOO@ory.sh","phone":"+49 176 671 11 638"}`, + schema: "file://./stub/extension/credentials/code-phone-email.schema.json", + expect: []string{"+4917667111638", "foo@ory.sh"}, + existing: &identity.Credentials{ + Identifiers: []string{"not-foo@ory.sh", "foo@ory.sh"}, + }, + ct: identity.CredentialsTypeCodeAuth, + }, + { + doc: `{"email":"FOO@ory.sh","phone":"+49 176 671 11 638"}`, + schema: "file://./stub/extension/credentials/code-phone-email.schema.json", + expect: []string{"+4917667111638", "foo@ory.sh"}, + existing: &identity.Credentials{ + Identifiers: []string{"not-foo@ory.sh", "foo@ory.sh"}, + }, + ct: identity.CredentialsTypeCodeAuth, + }, + { + doc: `{"email":"FOO@ory.sh","email2":"FOO@ory.sh","phone":"+49 176 671 11 638"}`, + schema: "file://./stub/extension/credentials/code-phone-email.schema.json", + expect: []string{"+4917667111638", "foo@ory.sh"}, + existing: &identity.Credentials{ + Identifiers: []string{"not-foo@ory.sh", "fOo@ory.sh"}, + }, + ct: identity.CredentialsTypeCodeAuth, + }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { c := jsonschema.NewCompiler() @@ -103,12 +141,15 @@ func TestSchemaExtensionCredentials(t *testing.T) { err = c.MustCompile(ctx, tc.schema).Validate(bytes.NewBufferString(tc.doc)) if tc.expectErr != nil { require.EqualError(t, err, tc.expectErr.Error()) + } else { + require.NoError(t, err) } require.NoError(t, e.Finish()) credentials, ok := i.GetCredentials(tc.ct) require.True(t, ok) assert.ElementsMatch(t, tc.expect, credentials.Identifiers) + snapshotx.SnapshotT(t, credentials, snapshotx.ExceptPaths("identifiers")) }) } } diff --git a/identity/handler_import.go b/identity/handler_import.go index cca346e8e676..babb09579af1 100644 --- a/identity/handler_import.go +++ b/identity/handler_import.go @@ -19,7 +19,15 @@ func (h *Handler) importCredentials(ctx context.Context, i *Identity, creds *Ide return nil } + // This method only support password and OIDC import at the moment. + // If other methods are added please ensure that the available AAL is set correctly in the identity. + // + // It would actually be good if we would validate the identity post-creation to see if the credentials are working. if creds.Password != nil { + // This method is somewhat hacky, because it does not set the credential's identifier. It relies on the + // identity validation to set the identifier, which is called after this method. + // + // It would be good to make this explicit. if err := h.importPasswordCredentials(ctx, i, creds.Password); err != nil { return err } diff --git a/identity/identity.go b/identity/identity.go index 3f692b831f2b..0277772c4708 100644 --- a/identity/identity.go +++ b/identity/identity.go @@ -68,9 +68,17 @@ type Identity struct { // Credentials represents all credentials that can be used for authenticating this identity. Credentials map[CredentialsType]Credentials `json:"credentials,omitempty" faker:"-" db:"-"` - // AvailableAAL defines the maximum available AAL for this identity. If the user has only a password - // configured, the AAL will be 1. If the user has a password and a TOTP configured, the AAL will be 2. - AvailableAAL NullableAuthenticatorAssuranceLevel `json:"-" faker:"-" db:"available_aal"` + // InternalAvailableAAL defines the maximum available AAL for this identity. + // + // - If the user has at least one two-factor authentication method configured, the AAL will be 2. + // - If the user has only a password configured, the AAL will be 1. + // + // This field is AAL2 as soon as a second factor credential is found. A first factor is not required for this + // field to return `aal2`. + // + // This field is primarily used to determine whether the user needs to upgrade to AAL2 without having to check + // all the credentials in the database. Use with caution! + InternalAvailableAAL NullableAuthenticatorAssuranceLevel `json:"-" faker:"-" db:"available_aal"` // // IdentifierCredentials contains the access and refresh token for oidc identifier // IdentifierCredentials []IdentifierCredential `json:"identifier_credentials,omitempty" faker:"-" db:"-"` @@ -345,24 +353,27 @@ func (i *Identity) UnmarshalJSON(b []byte) error { return err } +// SetAvailableAAL sets the InternalAvailableAAL field based on the credentials stored in the identity. +// +// If a second factor is set up, the AAL will be set to 2. If only a first factor is set up, the AAL will be set to 1. +// +// A first factor is NOT required for the AAL to be set to 2 if a second factor is set up. func (i *Identity) SetAvailableAAL(ctx context.Context, m *Manager) (err error) { - i.AvailableAAL = NewNullableAuthenticatorAssuranceLevel(NoAuthenticatorAssuranceLevel) - if c, err := m.CountActiveFirstFactorCredentials(ctx, i); err != nil { + if c, err := m.CountActiveMultiFactorCredentials(ctx, i); err != nil { return err - } else if c == 0 { - // No first factor set up - AAL is 0 + } else if c > 0 { + i.InternalAvailableAAL = NewNullableAuthenticatorAssuranceLevel(AuthenticatorAssuranceLevel2) return nil } - i.AvailableAAL = NewNullableAuthenticatorAssuranceLevel(AuthenticatorAssuranceLevel1) - if c, err := m.CountActiveMultiFactorCredentials(ctx, i); err != nil { + if c, err := m.CountActiveFirstFactorCredentials(ctx, i); err != nil { return err - } else if c == 0 { - // No second factor set up - AAL is 1 + } else if c > 0 { + i.InternalAvailableAAL = NewNullableAuthenticatorAssuranceLevel(AuthenticatorAssuranceLevel1) return nil } - i.AvailableAAL = NewNullableAuthenticatorAssuranceLevel(AuthenticatorAssuranceLevel2) + i.InternalAvailableAAL = NewNullableAuthenticatorAssuranceLevel(NoAuthenticatorAssuranceLevel) return nil } diff --git a/identity/manager.go b/identity/manager.go index 04fb3edae500..2429df260fbe 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -29,13 +29,13 @@ import ( "github.com/ory/herodot" "github.com/ory/jsonschema/v3" - "github.com/ory/x/errorsx" - "github.com/ory/kratos/courier" ) -var ErrProtectedFieldModified = herodot.ErrForbidden. - WithReasonf(`A field was modified that updates one or more credentials-related settings. This action was blocked because an unprivileged method was used to execute the update. This is either a configuration issue or a bug and should be reported to the system administrator.`) +var ( + ErrProtectedFieldModified = herodot.ErrForbidden. + WithReasonf(`A field was modified that updates one or more credentials-related settings. This action was blocked because an unprivileged method was used to execute the update. This is either a configuration issue or a bug and should be reported to the system administrator.`) +) type ( managerDependencies interface { @@ -96,10 +96,6 @@ func (m *Manager) Create(ctx context.Context, i *Identity, opts ...ManagerOption return err } - if err := i.SetAvailableAAL(ctx, m); err != nil { - return err - } - if err := m.r.PrivilegedIdentityPool().CreateIdentity(ctx, i); err != nil { if errors.Is(err, sqlcon.ErrUniqueViolation) { return m.findExistingAuthMethod(ctx, err, i) @@ -329,10 +325,6 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, i.SchemaID = m.r.Config().DefaultIdentityTraitsSchemaID(ctx) } - if err := i.SetAvailableAAL(ctx, m); err != nil { - return err - } - o := newManagerOptions(opts) if err := m.ValidateIdentity(ctx, i, o); err != nil { return err @@ -358,6 +350,7 @@ func (m *Manager) requiresPrivilegedAccess(ctx context.Context, original, update if !CredentialsEqual(updated.Credentials, original.Credentials) { // reset the identity *updated = *original + return errors.WithStack(ErrProtectedFieldModified) } @@ -390,10 +383,6 @@ func (m *Manager) Update(ctx context.Context, updated *Identity, opts ...Manager return err } - if err := updated.SetAvailableAAL(ctx, m); err != nil { - return err - } - return m.r.PrivilegedIdentityPool().UpdateIdentity(ctx, updated) } @@ -444,6 +433,30 @@ func (m *Manager) SetTraits(ctx context.Context, id uuid.UUID, traits Traits, op return updated, nil } +// RefreshAvailableAAL refreshes the available AAL for the identity. +// +// This method is a no-op if everything is up-to date. +// +// Please make sure to load all credentials before using this method. +func (m *Manager) RefreshAvailableAAL(ctx context.Context, i *Identity) (err error) { + if len(i.Credentials) == 0 { + if err := m.r.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, i, ExpandCredentials); err != nil { + return err + } + } + + aalBefore := i.InternalAvailableAAL + if err := i.SetAvailableAAL(ctx, m); err != nil { + return err + } + + if aalBefore.String != i.InternalAvailableAAL.String || aalBefore.Valid != i.InternalAvailableAAL.Valid { + return m.r.PrivilegedIdentityPool().UpdateIdentityColumns(ctx, i, "available_aal") + } + + return nil +} + func (m *Manager) UpdateTraits(ctx context.Context, id uuid.UUID, traits Traits, opts ...ManagerOption) (err error) { ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.UpdateTraits") defer otelx.End(span, &err) @@ -458,17 +471,18 @@ func (m *Manager) UpdateTraits(ctx context.Context, id uuid.UUID, traits Traits, } func (m *Manager) ValidateIdentity(ctx context.Context, i *Identity, o *ManagerOptions) (err error) { - // This trace is more noisy than it's worth in diagnostic power. - // ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.validate") - // defer otelx.End(span, &err) - if err := m.r.IdentityValidator().Validate(ctx, i); err != nil { - if _, ok := errorsx.Cause(err).(*jsonschema.ValidationError); ok && !o.ExposeValidationErrors { + var validationErr *jsonschema.ValidationError + if errors.As(err, &validationErr) && !o.ExposeValidationErrors { return herodot.ErrBadRequest.WithReasonf("%s", err).WithWrap(err) } return err } + if err := i.SetAvailableAAL(ctx, m); err != nil { + return err + } + return nil } @@ -478,7 +492,7 @@ func (m *Manager) CountActiveFirstFactorCredentials(ctx context.Context, i *Iden // defer otelx.End(span, &err) for _, strategy := range m.r.ActiveCredentialsCounterStrategies(ctx) { - current, err := strategy.CountActiveFirstFactorCredentials(i.Credentials) + current, err := strategy.CountActiveFirstFactorCredentials(ctx, i.Credentials) if err != nil { return 0, err } @@ -494,7 +508,7 @@ func (m *Manager) CountActiveMultiFactorCredentials(ctx context.Context, i *Iden // defer otelx.End(span, &err) for _, strategy := range m.r.ActiveCredentialsCounterStrategies(ctx) { - current, err := strategy.CountActiveMultiFactorCredentials(i.Credentials) + current, err := strategy.CountActiveMultiFactorCredentials(ctx, i.Credentials) if err != nil { return 0, err } diff --git a/identity/manager_test.go b/identity/manager_test.go index f45a3f05e4fc..dcd2b2762c8f 100644 --- a/identity/manager_test.go +++ b/identity/manager_test.go @@ -4,6 +4,7 @@ package identity_test import ( + "encoding/json" "fmt" "testing" "time" @@ -12,6 +13,8 @@ import ( "github.com/ory/x/pointerx" "github.com/ory/x/sqlcon" + _ "embed" + "github.com/gofrs/uuid" "github.com/ory/x/sqlxx" @@ -28,6 +31,9 @@ import ( "github.com/ory/kratos/x" ) +//go:embed stub/aal.json +var refreshAALStubs []byte + func TestManager(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t, configx.WithValues(map[string]interface{}{ config.ViperKeyPublicBaseURL: "https://www.ory.sh/", @@ -70,6 +76,37 @@ func TestManager(t *testing.T) { } } + t.Run("method=CreateIdentities", func(t *testing.T) { + t.Run("case=should set AAL to 2 if password and TOTP is set", func(t *testing.T) { + email := uuid.Must(uuid.NewV4()).String() + "@ory.sh" + original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) + original.Traits = newTraits(email, "") + original.Credentials = map[identity.CredentialsType]identity.Credentials{ + identity.CredentialsTypePassword: { + Type: identity.CredentialsTypePassword, + // By explicitly not setting the identifier, we mimic the behavior of the PATCH endpoint. + // This tests a bug we introduced on the PATCH endpoint where the AAL value would not be correct. + Identifiers: []string{}, + Config: sqlxx.JSONRawMessage(`{"hashed_password":"$2a$08$.cOYmAd.vCpDOoiVJrO5B.hjTLKQQ6cAK40u8uB.FnZDyPvVvQ9Q."}`), + }, + identity.CredentialsTypeTOTP: { + Type: identity.CredentialsTypeTOTP, + // By explicitly not setting the identifier, we mimic the behavior of the PATCH endpoint. + // This tests a bug we introduced on the PATCH endpoint where the AAL value would not be correct. + Identifiers: []string{}, + Config: sqlxx.JSONRawMessage(`{"totp_url":"otpauth://totp/test"}`), + }, + } + require.NoError(t, reg.IdentityManager().CreateIdentities(ctx, []*identity.Identity{original})) + fromStore, err := reg.PrivilegedIdentityPool().GetIdentity(ctx, original.ID, identity.ExpandNothing) + require.NoError(t, err) + + got, ok := fromStore.InternalAvailableAAL.ToAAL() + require.True(t, ok) + assert.Equal(t, identity.AuthenticatorAssuranceLevel2, got) + }) + }) + t.Run("method=Create", func(t *testing.T) { t.Run("case=should create identity and track extension fields", func(t *testing.T) { email := uuid.Must(uuid.NewV4()).String() + "@ory.sh" @@ -77,7 +114,7 @@ func TestManager(t *testing.T) { original.Traits = newTraits(email, "") require.NoError(t, reg.IdentityManager().Create(ctx, original)) checkExtensionFieldsForIdentities(t, email, original) - got, ok := original.AvailableAAL.ToAAL() + got, ok := original.InternalAvailableAAL.ToAAL() require.True(t, ok) assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, got) }) @@ -88,7 +125,7 @@ func TestManager(t *testing.T) { original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) original.Traits = newTraits(email, "") require.NoError(t, reg.IdentityManager().Create(ctx, original)) - got, ok := original.AvailableAAL.ToAAL() + got, ok := original.InternalAvailableAAL.ToAAL() require.True(t, ok) assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, got) }) @@ -105,7 +142,7 @@ func TestManager(t *testing.T) { }, } require.NoError(t, reg.IdentityManager().Create(ctx, original)) - got, ok := original.AvailableAAL.ToAAL() + got, ok := original.InternalAvailableAAL.ToAAL() require.True(t, ok) assert.Equal(t, identity.AuthenticatorAssuranceLevel1, got) }) @@ -127,12 +164,12 @@ func TestManager(t *testing.T) { }, } require.NoError(t, reg.IdentityManager().Create(ctx, original)) - got, ok := original.AvailableAAL.ToAAL() + got, ok := original.InternalAvailableAAL.ToAAL() require.True(t, ok) assert.Equal(t, identity.AuthenticatorAssuranceLevel2, got) }) - t.Run("case=should set AAL to 0 if only TOTP is set", func(t *testing.T) { + t.Run("case=should set AAL to 2 if only TOTP is set", func(t *testing.T) { email := uuid.Must(uuid.NewV4()).String() + "@ory.sh" original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) original.Traits = newTraits(email, "") @@ -144,9 +181,9 @@ func TestManager(t *testing.T) { }, } require.NoError(t, reg.IdentityManager().Create(ctx, original)) - got, ok := original.AvailableAAL.ToAAL() + got, ok := original.InternalAvailableAAL.ToAAL() require.True(t, ok) - assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, got) + assert.Equal(t, identity.AuthenticatorAssuranceLevel2, got) }) }) @@ -378,7 +415,7 @@ func TestManager(t *testing.T) { }, } require.NoError(t, reg.IdentityManager().Update(ctx, original, identity.ManagerAllowWriteProtectedTraits)) - assert.EqualValues(t, identity.AuthenticatorAssuranceLevel1, original.AvailableAAL.String) + assert.EqualValues(t, identity.AuthenticatorAssuranceLevel1, original.InternalAvailableAAL.String) }) t.Run("case=should set AAL to 2 if password and TOTP is set", func(t *testing.T) { @@ -393,19 +430,19 @@ func TestManager(t *testing.T) { }, } require.NoError(t, reg.IdentityManager().Create(ctx, original)) - assert.EqualValues(t, identity.AuthenticatorAssuranceLevel1, original.AvailableAAL.String) + assert.EqualValues(t, identity.AuthenticatorAssuranceLevel1, original.InternalAvailableAAL.String) require.NoError(t, reg.IdentityManager().Update(ctx, original, identity.ManagerAllowWriteProtectedTraits)) - assert.EqualValues(t, identity.AuthenticatorAssuranceLevel1, original.AvailableAAL.String, "Updating without changes should not change AAL") + assert.EqualValues(t, identity.AuthenticatorAssuranceLevel1, original.InternalAvailableAAL.String, "Updating without changes should not change AAL") original.Credentials[identity.CredentialsTypeTOTP] = identity.Credentials{ Type: identity.CredentialsTypeTOTP, Identifiers: []string{email}, Config: sqlxx.JSONRawMessage(`{"totp_url":"otpauth://totp/test"}`), } require.NoError(t, reg.IdentityManager().Update(ctx, original, identity.ManagerAllowWriteProtectedTraits)) - assert.EqualValues(t, identity.AuthenticatorAssuranceLevel2, original.AvailableAAL.String) + assert.EqualValues(t, identity.AuthenticatorAssuranceLevel2, original.InternalAvailableAAL.String) }) - t.Run("case=should set AAL to 0 if only TOTP is set", func(t *testing.T) { + t.Run("case=should set AAL to 2 if only TOTP is set", func(t *testing.T) { email := uuid.Must(uuid.NewV4()).String() + "@ory.sh" original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) original.Traits = newTraits(email, "") @@ -418,8 +455,8 @@ func TestManager(t *testing.T) { }, } require.NoError(t, reg.IdentityManager().Update(ctx, original, identity.ManagerAllowWriteProtectedTraits)) - assert.True(t, original.AvailableAAL.Valid) - assert.EqualValues(t, identity.NoAuthenticatorAssuranceLevel, original.AvailableAAL.String) + assert.True(t, original.InternalAvailableAAL.Valid) + assert.EqualValues(t, identity.AuthenticatorAssuranceLevel2, original.InternalAvailableAAL.String) }) t.Run("case=should not update protected traits without option", func(t *testing.T) { @@ -618,6 +655,51 @@ func TestManager(t *testing.T) { }) }) + t.Run("method=RefreshAvailableAAL", func(t *testing.T) { + var cases []struct { + Credentials []identity.Credentials `json:"credentials"` + Description string `json:"description"` + Expected string `json:"expected"` + } + require.NoError(t, json.Unmarshal(refreshAALStubs, &cases)) + + for k, tc := range cases { + t.Run("case="+tc.Description, func(t *testing.T) { + email := x.NewUUID().String() + "@ory.sh" + id := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) + id.Traits = identity.Traits(`{"email":"` + email + `"}`) + require.NoError(t, reg.IdentityManager().Create(ctx, id)) + assert.EqualValues(t, identity.NoAuthenticatorAssuranceLevel, id.InternalAvailableAAL.String) + + for _, c := range tc.Credentials { + for k := range c.Identifiers { + switch c.Identifiers[k] { + case "{email}": + c.Identifiers[k] = email + case "{id}": + c.Identifiers[k] = id.ID.String() + } + } + id.SetCredentials(c.Type, c) + } + + // We use the privileged pool here because we don't want to refresh AAL here but in the code below. + require.NoError(t, reg.PrivilegedIdentityPool().UpdateIdentity(ctx, id)) + + expand := identity.ExpandNothing + if k%2 == 1 { // expand every other test case to test if RefreshAvailableAAL behaves correctly + expand = identity.ExpandCredentials + } + + actual, err := reg.IdentityPool().GetIdentity(ctx, id.ID, expand) + require.NoError(t, err) + require.NoError(t, reg.IdentityManager().RefreshAvailableAAL(ctx, actual)) + assert.NotEmpty(t, actual.Credentials) + assert.EqualValues(t, tc.Expected, actual.InternalAvailableAAL.String) + }) + } + }) + t.Run("method=ConflictingIdentity", func(t *testing.T) { ctx := ctx diff --git a/identity/pool.go b/identity/pool.go index 3ea7d4129f6f..8a94aad3e075 100644 --- a/identity/pool.go +++ b/identity/pool.go @@ -78,7 +78,13 @@ type ( // UpdateIdentity updates an identity including its confidential / privileged / protected data. UpdateIdentity(context.Context, *Identity) error - // GetIdentityConfidential returns the identity including it's raw credentials. This should only be used internally. + // UpdateIdentityColumns updates targeted columns of an identity. + UpdateIdentityColumns(ctx context.Context, i *Identity, columns ...string) error + + // GetIdentityConfidential returns the identity including it's raw credentials. + // + // This should only be used internally. Please be aware that this method uses HydrateIdentityAssociations + // internally, which must not be executed as part of a transaction. GetIdentityConfidential(context.Context, uuid.UUID) (*Identity, error) // ListVerifiableAddresses lists all tracked verifiable addresses, regardless of whether they are already verified @@ -89,6 +95,9 @@ type ( ListRecoveryAddresses(ctx context.Context, page, itemsPerPage int) ([]RecoveryAddress, error) // HydrateIdentityAssociations hydrates the associations of an identity. + // + // Please be aware that this method must not be called within a transaction if more than one element is expanded. + // It may error with "conn busy" otherwise. HydrateIdentityAssociations(ctx context.Context, i *Identity, expandables Expandables) error // InjectTraitsSchemaURL sets the identity's traits JSON schema URL from the schema's ID. diff --git a/identity/stub/aal.json b/identity/stub/aal.json new file mode 100644 index 000000000000..bb206c54b395 --- /dev/null +++ b/identity/stub/aal.json @@ -0,0 +1,76 @@ +[ + { + "description": "password is available aal1", + "expected": "aal1", + "credentials": [ + { + "type": "password", + "identifiers": [ + "{email}" + ], + "config": { + "hashed_password": "$2a$fake" + } + } + ] + }, + { + "description": "password without identifier is no credential and ergo aal0", + "expected": "aal0", + "credentials": [ + { + "type": "password", + "config": { + "hashed_password": "$2a$fake" + } + } + ] + }, + { + "description": "second factor totp returns available aal2 even if no password is set", + "expected": "aal2", + "credentials": [ + { + "type": "totp", + "config": { + "totp_url": "totp://" + } + } + ] + }, + { + "description": "second factor totp returns aal0 if totp credentials is not set up", + "expected": "aal0", + "credentials": [ + { + "type": "totp", + "identifiers": [ + "{email}" + ], + "config": {} + } + ] + }, + { + "description": "password and totp is also available aal2", + "expected": "aal1", + "credentials": [ + { + "type": "password", + "identifiers": [ + "{email}" + ], + "config": { + "hashed_password": "$2a$fake" + } + }, + { + "type": "totp", + "identifiers": [ + "{email}" + ], + "config": {} + } + ] + } +] diff --git a/identity/stub/extension/credentials/code-phone-email.schema.json b/identity/stub/extension/credentials/code-phone-email.schema.json new file mode 100644 index 000000000000..0b8166b9c337 --- /dev/null +++ b/identity/stub/extension/credentials/code-phone-email.schema.json @@ -0,0 +1,59 @@ +{ + "type": "object", + "properties": { + "email": { + "type": "string", + "format": "email", + "ory.sh/kratos": { + "credentials": { + "password": { + "identifier": true + }, + "webauthn": { + "identifier": true + }, + "code": { + "identifier": true, + "via": "email" + } + } + } + }, + "email2": { + "type": "string", + "format": "email", + "ory.sh/kratos": { + "credentials": { + "password": { + "identifier": true + }, + "webauthn": { + "identifier": true + }, + "code": { + "identifier": true, + "via": "email" + } + } + } + }, + "phone": { + "type": "string", + "format": "tel", + "ory.sh/kratos": { + "credentials": { + "password": { + "identifier": true + }, + "webauthn": { + "identifier": true + }, + "code": { + "identifier": true, + "via": "sms" + } + } + } + } + } +} diff --git a/identity/test/pool.go b/identity/test/pool.go index 458c057da916..f997eaddc630 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -316,14 +316,14 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, t.Run("case=create with null AAL", func(t *testing.T) { expected := passwordIdentity("", "id-"+uuid.Must(uuid.NewV4()).String()) - expected.AvailableAAL.Valid = false + expected.InternalAvailableAAL.Valid = false require.NoError(t, p.CreateIdentity(ctx, expected)) createdIDs = append(createdIDs, expected.ID) actual, err := p.GetIdentity(ctx, expected.ID, identity.ExpandDefault) require.NoError(t, err) - assert.False(t, actual.AvailableAAL.Valid) + assert.False(t, actual.InternalAvailableAAL.Valid) }) t.Run("suite=create multiple identities", func(t *testing.T) { @@ -549,6 +549,22 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, require.Contains(t, err.Error(), "malformed") }) + t.Run("case=update an identity column", func(t *testing.T) { + initial := oidcIdentity("", x.NewUUID().String()) + initial.InternalAvailableAAL = identity.NewNullableAuthenticatorAssuranceLevel(identity.NoAuthenticatorAssuranceLevel) + require.NoError(t, p.CreateIdentity(ctx, initial)) + createdIDs = append(createdIDs, initial.ID) + + initial.InternalAvailableAAL = identity.NewNullableAuthenticatorAssuranceLevel(identity.AuthenticatorAssuranceLevel1) + initial.State = identity.StateInactive + require.NoError(t, p.UpdateIdentityColumns(ctx, initial, "available_aal")) + + actual, err := p.GetIdentity(ctx, initial.ID, identity.ExpandDefault) + require.NoError(t, err) + assert.Equal(t, string(identity.AuthenticatorAssuranceLevel1), actual.InternalAvailableAAL.String) + assert.Equal(t, identity.StateActive, actual.State, "the state remains unchanged") + }) + t.Run("case=should fail to insert identity because credentials from traits exist", func(t *testing.T) { first := passwordIdentity("", "test-identity@ory.sh") first.Traits = identity.Traits(`{}`) diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/testhelpers/courier.go b/internal/testhelpers/courier.go index 796dac29989d..fcb77f47005d 100644 --- a/internal/testhelpers/courier.go +++ b/internal/testhelpers/courier.go @@ -31,7 +31,7 @@ func CourierExpectMessage(ctx context.Context, t *testing.T, reg interface { }) for _, m := range messages { - if strings.EqualFold(m.Recipient, recipient) && strings.EqualFold(m.Subject, subject) { + if strings.EqualFold(m.Recipient, recipient) && (strings.EqualFold(m.Subject, subject) || strings.Contains(m.Body, subject)) { return &m } } diff --git a/internal/testhelpers/handler_mock.go b/internal/testhelpers/handler_mock.go index bcc68a1e61c1..36a51edcc1eb 100644 --- a/internal/testhelpers/handler_mock.go +++ b/internal/testhelpers/handler_mock.go @@ -5,6 +5,7 @@ package testhelpers import ( "context" + "encoding/json" "io" "net/http" "net/http/cookiejar" @@ -26,6 +27,7 @@ import ( type mockDeps interface { identity.PrivilegedPoolProvider + identity.ManagementProvider session.ManagementProvider session.PersistenceProvider config.Provider @@ -34,15 +36,24 @@ type mockDeps interface { func MockSetSession(t *testing.T, reg mockDeps, conf *config.Config) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) - require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), i)) + i.NID = uuid.Must(uuid.NewV4()) + require.NoError(t, i.SetCredentialsWithConfig( + identity.CredentialsTypePassword, + identity.Credentials{ + Type: identity.CredentialsTypePassword, + Identifiers: []string{faker.Email()}, + }, + json.RawMessage(`{"hashed_password":"$"}`))) + require.NoError(t, reg.IdentityManager().Create(context.Background(), i)) MockSetSessionWithIdentity(t, reg, conf, i)(w, r, ps) } } -func MockSetSessionWithIdentity(t *testing.T, reg mockDeps, conf *config.Config, i *identity.Identity) httprouter.Handle { +func MockSetSessionWithIdentity(t *testing.T, reg mockDeps, _ *config.Config, i *identity.Identity) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - activeSession, _ := session.NewActiveSession(r, i, conf, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + activeSession, err := NewActiveSession(r, reg, i, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + require.NoError(t, err) if aal := r.URL.Query().Get("set_aal"); len(aal) > 0 { activeSession.AuthenticatorAssuranceLevel = identity.AuthenticatorAssuranceLevel(aal) } @@ -52,18 +63,6 @@ func MockSetSessionWithIdentity(t *testing.T, reg mockDeps, conf *config.Config, } } -func MockGetSession(t *testing.T, reg mockDeps) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - _, err := reg.SessionManager().FetchFromRequest(r.Context(), r) - if r.URL.Query().Get("has") == "yes" { - require.NoError(t, err) - } else { - require.Error(t, err) - } - w.WriteHeader(http.StatusNoContent) - } -} - func MockMakeAuthenticatedRequest(t *testing.T, reg mockDeps, conf *config.Config, router *httprouter.Router, req *http.Request) ([]byte, *http.Response) { return MockMakeAuthenticatedRequestWithClient(t, reg, conf, router, req, NewClientWithCookies(t)) } diff --git a/internal/testhelpers/identity.go b/internal/testhelpers/identity.go index 5c7cdd5be692..6bbc7c5da27b 100644 --- a/internal/testhelpers/identity.go +++ b/internal/testhelpers/identity.go @@ -19,7 +19,7 @@ func CreateSession(t *testing.T, reg driver.Registry) *session.Session { req := NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(req.Context(), i)) - sess, err := session.NewActiveSession(req, i, reg.Config(), time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + sess, err := NewActiveSession(req, reg, i, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) require.NoError(t, err) require.NoError(t, reg.SessionPersister().UpsertSession(req.Context(), sess)) return sess diff --git a/internal/testhelpers/selfservice.go b/internal/testhelpers/selfservice.go index 8c7b4c588d78..56903b04ac79 100644 --- a/internal/testhelpers/selfservice.go +++ b/internal/testhelpers/selfservice.go @@ -11,6 +11,8 @@ import ( "net/url" "testing" + "github.com/gofrs/uuid" + "github.com/go-faker/faker/v4" "github.com/gobuffalo/httptest" "github.com/stretchr/testify/assert" @@ -90,6 +92,7 @@ func SelfServiceHookFakeIdentity(t *testing.T) *identity.Identity { require.NoError(t, faker.FakeData(&i)) i.Traits = identity.Traits(`{}`) i.State = identity.StateActive + i.NID = uuid.Must(uuid.NewV4()) return &i } diff --git a/internal/testhelpers/session.go b/internal/testhelpers/session.go index 1d2cecc824db..c614f1afe36d 100644 --- a/internal/testhelpers/session.go +++ b/internal/testhelpers/session.go @@ -10,6 +10,8 @@ import ( "testing" "time" + confighelpers "github.com/ory/kratos/driver/config/testhelpers" + "github.com/ory/nosurf" "github.com/stretchr/testify/assert" @@ -46,7 +48,7 @@ func maybePersistSession(t *testing.T, ctx context.Context, reg *driver.Registry id, err := reg.PrivilegedIdentityPool().GetIdentityConfidential(ctx, sess.Identity.ID) if err != nil { require.NoError(t, sess.Identity.SetAvailableAAL(ctx, reg.IdentityManager())) - require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(ctx, sess.Identity)) + require.NoError(t, reg.IdentityManager().Create(ctx, sess.Identity)) id, err = reg.PrivilegedIdentityPool().GetIdentityConfidential(ctx, sess.Identity.ID) require.NoError(t, err) } @@ -145,10 +147,9 @@ func NewHTTPClientWithArbitrarySessionToken(t *testing.T, ctx context.Context, r } func NewHTTPClientWithArbitrarySessionTokenAndTraits(t *testing.T, ctx context.Context, reg *driver.RegistryDefault, traits identity.Traits) *http.Client { - req := NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) - s, err := session.NewActiveSession(req, - &identity.Identity{ID: x.NewUUID(), State: identity.StateActive, Traits: traits}, - NewSessionLifespanProvider(time.Hour), + req := NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil).WithContext(confighelpers.WithConfigValue(ctx, "session.lifespan", time.Hour)) + s, err := NewActiveSession(req, reg, + &identity.Identity{ID: x.NewUUID(), State: identity.StateActive, Traits: traits, NID: x.NewUUID(), SchemaID: "default"}, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1, @@ -160,9 +161,12 @@ func NewHTTPClientWithArbitrarySessionTokenAndTraits(t *testing.T, ctx context.C func NewHTTPClientWithArbitrarySessionCookie(t *testing.T, ctx context.Context, reg *driver.RegistryDefault) *http.Client { req := NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) - s, err := session.NewActiveSession(req, - &identity.Identity{ID: x.NewUUID(), State: identity.StateActive, Traits: []byte("{}")}, - NewSessionLifespanProvider(time.Hour), + req = req.WithContext(confighelpers.WithConfigValue(ctx, "session.lifespan", time.Hour)) + id := x.NewUUID() + s, err := NewActiveSession(req, reg, + &identity.Identity{ID: id, State: identity.StateActive, Traits: []byte("{}"), Credentials: map[identity.CredentialsType]identity.Credentials{ + identity.CredentialsTypePassword: {Type: "password", Identifiers: []string{id.String()}, Config: []byte(`{"hashed_password":"$2a$04$zvZz1zV"}`)}, + }}, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1, @@ -174,9 +178,13 @@ func NewHTTPClientWithArbitrarySessionCookie(t *testing.T, ctx context.Context, func NewNoRedirectHTTPClientWithArbitrarySessionCookie(t *testing.T, ctx context.Context, reg *driver.RegistryDefault) *http.Client { req := NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) - s, err := session.NewActiveSession(req, - &identity.Identity{ID: x.NewUUID(), State: identity.StateActive}, - NewSessionLifespanProvider(time.Hour), + req = req.WithContext(confighelpers.WithConfigValue(ctx, "session.lifespan", time.Hour)) + id := x.NewUUID() + s, err := NewActiveSession(req, reg, + &identity.Identity{ID: id, State: identity.StateActive, + Credentials: map[identity.CredentialsType]identity.Credentials{ + identity.CredentialsTypePassword: {Type: "password", Identifiers: []string{id.String()}, Config: []byte(`{"hashed_password":"$2a$04$zvZz1zV"}`)}, + }}, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1, @@ -188,9 +196,9 @@ func NewNoRedirectHTTPClientWithArbitrarySessionCookie(t *testing.T, ctx context func NewHTTPClientWithIdentitySessionCookie(t *testing.T, ctx context.Context, reg *driver.RegistryDefault, id *identity.Identity) *http.Client { req := NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) - s, err := session.NewActiveSession(req, + req = req.WithContext(confighelpers.WithConfigValue(ctx, "session.lifespan", time.Hour)) + s, err := NewActiveSession(req, reg, id, - NewSessionLifespanProvider(time.Hour), time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1, @@ -202,9 +210,8 @@ func NewHTTPClientWithIdentitySessionCookie(t *testing.T, ctx context.Context, r func NewHTTPClientWithIdentitySessionCookieLocalhost(t *testing.T, ctx context.Context, reg *driver.RegistryDefault, id *identity.Identity) *http.Client { req := NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) - s, err := session.NewActiveSession(req, + s, err := NewActiveSession(req, reg, id, - NewSessionLifespanProvider(time.Hour), time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1, @@ -216,9 +223,9 @@ func NewHTTPClientWithIdentitySessionCookieLocalhost(t *testing.T, ctx context.C func NewHTTPClientWithIdentitySessionToken(t *testing.T, ctx context.Context, reg *driver.RegistryDefault, id *identity.Identity) *http.Client { req := NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) - s, err := session.NewActiveSession(req, + req = req.WithContext(confighelpers.WithConfigValue(ctx, "session.lifespan", time.Hour)) + s, err := NewActiveSession(req, reg, id, - NewSessionLifespanProvider(time.Hour), time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1, diff --git a/internal/testhelpers/session_active.go b/internal/testhelpers/session_active.go new file mode 100644 index 000000000000..a245abce45e1 --- /dev/null +++ b/internal/testhelpers/session_active.go @@ -0,0 +1,23 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package testhelpers + +import ( + "net/http" + "time" + + "github.com/ory/kratos/identity" + "github.com/ory/kratos/session" +) + +func NewActiveSession(r *http.Request, reg interface { + session.ManagementProvider +}, i *identity.Identity, authenticatedAt time.Time, completedLoginFor identity.CredentialsType, completedLoginAAL identity.AuthenticatorAssuranceLevel) (*session.Session, error) { + s := session.NewInactiveSession() + s.CompletedLoginFor(completedLoginFor, completedLoginAAL) + if err := reg.SessionManager().ActivateSession(r, s, i, authenticatedAt); err != nil { + return nil, err + } + return s, nil +} diff --git a/persistence/reference.go b/persistence/reference.go index 56a7ca1712df..d3ceeb8d26b5 100644 --- a/persistence/reference.go +++ b/persistence/reference.go @@ -7,6 +7,8 @@ import ( "context" "time" + "github.com/ory/kratos/x" + "github.com/ory/kratos/selfservice/sessiontokenexchange" "github.com/ory/x/networkx" @@ -63,7 +65,7 @@ type Persister interface { Migrator() *popx.Migrator MigrationBox() *popx.MigrationBox GetConnection(ctx context.Context) *pop.Connection - Transaction(ctx context.Context, callback func(ctx context.Context, connection *pop.Connection) error) error + x.TransactionalPersister Networker } diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 984fb0199da2..49474436b334 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -949,6 +949,19 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. return is, nextPage, nil } +func (p *IdentityPersister) UpdateIdentityColumns(ctx context.Context, i *identity.Identity, columns ...string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateIdentity", + trace.WithAttributes( + attribute.Stringer("identity.id", i.ID), + attribute.Stringer("network.id", p.NetworkID(ctx)))) + defer otelx.End(span, &err) + + return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + _, err := tx.Where("id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).UpdateQuery(i, columns...) + return sqlcon.HandleError(err) + }) +} + func (p *IdentityPersister) UpdateIdentity(ctx context.Context, i *identity.Identity) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateIdentity", trace.WithAttributes( diff --git a/selfservice/flow/login/handler.go b/selfservice/flow/login/handler.go index 63c38731ef19..e252a23e7011 100644 --- a/selfservice/flow/login/handler.go +++ b/selfservice/flow/login/handler.go @@ -9,6 +9,8 @@ import ( "strconv" "time" + "github.com/ory/x/otelx" + "github.com/gofrs/uuid" "github.com/julienschmidt/httprouter" "github.com/pkg/errors" @@ -54,6 +56,7 @@ type ( x.WriterProvider x.CSRFTokenGeneratorProvider x.CSRFProvider + x.TracingProvider config.Provider ErrorHandlerProvider sessiontokenexchange.PersistenceProvider @@ -330,6 +333,9 @@ type createNativeLoginFlow struct { // Via should contain the identity's credential the code should be sent to. Only relevant in aal2 flows. // + // DEPRECATED: This field is deprecated. Please remove it from your requests. The user will now see a choice + // of MFA credentials to choose from to perform the second factor instead. + // // in: query Via string `json:"via"` } @@ -369,6 +375,11 @@ type createNativeLoginFlow struct { // 400: errorGeneric // default: errorGeneric func (h *Handler) createNativeLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + var err error + ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.createNativeLoginFlow") + r = r.WithContext(ctx) + defer otelx.End(span, &err) + f, _, err := h.NewLoginFlow(w, r, flow.TypeAPI) if err != nil { h.d.Writer().WriteError(w, r, err) @@ -437,6 +448,9 @@ type createBrowserLoginFlow struct { // Via should contain the identity's credential the code should be sent to. Only relevant in aal2 flows. // + // DEPRECATED: This field is deprecated. Please remove it from your requests. The user will now see a choice + // of MFA credentials to choose from to perform the second factor instead. + // // in: query Via string `json:"via"` } @@ -480,6 +494,11 @@ type createBrowserLoginFlow struct { // 400: errorGeneric // default: errorGeneric func (h *Handler) createBrowserLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + var err error + ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.createBrowserLoginFlow") + r = r.WithContext(ctx) + defer otelx.End(span, &err) + var ( hydraLoginRequest *hydraclientgo.OAuth2LoginRequest hydraLoginChallenge sqlxx.NullString @@ -488,13 +507,13 @@ func (h *Handler) createBrowserLoginFlow(w http.ResponseWriter, r *http.Request, var err error hydraLoginChallenge, err = hydra.GetLoginChallengeID(h.d.Config(), r) if err != nil { - h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) + h.d.SelfServiceErrorManager().Forward(ctx, w, r, err) return } - hydraLoginRequest, err = h.d.Hydra().GetLoginRequest(r.Context(), string(hydraLoginChallenge)) + hydraLoginRequest, err = h.d.Hydra().GetLoginRequest(ctx, string(hydraLoginChallenge)) if err != nil { - h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) + h.d.SelfServiceErrorManager().Forward(ctx, w, r, err) return } @@ -510,7 +529,7 @@ func (h *Handler) createBrowserLoginFlow(w http.ResponseWriter, r *http.Request, // different flows, such as login to registration and login to recovery. // After completing a complex flow, such as recovery, we want the user // to be redirected back to the original OAuth2 login flow. - if hydraLoginRequest.RequestUrl != "" && h.d.Config().OAuth2ProviderOverrideReturnTo(r.Context()) { + if hydraLoginRequest.RequestUrl != "" && h.d.Config().OAuth2ProviderOverrideReturnTo(ctx) { // replace the return_to query parameter q := r.URL.Query() q.Set("return_to", hydraLoginRequest.RequestUrl) @@ -522,11 +541,11 @@ func (h *Handler) createBrowserLoginFlow(w http.ResponseWriter, r *http.Request, if errors.Is(err, ErrAlreadyLoggedIn) { if hydraLoginRequest != nil { if !hydraLoginRequest.GetSkip() { - h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrInternalServerError.WithReason("ErrAlreadyLoggedIn indicated we can skip login, but Hydra asked us to refresh"))) + h.d.SelfServiceErrorManager().Forward(ctx, w, r, errors.WithStack(herodot.ErrInternalServerError.WithReason("ErrAlreadyLoggedIn indicated we can skip login, but Hydra asked us to refresh"))) return } - rt, err := h.d.Hydra().AcceptLoginRequest(r.Context(), + rt, err := h.d.Hydra().AcceptLoginRequest(ctx, hydra.AcceptLoginRequestParams{ LoginChallenge: string(hydraLoginChallenge), IdentityID: sess.IdentityID.String(), @@ -534,37 +553,37 @@ func (h *Handler) createBrowserLoginFlow(w http.ResponseWriter, r *http.Request, AuthenticationMethods: sess.AMR, }) if err != nil { - h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) + h.d.SelfServiceErrorManager().Forward(ctx, w, r, err) return } returnTo, err := url.Parse(rt) if err != nil { - h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to parse URL: %s", rt))) + h.d.SelfServiceErrorManager().Forward(ctx, w, r, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to parse URL: %s", rt))) return } x.AcceptToRedirectOrJSON(w, r, h.d.Writer(), err, returnTo.String()) return } - returnTo, redirErr := x.SecureRedirectTo(r, h.d.Config().SelfServiceBrowserDefaultReturnTo(r.Context()), - x.SecureRedirectAllowSelfServiceURLs(h.d.Config().SelfPublicURL(r.Context())), - x.SecureRedirectAllowURLs(h.d.Config().SelfServiceBrowserAllowedReturnToDomains(r.Context())), + returnTo, redirErr := x.SecureRedirectTo(r, h.d.Config().SelfServiceBrowserDefaultReturnTo(ctx), + x.SecureRedirectAllowSelfServiceURLs(h.d.Config().SelfPublicURL(ctx)), + x.SecureRedirectAllowURLs(h.d.Config().SelfServiceBrowserAllowedReturnToDomains(ctx)), ) if redirErr != nil { - h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, redirErr) + h.d.SelfServiceErrorManager().Forward(ctx, w, r, redirErr) return } x.AcceptToRedirectOrJSON(w, r, h.d.Writer(), err, returnTo.String()) return } else if err != nil { - h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) + h.d.SelfServiceErrorManager().Forward(ctx, w, r, err) return } a.HydraLoginRequest = hydraLoginRequest - x.AcceptToRedirectOrJSON(w, r, h.d.Writer(), a, a.AppendTo(h.d.Config().SelfServiceFlowLoginUI(r.Context())).String()) + x.AcceptToRedirectOrJSON(w, r, h.d.Writer(), a, a.AppendTo(h.d.Config().SelfServiceFlowLoginUI(ctx)).String()) } // Get Login Flow Parameters @@ -633,7 +652,12 @@ type getLoginFlow struct { // 410: errorGeneric // default: errorGeneric func (h *Handler) getLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - ar, err := h.d.LoginFlowPersister().GetLoginFlow(r.Context(), x.ParseUUID(r.URL.Query().Get("id"))) + var err error + ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.getLoginFlow") + r = r.WithContext(ctx) + defer otelx.End(span, &err) + + ar, err := h.d.LoginFlowPersister().GetLoginFlow(ctx, x.ParseUUID(r.URL.Query().Get("id"))) if err != nil { h.d.Writer().WriteError(w, r, err) return @@ -649,7 +673,7 @@ func (h *Handler) getLoginFlow(w http.ResponseWriter, r *http.Request, _ httprou if ar.ExpiresAt.Before(time.Now()) { if ar.Type == flow.TypeBrowser { - redirectURL := flow.GetFlowExpiredRedirectURL(r.Context(), h.d.Config(), RouteInitBrowserFlow, ar.ReturnTo) + redirectURL := flow.GetFlowExpiredRedirectURL(ctx, h.d.Config(), RouteInitBrowserFlow, ar.ReturnTo) h.d.Writer().WriteError(w, r, errors.WithStack(x.ErrGone.WithID(text.ErrIDSelfServiceFlowExpired). WithReason("The login flow has expired. Redirect the user to the login flow init endpoint to initialize a new login flow."). @@ -659,16 +683,16 @@ func (h *Handler) getLoginFlow(w http.ResponseWriter, r *http.Request, _ httprou } h.d.Writer().WriteError(w, r, errors.WithStack(x.ErrGone.WithID(text.ErrIDSelfServiceFlowExpired). WithReason("The login flow has expired. Call the login flow init API endpoint to initialize a new login flow."). - WithDetail("api", urlx.AppendPaths(h.d.Config().SelfPublicURL(r.Context()), RouteInitAPIFlow).String()))) + WithDetail("api", urlx.AppendPaths(h.d.Config().SelfPublicURL(ctx), RouteInitAPIFlow).String()))) return } if ar.OAuth2LoginChallenge != "" { - hlr, err := h.d.Hydra().GetLoginRequest(r.Context(), string(ar.OAuth2LoginChallenge)) + hlr, err := h.d.Hydra().GetLoginRequest(ctx, string(ar.OAuth2LoginChallenge)) if err != nil { // We don't redirect back to the third party on errors because Hydra doesn't // give us the 3rd party return_uri when it redirects to the login UI. - h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err) + h.d.SelfServiceErrorManager().Forward(ctx, w, r, err) return } ar.HydraLoginRequest = hlr @@ -770,19 +794,24 @@ type updateLoginFlowBody struct{} // 422: errorBrowserLocationChangeRequired // default: errorGeneric func (h *Handler) updateLoginFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + var err error + ctx, span := h.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.flow.login.updateLoginFlow") + r = r.WithContext(ctx) + defer otelx.End(span, &err) + rid, err := flow.GetFlowID(r) if err != nil { h.d.LoginFlowErrorHandler().WriteFlowError(w, r, nil, node.DefaultGroup, err) return } - f, err := h.d.LoginFlowPersister().GetLoginFlow(r.Context(), rid) + f, err := h.d.LoginFlowPersister().GetLoginFlow(ctx, rid) if err != nil { h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, err) return } - sess, err := h.d.SessionManager().FetchFromRequest(r.Context(), r) + sess, err := h.d.SessionManager().FetchFromRequest(ctx, r) if err == nil { if f.Refresh { // If we want to refresh, continue the login @@ -800,7 +829,7 @@ func (h *Handler) updateLoginFlow(w http.ResponseWriter, r *http.Request, _ http return } - http.Redirect(w, r, h.d.Config().SelfServiceBrowserDefaultReturnTo(r.Context()).String(), http.StatusSeeOther) + http.Redirect(w, r, h.d.Config().SelfServiceBrowserDefaultReturnTo(ctx).String(), http.StatusSeeOther) return } else if e := new(session.ErrNoActiveSessionFound); errors.As(err, &e) { // Only failure scenario here is if we try to upgrade the session to a higher AAL without actually @@ -842,7 +871,7 @@ continueLogin: sess = session.NewInactiveSession() } - method := ss.CompletedAuthenticationMethod(r.Context(), sess.AMR) + method := ss.CompletedAuthenticationMethod(ctx) sess.CompletedLoginForMethod(method) i = interim break diff --git a/selfservice/flow/login/handler_test.go b/selfservice/flow/login/handler_test.go index c8d5ac97772e..504db9436100 100644 --- a/selfservice/flow/login/handler_test.go +++ b/selfservice/flow/login/handler_test.go @@ -21,12 +21,11 @@ import ( "github.com/ory/x/sqlxx" + stdtotp "github.com/pquerna/otp/totp" + "github.com/ory/kratos/hydra" "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/strategy/totp" - "github.com/ory/kratos/session" - - stdtotp "github.com/pquerna/otp/totp" "github.com/ory/kratos/ui/container" @@ -458,7 +457,7 @@ func TestFlowLifecycle(t *testing.T) { require.NoError(t, reg.IdentityManager().Update(context.Background(), id, identity.ManagerAllowWriteProtectedTraits)) h := func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - sess, err := session.NewActiveSession(r, id, reg.Config(), time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + sess, err := testhelpers.NewActiveSession(r, reg, id, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) require.NoError(t, err) sess.AuthenticatorAssuranceLevel = identity.AuthenticatorAssuranceLevel1 require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), sess)) diff --git a/selfservice/flow/login/hook.go b/selfservice/flow/login/hook.go index 73a484c22e50..e5ba20428887 100644 --- a/selfservice/flow/login/hook.go +++ b/selfservice/flow/login/hook.go @@ -49,6 +49,7 @@ type ( config.Provider hydra.Provider identity.PrivilegedPoolProvider + identity.ManagementProvider session.ManagementProvider session.PersistenceProvider x.CSRFTokenGeneratorProvider @@ -147,7 +148,7 @@ func (e *HookExecutor) PostLoginHook( return err } - if err := s.Activate(r, i, e.d.Config(), time.Now().UTC()); err != nil { + if err := e.d.SessionManager().ActivateSession(r, s, i, time.Now().UTC()); err != nil { return err } @@ -385,7 +386,7 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S return err } - method := strategy.CompletedAuthenticationMethod(ctx, sess.AMR) + method := strategy.CompletedAuthenticationMethod(ctx) sess.CompletedLoginForMethod(method) return nil diff --git a/selfservice/flow/login/strategy.go b/selfservice/flow/login/strategy.go index fec71d3beb1d..8ea671343c76 100644 --- a/selfservice/flow/login/strategy.go +++ b/selfservice/flow/login/strategy.go @@ -21,7 +21,7 @@ type Strategy interface { NodeGroup() node.UiNodeGroup RegisterLoginRoutes(*x.RouterPublic) Login(w http.ResponseWriter, r *http.Request, f *Flow, sess *session.Session) (i *identity.Identity, err error) - CompletedAuthenticationMethod(ctx context.Context, methods session.AuthenticationMethods) session.AuthenticationMethod + CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod } type Strategies []Strategy diff --git a/selfservice/flow/recovery/hook.go b/selfservice/flow/recovery/hook.go index 15c3078e918c..212eb061b7f4 100644 --- a/selfservice/flow/recovery/hook.go +++ b/selfservice/flow/recovery/hook.go @@ -81,10 +81,14 @@ func NewHookExecutor(d executorDependencies) *HookExecutor { } func (e *HookExecutor) PostRecoveryHook(w http.ResponseWriter, r *http.Request, a *Flow, s *session.Session) error { - e.d.Logger(). - WithRequest(r). - WithField("identity_id", s.Identity.ID). - Debug("Running ExecutePostRecoveryHooks.") + logger := e.d.Logger(). + WithRequest(r) + + if s.Identity != nil { + logger = logger.WithField("identity_id", s.Identity.ID) + } + + logger.Debug("Running ExecutePostRecoveryHooks.") for k, executor := range e.d.PostRecoveryHooks(r.Context()) { if err := executor.ExecutePostRecoveryHook(w, r, a, s); err != nil { var traits identity.Traits @@ -94,20 +98,16 @@ func (e *HookExecutor) PostRecoveryHook(w http.ResponseWriter, r *http.Request, return flow.HandleHookError(w, r, a, traits, node.LinkGroup, err, e.d, e.d) } - e.d.Logger().WithRequest(r). + logger. WithField("executor", fmt.Sprintf("%T", executor)). WithField("executor_position", k). WithField("executors", PostHookRecoveryExecutorNames(e.d.PostRecoveryHooks(r.Context()))). - WithField("identity_id", s.Identity.ID). Debug("ExecutePostRecoveryHook completed successfully.") } trace.SpanFromContext(r.Context()).AddEvent(events.NewRecoverySucceeded(r.Context(), s.Identity.ID, string(a.Type), a.Active.String())) - e.d.Logger(). - WithRequest(r). - WithField("identity_id", s.Identity.ID). - Debug("Post recovery execution hooks completed successfully.") + logger.Debug("Post recovery execution hooks completed successfully.") return nil } diff --git a/selfservice/flow/recovery/hook_test.go b/selfservice/flow/recovery/hook_test.go index deb4b0426363..ce4ccf6deb76 100644 --- a/selfservice/flow/recovery/hook_test.go +++ b/selfservice/flow/recovery/hook_test.go @@ -10,8 +10,6 @@ import ( "testing" "time" - "github.com/ory/kratos/session" - "github.com/ory/kratos/selfservice/flow/recovery" "github.com/ory/kratos/selfservice/strategy/code" @@ -31,6 +29,7 @@ import ( func TestRecoveryExecutor(t *testing.T) { ctx := context.Background() conf, reg := internal.NewFastRegistryWithMocks(t) + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json") s := code.NewStrategy(reg) newServer := func(t *testing.T, i *identity.Identity, ft flow.Type) *httptest.Server { @@ -46,13 +45,14 @@ func TestRecoveryExecutor(t *testing.T) { router.GET("/recovery/post", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { a, err := recovery.NewFlow(conf, time.Minute, x.FakeCSRFToken, r, s, ft) require.NoError(t, err) - s, _ := session.NewActiveSession(r, + s, err := testhelpers.NewActiveSession(r, + reg, i, - conf, time.Now().UTC(), identity.CredentialsTypeRecoveryLink, identity.AuthenticatorAssuranceLevel1, ) + require.NoError(t, err) a.RequestURL = x.RequestURL(r).String() if testhelpers.SelfServiceHookErrorHandler(t, w, r, recovery.ErrHookAbortFlow, reg.RecoveryExecutor().PostRecoveryHook(w, r, a, s)) { _, _ = w.Write([]byte("ok")) diff --git a/selfservice/flow/registration/hook.go b/selfservice/flow/registration/hook.go index e44be2487bbb..c1c7b7ed4b2c 100644 --- a/selfservice/flow/registration/hook.go +++ b/selfservice/flow/registration/hook.go @@ -214,7 +214,7 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque s.CompletedLoginForWithProvider(ct, identity.AuthenticatorAssuranceLevel1, provider, httprouter.ParamsFromContext(r.Context()).ByName("organization")) - if err := s.Activate(r, i, c, time.Now().UTC()); err != nil { + if err := e.d.SessionManager().ActivateSession(r, s, i, time.Now().UTC()); err != nil { return err } diff --git a/selfservice/flow/settings/error_test.go b/selfservice/flow/settings/error_test.go index 1a2c32e9c5c0..41b0c1e50104 100644 --- a/selfservice/flow/settings/error_test.go +++ b/selfservice/flow/settings/error_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + confighelpers "github.com/ory/kratos/driver/config/testhelpers" + "github.com/pkg/errors" "github.com/gofrs/uuid" @@ -149,9 +151,10 @@ func TestHandleError(t *testing.T) { t.Cleanup(reset) req := httptest.NewRequest("GET", "/sessions/whoami", nil) + req.WithContext(confighelpers.WithConfigValue(ctx, config.ViperKeySessionLifespan, time.Hour)) // This needs an authenticated client in order to call the RouteGetFlow endpoint - s, err := session.NewActiveSession(req, &id, testhelpers.NewSessionLifespanProvider(time.Hour), time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + s, err := testhelpers.NewActiveSession(req, reg, &id, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) require.NoError(t, err) c := testhelpers.NewHTTPClientWithSessionToken(t, ctx, reg, s) diff --git a/selfservice/flow/settings/hook_test.go b/selfservice/flow/settings/hook_test.go index 5253bb2886f2..242eacf0e8da 100644 --- a/selfservice/flow/settings/hook_test.go +++ b/selfservice/flow/settings/hook_test.go @@ -24,7 +24,6 @@ import ( "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/hook" - "github.com/ory/kratos/session" "github.com/ory/kratos/x" ) @@ -54,7 +53,7 @@ func TestSettingsExecutor(t *testing.T) { if i == nil { i = testhelpers.SelfServiceHookCreateFakeIdentity(t, reg) } - sess, _ := session.NewActiveSession(r, i, conf, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + sess, _ := testhelpers.NewActiveSession(r, reg, i, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) f, err := settings.NewFlow(conf, time.Minute, r, sess.Identity, ft) require.NoError(t, err) @@ -67,7 +66,7 @@ func TestSettingsExecutor(t *testing.T) { if i == nil { i = testhelpers.SelfServiceHookCreateFakeIdentity(t, reg) } - sess, _ := session.NewActiveSession(r, i, conf, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + sess, _ := testhelpers.NewActiveSession(r, reg, i, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) a, err := settings.NewFlow(conf, time.Minute, r, sess.Identity, ft) require.NoError(t, err) diff --git a/selfservice/flowhelpers/login_test.go b/selfservice/flowhelpers/login_test.go index 3506c73c8513..0a400ff7d4b7 100644 --- a/selfservice/flowhelpers/login_test.go +++ b/selfservice/flowhelpers/login_test.go @@ -17,7 +17,6 @@ import ( "github.com/ory/kratos/internal/testhelpers" "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flowhelpers" - "github.com/ory/kratos/session" ) func TestGuessForcedLoginIdentifier(t *testing.T) { @@ -34,9 +33,9 @@ func TestGuessForcedLoginIdentifier(t *testing.T) { req := httptest.NewRequest("GET", "/sessions/whoami", nil) - sess, err := session.NewActiveSession(req, i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + sess, err := testhelpers.NewActiveSession(req, reg, i, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) require.NoError(t, err) - reg.SessionPersister().UpsertSession(context.Background(), sess) + require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), sess)) r := httptest.NewRequest("GET", "/login", nil) r.Header.Set("Authorization", "Bearer "+sess.Token) diff --git a/selfservice/hook/code_address_verifier_test.go b/selfservice/hook/code_address_verifier_test.go index 579432e60556..000002cdac36 100644 --- a/selfservice/hook/code_address_verifier_test.go +++ b/selfservice/hook/code_address_verifier_test.go @@ -40,7 +40,7 @@ func TestCodeAddressVerifier(t *testing.T) { _, err := reg.RegistrationCodePersister().CreateRegistrationCode(ctx, &code.CreateRegistrationCodeParams{ Address: address, - AddressType: identity.CodeAddressTypeEmail, + AddressType: identity.AddressTypeEmail, RawCode: rawCode, ExpiresIn: time.Hour, FlowID: rf.ID, diff --git a/selfservice/strategy/code/.schema/login.schema.json b/selfservice/strategy/code/.schema/login.schema.json index 1bcc36b12c88..9a8a0aa5cf25 100644 --- a/selfservice/strategy/code/.schema/login.schema.json +++ b/selfservice/strategy/code/.schema/login.schema.json @@ -15,6 +15,9 @@ "identifier": { "type": "string" }, + "address": { + "type": "string" + }, "resend": { "type": "string", "enum": [ diff --git a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodIdentifierFirstCredentials-case=WithIdentifier-case=code_is_used_for_passwordless_login.json b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodIdentifierFirstCredentials-case=WithIdentifier-case=code_is_used_for_passwordless_login.json deleted file mode 100644 index 66b84bf1a436..000000000000 --- a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodIdentifierFirstCredentials-case=WithIdentifier-case=code_is_used_for_passwordless_login.json +++ /dev/null @@ -1,21 +0,0 @@ -[ - { - "type": "input", - "group": "code", - "attributes": { - "name": "method", - "type": "submit", - "value": "code", - "disabled": false, - "node_type": "input" - }, - "messages": [], - "meta": { - "label": { - "id": 1010015, - "text": "Send sign in code", - "type": "info" - } - } - } -] diff --git a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor#01-case=code_is_used_for_2fa_and_request_is_2fa.json b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor#01-case=code_is_used_for_2fa_and_request_is_2fa.json deleted file mode 100644 index 60b142ed4181..000000000000 --- a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor#01-case=code_is_used_for_2fa_and_request_is_2fa.json +++ /dev/null @@ -1,66 +0,0 @@ -[ - { - "type": "input", - "group": "default", - "attributes": { - "name": "identifier", - "type": "text", - "value": "", - "required": true, - "disabled": false, - "node_type": "input" - }, - "messages": [ - { - "id": 1010020, - "text": "We will send a code to fo****@ory.sh. To verify that this is your address please enter it here.", - "type": "info", - "context": { - "masked_to": "fo****@ory.sh" - } - } - ], - "meta": { - "label": { - "id": 1070002, - "text": "", - "type": "info", - "context": { - "title": "" - } - } - } - }, - { - "type": "input", - "group": "code", - "attributes": { - "name": "method", - "type": "submit", - "value": "code", - "disabled": false, - "node_type": "input" - }, - "messages": [], - "meta": { - "label": { - "id": 1010019, - "text": "Continue with code", - "type": "info" - } - } - }, - { - "type": "input", - "group": "default", - "attributes": { - "name": "csrf_token", - "type": "hidden", - "required": true, - "disabled": false, - "node_type": "input" - }, - "messages": [], - "meta": {} - } -] diff --git a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-case=code_is_used_for_2fa.json b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-case=code_is_used_for_2fa.json deleted file mode 100644 index 60b142ed4181..000000000000 --- a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-case=code_is_used_for_2fa.json +++ /dev/null @@ -1,66 +0,0 @@ -[ - { - "type": "input", - "group": "default", - "attributes": { - "name": "identifier", - "type": "text", - "value": "", - "required": true, - "disabled": false, - "node_type": "input" - }, - "messages": [ - { - "id": 1010020, - "text": "We will send a code to fo****@ory.sh. To verify that this is your address please enter it here.", - "type": "info", - "context": { - "masked_to": "fo****@ory.sh" - } - } - ], - "meta": { - "label": { - "id": 1070002, - "text": "", - "type": "info", - "context": { - "title": "" - } - } - } - }, - { - "type": "input", - "group": "code", - "attributes": { - "name": "method", - "type": "submit", - "value": "code", - "disabled": false, - "node_type": "input" - }, - "messages": [], - "meta": { - "label": { - "id": 1010019, - "text": "Continue with code", - "type": "info" - } - } - }, - { - "type": "input", - "group": "default", - "attributes": { - "name": "csrf_token", - "type": "hidden", - "required": true, - "disabled": false, - "node_type": "input" - }, - "messages": [], - "meta": {} - } -] diff --git a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-case=code_is_used_for_2fa_and_request_is_2fa.json b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-case=code_is_used_for_2fa_and_request_is_2fa.json new file mode 100644 index 000000000000..364b8abc331c --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-case=code_is_used_for_2fa_and_request_is_2fa.json @@ -0,0 +1,15 @@ +[ + { + "type": "input", + "group": "default", + "attributes": { + "name": "csrf_token", + "type": "hidden", + "required": true, + "disabled": false, + "node_type": "input" + }, + "messages": [], + "meta": {} + } +] diff --git a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodIdentifierFirstCredentials-case=WithIdentifier-case=code_is_used_for_2fa.json b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-case=code_is_used_for_passwordless_login_and_request_is_2fa.json similarity index 100% rename from selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodIdentifierFirstCredentials-case=WithIdentifier-case=code_is_used_for_2fa.json rename to selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-case=code_is_used_for_passwordless_login_and_request_is_2fa.json diff --git a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-using_via-case=code_is_used_for_2fa.json b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-using_via-case=code_is_used_for_2fa.json new file mode 100644 index 000000000000..0680486de62c --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-using_via-case=code_is_used_for_2fa.json @@ -0,0 +1,38 @@ +[ + { + "type": "input", + "group": "code", + "attributes": { + "name": "address", + "type": "submit", + "value": "populateloginmethodsecondfactor-code-mfa-via-2fa@ory.sh", + "disabled": false, + "node_type": "input" + }, + "messages": [], + "meta": { + "label": { + "id": 1010023, + "text": "Send code to populateloginmethodsecondfactor-code-mfa-via-2fa@ory.sh", + "type": "info", + "context": { + "address": "populateloginmethodsecondfactor-code-mfa-via-2fa@ory.sh", + "channel": "email" + } + } + } + }, + { + "type": "input", + "group": "default", + "attributes": { + "name": "csrf_token", + "type": "hidden", + "required": true, + "disabled": false, + "node_type": "input" + }, + "messages": [], + "meta": {} + } +] diff --git a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor#01-case=code_is_used_for_passwordless_login_and_request_is_2fa.json b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-using_via-case=code_is_used_for_passwordless_login.json similarity index 100% rename from selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor#01-case=code_is_used_for_passwordless_login_and_request_is_2fa.json rename to selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-using_via-case=code_is_used_for_passwordless_login.json diff --git a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-without_via-case=code_is_used_for_2fa.json b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-without_via-case=code_is_used_for_2fa.json new file mode 100644 index 000000000000..d050a31ef32d --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-without_via-case=code_is_used_for_2fa.json @@ -0,0 +1,84 @@ +[ + { + "type": "input", + "group": "code", + "attributes": { + "name": "address", + "type": "submit", + "value": "+4917655138291", + "disabled": false, + "node_type": "input" + }, + "messages": [], + "meta": { + "label": { + "id": 1010023, + "text": "Send code to +4917655138291", + "type": "info", + "context": { + "address": "+4917655138291", + "channel": "sms" + } + } + } + }, + { + "type": "input", + "group": "code", + "attributes": { + "name": "address", + "type": "submit", + "value": "populateloginmethodsecondfactor-no-via-2fa-0@ory.sh", + "disabled": false, + "node_type": "input" + }, + "messages": [], + "meta": { + "label": { + "id": 1010023, + "text": "Send code to populateloginmethodsecondfactor-no-via-2fa-0@ory.sh", + "type": "info", + "context": { + "address": "populateloginmethodsecondfactor-no-via-2fa-0@ory.sh", + "channel": "email" + } + } + } + }, + { + "type": "input", + "group": "code", + "attributes": { + "name": "address", + "type": "submit", + "value": "populateloginmethodsecondfactor-no-via-2fa-1@ory.sh", + "disabled": false, + "node_type": "input" + }, + "messages": [], + "meta": { + "label": { + "id": 1010023, + "text": "Send code to populateloginmethodsecondfactor-no-via-2fa-1@ory.sh", + "type": "info", + "context": { + "address": "populateloginmethodsecondfactor-no-via-2fa-1@ory.sh", + "channel": "email" + } + } + } + }, + { + "type": "input", + "group": "default", + "attributes": { + "name": "csrf_token", + "type": "hidden", + "required": true, + "disabled": false, + "node_type": "input" + }, + "messages": [], + "meta": {} + } +] diff --git a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-case=code_is_used_for_passwordless_login.json b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-without_via-case=code_is_used_for_passwordless_login.json similarity index 100% rename from selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-case=code_is_used_for_passwordless_login.json rename to selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactor-without_via-case=code_is_used_for_passwordless_login.json diff --git a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactorRefresh-case=code_is_used_for_2fa_and_request_is_2fa_with_refresh.json b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactorRefresh-case=code_is_used_for_2fa_and_request_is_2fa_with_refresh.json index 60b142ed4181..364b8abc331c 100644 --- a/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactorRefresh-case=code_is_used_for_2fa_and_request_is_2fa_with_refresh.json +++ b/selfservice/strategy/code/.snapshots/TestFormHydration-method=PopulateLoginMethodSecondFactorRefresh-case=code_is_used_for_2fa_and_request_is_2fa_with_refresh.json @@ -1,55 +1,4 @@ [ - { - "type": "input", - "group": "default", - "attributes": { - "name": "identifier", - "type": "text", - "value": "", - "required": true, - "disabled": false, - "node_type": "input" - }, - "messages": [ - { - "id": 1010020, - "text": "We will send a code to fo****@ory.sh. To verify that this is your address please enter it here.", - "type": "info", - "context": { - "masked_to": "fo****@ory.sh" - } - } - ], - "meta": { - "label": { - "id": 1070002, - "text": "", - "type": "info", - "context": { - "title": "" - } - } - } - }, - { - "type": "input", - "group": "code", - "attributes": { - "name": "method", - "type": "submit", - "value": "code", - "disabled": false, - "node_type": "input" - }, - "messages": [], - "meta": { - "label": { - "id": 1010019, - "text": "Continue with code", - "type": "info" - } - } - }, { "type": "input", "group": "default", diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=verify_initial_payload.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=verify_initial_payload.json index 14efb22f25a3..612c5c980dc2 100644 --- a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=verify_initial_payload.json +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=verify_initial_payload.json @@ -30,49 +30,21 @@ { "attributes": { "disabled": false, - "name": "identifier", - "node_type": "input", - "required": true, - "type": "text", - "value": "" - }, - "group": "default", - "messages": [ - { - "context": { - "masked_to": "fi****@ory.sh" - }, - "id": 1010020, - "text": "We will send a code to fi****@ory.sh. To verify that this is your address please enter it here.", - "type": "info" - } - ], - "meta": { - "label": { - "context": { - "title": "Email" - }, - "id": 1070002, - "text": "Email", - "type": "info" - } - }, - "type": "input" - }, - { - "attributes": { - "disabled": false, - "name": "method", + "name": "address", "node_type": "input", "type": "submit", - "value": "code" + "value": "fixed_mfa_test_browser@ory.sh" }, "group": "code", "messages": [], "meta": { "label": { - "id": 1010019, - "text": "Continue with code", + "context": { + "address": "fixed_mfa_test_browser@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to fixed_mfa_test_browser@ory.sh", "type": "info" } }, diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-email.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-email.json new file mode 100644 index 000000000000..1d8d8153054e --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-email.json @@ -0,0 +1,71 @@ +[ + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "+4917613213110" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "+4917613213110", + "channel": "sms" + }, + "id": 1010023, + "text": "Send code to +4917613213110", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-1browser@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-1browser@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-1browser@ory.sh", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-2browser@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-2browser@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-2browser@ory.sh", + "type": "info" + } + }, + "type": "input" + } +] diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-phone.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-phone.json new file mode 100644 index 000000000000..1d8d8153054e --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-phone.json @@ -0,0 +1,71 @@ +[ + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "+4917613213110" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "+4917613213110", + "channel": "sms" + }, + "id": 1010023, + "text": "Send code to +4917613213110", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-1browser@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-1browser@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-1browser@ory.sh", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-2browser@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-2browser@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-2browser@ory.sh", + "type": "info" + } + }, + "type": "input" + } +] diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=identifier-email.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=identifier-email.json new file mode 100644 index 000000000000..1d8d8153054e --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Browser_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=identifier-email.json @@ -0,0 +1,71 @@ +[ + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "+4917613213110" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "+4917613213110", + "channel": "sms" + }, + "id": 1010023, + "text": "Send code to +4917613213110", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-1browser@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-1browser@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-1browser@ory.sh", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-2browser@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-2browser@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-2browser@ory.sh", + "type": "info" + } + }, + "type": "input" + } +] diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=verify_initial_payload.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=verify_initial_payload.json index 49f7d3a37cc1..b2dfa774866c 100644 --- a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=verify_initial_payload.json +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=verify_initial_payload.json @@ -17,48 +17,20 @@ { "attributes": { "disabled": false, - "name": "identifier", + "name": "address", "node_type": "input", - "required": true, - "type": "text" - }, - "group": "default", - "messages": [ - { - "context": { - "masked_to": "fi****@ory.sh" - }, - "id": 1010020, - "text": "We will send a code to fi****@ory.sh. To verify that this is your address please enter it here.", - "type": "info" - } - ], - "meta": { - "label": { - "context": { - "title": "Email" - }, - "id": 1070002, - "text": "Email", - "type": "info" - } - }, - "type": "input" - }, - { - "attributes": { - "disabled": false, - "name": "method", - "node_type": "input", - "type": "submit", - "value": "code" + "type": "submit" }, "group": "code", "messages": [], "meta": { "label": { - "id": 1010019, - "text": "Continue with code", + "context": { + "address": "fixed_mfa_test_api@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to fixed_mfa_test_api@ory.sh", "type": "info" } }, diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-email.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-email.json new file mode 100644 index 000000000000..e3510d16393a --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-email.json @@ -0,0 +1,71 @@ +[ + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "+4917613213111" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "+4917613213111", + "channel": "sms" + }, + "id": 1010023, + "text": "Send code to +4917613213111", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-1api@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-1api@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-1api@ory.sh", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-2api@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-2api@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-2api@ory.sh", + "type": "info" + } + }, + "type": "input" + } +] diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-phone.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-phone.json new file mode 100644 index 000000000000..e3510d16393a --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-phone.json @@ -0,0 +1,71 @@ +[ + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "+4917613213111" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "+4917613213111", + "channel": "sms" + }, + "id": 1010023, + "text": "Send code to +4917613213111", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-1api@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-1api@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-1api@ory.sh", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-2api@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-2api@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-2api@ory.sh", + "type": "info" + } + }, + "type": "input" + } +] diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=identifier-email.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=identifier-email.json new file mode 100644 index 000000000000..e3510d16393a --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=Native_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=identifier-email.json @@ -0,0 +1,71 @@ +[ + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "+4917613213111" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "+4917613213111", + "channel": "sms" + }, + "id": 1010023, + "text": "Send code to +4917613213111", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-1api@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-1api@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-1api@ory.sh", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-2api@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-2api@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-2api@ory.sh", + "type": "info" + } + }, + "type": "input" + } +] diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=verify_initial_payload.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=verify_initial_payload.json index 14efb22f25a3..41c14e29d43d 100644 --- a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=verify_initial_payload.json +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=verify_initial_payload.json @@ -30,49 +30,21 @@ { "attributes": { "disabled": false, - "name": "identifier", - "node_type": "input", - "required": true, - "type": "text", - "value": "" - }, - "group": "default", - "messages": [ - { - "context": { - "masked_to": "fi****@ory.sh" - }, - "id": 1010020, - "text": "We will send a code to fi****@ory.sh. To verify that this is your address please enter it here.", - "type": "info" - } - ], - "meta": { - "label": { - "context": { - "title": "Email" - }, - "id": 1070002, - "text": "Email", - "type": "info" - } - }, - "type": "input" - }, - { - "attributes": { - "disabled": false, - "name": "method", + "name": "address", "node_type": "input", "type": "submit", - "value": "code" + "value": "fixed_mfa_test_spa@ory.sh" }, "group": "code", "messages": [], "meta": { "label": { - "id": 1010019, - "text": "Continue with code", + "context": { + "address": "fixed_mfa_test_spa@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to fixed_mfa_test_spa@ory.sh", "type": "info" } }, diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-email.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-email.json new file mode 100644 index 000000000000..721c86a79617 --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-email.json @@ -0,0 +1,71 @@ +[ + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "+4917613213112" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "+4917613213112", + "channel": "sms" + }, + "id": 1010023, + "text": "Send code to +4917613213112", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-1spa@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-1spa@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-1spa@ory.sh", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-2spa@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-2spa@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-2spa@ory.sh", + "type": "info" + } + }, + "type": "input" + } +] diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-phone.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-phone.json new file mode 100644 index 000000000000..721c86a79617 --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=address-phone.json @@ -0,0 +1,71 @@ +[ + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "+4917613213112" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "+4917613213112", + "channel": "sms" + }, + "id": 1010023, + "text": "Send code to +4917613213112", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-1spa@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-1spa@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-1spa@ory.sh", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-2spa@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-2spa@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-2spa@ory.sh", + "type": "info" + } + }, + "type": "input" + } +] diff --git a/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=identifier-email.json b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=identifier-email.json new file mode 100644 index 000000000000..721c86a79617 --- /dev/null +++ b/selfservice/strategy/code/.snapshots/TestLoginCodeStrategy-test=SPA_client-suite=mfa-case=without_via_parameter_all_options_are_shown-field=identifier-email.json @@ -0,0 +1,71 @@ +[ + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "+4917613213112" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "+4917613213112", + "channel": "sms" + }, + "id": 1010023, + "text": "Send code to +4917613213112", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-1spa@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-1spa@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-1spa@ory.sh", + "type": "info" + } + }, + "type": "input" + }, + { + "attributes": { + "disabled": false, + "name": "address", + "node_type": "input", + "type": "submit", + "value": "code-mfa-2spa@ory.sh" + }, + "group": "code", + "messages": [], + "meta": { + "label": { + "context": { + "address": "code-mfa-2spa@ory.sh", + "channel": "email" + }, + "id": 1010023, + "text": "Send code to code-mfa-2spa@ory.sh", + "type": "info" + } + }, + "type": "input" + } +] diff --git a/selfservice/strategy/code/code_login.go b/selfservice/strategy/code/code_login.go index 689d52f0cb4f..820a1818a140 100644 --- a/selfservice/strategy/code/code_login.go +++ b/selfservice/strategy/code/code_login.go @@ -32,7 +32,7 @@ type LoginCode struct { // AddressType represents the type of the address // this can be an email address or a phone number. - AddressType identity.CodeAddressType `json:"-" db:"address_type"` + AddressType identity.CodeChannel `json:"-" db:"address_type"` // CodeHMAC represents the HMACed value of the verification code CodeHMAC string `json:"-" db:"code"` @@ -94,7 +94,7 @@ type CreateLoginCodeParams struct { // AddressType is the type of the address (email or phone number). // required: true - AddressType identity.CodeAddressType + AddressType identity.CodeChannel // Code represents the recovery code // required: true diff --git a/selfservice/strategy/code/code_registration.go b/selfservice/strategy/code/code_registration.go index 4093480fb91e..4015474112dd 100644 --- a/selfservice/strategy/code/code_registration.go +++ b/selfservice/strategy/code/code_registration.go @@ -32,7 +32,7 @@ type RegistrationCode struct { // AddressType represents the type of the address // this can be an email address or a phone number. - AddressType identity.CodeAddressType `json:"-" db:"address_type"` + AddressType identity.CodeChannel `json:"-" db:"address_type"` // CodeHMAC represents the HMACed value of the verification code CodeHMAC string `json:"-" db:"code"` @@ -93,7 +93,7 @@ type CreateRegistrationCodeParams struct { // AddressType is the type of the address (email or phone number). // required: true - AddressType identity.CodeAddressType + AddressType identity.CodeChannel // Code represents the recovery code // required: true diff --git a/selfservice/strategy/code/code_sender.go b/selfservice/strategy/code/code_sender.go index fdc5e8b38b2d..02fda31f3d0a 100644 --- a/selfservice/strategy/code/code_sender.go +++ b/selfservice/strategy/code/code_sender.go @@ -54,7 +54,7 @@ type ( } Address struct { To string - Via identity.CodeAddressType + Via identity.CodeChannel } ) @@ -87,7 +87,7 @@ func (s *Sender) SendCode(ctx context.Context, f flow.Flow, id *identity.Identit code, err := s.deps. RegistrationCodePersister(). CreateRegistrationCode(ctx, &CreateRegistrationCodeParams{ - AddressType: identity.CodeAddressType(address.Via), + AddressType: address.Via, RawCode: rawCode, ExpiresIn: s.deps.Config().SelfServiceCodeMethodLifespan(ctx), FlowID: f.GetID(), @@ -123,7 +123,7 @@ func (s *Sender) SendCode(ctx context.Context, f flow.Flow, id *identity.Identit code, err := s.deps. LoginCodePersister(). CreateLoginCode(ctx, &CreateLoginCodeParams{ - AddressType: identity.CodeAddressType(address.Via), + AddressType: address.Via, Address: address.To, RawCode: rawCode, ExpiresIn: s.deps.Config().SelfServiceCodeMethodLifespan(ctx), diff --git a/selfservice/strategy/code/strategy.go b/selfservice/strategy/code/strategy.go index 2d275fdff85f..1e8fb5cb6af4 100644 --- a/selfservice/strategy/code/strategy.go +++ b/selfservice/strategy/code/strategy.go @@ -6,8 +6,11 @@ package code import ( "context" "net/http" + "sort" "strings" + "github.com/samber/lo" + "github.com/pkg/errors" "github.com/tidwall/gjson" @@ -36,20 +39,21 @@ import ( ) var ( - _ recovery.Strategy = new(Strategy) - _ recovery.AdminHandler = new(Strategy) - _ recovery.PublicHandler = new(Strategy) + _ recovery.Strategy = (*Strategy)(nil) + _ recovery.AdminHandler = (*Strategy)(nil) + _ recovery.PublicHandler = (*Strategy)(nil) ) var ( - _ verification.Strategy = new(Strategy) - _ verification.AdminHandler = new(Strategy) - _ verification.PublicHandler = new(Strategy) + _ verification.Strategy = (*Strategy)(nil) + _ verification.AdminHandler = (*Strategy)(nil) + _ verification.PublicHandler = (*Strategy)(nil) ) var ( - _ login.Strategy = new(Strategy) - _ registration.Strategy = new(Strategy) + _ login.Strategy = (*Strategy)(nil) + _ registration.Strategy = (*Strategy)(nil) + _ identity.ActiveCredentialsCounter = (*Strategy)(nil) ) type ( @@ -121,6 +125,25 @@ type ( } ) +func (s *Strategy) CountActiveFirstFactorCredentials(ctx context.Context, cc map[identity.CredentialsType]identity.Credentials) (int, error) { + codeConfig := s.deps.Config().SelfServiceCodeStrategy(ctx) + if codeConfig.PasswordlessEnabled { + // Login with code for passwordless is enabled + return 1, nil + } + + return 0, nil +} + +func (s *Strategy) CountActiveMultiFactorCredentials(ctx context.Context, cc map[identity.CredentialsType]identity.Credentials) (int, error) { + codeConfig := s.deps.Config().SelfServiceCodeStrategy(ctx) + if codeConfig.MFAEnabled { + return 1, nil + } + + return 0, nil +} + func NewStrategy(deps any) *Strategy { return &Strategy{deps: deps.(strategyDependencies), dx: decoderx.NewHTTP()} } @@ -186,14 +209,16 @@ func (s *Strategy) PopulateMethod(r *http.Request, f flow.Flow) error { func (s *Strategy) populateChooseMethodFlow(r *http.Request, f flow.Flow) error { ctx := r.Context() - var codeMetaLabel *text.Message switch f := f.(type) { case *recovery.Flow, *verification.Flow: f.GetUI().Nodes.Append( node.NewInputField("email", nil, node.CodeGroup, node.InputAttributeTypeEmail, node.WithRequiredInputAttribute). WithMetaLabel(text.NewInfoNodeInputEmail()), ) - codeMetaLabel = text.NewInfoNodeLabelContinue() + f.GetUI().Nodes.Append( + node.NewInputField("method", s.ID(), node.CodeGroup, node.InputAttributeTypeSubmit). + WithMetaLabel(text.NewInfoNodeLabelContinue()), + ) case *login.Flow: ds, err := s.deps.Config().DefaultIdentityTraitsSchemaURL(ctx) if err != nil { @@ -201,48 +226,80 @@ func (s *Strategy) populateChooseMethodFlow(r *http.Request, f flow.Flow) error } if f.RequestedAAL == identity.AuthenticatorAssuranceLevel2 { via := r.URL.Query().Get("via") - if via == "" { - return errors.WithStack(herodot.ErrBadRequest.WithReason("AAL2 login via code requires the `via` query parameter")) - } sess, err := s.deps.SessionManager().FetchFromRequest(r.Context(), r) if err != nil { return err } - allSchemas, err := s.deps.IdentityTraitsSchemas(ctx) - if err != nil { - return err - } - iSchema, err := allSchemas.GetByID(sess.Identity.SchemaID) - if err != nil { - return err - } - identifierLabel, err := login.GetIdentifierLabelFromSchemaWithField(ctx, iSchema.RawURL, via) - if err != nil { - return err + // We need to load the identity's credentials. + if len(sess.Identity.Credentials) == 0 { + if err := s.deps.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, sess.Identity, identity.ExpandCredentials); err != nil { + return err + } } - value := gjson.GetBytes(sess.Identity.Traits, via).String() - if value == "" { - return errors.WithStack(herodot.ErrBadRequest.WithReasonf("No value found for trait %s in the current identity", via)) - } + // The via parameter lets us hint at the OTP address to use for 2fa. + if via == "" { + addresses, found, err := FindCodeAddressCandidates(sess.Identity, s.deps.Config().SelfServiceCodeMethodMissingCredentialFallbackEnabled(ctx)) + if err != nil { + return err + } else if !found { + return nil + } + + sort.SliceStable(addresses, func(i, j int) bool { + return addresses[i].To < addresses[j].To && addresses[i].Via < addresses[j].Via + }) + + for _, address := range addresses { + f.GetUI().Nodes.Append(node.NewInputField("address", address.To, node.CodeGroup, node.InputAttributeTypeSubmit). + WithMetaLabel(text.NewInfoSelfServiceLoginAAL2CodeAddress(string(address.Via), address.To))) + } + } else { + value := gjson.GetBytes(sess.Identity.Traits, via).String() + if value == "" { + return errors.WithStack(herodot.ErrBadRequest.WithReasonf("No value found for trait %s in the current identity.", via)) + } + + // TODO Remove this normalization once the via parameter is deprecated. + // + // Here we need to normalize the via parameter to the actual address. This is necessary because otherwise + // we won't find the address in the list of addresses. + // + // Since we don't know if the via parameter is an email address or a phone number, we need to normalize for both. + value = x.GracefulNormalization(value) + + addresses, found, err := FindCodeAddressCandidates(sess.Identity, s.deps.Config().SelfServiceCodeMethodMissingCredentialFallbackEnabled(ctx)) + if err != nil { + return err + } else if !found { + return nil + } + + address, found := lo.Find(addresses, func(item Address) bool { + return item.To == value + }) + if !found { + return errors.WithStack(herodot.ErrBadRequest.WithReasonf("You can only reference a trait that matches a verification email address in the via parameter, or a registered credential.")) + } - codeMetaLabel = text.NewInfoSelfServiceLoginCodeMFA() - idNode := node.NewInputField("identifier", "", node.DefaultGroup, node.InputAttributeTypeText, node.WithRequiredInputAttribute).WithMetaLabel(identifierLabel) - idNode.Messages.Add(text.NewInfoSelfServiceLoginCodeMFAHint(MaskAddress(value))) - f.GetUI().Nodes.Upsert(idNode) + f.GetUI().Nodes.Append(node.NewInputField("address", address.To, node.CodeGroup, node.InputAttributeTypeSubmit). + WithMetaLabel(text.NewInfoSelfServiceLoginAAL2CodeAddress(string(address.Via), address.To))) + } } else { - codeMetaLabel = text.NewInfoSelfServiceLoginCode() identifierLabel, err := login.GetIdentifierLabelFromSchema(ctx, ds.String()) if err != nil { return err } f.GetUI().Nodes.Upsert(node.NewInputField("identifier", "", node.DefaultGroup, node.InputAttributeTypeText, node.WithRequiredInputAttribute).WithMetaLabel(identifierLabel)) + f.GetUI().Nodes.Append( + node.NewInputField("method", s.ID(), node.CodeGroup, node.InputAttributeTypeSubmit).WithMetaLabel(text.NewInfoSelfServiceLoginCode()), + ) } + case *registration.Flow: - codeMetaLabel = text.NewInfoSelfServiceRegistrationRegisterCode() ds, err := s.deps.Config().DefaultIdentityTraitsSchemaURL(ctx) if err != nil { return err @@ -258,12 +315,12 @@ func (s *Strategy) populateChooseMethodFlow(r *http.Request, f flow.Flow) error for _, n := range traitNodes { f.GetUI().Nodes.Upsert(n) } - } - - methodButton := node.NewInputField("method", s.ID(), node.CodeGroup, node.InputAttributeTypeSubmit). - WithMetaLabel(codeMetaLabel) - f.GetUI().Nodes.Append(methodButton) + f.GetUI().Nodes.Append( + node.NewInputField("method", s.ID(), node.CodeGroup, node.InputAttributeTypeSubmit). + WithMetaLabel(text.NewInfoSelfServiceRegistrationRegisterCode()), + ) + } return nil } @@ -300,10 +357,11 @@ func (s *Strategy) populateEmailSentFlow(ctx context.Context, f flow.Flow) error // preserve the login identifier that was submitted // so we can retry the code flow with the same data for _, n := range f.GetUI().Nodes { - if n.ID() == "identifier" { + if n.ID() == "identifier" || n.ID() == "address" { if input, ok := n.Attributes.(*node.InputAttributes); ok { input.Type = "hidden" n.Attributes = input + input.Name = "identifier" } freshNodes = append(freshNodes, n) } diff --git a/selfservice/strategy/code/strategy_login.go b/selfservice/strategy/code/strategy_login.go index a6d9af4fad0f..652de2104248 100644 --- a/selfservice/strategy/code/strategy_login.go +++ b/selfservice/strategy/code/strategy_login.go @@ -4,12 +4,16 @@ package code import ( + "cmp" "context" - "database/sql" "encoding/json" "net/http" "strings" + "go.opentelemetry.io/otel/attribute" + + "github.com/ory/kratos/driver/config" + "github.com/ory/kratos/selfservice/strategy/idfirst" "github.com/ory/kratos/text" @@ -61,6 +65,10 @@ type updateLoginFlowWithCodeMethod struct { // required: false Identifier string `json:"identifier" form:"identifier"` + // Address is the address to send the code to, in case that there are multiple addresses. This field + // is only used in two-factor flows and is ineffective for passwordless flows. + Address string `json:"address" form:"address"` + // Resend is set when the user wants to resend the code // required: false Resend string `json:"resend" form:"resend"` @@ -73,19 +81,15 @@ type updateLoginFlowWithCodeMethod struct { func (s *Strategy) RegisterLoginRoutes(*x.RouterPublic) {} -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, amr session.AuthenticationMethods) session.AuthenticationMethod { - aal1Satisfied := lo.ContainsBy(amr, func(am session.AuthenticationMethod) bool { - return am.Method != identity.CredentialsTypeCodeAuth && am.AAL == identity.AuthenticatorAssuranceLevel1 - }) - if aal1Satisfied { - return session.AuthenticationMethod{ - Method: identity.CredentialsTypeCodeAuth, - AAL: identity.AuthenticatorAssuranceLevel2, - } +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { + aal := identity.AuthenticatorAssuranceLevel1 + if s.deps.Config().SelfServiceCodeStrategy(ctx).MFAEnabled { + aal = identity.AuthenticatorAssuranceLevel2 } + return session.AuthenticationMethod{ - Method: identity.CredentialsTypeCodeAuth, - AAL: identity.AuthenticatorAssuranceLevel1, + Method: s.ID(), + AAL: aal, } } @@ -95,24 +99,16 @@ func (s *Strategy) HandleLoginError(r *http.Request, f *login.Flow, body *update } if f != nil { - email := "" + identifier := "" if body != nil { - email = body.Identifier + identifier = cmp.Or(body.Address, body.Identifier) } - ds, err := s.deps.Config().DefaultIdentityTraitsSchemaURL(r.Context()) - if err != nil { - return err - } - identifierLabel, err := login.GetIdentifierLabelFromSchema(r.Context(), ds.String()) - if err != nil { - return err - } f.UI.SetCSRF(s.deps.GenerateCSRFToken(r)) - f.UI.GetNodes().Upsert( - node.NewInputField("identifier", email, node.DefaultGroup, node.InputAttributeTypeText, node.WithRequiredInputAttribute). - WithMetaLabel(identifierLabel), - ) + identifierNode := node.NewInputField("identifier", identifier, node.DefaultGroup, node.InputAttributeTypeHidden) + + identifierNode.Attributes.SetValue(identifier) + f.UI.GetNodes().Upsert(identifierNode) } return err @@ -122,52 +118,85 @@ func (s *Strategy) HandleLoginError(r *http.Request, f *login.Flow, body *update // If the identity does not have a code credential, it will attempt to find // the identity through other credentials matching the identifier. // the fallback mechanism is used for migration purposes of old accounts that do not have a code credential. -func (s *Strategy) findIdentityByIdentifier(ctx context.Context, identifier string) (_ *identity.Identity, isFallback bool, err error) { +func (s *Strategy) findIdentityByIdentifier(ctx context.Context, identifier string) (id *identity.Identity, cred *identity.Credentials, isFallback bool, err error) { ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.findIdentityByIdentifier") defer otelx.End(span, &err) - id, cred, err := s.deps.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, s.ID(), identifier) + id, cred, err = s.deps.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, s.ID(), identifier) if errors.Is(err, sqlcon.ErrNoRows) { // this is a migration for old identities that do not have a code credential // we might be able to do a fallback login since we could not find a credential on this identifier // Case insensitive because we only care about emails. id, err := s.deps.PrivilegedIdentityPool().FindIdentityByCredentialIdentifier(ctx, identifier, false) if err != nil { - return nil, false, errors.WithStack(schema.NewNoCodeAuthnCredentials()) + return nil, nil, false, errors.WithStack(schema.NewNoCodeAuthnCredentials()) } // we don't know if the user has verified the code yet, so we just return the identity // and let the caller decide what to do with it - return id, true, nil + return id, nil, true, nil } else if err != nil { - return nil, false, errors.WithStack(schema.NewNoCodeAuthnCredentials()) + return nil, nil, false, errors.WithStack(schema.NewNoCodeAuthnCredentials()) } if len(cred.Identifiers) == 0 { - return nil, false, errors.WithStack(schema.NewNoCodeAuthnCredentials()) + return nil, nil, false, errors.WithStack(schema.NewNoCodeAuthnCredentials()) } // we don't need the code credential, we just need to know that it exists - return id, false, nil + return id, cred, false, nil } -func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, sess *session.Session) (_ *identity.Identity, err error) { - ctx, span := s.deps.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.code.strategy.Login") - defer otelx.End(span, &err) +type decodedMethod struct { + Method string `json:"method" form:"method"` + Address string `json:"address" form:"address"` +} - if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.deps); err != nil { +func (s *Strategy) methodEnabledAndAllowedFromRequest(r *http.Request, f *login.Flow) (*decodedMethod, error) { + var method decodedMethod + + compiler, err := decoderx.HTTPRawJSONSchemaCompiler(loginMethodSchema) + if err != nil { + return nil, errors.WithStack(err) + } + + if err := decoderx.NewHTTP().Decode(r, &method, compiler, + decoderx.HTTPKeepRequestBody(true), + decoderx.HTTPDecoderAllowedMethods("POST", "PUT", "PATCH", "GET"), + decoderx.HTTPDecoderSetValidatePayloads(false), + decoderx.HTTPDecoderJSONFollowsFormFormat()); err != nil { + return nil, errors.WithStack(err) + } + + if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.ID().String(), method.Method, s.deps); err != nil { return nil, err } - var aal identity.AuthenticatorAssuranceLevel + return &method, nil +} + +func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, sess *session.Session) (_ *identity.Identity, err error) { + ctx, span := s.deps.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.code.strategy.Login") + defer otelx.End(span, &err) if s.deps.Config().SelfServiceCodeStrategy(ctx).PasswordlessEnabled { - aal = identity.AuthenticatorAssuranceLevel1 + if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel1); err != nil { + return nil, err + } } else if s.deps.Config().SelfServiceCodeStrategy(ctx).MFAEnabled { - aal = identity.AuthenticatorAssuranceLevel2 + if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel2); err != nil { + return nil, err + } + } else { + return nil, errors.WithStack(flow.ErrStrategyNotResponsible) } - if err := login.CheckAAL(f, aal); err != nil { + if p, err := s.methodEnabledAndAllowedFromRequest(r, f); errors.Is(err, flow.ErrStrategyNotResponsible) { + if !(s.deps.Config().SelfServiceCodeStrategy(ctx).MFAEnabled && (p == nil || len(p.Address) > 0)) { + return nil, err + } + // In this special case we only expect `address` to be set. + } else if err != nil { return nil, err } @@ -197,7 +226,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } return nil, nil case flow.StateEmailSent: - i, err := s.loginVerifyCode(ctx, r, f, &p) + i, err := s.loginVerifyCode(ctx, r, f, &p, sess) if err != nil { return nil, s.HandleLoginError(r, f, &p, err) } @@ -209,40 +238,151 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, s.HandleLoginError(r, f, &p, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unexpected flow state: %s", f.GetState()))) } -func (s *Strategy) loginSendCode(ctx context.Context, w http.ResponseWriter, r *http.Request, f *login.Flow, p *updateLoginFlowWithCodeMethod, sess *session.Session) (err error) { - ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.loginSendCode") +func (s *Strategy) findIdentifierInVerifiableAddress(i *identity.Identity, identifier string) (*Address, error) { + verifiableAddress, found := lo.Find(i.VerifiableAddresses, func(va identity.VerifiableAddress) bool { + return va.Value == identifier + }) + if !found { + return nil, errors.WithStack(schema.NewUnknownAddressError()) + } + + // This should be fine for legacy cases because we use `UpgradeCredentials` to normalize all address types prior + // to calling this method. + parsed, err := identity.NewCodeChannel(verifiableAddress.Via) + if err != nil { + return nil, err + } + + return &Address{ + To: verifiableAddress.Value, + Via: parsed, + }, nil +} + +func (s *Strategy) findIdentityForIdentifier(ctx context.Context, identifier string, requestedAAL identity.AuthenticatorAssuranceLevel, session *session.Session) (_ *identity.Identity, _ []Address, err error) { + ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.findIdentityForIdentifier") defer otelx.End(span, &err) - if len(p.Identifier) == 0 { - return errors.WithStack(schema.NewRequiredError("#/identifier", "identifier")) + if len(identifier) == 0 { + return nil, nil, errors.WithStack(schema.NewRequiredError("#/identifier", "identifier")) } - p.Identifier = maybeNormalizeEmail(p.Identifier) + identifier = x.GracefulNormalization(identifier) var addresses []Address - var i *identity.Identity - if f.RequestedAAL > identity.AuthenticatorAssuranceLevel1 { - address, found := lo.Find(sess.Identity.VerifiableAddresses, func(va identity.VerifiableAddress) bool { - return va.Value == p.Identifier - }) - if !found { - return errors.WithStack(schema.NewUnknownAddressError()) + + // Step 1: Get the identity + i, cred, isFallback, err := s.findIdentityByIdentifier(ctx, identifier) + if err != nil { + if requestedAAL == identity.AuthenticatorAssuranceLevel2 { + // When using two-factor auth, the identity used to not have any code credential associated. Therefore, + // we need to gracefully handle this flow. + // + // TODO this section should be removed at some point when we are sure that all identities have a code credential. + if errors.Is(err, schema.NewNoCodeAuthnCredentials()) { + fallbackAllowed := s.deps.Config().SelfServiceCodeMethodMissingCredentialFallbackEnabled(ctx) + span.SetAttributes( + attribute.Bool(config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, fallbackAllowed), + ) + + if !fallbackAllowed { + s.deps.Logger().Warn("The identity does not have a code credential but the fallback mechanism is disabled. Login failed.") + return nil, nil, errors.WithStack(schema.NewNoCodeAuthnCredentials()) + } + + _, err := s.findIdentifierInVerifiableAddress(session.Identity, identifier) + if err != nil { + return nil, nil, err + } + + // We only end up here if the identity's identity schema does not have the `code` identifier extension defined. + // We know that this is the case for a couple of projects who use 2FA with the code credential. + // + // In those scenarios, the identity has no code credential, and the code credential will also not be created by + // the identity schema. + // + // To avoid future regressions, we will not perform an update on the identity here. Effectively, whenever + // the identity would be updated again (and the identity schema + extensions parsed), it would be likely + // that the code credentials are overwritten. + // + // So we accept that the identity in this case will simply not have code credentials, and we will rely on the + // fallback mechanism to authenticate the user. + } else if err != nil { + return nil, nil, err + } } - i = sess.Identity + return nil, nil, err + } else if isFallback { + fallbackAllowed := s.deps.Config().SelfServiceCodeMethodMissingCredentialFallbackEnabled(ctx) + span.SetAttributes( + attribute.String("identity.id", i.ID.String()), + attribute.String("network.id", i.NID.String()), + attribute.Bool(config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, fallbackAllowed), + ) + + if !fallbackAllowed { + s.deps.Logger().Warn("The identity does not have a code credential but the fallback mechanism is disabled. Login failed.") + return nil, nil, errors.WithStack(schema.NewNoCodeAuthnCredentials()) + } + + // We don't have a code credential, but we can still login the user if they have a verified address. + // This is a migration path for old accounts that do not have a code credential. addresses = []Address{{ - To: address.Value, - Via: address.Via, + To: identifier, + Via: identity.CodeChannelEmail, }} + + // We only end up here if the identity's identity schema does not have the `code` identifier extension defined. + // We know that this is the case for a couple of projects who use 2FA with the code credential. + // + // In those scenarios, the identity has no code credential, and the code credential will also not be created by + // the identity schema. + // + // To avoid future regressions, we will not perform an update on the identity here. Effectively, whenever + // the identity would be updated again (and the identity schema + extensions parsed), it would be likely + // that the code credentials are overwritten. + // + // So we accept that the identity in this case will simply not have code credentials, and we will rely on the + // fallback mechanism to authenticate the user. } else { - // Step 1: Get the identity - i, _, err = s.findIdentityByIdentifier(ctx, p.Identifier) - if err != nil { - return err + span.SetAttributes( + attribute.String("identity.id", i.ID.String()), + attribute.String("network.id", i.NID.String()), + ) + + var conf identity.CredentialsCode + if err := json.Unmarshal(cred.Config, &conf); err != nil { + return nil, nil, errors.WithStack(err) } - addresses = []Address{{ - To: p.Identifier, - Via: identity.CodeAddressType(identity.AddressTypeEmail), - }} + + for _, address := range conf.Addresses { + addresses = append(addresses, Address{ + To: address.Address, + Via: address.Channel, + }) + } + } + + return i, addresses, nil +} + +func (s *Strategy) loginSendCode(ctx context.Context, w http.ResponseWriter, r *http.Request, f *login.Flow, p *updateLoginFlowWithCodeMethod, sess *session.Session) (err error) { + ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.loginSendCode") + defer otelx.End(span, &err) + + p.Identifier = maybeNormalizeEmail( + cmp.Or(p.Identifier, p.Address), + ) + + i, addresses, err := s.findIdentityForIdentifier(ctx, p.Identifier, f.RequestedAAL, sess) + if err != nil { + return err + } + + if address, found := lo.Find(addresses, func(item Address) bool { + return item.To == x.GracefulNormalization(p.Identifier) + }); found { + addresses = []Address{address} } // Step 2: Delete any previous login codes for this flow ID @@ -287,7 +427,7 @@ func maybeNormalizeEmail(input string) string { return input } -func (s *Strategy) loginVerifyCode(ctx context.Context, r *http.Request, f *login.Flow, p *updateLoginFlowWithCodeMethod) (_ *identity.Identity, err error) { +func (s *Strategy) loginVerifyCode(ctx context.Context, r *http.Request, f *login.Flow, p *updateLoginFlowWithCodeMethod, sess *session.Session) (_ *identity.Identity, err error) { ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.loginVerifyCode") defer otelx.End(span, &err) @@ -297,27 +437,21 @@ func (s *Strategy) loginVerifyCode(ctx context.Context, r *http.Request, f *logi return nil, errors.WithStack(schema.NewRequiredError("#/code", "code")) } - if len(p.Identifier) == 0 { - return nil, errors.WithStack(schema.NewRequiredError("#/identifier", "identifier")) - } - - p.Identifier = maybeNormalizeEmail(p.Identifier) + p.Identifier = maybeNormalizeEmail( + cmp.Or( + p.Address, + p.Identifier, // Older versions of Kratos required us to send the identifier here. + ), + ) - isFallback := false var i *identity.Identity - if f.RequestedAAL > identity.AuthenticatorAssuranceLevel1 { - // Don't require the code credential if the user already has a session (e.g. this is an MFA flow) - sess, err := s.deps.SessionManager().FetchFromRequest(ctx, r) - if err != nil { - return nil, err - } + if f.RequestedAAL == identity.AuthenticatorAssuranceLevel2 { i = sess.Identity } else { - // Step 1: Get the identity - i, isFallback, err = s.findIdentityByIdentifier(ctx, p.Identifier) - if err != nil { - return nil, err - } + i, _, err = s.findIdentityForIdentifier(ctx, p.Identifier, f.RequestedAAL, sess) + } + if err != nil { + return nil, err } loginCode, err := s.deps.LoginCodePersister().UseLoginCode(ctx, f.ID, i.ID, p.Code) @@ -333,22 +467,6 @@ func (s *Strategy) loginVerifyCode(ctx context.Context, r *http.Request, f *logi return nil, errors.WithStack(err) } - // the code is correct, if the login happened through a different credential, we need to update the identity - if isFallback { - if err := i.SetCredentialsWithConfig( - s.ID(), - // p.Identifier was normalized prior. - identity.Credentials{Type: s.ID(), Identifiers: []string{p.Identifier}}, - &identity.CredentialsCode{UsedAt: sql.NullTime{}}, - ); err != nil { - return nil, errors.WithStack(err) - } - - if err := s.deps.PrivilegedIdentityPool().UpdateIdentity(ctx, i); err != nil { - return nil, errors.WithStack(err) - } - } - // Step 2: The code was correct f.Active = identity.CredentialsTypeCodeAuth @@ -360,6 +478,14 @@ func (s *Strategy) loginVerifyCode(ctx context.Context, r *http.Request, f *logi return nil, errors.WithStack(err) } + // Step 3: Verify the address + if err := s.verifyAddress(ctx, i, Address{ + To: loginCode.Address, + Via: loginCode.AddressType, + }); err != nil { + return nil, err + } + for idx := range i.VerifiableAddresses { va := i.VerifiableAddresses[idx] if !va.Verified && loginCode.Address == va.Value { @@ -375,6 +501,31 @@ func (s *Strategy) loginVerifyCode(ctx context.Context, r *http.Request, f *logi return i, nil } +func (s *Strategy) verifyAddress(ctx context.Context, i *identity.Identity, verified Address) error { + for idx := range i.VerifiableAddresses { + va := i.VerifiableAddresses[idx] + if va.Verified { + continue + } + + if verified.To != va.Value || string(verified.Via) != va.Via { + continue + } + + va.Verified = true + va.Status = identity.VerifiableAddressStatusCompleted + if err := s.deps.PrivilegedIdentityPool().UpdateVerifiableAddress(ctx, &va); errors.Is(err, sqlcon.ErrNoRows) { + // This happens when the verified address does not yet exist, for example during registration. In this case we just skip. + continue + } else if err != nil { + return err + } + break + } + + return nil +} + func (s *Strategy) PopulateLoginMethodFirstFactorRefresh(r *http.Request, f *login.Flow) error { return s.PopulateMethod(r, f) } diff --git a/selfservice/strategy/code/strategy_login_test.go b/selfservice/strategy/code/strategy_login_test.go index 880d64f5971c..9c98a929dcc9 100644 --- a/selfservice/strategy/code/strategy_login_test.go +++ b/selfservice/strategy/code/strategy_login_test.go @@ -10,9 +10,12 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" + "github.com/ory/kratos/courier" + "github.com/ory/kratos/selfservice/strategy/idfirst" configtesthelpers "github.com/ory/kratos/driver/config/testhelpers" @@ -45,6 +48,7 @@ import ( func createIdentity(ctx context.Context, t *testing.T, reg driver.Registry, withoutCodeCredential bool, moreIdentifiers ...string) *identity.Identity { t.Helper() i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) + i.NID = x.NewUUID() email := testhelpers.RandomEmail() ids := fmt.Sprintf(`"email":"%s"`, email) @@ -60,7 +64,7 @@ func createIdentity(ctx context.Context, t *testing.T, reg driver.Registry, with identity.CredentialsTypeWebAuthn: {Type: identity.CredentialsTypeWebAuthn, Identifiers: append([]string{email}, moreIdentifiers...), Config: sqlxx.JSONRawMessage("{\"some\" : \"secret\", \"user_handle\": \"rVIFaWRcTTuQLkXFmQWpgA==\"}")}, } if !withoutCodeCredential { - credentials[identity.CredentialsTypeCodeAuth] = identity.Credentials{Type: identity.CredentialsTypeCodeAuth, Identifiers: append([]string{email}, moreIdentifiers...), Config: sqlxx.JSONRawMessage("{\"address_type\": \"email\", \"used_at\": \"2023-07-26T16:59:06+02:00\"}")} + credentials[identity.CredentialsTypeCodeAuth] = identity.Credentials{Type: identity.CredentialsTypeCodeAuth, Identifiers: append([]string{email}, moreIdentifiers...), Config: sqlxx.JSONRawMessage(`{"addresses":[{"channel":"email","address":"` + email + `"}]}`)} } i.Credentials = credentials @@ -73,7 +77,7 @@ func createIdentity(ctx context.Context, t *testing.T, reg driver.Registry, with i.VerifiableAddresses = va - require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(ctx, i)) + require.NoError(t, reg.IdentityManager().Create(ctx, i)) return i } @@ -109,11 +113,9 @@ func TestLoginCodeStrategy(t *testing.T) { ApiTypeNative ApiType = "api" ) - createLoginFlow := func(ctx context.Context, t *testing.T, public *httptest.Server, apiType ApiType, withoutCodeCredential bool, moreIdentifiers ...string) *state { + createLoginFlowWithIdentity := func(ctx context.Context, t *testing.T, public *httptest.Server, apiType ApiType, user *identity.Identity) *state { t.Helper() - identity := createIdentity(ctx, t, reg, withoutCodeCredential, moreIdentifiers...) - var client *http.Client if apiType == ApiTypeNative { client = &http.Client{} @@ -140,18 +142,23 @@ func TestLoginCodeStrategy(t *testing.T) { require.NotEmptyf(t, csrfToken, "could not find csrf_token in: %s", body) } - loginEmail := gjson.Get(identity.Traits.String(), "email").String() - require.NotEmptyf(t, loginEmail, "could not find the email trait inside the identity: %s", identity.Traits.String()) - return &state{ - flowID: clientInit.GetId(), - identity: identity, - identityEmail: loginEmail, - client: client, - testServer: public, + flowID: clientInit.GetId(), + identity: user, + client: client, + testServer: public, } } + createLoginFlow := func(ctx context.Context, t *testing.T, public *httptest.Server, apiType ApiType, withoutCodeCredential bool, moreIdentifiers ...string) *state { + t.Helper() + s := createLoginFlowWithIdentity(ctx, t, public, apiType, createIdentity(ctx, t, reg, withoutCodeCredential, moreIdentifiers...)) + loginEmail := gjson.Get(s.identity.Traits.String(), "email").String() + require.NotEmptyf(t, loginEmail, "could not find the email trait inside the identity: %s", s.identity.Traits.String()) + s.identityEmail = loginEmail + return s + } + type onSubmitAssertion func(t *testing.T, s *state, body string, res *http.Response) submitLogin := func(ctx context.Context, t *testing.T, s *state, apiType ApiType, vals func(v *url.Values), mustHaveSession bool, submitAssertion onSubmitAssertion) *state { @@ -187,8 +194,8 @@ func TestLoginCodeStrategy(t *testing.T) { resp, err = s.client.Do(req) require.NoError(t, err) - require.EqualValues(t, http.StatusOK, resp.StatusCode) body = string(ioutilx.MustReadAll(resp.Body)) + require.EqualValues(t, http.StatusOK, resp.StatusCode, "%s", body) } else { // SPAs need to be informed that the login has not yet completed using status 400. // Browser clients will redirect back to the login URL. @@ -241,7 +248,7 @@ func TestLoginCodeStrategy(t *testing.T) { }, true, nil) }) - t.Run("case=should be able to log in with code", func(t *testing.T) { + t.Run("case=should be able to log in with code sent to email", func(t *testing.T) { // create login flow s := createLoginFlow(ctx, t, public, tc.apiType, false) @@ -268,6 +275,148 @@ func TestLoginCodeStrategy(t *testing.T) { } }) + t.Run("case=should be able to log in legacy cases", func(t *testing.T) { + run := func(t *testing.T, s *state) { + // submit email + s = submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { + v.Set("identifier", s.identityEmail) + }, false, nil) + + t.Logf("s.body: %s", s.body) + + message := testhelpers.CourierExpectMessage(ctx, t, reg, s.identityEmail, "Login to your account") + assert.Contains(t, message.Body, "please login to your account by entering the following code") + + loginCode := testhelpers.CourierExpectCodeInMessage(t, message, 1) + assert.NotEmpty(t, loginCode) + + // 3. Submit OTP + state := submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { + v.Set("code", loginCode) + }, true, nil) + if tc.apiType == ApiTypeSPA { + assert.Contains(t, gjson.Get(state.body, "continue_with.0.redirect_browser_to").String(), conf.SelfServiceBrowserDefaultReturnTo(ctx).String(), "%s", state.body) + } else { + assert.Empty(t, gjson.Get(state.body, "continue_with").Array(), "%s", state.body) + } + } + + initDefault := func(t *testing.T, cf string) *state { + i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) + i.NID = x.NewUUID() + + // valid fake phone number for libphonenumber + email := testhelpers.RandomEmail() + i.Traits = identity.Traits(fmt.Sprintf(`{"tos": true, "email": "%s"}`, email)) + i.Credentials = map[identity.CredentialsType]identity.Credentials{ + identity.CredentialsTypeCodeAuth: { + Type: identity.CredentialsTypeCodeAuth, + Identifiers: []string{email}, + Version: 0, + Config: sqlxx.JSONRawMessage(cf), + }, + } + require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentities(ctx, i)) // We explicitly bypass identity validation to test the legacy code path + s := createLoginFlowWithIdentity(ctx, t, public, tc.apiType, i) + s.identityEmail = email + return s + } + + t.Run("case=should be able to send address type with spaces", func(t *testing.T) { + run(t, + initDefault(t, `{"address_type": "email ", "used_at": {"Time": "0001-01-01T00:00:00Z", "Valid": false}}`), + ) + }) + + t.Run("case=should be able to send to empty address type", func(t *testing.T) { + run(t, + initDefault(t, `{"address_type": "", "used_at": {"Time": "0001-01-01T00:00:00Z", "Valid": false}}`), + ) + }) + + t.Run("case=should be able to send to empty credentials config", func(t *testing.T) { + run(t, + initDefault(t, `{}`), + ) + }) + + t.Run("case=should be able to send to identity with no credentials at all when fallback is enabled", func(t *testing.T) { + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, true) + t.Cleanup(func() { + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, nil) + }) + + i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) + i.NID = x.NewUUID() + email := testhelpers.RandomEmail() + i.Traits = identity.Traits(fmt.Sprintf(`{"tos": true, "email": "%s"}`, email)) + i.Credentials = map[identity.CredentialsType]identity.Credentials{ + // This makes it possible for our code to find the identity identifier here. + identity.CredentialsTypePassword: {Type: identity.CredentialsTypePassword, Identifiers: []string{email}, Config: sqlxx.JSONRawMessage(`{}`)}, + } + + require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentities(ctx, i)) // We explicitly bypass identity validation to test the legacy code path + s := createLoginFlowWithIdentity(ctx, t, public, tc.apiType, i) + s.identityEmail = email + run(t, s) + }) + + t.Run("case=should fail to send to identity with no credentials at all when fallback is disabled", func(t *testing.T) { + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, false) + t.Cleanup(func() { + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, nil) + }) + + i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) + i.NID = x.NewUUID() + email := testhelpers.RandomEmail() + i.Traits = identity.Traits(fmt.Sprintf(`{"tos": true, "email": "%s"}`, email)) + require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentities(ctx, i)) // We explicitly bypass identity validation to test the legacy code path + s := createLoginFlowWithIdentity(ctx, t, public, tc.apiType, i) + s.identityEmail = email + // submit email + s = submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { + v.Set("identifier", s.identityEmail) + }, false, nil) + assert.Contains(t, s.body, "4000035", "Should not find the account") + }) + }) + + t.Run("case=should be able to log in with code to sms and normalize the number", func(t *testing.T) { + i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) + i.NID = x.NewUUID() + + // valid fake phone number for libphonenumber + phone := "+1 (415) 55526-71" + i.Traits = identity.Traits(fmt.Sprintf(`{"tos": true, "phone_1": "%s"}`, phone)) + require.NoError(t, reg.IdentityManager().Create(ctx, i)) + t.Cleanup(func() { + require.NoError(t, reg.PrivilegedIdentityPool().DeleteIdentity(ctx, i.ID)) + }) + + s := createLoginFlowWithIdentity(ctx, t, public, tc.apiType, i) + + // submit email + s = submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { + v.Set("identifier", phone) + }, false, nil) + + message := testhelpers.CourierExpectMessage(ctx, t, reg, x.GracefulNormalization(phone), "Your login code is:") + loginCode := testhelpers.CourierExpectCodeInMessage(t, message, 1) + assert.NotEmpty(t, loginCode) + + // 3. Submit OTP + state := submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { + v.Set("code", loginCode) + }, true, nil) + if tc.apiType == ApiTypeSPA { + assert.EqualValues(t, flow.ContinueWithActionRedirectBrowserToString, gjson.Get(state.body, "continue_with.0.action").String(), "%s", state.body) + assert.Contains(t, gjson.Get(state.body, "continue_with.0.redirect_browser_to").String(), conf.SelfServiceBrowserDefaultReturnTo(ctx).String(), "%s", state.body) + } else { + assert.Empty(t, gjson.Get(state.body, "continue_with").Array(), "%s", state.body) + } + }) + t.Run("case=new identities automatically have login with code", func(t *testing.T) { ctx := context.Background() @@ -602,46 +751,260 @@ func TestLoginCodeStrategy(t *testing.T) { }) t.Run("case=should be able to get AAL2 session", func(t *testing.T) { - t.Cleanup(testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/default.schema.json")) // doesn't have the code credential - identity := createIdentity(ctx, t, reg, true) + run := func(t *testing.T, withoutCodeCredential bool, overrideCodeCredential *identity.Credentials) (*state, *http.Client) { + user := createIdentity(ctx, t, reg, withoutCodeCredential) + if overrideCodeCredential != nil { + toUpdate := user.Credentials[identity.CredentialsTypeCodeAuth] + if overrideCodeCredential.Config != nil { + toUpdate.Config = overrideCodeCredential.Config + } + if overrideCodeCredential.Identifiers != nil { + toUpdate.Identifiers = overrideCodeCredential.Identifiers + } + user.Credentials[identity.CredentialsTypeCodeAuth] = toUpdate + } + + var cl *http.Client + var f *oryClient.LoginFlow + if tc.apiType == ApiTypeNative { + cl = testhelpers.NewHTTPClientWithIdentitySessionToken(t, ctx, reg, user) + f = testhelpers.InitializeLoginFlowViaAPI(t, cl, public, false, testhelpers.InitFlowWithAAL("aal2"), testhelpers.InitFlowWithVia("email")) + } else { + cl = testhelpers.NewHTTPClientWithIdentitySessionCookieLocalhost(t, ctx, reg, user) + f = testhelpers.InitializeLoginFlowViaBrowser(t, cl, public, false, tc.apiType == ApiTypeSPA, false, false, testhelpers.InitFlowWithAAL("aal2"), testhelpers.InitFlowWithVia("email")) + } + + body, err := json.Marshal(f) + require.NoError(t, err) + require.Len(t, gjson.GetBytes(body, "ui.nodes.#(group==code)").Array(), 1, "%s", body) + require.Len(t, gjson.GetBytes(body, "ui.messages").Array(), 1, "%s", body) + require.EqualValues(t, gjson.GetBytes(body, "ui.messages.0.id").Int(), text.InfoSelfServiceLoginMFA, "%s", body) + + s := &state{ + flowID: f.GetId(), + identity: user, + client: cl, + testServer: public, + identityEmail: gjson.Get(user.Traits.String(), "email").String(), + } + s = submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { + v.Set("identifier", s.identityEmail) + }, false, nil) + + message := testhelpers.CourierExpectMessage(ctx, t, reg, s.identityEmail, "Login to your account") + assert.Contains(t, message.Body, "please login to your account by entering the following code") + loginCode := testhelpers.CourierExpectCodeInMessage(t, message, 1) + assert.NotEmpty(t, loginCode) + + return submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { + v.Set("code", loginCode) + }, true, nil), cl + } + + t.Run("case=correct code credential without fallback works", func(t *testing.T) { + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/code.identity.schema.json") // has code identifier + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, false) // fallback enabled + + _, cl := run(t, true, nil) + testhelpers.EnsureAAL(t, cl, public, "aal2", "code") + }) + + t.Run("case=disabling mfa does not lock out the users", func(t *testing.T) { + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/code.identity.schema.json") // has code identifier + + s, cl := run(t, true, nil) + testhelpers.EnsureAAL(t, cl, public, "aal2", "code") + + email := gjson.GetBytes(s.identity.Traits, "email").String() + s.identityEmail = email + + // We change now disable code mfa and enable passwordless instead. + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.mfa_enabled", false) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", true) + + t.Cleanup(func() { + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", false) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.mfa_enabled", true) + }) + + s = createLoginFlowWithIdentity(ctx, t, public, tc.apiType, s.identity) + s = submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { + v.Set("identifier", email) + v.Set("method", "code") + }, false, nil) + + message := testhelpers.CourierExpectMessage(ctx, t, reg, email, "Login to your account") + assert.Contains(t, message.Body, "please login to your account by entering the following code") + loginCode := testhelpers.CourierExpectCodeInMessage(t, message, 1) + assert.NotEmpty(t, loginCode) + + loginResult := submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { + v.Set("code", loginCode) + }, true, nil) + + if tc.apiType == ApiTypeNative { + assert.EqualValues(t, "aal1", gjson.Get(loginResult.body, "session.authenticator_assurance_level").String()) + assert.EqualValues(t, "code", gjson.Get(loginResult.body, "session.authentication_methods.#(method==code).method").String()) + } else { + // The user should be able to sign in correctly even though, probably, the internal state was aal2 for available AAL. + res, err := s.client.Get(public.URL + session.RouteWhoami) + require.NoError(t, err) + assert.EqualValues(t, http.StatusOK, res.StatusCode, loginResult.body) + sess := x.MustReadAll(res.Body) + require.NoError(t, res.Body.Close()) + + assert.EqualValues(t, "aal1", gjson.GetBytes(sess, "authenticator_assurance_level").String()) + assert.EqualValues(t, "code", gjson.GetBytes(sess, "authentication_methods.#(method==code).method").String()) + } + }) + + t.Run("case=missing code credential with fallback works when identity schema has the code identifier set", func(t *testing.T) { + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/code.identity.schema.json") // has code identifier + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, true) // fallback enabled + t.Cleanup(func() { + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, false) + }) + + _, cl := run(t, false, nil) + testhelpers.EnsureAAL(t, cl, public, "aal2", "code") + }) + + t.Run("case=missing code credential with fallback works even when identity schema has no code identifier set", func(t *testing.T) { + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/default.schema.json") // missing the code identifier + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, true) // fallback enabled + t.Cleanup(func() { + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/code.identity.schema.json") + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, false) + }) + + _, cl := run(t, false, nil) + testhelpers.EnsureAAL(t, cl, public, "aal2", "code") + }) + + t.Run("case=legacy code credential with fallback works when identity schema has the code identifier not set", func(t *testing.T) { + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/default.schema.json") // has code identifier + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, true) // fallback enabled + t.Cleanup(func() { + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/code.identity.schema.json") // has code identifier + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, false) + }) + + _, cl := run(t, false, &identity.Credentials{Config: []byte(`{"via":""}`)}) + testhelpers.EnsureAAL(t, cl, public, "aal2", "code") + }) + + t.Run("case=legacy code credential with fallback works when identity schema has the code identifier not set", func(t *testing.T) { + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/default.schema.json") // has code identifier + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, true) // fallback enabled + t.Cleanup(func() { + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/code.identity.schema.json") // has code identifier + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, false) + }) + + for k, credentialsConfig := range []string{ + `{"address_type": "email ", "used_at": {"Time": "0001-01-01T00:00:00Z", "Valid": false}}`, + `{"address_type": "email", "used_at": {"Time": "0001-01-01T00:00:00Z", "Valid": false}}`, + `{"address_type": "", "used_at": {"Time": "0001-01-01T00:00:00Z", "Valid": false}}`, + `{"address_type": ""}`, + `{"address_type": "sms"}`, + `{"address_type": "phone"}`, + `{}`, + } { + t.Run(fmt.Sprintf("config=%d", k), func(t *testing.T) { + _, cl := run(t, false, &identity.Credentials{Config: []byte(credentialsConfig)}) + testhelpers.EnsureAAL(t, cl, public, "aal2", "code") + }) + } + }) + }) + + t.Run("case=without via parameter all options are shown", func(t *testing.T) { + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/code-mfa.identity.schema.json") + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, false) + t.Cleanup(func() { + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/code.identity.schema.json") + }) + var cl *http.Client var f *oryClient.LoginFlow + + user := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) + user.NID = x.NewUUID() + email1 := "code-mfa-1" + string(tc.apiType) + "@ory.sh" + email2 := "code-mfa-2" + string(tc.apiType) + "@ory.sh" + phone1 := 4917613213110 if tc.apiType == ApiTypeNative { - cl = testhelpers.NewHTTPClientWithIdentitySessionToken(t, ctx, reg, identity) - f = testhelpers.InitializeLoginFlowViaAPI(t, cl, public, false, testhelpers.InitFlowWithAAL("aal2"), testhelpers.InitFlowWithVia("email")) - } else { - cl = testhelpers.NewHTTPClientWithIdentitySessionCookieLocalhost(t, ctx, reg, identity) - f = testhelpers.InitializeLoginFlowViaBrowser(t, cl, public, false, tc.apiType == ApiTypeSPA, false, false, testhelpers.InitFlowWithAAL("aal2"), testhelpers.InitFlowWithVia("email")) + phone1 += 1 + } else if tc.apiType == ApiTypeSPA { + phone1 += 2 } + user.Traits = identity.Traits(fmt.Sprintf(`{"email1":"%s","email2":"%s","phone1":"+%d"}`, email1, email2, phone1)) + require.NoError(t, reg.IdentityManager().Create(ctx, user)) - body, err := json.Marshal(f) - require.NoError(t, err) - require.Len(t, gjson.GetBytes(body, "ui.nodes.#(group==code)").Array(), 1) - require.Len(t, gjson.GetBytes(body, "ui.messages").Array(), 1, "%s", body) - require.EqualValues(t, gjson.GetBytes(body, "ui.messages.0.id").Int(), text.InfoSelfServiceLoginMFA, "%s", body) + run := func(t *testing.T, identifierField string, identifier string) { + if tc.apiType == ApiTypeNative { + cl = testhelpers.NewHTTPClientWithIdentitySessionToken(t, ctx, reg, user) + f = testhelpers.InitializeLoginFlowViaAPI(t, cl, public, false, testhelpers.InitFlowWithAAL("aal2")) + } else { + cl = testhelpers.NewHTTPClientWithIdentitySessionCookieLocalhost(t, ctx, reg, user) + f = testhelpers.InitializeLoginFlowViaBrowser(t, cl, public, false, tc.apiType == ApiTypeSPA, false, false, testhelpers.InitFlowWithAAL("aal2")) + } - s := &state{ - flowID: f.GetId(), - identity: identity, - client: cl, - testServer: public, - identityEmail: gjson.Get(identity.Traits.String(), "email").String(), + body, err := json.Marshal(f) + require.NoError(t, err) + + snapshotx.SnapshotT(t, json.RawMessage(gjson.GetBytes(body, "ui.nodes.#(group==code)#").Raw)) + require.Len(t, gjson.GetBytes(body, "ui.messages").Array(), 1, "%s", body) + require.EqualValues(t, gjson.GetBytes(body, "ui.messages.0.id").Int(), text.InfoSelfServiceLoginMFA, "%s", body) + + s := &state{ + flowID: f.GetId(), + identity: user, + client: cl, + testServer: public, + identityEmail: gjson.Get(user.Traits.String(), "email").String(), + } + + s = submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { + v.Del("method") + v.Set(identifierField, identifier) + }, false, nil) + + var message *courier.Message + if !strings.HasPrefix(identifier, "+") { + // email + message = testhelpers.CourierExpectMessage(ctx, t, reg, x.GracefulNormalization(identifier), "Login to your account") + assert.Contains(t, message.Body, "please login to your account by entering the following code") + } else { + // SMS + message = testhelpers.CourierExpectMessage(ctx, t, reg, x.GracefulNormalization(identifier), "Your login code is:") + } + loginCode := testhelpers.CourierExpectCodeInMessage(t, message, 1) + assert.NotEmpty(t, loginCode) + + t.Logf("loginCode: %s", loginCode) + + s = submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { + v.Set("code", loginCode) + v.Set(identifierField, identifier) + }, true, nil) + + testhelpers.EnsureAAL(t, cl, public, "aal2", "code") } - s = submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { - v.Set("identifier", s.identityEmail) - }, true, nil) - message := testhelpers.CourierExpectMessage(ctx, t, reg, s.identityEmail, "Login to your account") - assert.Contains(t, message.Body, "please login to your account by entering the following code") - loginCode := testhelpers.CourierExpectCodeInMessage(t, message, 1) - assert.NotEmpty(t, loginCode) + t.Run("field=identifier-email", func(t *testing.T) { + run(t, "identifier", email1) + }) - s = submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { - v.Set("code", loginCode) - }, true, nil) + t.Run("field=address-email", func(t *testing.T) { + run(t, "address", email2) + }) - testhelpers.EnsureAAL(t, cl, public, "aal2", "code") + t.Run("field=address-phone", func(t *testing.T) { + run(t, "address", fmt.Sprintf("+%d", phone1)) + }) }) + t.Run("case=cannot use different identifier", func(t *testing.T) { identity := createIdentity(ctx, t, reg, false) var cl *http.Client @@ -670,9 +1033,9 @@ func TestLoginCodeStrategy(t *testing.T) { email := testhelpers.RandomEmail() s = submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { v.Set("identifier", email) - }, true, nil) + }, false, nil) - require.Equal(t, "The address you entered does not match any known addresses in the current account.", gjson.Get(s.body, "ui.messages.0.text").String(), "%s", body) + require.Equal(t, "This account does not exist or has not setup sign in with code.", gjson.Get(s.body, "ui.messages.0.text").String(), "%s", body) }) t.Run("case=verify initial payload", func(t *testing.T) { @@ -711,28 +1074,7 @@ func TestLoginCodeStrategy(t *testing.T) { if tc.apiType == ApiTypeNative { body = []byte(gjson.GetBytes(body, "error").Raw) } - require.Equal(t, "Trait does not exist in identity schema", gjson.GetBytes(body, "reason").String(), "%s", body) - }) - - t.Run("case=missing via parameter results results in an error", func(t *testing.T) { - identity := createIdentity(ctx, t, reg, false) - var cl *http.Client - var res *http.Response - var err error - if tc.apiType == ApiTypeNative { - cl = testhelpers.NewHTTPClientWithIdentitySessionToken(t, ctx, reg, identity) - res, err = cl.Get(public.URL + "/self-service/login/api?aal=aal2") - } else { - cl = testhelpers.NewHTTPClientWithIdentitySessionCookieLocalhost(t, ctx, reg, identity) - res, err = cl.Get(public.URL + "/self-service/login/browser?aal=aal2") - } - require.NoError(t, err) - - body := ioutilx.MustReadAll(res.Body) - if tc.apiType == ApiTypeNative { - body = []byte(gjson.GetBytes(body, "error").Raw) - } - require.Equal(t, "AAL2 login via code requires the `via` query parameter", gjson.GetBytes(body, "reason").String(), "%s", body) + require.Equal(t, "No value found for trait doesnt_exist in the current identity.", gjson.GetBytes(body, "reason").String(), "%s", body) }) t.Run("case=unset trait in identity should lead to an error", func(t *testing.T) { @@ -753,8 +1095,9 @@ func TestLoginCodeStrategy(t *testing.T) { if tc.apiType == ApiTypeNative { body = []byte(gjson.GetBytes(body, "error").Raw) } - require.Equal(t, "No value found for trait email_1 in the current identity", gjson.GetBytes(body, "reason").String(), "%s", body) + require.Equal(t, "No value found for trait email_1 in the current identity.", gjson.GetBytes(body, "reason").String(), "%s", body) }) + }) }) } @@ -767,7 +1110,7 @@ func TestFormHydration(t *testing.T) { "enabled": true, "passwordless_enabled": true, }) - ctx = testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/default.schema.json") + ctx = testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/code.identity.schema.json") s, err := reg.AllLoginStrategies().Strategy(identity.CredentialsTypeCodeAuth) require.NoError(t, err) @@ -801,37 +1144,13 @@ func TestFormHydration(t *testing.T) { "mfa_enabled": true, }) - toMFARequest := func(r *http.Request, f *login.Flow) { + toMFARequest := func(t *testing.T, r *http.Request, f *login.Flow, traits string) { f.RequestedAAL = identity.AuthenticatorAssuranceLevel2 r.URL = &url.URL{Path: "/", RawQuery: "via=email"} // I only fear god. - r.Header = testhelpers.NewHTTPClientWithArbitrarySessionTokenAndTraits(t, ctx, reg, []byte(`{"email":"foo@ory.sh"}`)).Transport.(*testhelpers.TransportWithHeader).GetHeader() + r.Header = testhelpers.NewHTTPClientWithArbitrarySessionTokenAndTraits(t, ctx, reg, []byte(traits)).Transport.(*testhelpers.TransportWithHeader).GetHeader() } - t.Run("method=PopulateLoginMethodSecondFactor", func(t *testing.T) { - test := func(t *testing.T, ctx context.Context) { - r, f := newFlow(ctx, t) - toMFARequest(r, f) - - r.Header = testhelpers.NewHTTPClientWithArbitrarySessionTokenAndTraits(t, ctx, reg, []byte(`{"email":"foo@ory.sh"}`)).Transport.(*testhelpers.TransportWithHeader).GetHeader() - - // We still use the legacy hydrator under the hood here and thus need to set this correctly. - f.RequestedAAL = identity.AuthenticatorAssuranceLevel2 - r.URL = &url.URL{Path: "/", RawQuery: "via=email"} - - require.NoError(t, fh.PopulateLoginMethodSecondFactor(r, f)) - toSnapshot(t, f) - } - - t.Run("case=code is used for 2fa", func(t *testing.T) { - test(t, mfaEnabled) - }) - - t.Run("case=code is used for passwordless login", func(t *testing.T) { - test(t, passwordlessEnabled) - }) - }) - t.Run("method=PopulateLoginMethodFirstFactor", func(t *testing.T) { t.Run("case=code is used for 2fa but request is 1fa", func(t *testing.T) { r, f := newFlow(mfaEnabled, t) @@ -867,16 +1186,62 @@ func TestFormHydration(t *testing.T) { }) t.Run("method=PopulateLoginMethodSecondFactor", func(t *testing.T) { + t.Run("using via", func(t *testing.T) { + test := func(t *testing.T, ctx context.Context, email string) { + r, f := newFlow(ctx, t) + toMFARequest(t, r, f, `{"email":"`+email+`"}`) + + // We still use the legacy hydrator under the hood here and thus need to set this correctly. + f.RequestedAAL = identity.AuthenticatorAssuranceLevel2 + r.URL = &url.URL{Path: "/", RawQuery: "via=email"} + + require.NoError(t, fh.PopulateLoginMethodSecondFactor(r, f)) + toSnapshot(t, f) + } + + t.Run("case=code is used for 2fa", func(t *testing.T) { + test(t, mfaEnabled, "PopulateLoginMethodSecondFactor-code-mfa-via-2fa@ory.sh") + }) + + t.Run("case=code is used for passwordless login", func(t *testing.T) { + test(t, passwordlessEnabled, "PopulateLoginMethodSecondFactor-code-mfa-via-passwordless@ory.sh") + }) + }) + + t.Run("without via", func(t *testing.T) { + test := func(t *testing.T, ctx context.Context, traits string) { + r, f := newFlow(ctx, t) + toMFARequest(t, r, f, traits) + + // We still use the legacy hydrator under the hood here and thus need to set this correctly. + f.RequestedAAL = identity.AuthenticatorAssuranceLevel2 + r.URL = &url.URL{Path: "/"} + + require.NoError(t, fh.PopulateLoginMethodSecondFactor(r, f)) + toSnapshot(t, f) + } + + t.Run("case=code is used for 2fa", func(t *testing.T) { + ctx = testhelpers.WithDefaultIdentitySchema(mfaEnabled, "file://./stub/code-mfa.identity.schema.json") + test(t, ctx, `{"email1":"PopulateLoginMethodSecondFactor-no-via-2fa-0@ory.sh","email2":"PopulateLoginMethodSecondFactor-no-via-2fa-1@ory.sh","phone1":"+4917655138291"}`) + }) + + t.Run("case=code is used for passwordless login", func(t *testing.T) { + ctx = testhelpers.WithDefaultIdentitySchema(passwordlessEnabled, "file://./stub/code-mfa.identity.schema.json") + test(t, ctx, `{"email1":"PopulateLoginMethodSecondFactor-no-via-passwordless-0@ory.sh","email2":"PopulateLoginMethodSecondFactor-no-via-passwordless-1@ory.sh","phone1":"+4917655138292"}`) + }) + }) + t.Run("case=code is used for 2fa and request is 2fa", func(t *testing.T) { r, f := newFlow(mfaEnabled, t) - toMFARequest(r, f) + toMFARequest(t, r, f, `{"email":"foo@ory.sh"}`) require.NoError(t, fh.PopulateLoginMethodSecondFactor(r, f)) toSnapshot(t, f) }) t.Run("case=code is used for passwordless login and request is 2fa", func(t *testing.T) { r, f := newFlow(passwordlessEnabled, t) - toMFARequest(r, f) + toMFARequest(t, r, f, `{"email":"foo@ory.sh"}`) require.NoError(t, fh.PopulateLoginMethodSecondFactor(r, f)) toSnapshot(t, f) }) @@ -885,7 +1250,7 @@ func TestFormHydration(t *testing.T) { t.Run("method=PopulateLoginMethodSecondFactorRefresh", func(t *testing.T) { t.Run("case=code is used for 2fa and request is 2fa with refresh", func(t *testing.T) { r, f := newFlow(mfaEnabled, t) - toMFARequest(r, f) + toMFARequest(t, r, f, `{"email":"foo@ory.sh"}`) f.Refresh = true require.NoError(t, fh.PopulateLoginMethodSecondFactorRefresh(r, f)) toSnapshot(t, f) @@ -893,7 +1258,7 @@ func TestFormHydration(t *testing.T) { t.Run("case=code is used for passwordless login and request is 2fa with refresh", func(t *testing.T) { r, f := newFlow(passwordlessEnabled, t) - toMFARequest(r, f) + toMFARequest(t, r, f, `{"email":"foo@ory.sh"}`) f.Refresh = true require.NoError(t, fh.PopulateLoginMethodSecondFactorRefresh(r, f)) toSnapshot(t, f) diff --git a/selfservice/strategy/code/strategy_mfa.go b/selfservice/strategy/code/strategy_mfa.go new file mode 100644 index 000000000000..89fe79c393e8 --- /dev/null +++ b/selfservice/strategy/code/strategy_mfa.go @@ -0,0 +1,57 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package code + +import ( + "encoding/json" + + "github.com/pkg/errors" + "github.com/samber/lo" + + "github.com/ory/herodot" + "github.com/ory/kratos/identity" +) + +func FindAllIdentifiers(i *identity.Identity) (result []Address) { + for _, a := range i.VerifiableAddresses { + if len(a.Via) == 0 || len(a.Value) == 0 { + continue + } + + result = append(result, Address{Via: identity.CodeChannel(a.Via), To: a.Value}) + } + return result +} + +func FindCodeAddressCandidates(i *identity.Identity, fallbackEnabled bool) (result []Address, found bool, _ error) { + // If no hint was given, we show all OTP addresses from the credentials. + creds, ok := i.GetCredentials(identity.CredentialsTypeCodeAuth) + if !ok { + if !fallbackEnabled { + // Without a fallback and with no credentials found, we can't really do a lot and exit early. + return nil, false, nil + } + + return FindAllIdentifiers(i), true, nil + } else { + var conf identity.CredentialsCode + if len(creds.Config) > 0 { + if err := json.Unmarshal(creds.Config, &conf); err != nil { + return nil, false, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to unmarshal credentials config: %s", err)) + } + } + + if len(conf.Addresses) == 0 { + if !fallbackEnabled { + // Without a fallback and with no credentials found, we can't really do a lot and exit early. + return nil, false, nil + } + + return FindAllIdentifiers(i), true, nil + } + return lo.Map(conf.Addresses, func(item identity.CredentialsCodeAddress, _ int) Address { + return Address{Via: item.Channel, To: item.Address} + }), true, nil + } +} diff --git a/selfservice/strategy/code/strategy_mfa_test.go b/selfservice/strategy/code/strategy_mfa_test.go new file mode 100644 index 000000000000..5a97b881b179 --- /dev/null +++ b/selfservice/strategy/code/strategy_mfa_test.go @@ -0,0 +1,161 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package code + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/ory/kratos/identity" +) + +func TestFindAllIdentifiers(t *testing.T) { + tests := []struct { + name string + input *identity.Identity + expected []Address + }{ + { + name: "valid verifiable addresses", + input: &identity.Identity{ + VerifiableAddresses: []identity.VerifiableAddress{ + {Via: "email", Value: "user@example.com"}, + {Via: "sms", Value: "+1234567890"}, + }, + }, + expected: []Address{ + {Via: identity.CodeChannel("email"), To: "user@example.com"}, + {Via: identity.CodeChannel("sms"), To: "+1234567890"}, + }, + }, + { + name: "empty verifiable addresses", + input: &identity.Identity{ + VerifiableAddresses: []identity.VerifiableAddress{}, + }, + }, + { + name: "verifiable address with empty fields", + input: &identity.Identity{ + VerifiableAddresses: []identity.VerifiableAddress{ + {Via: "", Value: ""}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FindAllIdentifiers(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestFindCodeAddressCandidates(t *testing.T) { + tests := []struct { + name string + input *identity.Identity + fallbackEnabled bool + expected []Address + found bool + wantErr bool + }{ + { + name: "valid credentials with addresses", + input: &identity.Identity{ + Credentials: map[identity.CredentialsType]identity.Credentials{ + identity.CredentialsTypeCodeAuth: { + Config: []byte(`{"addresses":[{"channel":"email","address":"user@example.com"},{"channel":"sms","address":"+1234567890"}]}`), + }, + }, + }, + fallbackEnabled: false, + expected: []Address{ + {Via: identity.CodeChannel("email"), To: "user@example.com"}, + {Via: identity.CodeChannel("sms"), To: "+1234567890"}, + }, + found: true, + wantErr: false, + }, + { + name: "no credentials, fallback enabled", + input: &identity.Identity{ + VerifiableAddresses: []identity.VerifiableAddress{ + {Via: "email", Value: "user@example.com"}, + {Via: "sms", Value: "+1234567890"}, + }, + }, + fallbackEnabled: true, + expected: []Address{ + {Via: identity.CodeChannel("email"), To: "user@example.com"}, + {Via: identity.CodeChannel("sms"), To: "+1234567890"}, + }, + found: true, + wantErr: false, + }, + { + name: "no credentials, fallback disabled", + input: &identity.Identity{ + VerifiableAddresses: []identity.VerifiableAddress{ + {Via: "email", Value: "user@example.com"}, + {Via: "sms", Value: "+1234567890"}, + }, + }, + fallbackEnabled: false, + expected: nil, + found: false, + wantErr: false, + }, + { + name: "invalid credentials config", + input: &identity.Identity{ + Credentials: map[identity.CredentialsType]identity.Credentials{ + identity.CredentialsTypeCodeAuth: { + Config: []byte(`invalid`), + }, + }, + }, + fallbackEnabled: false, + expected: nil, + found: false, + wantErr: true, + }, + { + name: "invalid credentials config, fallback enabled, verifiable addresses exist", + input: &identity.Identity{ + Credentials: map[identity.CredentialsType]identity.Credentials{ + identity.CredentialsTypeCodeAuth: { + Config: []byte(`invalid`), + }, + }, + VerifiableAddresses: []identity.VerifiableAddress{ + {Via: "email", Value: "user@example.com"}, + {Via: "sms", Value: "+1234567890"}, + }, + }, + fallbackEnabled: true, + expected: []Address{ + {Via: identity.CodeChannel("email"), To: "user@example.com"}, + {Via: identity.CodeChannel("sms"), To: "+1234567890"}, + }, + found: true, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, found, err := FindCodeAddressCandidates(tt.input, tt.fallbackEnabled) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + assert.Equal(t, tt.found, found) + } + }) + } +} diff --git a/selfservice/strategy/code/strategy_recovery.go b/selfservice/strategy/code/strategy_recovery.go index f33356f2df31..8ab34c4f1ce2 100644 --- a/selfservice/strategy/code/strategy_recovery.go +++ b/selfservice/strategy/code/strategy_recovery.go @@ -190,9 +190,9 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request, return s.retryRecoveryFlow(w, r, f.Type, RetryWithError(err)) } - sess, err := session.NewActiveSession(r, id, s.deps.Config(), time.Now().UTC(), - identity.CredentialsTypeRecoveryCode, identity.AuthenticatorAssuranceLevel1) - if err != nil { + sess := session.NewInactiveSession() + sess.CompletedLoginFor(identity.CredentialsTypeRecoveryCode, identity.AuthenticatorAssuranceLevel1) + if err := s.deps.SessionManager().ActivateSession(r, sess, id, time.Now().UTC()); err != nil { return s.retryRecoveryFlow(w, r, f.Type, RetryWithError(err)) } diff --git a/selfservice/strategy/code/strategy_recovery_test.go b/selfservice/strategy/code/strategy_recovery_test.go index 2652ca1823c5..483e8dfc6f89 100644 --- a/selfservice/strategy/code/strategy_recovery_test.go +++ b/selfservice/strategy/code/strategy_recovery_test.go @@ -15,6 +15,8 @@ import ( "testing" "time" + confighelpers "github.com/ory/kratos/driver/config/testhelpers" + "github.com/davecgh/go-spew/spew" "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" @@ -520,10 +522,10 @@ func TestRecovery(t *testing.T) { } req := httptest.NewRequest("GET", "/sessions/whoami", nil) - session, err := session.NewActiveSession( - req, - &identity.Identity{ID: x.NewUUID(), State: identity.StateActive}, - testhelpers.NewSessionLifespanProvider(time.Hour), + req.WithContext(confighelpers.WithConfigValue(ctx, config.ViperKeySessionLifespan, time.Hour)) + session, err := testhelpers.NewActiveSession(req, + reg, + &identity.Identity{ID: x.NewUUID(), State: identity.StateActive, NID: x.NewUUID()}, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1, @@ -632,7 +634,7 @@ func TestRecovery(t *testing.T) { id := createIdentityToRecover(t, reg, email) req := httptest.NewRequest("GET", "/sessions/whoami", nil) - sess, err := session.NewActiveSession(req, id, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + sess, err := testhelpers.NewActiveSession(req, reg, id, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) require.NoError(t, err) require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), sess)) @@ -1360,11 +1362,12 @@ func TestRecovery_WithContinueWith(t *testing.T) { f = testhelpers.InitializeRecoveryFlowViaBrowser(t, client, isSPA, public, nil) } req := httptest.NewRequest("GET", "/sessions/whoami", nil) + req.WithContext(confighelpers.WithConfigValue(ctx, config.ViperKeySessionLifespan, time.Hour)) - session, err := session.NewActiveSession( + session, err := testhelpers.NewActiveSession( req, - &identity.Identity{ID: x.NewUUID(), State: identity.StateActive}, - testhelpers.NewSessionLifespanProvider(time.Hour), + reg, + &identity.Identity{ID: x.NewUUID(), State: identity.StateActive, NID: x.NewUUID()}, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1, @@ -1463,7 +1466,7 @@ func TestRecovery_WithContinueWith(t *testing.T) { email := testhelpers.RandomEmail() id := createIdentityToRecover(t, reg, email) - otherSession, err := session.NewActiveSession(httptest.NewRequest("GET", "/sessions/whoami", nil), id, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + otherSession, err := testhelpers.NewActiveSession(httptest.NewRequest("GET", "/sessions/whoami", nil), reg, id, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) require.NoError(t, err) require.NoError(t, reg.SessionPersister().UpsertSession(ctx, otherSession)) diff --git a/selfservice/strategy/code/strategy_registration.go b/selfservice/strategy/code/strategy_registration.go index dca3054e0c8e..4c6679ff374e 100644 --- a/selfservice/strategy/code/strategy_registration.go +++ b/selfservice/strategy/code/strategy_registration.go @@ -5,9 +5,11 @@ package code import ( "context" - "database/sql" "encoding/json" "net/http" + "strings" + + "github.com/tidwall/gjson" "github.com/ory/herodot" "github.com/ory/x/otelx" @@ -91,30 +93,10 @@ func (s *Strategy) PopulateRegistrationMethod(r *http.Request, rf *registration. return s.PopulateMethod(r, rf) } -type options func(*identity.Identity) error - -func withCredentials(via identity.CodeAddressType, usedAt sql.NullTime) options { - return func(i *identity.Identity) error { - return i.SetCredentialsWithConfig(identity.CredentialsTypeCodeAuth, identity.Credentials{Type: identity.CredentialsTypePassword, Identifiers: []string{}}, &identity.CredentialsCode{AddressType: via, UsedAt: usedAt}) - } -} - -func (s *Strategy) handleIdentityTraits(ctx context.Context, f *registration.Flow, traits, transientPayload json.RawMessage, i *identity.Identity, opts ...options) error { - f.TransientPayload = transientPayload - if len(traits) == 0 { - traits = json.RawMessage("{}") - } - - // we explicitly set the Code credentials type - i.Traits = identity.Traits(traits) - if err := i.SetCredentialsWithConfig(s.ID(), identity.Credentials{Type: s.ID(), Identifiers: []string{}}, &identity.CredentialsCode{UsedAt: sql.NullTime{}}); err != nil { - return err - } - - for _, opt := range opts { - if err := opt(i); err != nil { - return err - } +func (s *Strategy) validateTraits(ctx context.Context, traits json.RawMessage, i *identity.Identity) error { + i.Traits = []byte("{}") + if gjson.ValidBytes(traits) { + i.Traits = identity.Traits(traits) } // Validate the identity @@ -125,18 +107,26 @@ func (s *Strategy) handleIdentityTraits(ctx context.Context, f *registration.Flo return nil } -func (s *Strategy) getCredentialsFromTraits(ctx context.Context, f *registration.Flow, i *identity.Identity, traits, transientPayload json.RawMessage) (*identity.Credentials, error) { - if err := s.handleIdentityTraits(ctx, f, traits, transientPayload, i); err != nil { - return nil, errors.WithStack(err) +func (s *Strategy) validateAndGetCredentialsFromTraits(ctx context.Context, i *identity.Identity, traits json.RawMessage) (*identity.Credentials, *identity.CredentialsCode, error) { + if err := s.validateTraits(ctx, traits, i); err != nil { + return nil, nil, errors.WithStack(err) } cred, ok := i.GetCredentials(identity.CredentialsTypeCodeAuth) if !ok { - return nil, errors.WithStack(schema.NewMissingIdentifierError()) - } else if len(cred.Identifiers) == 0 { - return nil, errors.WithStack(schema.NewMissingIdentifierError()) + return nil, nil, errors.WithStack(schema.NewMissingIdentifierError()) + } else if len(strings.Join(cred.Identifiers, "")) == 0 { + return nil, nil, errors.WithStack(schema.NewMissingIdentifierError()) + } + + var conf identity.CredentialsCode + if len(cred.Config) > 0 { + if err := json.Unmarshal(cred.Config, &conf); err != nil { + return nil, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to unmarshal credentials config: %s", err)) + } } - return cred, nil + + return cred, &conf, nil } func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registration.Flow, i *identity.Identity) (err error) { @@ -184,7 +174,7 @@ func (s *Strategy) registrationSendEmail(ctx context.Context, w http.ResponseWri // Create the Registration code // Step 1: validate the identity's traits - cred, err := s.getCredentialsFromTraits(ctx, f, i, p.Traits, p.TransientPayload) + _, conf, err := s.validateAndGetCredentialsFromTraits(ctx, i, p.Traits) if err != nil { return err } @@ -196,9 +186,10 @@ func (s *Strategy) registrationSendEmail(ctx context.Context, w http.ResponseWri // Step 3: Get the identity email and send the code var addresses []Address - for _, identifier := range cred.Identifiers { - addresses = append(addresses, Address{To: identifier, Via: identity.AddressTypeEmail}) + for _, address := range conf.Addresses { + addresses = append(addresses, Address{To: address.Address, Via: address.Channel}) } + // kratos only supports `email` identifiers at the moment with the code method // this is validated in the identity validation step above if err := s.deps.CodeSender().SendCode(ctx, f, i, addresses...); err != nil { @@ -246,7 +237,7 @@ func (s *Strategy) registrationVerifyCode(ctx context.Context, f *registration.F // Step 1: Re-validate the identity's traits // this is important since the client could have switched out the identity's traits // this method also returns the credentials for a temporary identity - cred, err := s.getCredentialsFromTraits(ctx, f, i, p.Traits, p.TransientPayload) + cred, _, err := s.validateAndGetCredentialsFromTraits(ctx, i, p.Traits) if err != nil { return err } @@ -267,9 +258,12 @@ func (s *Strategy) registrationVerifyCode(ctx context.Context, f *registration.F return errors.WithStack(err) } - // Step 4: The code was correct, populate the Identity credentials and traits - if err := s.handleIdentityTraits(ctx, f, p.Traits, p.TransientPayload, i, withCredentials(registrationCode.AddressType, registrationCode.UsedAt)); err != nil { - return errors.WithStack(err) + // Step 4: Verify the address + if err := s.verifyAddress(ctx, i, Address{ + To: registrationCode.Address, + Via: registrationCode.AddressType, + }); err != nil { + return err } // since nothing has errored yet, we can assume that the code is correct diff --git a/selfservice/strategy/code/strategy_registration_test.go b/selfservice/strategy/code/strategy_registration_test.go index ecef2a41ace8..c5988047d0a3 100644 --- a/selfservice/strategy/code/strategy_registration_test.go +++ b/selfservice/strategy/code/strategy_registration_test.go @@ -224,7 +224,7 @@ func TestRegistrationCodeStrategy(t *testing.T) { return s } - require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, http.StatusOK, resp.StatusCode, body) verifiableAddress, err := reg.PrivilegedIdentityPool().FindVerifiableAddressByValue(ctx, identity.VerifiableAddressTypeEmail, s.email) require.NoError(t, err) @@ -348,7 +348,7 @@ func TestRegistrationCodeStrategy(t *testing.T) { require.NotEmpty(t, attr) val := gjson.Get(attr, "#(attributes.type==hidden).attributes.value").String() - require.Equal(t, "code", val) + require.Equal(t, "code", val, body) }) message := testhelpers.CourierExpectMessage(ctx, t, reg, s.email, "Complete your account registration") @@ -526,8 +526,11 @@ func TestRegistrationCodeStrategy(t *testing.T) { } { t.Run("test="+tc.d, func(t *testing.T) { t.Run("case=should fail when schema does not contain the `code` extension", func(t *testing.T) { - testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/default.schema.json") + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/no-code.schema.json") + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, false) + t.Cleanup(func() { + conf.MustSet(ctx, config.ViperKeyCodeConfigMissingCredentialFallbackEnabled, true) testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/code.identity.schema.json") }) diff --git a/selfservice/strategy/code/strategy_test.go b/selfservice/strategy/code/strategy_test.go index 4561a280fb96..5edbef5ec718 100644 --- a/selfservice/strategy/code/strategy_test.go +++ b/selfservice/strategy/code/strategy_test.go @@ -5,8 +5,14 @@ package code_test import ( "context" + "fmt" "testing" + "github.com/stretchr/testify/require" + + confighelpers "github.com/ory/kratos/driver/config/testhelpers" + "github.com/ory/kratos/internal" + "github.com/stretchr/testify/assert" "github.com/ory/kratos/internal/testhelpers" @@ -74,3 +80,123 @@ func TestMaskAddress(t *testing.T) { }) } } + +func TestCountActiveCredentials(t *testing.T) { + _, reg := internal.NewFastRegistryWithMocks(t) + strategy := code.NewStrategy(reg) + ctx := context.Background() + + t.Run("first factor", func(t *testing.T) { + for k, tc := range []struct { + in map[identity.CredentialsType]identity.Credentials + expected int + passwordlessEnabled bool + enabled bool + }{ + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Config: []byte{}, + }}, + passwordlessEnabled: false, + enabled: true, + expected: 0, + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Config: []byte{}, + }}, + passwordlessEnabled: true, + enabled: false, + expected: 1, + }, + { + in: map[identity.CredentialsType]identity.Credentials{}, + passwordlessEnabled: true, + enabled: true, + expected: 1, + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Config: []byte(`{}`), + }}, + passwordlessEnabled: true, + enabled: true, + expected: 1, + }, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + ctx := confighelpers.WithConfigValue(ctx, "selfservice.methods.code.passwordless_enabled", tc.passwordlessEnabled) + ctx = confighelpers.WithConfigValue(ctx, "selfservice.methods.code.enabled", tc.enabled) + + cc := map[identity.CredentialsType]identity.Credentials{} + for _, c := range tc.in { + cc[c.Type] = c + } + + actual, err := strategy.CountActiveFirstFactorCredentials(ctx, cc) + require.NoError(t, err) + assert.Equal(t, tc.expected, actual) + }) + } + }) + + t.Run("second factor", func(t *testing.T) { + for k, tc := range []struct { + in map[identity.CredentialsType]identity.Credentials + expected int + mfaEnabled bool + enabled bool + }{ + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Config: []byte{}, + }}, + mfaEnabled: false, + enabled: true, + expected: 0, + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Config: []byte{}, + }}, + mfaEnabled: true, + enabled: false, + expected: 1, + }, + { + in: map[identity.CredentialsType]identity.Credentials{}, + mfaEnabled: true, + enabled: true, + expected: 1, + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Config: []byte(`{}`), + }}, + mfaEnabled: true, + enabled: true, + expected: 1, + }, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + ctx := confighelpers.WithConfigValue(ctx, "selfservice.methods.code.mfa_enabled", tc.mfaEnabled) + ctx = confighelpers.WithConfigValue(ctx, "selfservice.methods.code.enabled", tc.enabled) + + cc := map[identity.CredentialsType]identity.Credentials{} + for _, c := range tc.in { + cc[c.Type] = c + } + + actual, err := strategy.CountActiveMultiFactorCredentials(ctx, cc) + require.NoError(t, err) + assert.Equal(t, tc.expected, actual) + }) + } + }) +} diff --git a/selfservice/strategy/code/stub/code-mfa.identity.schema.json b/selfservice/strategy/code/stub/code-mfa.identity.schema.json new file mode 100644 index 000000000000..9ccdaa4cc8ed --- /dev/null +++ b/selfservice/strategy/code/stub/code-mfa.identity.schema.json @@ -0,0 +1,50 @@ +{ + "$id": "https://example.com/person.schema.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Person", + "type": "object", + "properties": { + "traits": { + "type": "object", + "properties": { + "email1": { + "type": "string", + "format": "email", + "ory.sh/kratos": { + "credentials": { + "code": { + "identifier": true, + "via": "email" + } + } + } + }, + "email2": { + "type": "string", + "format": "email", + "ory.sh/kratos": { + "credentials": { + "code": { + "identifier": true, + "via": "email" + } + } + } + }, + "phone1": { + "type": "string", + "format": "tel", + "ory.sh/kratos": { + "credentials": { + "code": { + "identifier": true, + "via": "sms" + } + } + } + } + }, + "required": [] + } + } +} diff --git a/selfservice/strategy/code/stub/code.identity.schema.json b/selfservice/strategy/code/stub/code.identity.schema.json index a7a4e4448442..ecbeb33574c5 100644 --- a/selfservice/strategy/code/stub/code.identity.schema.json +++ b/selfservice/strategy/code/stub/code.identity.schema.json @@ -58,13 +58,29 @@ } } }, + "phone_1": { + "type": "string", + "format": "tel", + "title": "Phone", + "ory.sh/kratos": { + "credentials": { + "code": { + "identifier": true, + "via": "sms" + } + }, + "verification": { + "via": "sms" + } + } + }, "tos": { "type": "boolean", "title": "Tos", "description": "Please accept the terms and conditions" } }, - "required": ["email", "tos"] + "required": [] } } } diff --git a/selfservice/strategy/code/stub/default.schema.json b/selfservice/strategy/code/stub/default.schema.json index 8dc923266050..f13da2b4d1da 100644 --- a/selfservice/strategy/code/stub/default.schema.json +++ b/selfservice/strategy/code/stub/default.schema.json @@ -13,6 +13,10 @@ "credentials": { "password": { "identifier": true + }, + "code": { + "identifier": true, + "via": "email" } }, "verification": { diff --git a/selfservice/strategy/code/stub/no-code.schema.json b/selfservice/strategy/code/stub/no-code.schema.json new file mode 100644 index 000000000000..8dc923266050 --- /dev/null +++ b/selfservice/strategy/code/stub/no-code.schema.json @@ -0,0 +1,29 @@ +{ + "$id": "https://example.com/person.schema.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Person", + "type": "object", + "properties": { + "traits": { + "type": "object", + "properties": { + "email": { + "type": "string", + "ory.sh/kratos": { + "credentials": { + "password": { + "identifier": true + } + }, + "verification": { + "via": "email" + }, + "recovery": { + "via": "email" + } + } + } + } + } + } +} diff --git a/selfservice/strategy/idfirst/strategy.go b/selfservice/strategy/idfirst/strategy.go index b4590ce45634..792fff7bed95 100644 --- a/selfservice/strategy/idfirst/strategy.go +++ b/selfservice/strategy/idfirst/strategy.go @@ -22,6 +22,7 @@ type dependencies interface { x.WriterProvider x.CSRFTokenGeneratorProvider x.CSRFProvider + x.TracingProvider config.Provider @@ -56,10 +57,10 @@ func (s *Strategy) ID() identity.CredentialsType { return identity.CredentialsType(node.IdentifierFirstGroup) } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { return session.AuthenticationMethod{ Method: s.ID(), - AAL: identity.AuthenticatorAssuranceLevel1, + AAL: identity.NoAuthenticatorAssuranceLevel, } } diff --git a/selfservice/strategy/idfirst/strategy_login.go b/selfservice/strategy/idfirst/strategy_login.go index eaca286f1fb7..3a3f018d4cdb 100644 --- a/selfservice/strategy/idfirst/strategy_login.go +++ b/selfservice/strategy/idfirst/strategy_login.go @@ -6,6 +6,8 @@ package idfirst import ( "net/http" + "github.com/ory/x/otelx" + "github.com/ory/kratos/schema" "github.com/pkg/errors" @@ -39,7 +41,10 @@ func (s *Strategy) handleLoginError(r *http.Request, f *login.Flow, payload upda } func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ *session.Session) (_ *identity.Identity, err error) { - if !s.d.Config().SelfServiceLoginFlowIdentifierFirstEnabled(r.Context()) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.link.strategy.Login") + defer otelx.End(span, &err) + + if !s.d.Config().SelfServiceLoginFlowIdentifierFirstEnabled(ctx) { return nil, errors.WithStack(flow.ErrStrategyNotResponsible) } @@ -56,14 +61,14 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } f.TransientPayload = p.TransientPayload - if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return nil, s.handleLoginError(r, f, p, err) } var opts []login.FormHydratorModifier // Look up the user by the identifier. - identityHint, err := s.d.PrivilegedIdentityPool().FindIdentityByCredentialIdentifier(r.Context(), p.Identifier, + identityHint, err := s.d.PrivilegedIdentityPool().FindIdentityByCredentialIdentifier(ctx, p.Identifier, // We are dealing with user input -> lookup should be case-insensitive. false, ) @@ -75,9 +80,9 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } else if err != nil { // An error happened during lookup return nil, s.handleLoginError(r, f, p, err) - } else if !s.d.Config().SecurityAccountEnumerationMitigate(r.Context()) { + } else if !s.d.Config().SecurityAccountEnumerationMitigate(ctx) { // Hydrate credentials - if err := s.d.PrivilegedIdentityPool().HydrateIdentityAssociations(r.Context(), identityHint, identity.ExpandCredentials); err != nil { + if err := s.d.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, identityHint, identity.ExpandCredentials); err != nil { return nil, s.handleLoginError(r, f, p, err) } } @@ -90,7 +95,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, opts = append(opts, login.WithIdentifier(p.Identifier)) didPopulate := false - for _, ls := range s.d.LoginStrategies(r.Context()) { + for _, ls := range s.d.LoginStrategies(ctx) { populator, ok := ls.(login.FormHydrator) if !ok { continue @@ -110,7 +115,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, // If no strategy populated, it means that the account (very likely) does not exist. We show a user not found error, // but only if account enumeration mitigation is disabled. Otherwise, we proceed to render the rest of the form. - if !didPopulate && !s.d.Config().SecurityAccountEnumerationMitigate(r.Context()) { + if !didPopulate && !s.d.Config().SecurityAccountEnumerationMitigate(ctx) { return nil, s.handleLoginError(r, f, p, errors.WithStack(schema.NewAccountNotFoundError())) } @@ -133,14 +138,14 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } f.Active = s.ID() - if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil { + if err = s.d.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { return nil, s.handleLoginError(r, f, p, err) } if x.IsJSONRequest(r) { s.d.Writer().WriteCode(w, r, http.StatusBadRequest, f) } else { - http.Redirect(w, r, f.AppendTo(s.d.Config().SelfServiceFlowLoginUI(r.Context())).String(), http.StatusSeeOther) + http.Redirect(w, r, f.AppendTo(s.d.Config().SelfServiceFlowLoginUI(ctx)).String(), http.StatusSeeOther) } return nil, flow.ErrCompletedByStrategy diff --git a/selfservice/strategy/idfirst/strategy_test.go b/selfservice/strategy/idfirst/strategy_test.go index f6d483090abb..9a05358d25cf 100644 --- a/selfservice/strategy/idfirst/strategy_test.go +++ b/selfservice/strategy/idfirst/strategy_test.go @@ -22,7 +22,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/ory/kratos/identity" - "github.com/ory/kratos/session" ) func TestCountActiveFirstFactorCredentials(t *testing.T) { @@ -50,9 +49,9 @@ func TestCompletedAuthenticationMethod(t *testing.T) { s := idfirst.NewStrategy(reg) ctx := context.Background() - method := s.CompletedAuthenticationMethod(ctx, session.AuthenticationMethods{}) + method := s.CompletedAuthenticationMethod(ctx) assert.Equal(t, s.ID(), method.Method) - assert.Equal(t, identity.AuthenticatorAssuranceLevel1, method.AAL) + assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, method.AAL) } func TestNodeGroup(t *testing.T) { diff --git a/selfservice/strategy/link/strategy_recovery.go b/selfservice/strategy/link/strategy_recovery.go index e6d91051c2c4..718adb8d1923 100644 --- a/selfservice/strategy/link/strategy_recovery.go +++ b/selfservice/strategy/link/strategy_recovery.go @@ -306,8 +306,9 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request, return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err) } - sess, err := session.NewActiveSession(r, id, s.d.Config(), time.Now().UTC(), identity.CredentialsTypeRecoveryLink, identity.AuthenticatorAssuranceLevel1) - if err != nil { + sess := session.NewInactiveSession() + sess.CompletedLoginFor(identity.CredentialsTypeRecoveryLink, identity.AuthenticatorAssuranceLevel1) + if err := s.d.SessionManager().ActivateSession(r, sess, id, time.Now().UTC()); err != nil { return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err) } diff --git a/selfservice/strategy/link/strategy_recovery_test.go b/selfservice/strategy/link/strategy_recovery_test.go index 6ff7ecb196a3..5cc6c510d73e 100644 --- a/selfservice/strategy/link/strategy_recovery_test.go +++ b/selfservice/strategy/link/strategy_recovery_test.go @@ -14,6 +14,8 @@ import ( "testing" "time" + confighelpers "github.com/ory/kratos/driver/config/testhelpers" + "github.com/davecgh/go-spew/spew" "github.com/gofrs/uuid" "github.com/pkg/errors" @@ -251,6 +253,7 @@ func TestRecovery(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".link.enabled", true) + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/default.schema.json") initViper(t, conf) _ = testhelpers.NewRecoveryUIFlowEchoServer(t, reg) @@ -372,9 +375,9 @@ func TestRecovery(t *testing.T) { authClient := testhelpers.NewHTTPClientWithArbitrarySessionToken(t, ctx, reg) if isAPI { req := httptest.NewRequest("GET", "/sessions/whoami", nil) - s, err := session.NewActiveSession(req, - &identity.Identity{ID: x.NewUUID(), State: identity.StateActive}, - testhelpers.NewSessionLifespanProvider(time.Hour), + req.WithContext(confighelpers.WithConfigValue(ctx, config.ViperKeySessionLifespan, time.Hour)) + s, err := testhelpers.NewActiveSession(req, reg, + &identity.Identity{ID: x.NewUUID(), State: identity.StateActive, NID: x.NewUUID()}, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1, @@ -694,7 +697,7 @@ func TestRecovery(t *testing.T) { id := createIdentityToRecover(t, reg, email) req := httptest.NewRequest("GET", "/sessions/whoami", nil) - sess, err := session.NewActiveSession(req, id, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + sess, err := testhelpers.NewActiveSession(req, reg, id, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) require.NoError(t, err) require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), sess)) diff --git a/selfservice/strategy/link/strategy_verification.go b/selfservice/strategy/link/strategy_verification.go index 200334700c80..8be16ed15b48 100644 --- a/selfservice/strategy/link/strategy_verification.go +++ b/selfservice/strategy/link/strategy_verification.go @@ -132,7 +132,6 @@ func (s *Strategy) Verify(w http.ResponseWriter, r *http.Request, f *verificatio ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.link.strategy.Verify") span.SetAttributes(attribute.String("selfservice_flows_verification_use", s.d.Config().SelfServiceFlowVerificationUse(ctx))) defer otelx.End(span, &err) - r = r.WithContext(ctx) body, err := s.decodeVerification(r) if err != nil { @@ -141,14 +140,14 @@ func (s *Strategy) Verify(w http.ResponseWriter, r *http.Request, f *verificatio f.TransientPayload = body.TransientPayload if len(body.Token) > 0 { - if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), s.VerificationStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.VerificationStrategyID(), s.VerificationStrategyID(), s.d); err != nil { return s.handleVerificationError(r, nil, body, err) } - return s.verificationUseToken(w, r, body, f) + return s.verificationUseToken(ctx, w, r, body, f) } - if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.VerificationStrategyID(), body.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.VerificationStrategyID(), body.Method, s.d); err != nil { return s.handleVerificationError(r, f, body, err) } @@ -158,15 +157,15 @@ func (s *Strategy) Verify(w http.ResponseWriter, r *http.Request, f *verificatio switch f.State { case flow.StateChooseMethod, flow.StateEmailSent: - return s.verificationHandleFormSubmission(r, f) + return s.verificationHandleFormSubmission(ctx, r, f) case flow.StatePassedChallenge: - return s.retryVerificationFlowWithMessage(w, r, f.Type, text.NewErrorValidationVerificationRetrySuccess()) + return s.retryVerificationFlowWithMessage(ctx, w, r, f.Type, text.NewErrorValidationVerificationRetrySuccess()) default: - return s.retryVerificationFlowWithMessage(w, r, f.Type, text.NewErrorValidationVerificationStateFailure()) + return s.retryVerificationFlowWithMessage(ctx, w, r, f.Type, text.NewErrorValidationVerificationStateFailure()) } } -func (s *Strategy) verificationHandleFormSubmission(r *http.Request, f *verification.Flow) error { +func (s *Strategy) verificationHandleFormSubmission(ctx context.Context, r *http.Request, f *verification.Flow) error { body, err := s.decodeVerification(r) if err != nil { return s.handleVerificationError(r, f, body, err) @@ -176,11 +175,11 @@ func (s *Strategy) verificationHandleFormSubmission(r *http.Request, f *verifica return s.handleVerificationError(r, f, body, schema.NewRequiredError("#/email", "email")) } - if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, body.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, body.CSRFToken); err != nil { return s.handleVerificationError(r, f, body, err) } - if err := s.d.LinkSender().SendVerificationLink(r.Context(), f, identity.VerifiableAddressTypeEmail, body.Email); err != nil { + if err := s.d.LinkSender().SendVerificationLink(ctx, f, identity.VerifiableAddressTypeEmail, body.Email); err != nil { if !errors.Is(err, ErrUnknownAddress) { return s.handleVerificationError(r, f, body, err) } @@ -203,18 +202,18 @@ func (s *Strategy) verificationHandleFormSubmission(r *http.Request, f *verifica return nil } -func (s *Strategy) verificationUseToken(w http.ResponseWriter, r *http.Request, body *verificationSubmitPayload, f *verification.Flow) error { - token, err := s.d.VerificationTokenPersister().UseVerificationToken(r.Context(), f.ID, body.Token) +func (s *Strategy) verificationUseToken(ctx context.Context, w http.ResponseWriter, r *http.Request, body *verificationSubmitPayload, f *verification.Flow) error { + token, err := s.d.VerificationTokenPersister().UseVerificationToken(ctx, f.ID, body.Token) if err != nil { if errors.Is(err, sqlcon.ErrNoRows) { - return s.retryVerificationFlowWithMessage(w, r, flow.TypeBrowser, text.NewErrorValidationVerificationTokenInvalidOrAlreadyUsed()) + return s.retryVerificationFlowWithMessage(ctx, w, r, flow.TypeBrowser, text.NewErrorValidationVerificationTokenInvalidOrAlreadyUsed()) } - return s.retryVerificationFlowWithError(w, r, flow.TypeBrowser, err) + return s.retryVerificationFlowWithError(ctx, w, r, flow.TypeBrowser, err) } if err := token.Valid(); err != nil { - return s.retryVerificationFlowWithError(w, r, flow.TypeBrowser, err) + return s.retryVerificationFlowWithError(ctx, w, r, flow.TypeBrowser, err) } address := token.VerifiableAddress @@ -223,12 +222,12 @@ func (s *Strategy) verificationUseToken(w http.ResponseWriter, r *http.Request, address.VerifiedAt = &verifiedAt address.Status = identity.VerifiableAddressStatusCompleted if err := s.d.PrivilegedIdentityPool().UpdateVerifiableAddress(r.Context(), address); err != nil { - return s.retryVerificationFlowWithError(w, r, flow.TypeBrowser, err) + return s.retryVerificationFlowWithError(ctx, w, r, flow.TypeBrowser, err) } i, err := s.d.IdentityPool().GetIdentity(r.Context(), token.VerifiableAddress.IdentityID, identity.ExpandDefault) if err != nil { - return s.retryVerificationFlowWithError(w, r, flow.TypeBrowser, err) + return s.retryVerificationFlowWithError(ctx, w, r, flow.TypeBrowser, err) } returnTo := f.ContinueURL(r.Context(), s.d.Config()) @@ -253,65 +252,65 @@ func (s *Strategy) verificationUseToken(w http.ResponseWriter, r *http.Request, WithMetaLabel(text.NewInfoNodeLabelContinue())) if err := s.d.VerificationFlowPersister().UpdateVerificationFlow(r.Context(), f); err != nil { - return s.retryVerificationFlowWithError(w, r, flow.TypeBrowser, err) + return s.retryVerificationFlowWithError(ctx, w, r, flow.TypeBrowser, err) } if err := s.d.VerificationExecutor().PostVerificationHook(w, r, f, i); err != nil { - return s.retryVerificationFlowWithError(w, r, flow.TypeBrowser, err) + return s.retryVerificationFlowWithError(ctx, w, r, flow.TypeBrowser, err) } return nil } -func (s *Strategy) retryVerificationFlowWithMessage(w http.ResponseWriter, r *http.Request, ft flow.Type, message *text.Message) error { +func (s *Strategy) retryVerificationFlowWithMessage(ctx context.Context, w http.ResponseWriter, r *http.Request, ft flow.Type, message *text.Message) error { s.d.Logger().WithRequest(r).WithField("message", message).Debug("A verification flow is being retried because a validation error occurred.") f, err := verification.NewFlow(s.d.Config(), - s.d.Config().SelfServiceFlowVerificationRequestLifespan(r.Context()), s.d.CSRFHandler().RegenerateToken(w, r), r, s, ft) + s.d.Config().SelfServiceFlowVerificationRequestLifespan(ctx), s.d.CSRFHandler().RegenerateToken(w, r), r, s, ft) if err != nil { return s.handleVerificationError(r, f, nil, err) } f.UI.Messages.Add(message) - if err := s.d.VerificationFlowPersister().CreateVerificationFlow(r.Context(), f); err != nil { + if err := s.d.VerificationFlowPersister().CreateVerificationFlow(ctx, f); err != nil { return s.handleVerificationError(r, f, nil, err) } if ft == flow.TypeBrowser { - http.Redirect(w, r, f.AppendTo(s.d.Config().SelfServiceFlowVerificationUI(r.Context())).String(), http.StatusSeeOther) + http.Redirect(w, r, f.AppendTo(s.d.Config().SelfServiceFlowVerificationUI(ctx)).String(), http.StatusSeeOther) } else { - http.Redirect(w, r, urlx.CopyWithQuery(urlx.AppendPaths(s.d.Config().SelfPublicURL(r.Context()), + http.Redirect(w, r, urlx.CopyWithQuery(urlx.AppendPaths(s.d.Config().SelfPublicURL(ctx), verification.RouteGetFlow), url.Values{"id": {f.ID.String()}}).String(), http.StatusSeeOther) } return errors.WithStack(flow.ErrCompletedByStrategy) } -func (s *Strategy) retryVerificationFlowWithError(w http.ResponseWriter, r *http.Request, ft flow.Type, verErr error) error { +func (s *Strategy) retryVerificationFlowWithError(ctx context.Context, w http.ResponseWriter, r *http.Request, ft flow.Type, verErr error) error { s.d.Logger().WithRequest(r).WithError(verErr).Debug("A verification flow is being retried because an error occurred.") f, err := verification.NewFlow(s.d.Config(), - s.d.Config().SelfServiceFlowVerificationRequestLifespan(r.Context()), s.d.CSRFHandler().RegenerateToken(w, r), r, s, ft) + s.d.Config().SelfServiceFlowVerificationRequestLifespan(ctx), s.d.CSRFHandler().RegenerateToken(w, r), r, s, ft) if err != nil { return s.handleVerificationError(r, f, nil, err) } if expired := new(flow.ExpiredError); errors.As(verErr, &expired) { - return s.retryVerificationFlowWithMessage(w, r, ft, text.NewErrorValidationVerificationFlowExpired(expired.ExpiredAt)) + return s.retryVerificationFlowWithMessage(ctx, w, r, ft, text.NewErrorValidationVerificationFlowExpired(expired.ExpiredAt)) } else { if err := f.UI.ParseError(node.LinkGroup, verErr); err != nil { return err } } - if err := s.d.VerificationFlowPersister().CreateVerificationFlow(r.Context(), f); err != nil { + if err := s.d.VerificationFlowPersister().CreateVerificationFlow(ctx, f); err != nil { return s.handleVerificationError(r, f, nil, err) } if ft == flow.TypeBrowser { - http.Redirect(w, r, f.AppendTo(s.d.Config().SelfServiceFlowVerificationUI(r.Context())).String(), http.StatusSeeOther) + http.Redirect(w, r, f.AppendTo(s.d.Config().SelfServiceFlowVerificationUI(ctx)).String(), http.StatusSeeOther) } else { - http.Redirect(w, r, urlx.CopyWithQuery(urlx.AppendPaths(s.d.Config().SelfPublicURL(r.Context()), + http.Redirect(w, r, urlx.CopyWithQuery(urlx.AppendPaths(s.d.Config().SelfPublicURL(ctx), verification.RouteGetFlow), url.Values{"id": {f.ID.String()}}).String(), http.StatusSeeOther) } diff --git a/selfservice/strategy/lookup/login.go b/selfservice/strategy/lookup/login.go index 2eda9b5796c8..21b04e54e39d 100644 --- a/selfservice/strategy/lookup/login.go +++ b/selfservice/strategy/lookup/login.go @@ -8,6 +8,8 @@ import ( "net/http" "time" + "github.com/ory/x/otelx" + "github.com/ory/x/sqlcon" "github.com/ory/x/sqlxx" @@ -89,6 +91,9 @@ type updateLoginFlowWithLookupSecretMethod struct { } func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, sess *session.Session) (i *identity.Identity, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.lookup.strategy.Login") + defer otelx.End(span, &err) + if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel2); err != nil { return nil, err } @@ -105,11 +110,11 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, s.handleLoginError(r, f, err) } - if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return nil, s.handleLoginError(r, f, err) } - i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), s.ID(), sess.IdentityID.String()) + i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, s.ID(), sess.IdentityID.String()) if errors.Is(err, sqlcon.ErrNoRows) { return nil, s.handleLoginError(r, f, errors.WithStack(schema.NewNoLookupDefined())) } else if err != nil { @@ -137,25 +142,30 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, s.handleLoginError(r, f, errors.WithStack(schema.NewErrorValidationLookupInvalid())) } - toUpdate, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), sess.IdentityID) + // We can't use a transaction here because HydrateIdentityAssociations (used by update) does not support transactions. + toUpdate, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, sess.IdentityID) if err != nil { - return nil, err + return nil, s.handleLoginError(r, f, err) } encoded, err := json.Marshal(&o) if err != nil { - return nil, s.handleLoginError(r, f, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to encoded updated lookup secrets.").WithDebug(err.Error()))) + return nil, s.handleLoginError(r, f, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to encode updated lookup secrets.").WithDebug(err.Error()))) } c.Config = encoded toUpdate.SetCredentials(s.ID(), *c) - if err := s.d.PrivilegedIdentityPool().UpdateIdentity(r.Context(), toUpdate); err != nil { + // We can't use a transaction here because HydrateIdentityAssociations (used by update) does not support transactions. + if err := s.d.IdentityManager().Update(ctx, toUpdate, + // We need to allow write protected traits because we are updating the lookup secrets. + identity.ManagerAllowWriteProtectedTraits, + ); err != nil { return nil, s.handleLoginError(r, f, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to update identity.").WithDebug(err.Error()))) } f.Active = s.ID() - if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil { + if err = s.d.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { return nil, s.handleLoginError(r, f, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow.").WithDebug(err.Error()))) } diff --git a/selfservice/strategy/lookup/settings.go b/selfservice/strategy/lookup/settings.go index 3eee82d8d28a..4102a2f94857 100644 --- a/selfservice/strategy/lookup/settings.go +++ b/selfservice/strategy/lookup/settings.go @@ -9,6 +9,8 @@ import ( "net/http" "time" + "github.com/ory/x/otelx" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -97,7 +99,10 @@ func (p *updateSettingsFlowWithLookupMethod) SetFlowID(rid uuid.UUID) { p.Flow = rid.String() } -func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (*settings.UpdateContext, error) { +func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (_ *settings.UpdateContext, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.lookup.strategy.Settings") + defer otelx.End(span, &err) + var p updateSettingsFlowWithLookupMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { @@ -113,7 +118,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. if p.RegenerateLookup || p.RevealLookup || p.ConfirmLookup || p.DisableLookup { // This method has only two submit buttons p.Method = s.SettingsStrategyID() - if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { return nil, s.handleSettingsError(w, r, ctxUpdate, p, err) } } else { @@ -148,11 +153,11 @@ func (s *Strategy) continueSettingsFlow(r *http.Request, ctxUpdate *settings.Upd return err } - if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return err } - if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(r.Context())).Before(time.Now()) { + if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(ctx)).Before(time.Now()) { return errors.WithStack(settings.NewFlowNeedsReAuth()) } } else { @@ -329,7 +334,7 @@ func (s *Strategy) identityHasLookup(ctx context.Context, id uuid.UUID) (bool, e return false, err } - count, err := s.CountActiveMultiFactorCredentials(confidential.Credentials) + count, err := s.CountActiveMultiFactorCredentials(ctx, confidential.Credentials) if err != nil { return false, err } diff --git a/selfservice/strategy/lookup/strategy.go b/selfservice/strategy/lookup/strategy.go index e8f12cac9948..ffd1081cee0f 100644 --- a/selfservice/strategy/lookup/strategy.go +++ b/selfservice/strategy/lookup/strategy.go @@ -34,6 +34,8 @@ type lookupStrategyDependencies interface { x.WriterProvider x.CSRFTokenGeneratorProvider x.CSRFProvider + x.TransactionPersistenceProvider + x.TracingProvider config.Provider @@ -61,6 +63,7 @@ type lookupStrategyDependencies interface { identity.PrivilegedPoolProvider identity.ValidationProvider + identity.ManagementProvider session.HandlerProvider session.ManagementProvider @@ -78,15 +81,15 @@ func NewStrategy(d any) *Strategy { } } -func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveFirstFactorCredentials(_ context.Context, _ map[identity.CredentialsType]identity.Credentials) (count int, err error) { return 0, nil } -func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveMultiFactorCredentials(_ context.Context, cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { for _, c := range cc { if c.Type == s.ID() && len(c.Config) > 0 { var conf identity.CredentialsLookupConfig - if err = json.Unmarshal(c.Config, &conf); err != nil { + if err := json.Unmarshal(c.Config, &conf); err != nil { return 0, errors.WithStack(err) } @@ -106,7 +109,7 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.LookupGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { return session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel2, diff --git a/selfservice/strategy/lookup/strategy_test.go b/selfservice/strategy/lookup/strategy_test.go index f9674ab2805e..5c8cf126f59d 100644 --- a/selfservice/strategy/lookup/strategy_test.go +++ b/selfservice/strategy/lookup/strategy_test.go @@ -21,7 +21,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { strategy := lookup.NewStrategy(reg) t.Run("first factor", func(t *testing.T) { - actual, err := strategy.CountActiveFirstFactorCredentials(nil) + actual, err := strategy.CountActiveFirstFactorCredentials(nil, nil) require.NoError(t, err) assert.Equal(t, 0, actual) }) @@ -66,7 +66,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - actual, err := strategy.CountActiveMultiFactorCredentials(tc.in) + actual, err := strategy.CountActiveMultiFactorCredentials(nil, tc.in) require.NoError(t, err) assert.Equal(t, tc.expected, actual) }) diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 9b0acecc05fa..1a0e1c968e66 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -177,7 +177,7 @@ func parseState(s string) (*State, error) { } } -func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveFirstFactorCredentials(_ context.Context, cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { for _, c := range cc { if c.Type == s.ID() && gjson.ValidBytes(c.Config) { var conf identity.CredentialsOIDC @@ -202,7 +202,7 @@ func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.Credentials return } -func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveMultiFactorCredentials(_ context.Context, _ map[identity.CredentialsType]identity.Credentials) (count int, err error) { return 0, nil } @@ -404,22 +404,22 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt req, cntnr, err := s.ValidateCallback(w, r) if err != nil { if req != nil { - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err)) } else { - s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, s.handleError(w, r, nil, pid, nil, err)) + s.d.SelfServiceErrorManager().Forward(ctx, w, r, s.handleError(ctx, w, r, nil, pid, nil, err)) } return } if authenticated, err := s.alreadyAuthenticated(w, r, req); err != nil { - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err)) } else if authenticated { return } provider, err := s.provider(r.Context(), r, pid) if err != nil { - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err)) return } @@ -429,37 +429,37 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt case OAuth2Provider: token, err := s.ExchangeCode(r.Context(), provider, code) if err != nil { - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err)) return } et, err = s.encryptOAuth2Tokens(r.Context(), token) if err != nil { - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err)) return } claims, err = p.Claims(r.Context(), token, r.URL.Query()) if err != nil { - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err)) return } case OAuth1Provider: token, err := p.ExchangeToken(r.Context(), r) if err != nil { - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err)) return } claims, err = p.Claims(r.Context(), token) if err != nil { - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err)) return } } if err = claims.Validate(); err != nil { - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err)) return } @@ -469,7 +469,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt case *login.Flow: a.Active = s.ID() a.TransientPayload = cntnr.TransientPayload - if ff, err := s.processLogin(w, r, a, et, claims, provider, cntnr); err != nil { + if ff, err := s.processLogin(ctx, w, r, a, et, claims, provider, cntnr); err != nil { if errors.Is(err, flow.ErrCompletedByStrategy) { return } @@ -483,7 +483,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt case *registration.Flow: a.Active = s.ID() a.TransientPayload = cntnr.TransientPayload - if ff, err := s.processRegistration(w, r, a, et, claims, provider, cntnr, ""); err != nil { + if ff, err := s.processRegistration(ctx, w, r, a, et, claims, provider, cntnr, ""); err != nil { if ff != nil { s.forwardError(w, r, ff, err) return @@ -496,16 +496,16 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt a.TransientPayload = cntnr.TransientPayload sess, err := s.d.SessionManager().FetchFromRequest(r.Context(), r) if err != nil { - s.forwardError(w, r, a, s.handleError(w, r, a, pid, nil, err)) + s.forwardError(w, r, a, s.handleError(ctx, w, r, a, pid, nil, err)) return } if err := s.linkProvider(w, r, &settings.UpdateContext{Session: sess, Flow: a}, et, claims, provider); err != nil { - s.forwardError(w, r, a, s.handleError(w, r, a, pid, nil, err)) + s.forwardError(w, r, a, s.handleError(ctx, w, r, a, pid, nil, err)) return } return default: - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, errors.WithStack(x.PseudoPanic. + s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, errors.WithStack(x.PseudoPanic. WithDetailf("cause", "Unexpected type in OpenID Connect flow: %T", a)))) return } @@ -589,7 +589,7 @@ func (s *Strategy) forwardError(w http.ResponseWriter, r *http.Request, f flow.F } } -func (s *Strategy) handleError(w http.ResponseWriter, r *http.Request, f flow.Flow, providerID string, traits []byte, err error) error { +func (s *Strategy) handleError(ctx context.Context, w http.ResponseWriter, r *http.Request, f flow.Flow, providerID string, traits []byte, err error) error { switch rf := f.(type) { case *login.Flow: return err @@ -697,7 +697,7 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.OpenIDConnectGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { return session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel1, diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 88f0c7727537..25dd1765ef14 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -5,6 +5,7 @@ package oidc import ( "bytes" + "context" "encoding/json" "net/http" "strings" @@ -107,8 +108,11 @@ type UpdateLoginFlowWithOidcMethod struct { TransientPayload json.RawMessage `json:"transient_payload,omitempty" form:"transient_payload"` } -func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider, container *AuthCodeContainer) (*registration.Flow, error) { - i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject)) +func (s *Strategy) processLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider, container *AuthCodeContainer) (*registration.Flow, error) { + ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.strategy.processLogin") + defer otelx.End(span, &err) + + i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject)) if err != nil { if errors.Is(err, sqlcon.ErrNoRows) { // If no account was found we're "manually" creating a new registration flow and redirecting the browser @@ -139,12 +143,12 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlo registrationFlow, err := s.d.RegistrationHandler().NewRegistrationFlow(w, r, loginFlow.Type, opts...) if err != nil { - return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err) + return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, err) } - err = s.d.SessionTokenExchangePersister().MoveToNewFlow(r.Context(), loginFlow.ID, registrationFlow.ID) + err = s.d.SessionTokenExchangePersister().MoveToNewFlow(ctx, loginFlow.ID, registrationFlow.ID) if err != nil { - return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err) + return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, err) } registrationFlow.OrganizationID = loginFlow.OrganizationID @@ -155,37 +159,37 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlo registrationFlow.Active = s.ID() if err != nil { - return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err) + return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, err) } - if _, err := s.processRegistration(w, r, registrationFlow, token, claims, provider, container, loginFlow.IDToken); err != nil { + if _, err := s.processRegistration(ctx, w, r, registrationFlow, token, claims, provider, container, loginFlow.IDToken); err != nil { return registrationFlow, err } return nil, nil } - return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err) + return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, err) } var oidcCredentials identity.CredentialsOIDC if err := json.NewDecoder(bytes.NewBuffer(c.Config)).Decode(&oidcCredentials); err != nil { - return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("The password credentials could not be decoded properly").WithDebug(err.Error()))) + return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("The password credentials could not be decoded properly").WithDebug(err.Error()))) } sess := session.NewInactiveSession() sess.CompletedLoginForWithProvider(s.ID(), identity.AuthenticatorAssuranceLevel1, provider.Config().ID, - httprouter.ParamsFromContext(r.Context()).ByName("organization")) + httprouter.ParamsFromContext(ctx).ByName("organization")) for _, c := range oidcCredentials.Providers { if c.Subject == claims.Subject && c.Provider == provider.Config().ID { if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID, login.WithClaims(claims)); err != nil { - return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err) + return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, err) } return nil, nil } } - return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to find matching OpenID Connect Credentials.").WithDebugf(`Unable to find credentials that match the given provider "%s" and subject "%s".`, provider.Config().ID, claims.Subject))) + return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to find matching OpenID Connect Credentials.").WithDebugf(`Unable to find credentials that match the given provider "%s" and subject "%s".`, provider.Config().ID, claims.Subject))) } func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ *session.Session) (i *identity.Identity, err error) { @@ -198,7 +202,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, var p UpdateLoginFlowWithOidcMethod if err := s.newLinkDecoder(&p, r); err != nil { - return nil, s.handleError(w, r, f, "", nil, err) + return nil, s.handleError(ctx, w, r, f, "", nil, err) } f.IDToken = p.IDToken @@ -221,21 +225,21 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { - return nil, s.handleError(w, r, f, pid, nil, err) + return nil, s.handleError(ctx, w, r, f, pid, nil, err) } provider, err := s.provider(ctx, r, pid) if err != nil { - return nil, s.handleError(w, r, f, pid, nil, err) + return nil, s.handleError(ctx, w, r, f, pid, nil, err) } req, err := s.validateFlow(ctx, r, f.ID) if err != nil { - return nil, s.handleError(w, r, f, pid, nil, err) + return nil, s.handleError(ctx, w, r, f, pid, nil, err) } if authenticated, err := s.alreadyAuthenticated(w, r, req); err != nil { - return nil, s.handleError(w, r, f, pid, nil, err) + return nil, s.handleError(ctx, w, r, f, pid, nil, err) } else if authenticated { return i, nil } @@ -243,14 +247,14 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, if p.IDToken != "" { claims, err := s.processIDToken(w, r, provider, p.IDToken, p.IDTokenNonce) if err != nil { - return nil, s.handleError(w, r, f, pid, nil, err) + return nil, s.handleError(ctx, w, r, f, pid, nil, err) } - _, err = s.processLogin(w, r, f, nil, claims, provider, &AuthCodeContainer{ + _, err = s.processLogin(ctx, w, r, f, nil, claims, provider, &AuthCodeContainer{ FlowID: f.ID.String(), Traits: p.Traits, }) if err != nil { - return nil, s.handleError(w, r, f, pid, nil, err) + return nil, s.handleError(ctx, w, r, f, pid, nil, err) } return nil, errors.WithStack(flow.ErrCompletedByStrategy) } @@ -267,12 +271,12 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, TransientPayload: f.TransientPayload, }), continuity.WithLifespan(time.Minute*30)); err != nil { - return nil, s.handleError(w, r, f, pid, nil, err) + return nil, s.handleError(ctx, w, r, f, pid, nil, err) } f.Active = s.ID() if err = s.d.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { - return nil, s.handleError(w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error()))) + return nil, s.handleError(ctx, w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error()))) } var up map[string]string @@ -282,7 +286,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, codeURL, err := getAuthRedirectURL(ctx, provider, f, state, up) if err != nil { - return nil, s.handleError(w, r, f, pid, nil, err) + return nil, s.handleError(ctx, w, r, f, pid, nil, err) } if x.IsJSONRequest(r) { @@ -340,8 +344,11 @@ func (s *Strategy) PopulateLoginMethodSecondFactorRefresh(r *http.Request, sr *l return nil } -func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request, f *login.Flow, mods ...login.FormHydratorModifier) error { - conf, err := s.Config(r.Context()) +func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request, f *login.Flow, mods ...login.FormHydratorModifier) (err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.oidc.strategy.PopulateLoginMethodIdentifierFirstCredentials") + defer otelx.End(span, &err) + + conf, err := s.Config(ctx) if err != nil { return err } @@ -359,7 +366,7 @@ func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request if len(linked) == 0 { // If we found no credentials: - if s.d.Config().SecurityAccountEnumerationMitigate(r.Context()) { + if s.d.Config().SecurityAccountEnumerationMitigate(ctx) { // We found no credentials but do not want to leak that we know that. So we return early and do not // modify the initial provider list. return nil @@ -370,7 +377,7 @@ func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request return idfirst.ErrNoCredentialsFound } - if !s.d.Config().SecurityAccountEnumerationMitigate(r.Context()) { + if !s.d.Config().SecurityAccountEnumerationMitigate(ctx) { // Account enumeration is disabled, so we show all providers that are linked to the identity. // User is found and enumeration mitigation is disabled. Filter the list! f.GetUI().UnsetNode("provider") diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index 5d5cb6ead25c..44b821fb89ee 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -5,6 +5,7 @@ package oidc import ( "bytes" + "context" "encoding/json" "net/http" "strings" @@ -155,7 +156,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat var p UpdateRegistrationFlowWithOidcMethod if err := s.newLinkDecoder(&p, r); err != nil { - return s.handleError(w, r, f, "", nil, err) + return s.handleError(ctx, w, r, f, "", nil, err) } pid := p.Provider // this can come from both url query and post body @@ -178,21 +179,21 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat } if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { - return s.handleError(w, r, f, pid, nil, err) + return s.handleError(ctx, w, r, f, pid, nil, err) } provider, err := s.provider(ctx, r, pid) if err != nil { - return s.handleError(w, r, f, pid, nil, err) + return s.handleError(ctx, w, r, f, pid, nil, err) } req, err := s.validateFlow(ctx, r, f.ID) if err != nil { - return s.handleError(w, r, f, pid, nil, err) + return s.handleError(ctx, w, r, f, pid, nil, err) } if authenticated, err := s.alreadyAuthenticated(w, r, req); err != nil { - return s.handleError(w, r, f, pid, nil, err) + return s.handleError(ctx, w, r, f, pid, nil, err) } else if authenticated { return errors.WithStack(registration.ErrAlreadyLoggedIn) } @@ -200,15 +201,15 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat if p.IDToken != "" { claims, err := s.processIDToken(w, r, provider, p.IDToken, p.IDTokenNonce) if err != nil { - return s.handleError(w, r, f, pid, nil, err) + return s.handleError(ctx, w, r, f, pid, nil, err) } - _, err = s.processRegistration(w, r, f, nil, claims, provider, &AuthCodeContainer{ + _, err = s.processRegistration(ctx, w, r, f, nil, claims, provider, &AuthCodeContainer{ FlowID: f.ID.String(), Traits: p.Traits, TransientPayload: f.TransientPayload, }, p.IDToken) if err != nil { - return s.handleError(w, r, f, pid, nil, err) + return s.handleError(ctx, w, r, f, pid, nil, err) } return errors.WithStack(flow.ErrCompletedByStrategy) } @@ -225,7 +226,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat TransientPayload: f.TransientPayload, }), continuity.WithLifespan(time.Minute*30)); err != nil { - return s.handleError(w, r, f, pid, nil, err) + return s.handleError(ctx, w, r, f, pid, nil, err) } var up map[string]string @@ -235,7 +236,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat codeURL, err := getAuthRedirectURL(ctx, provider, f, state, up) if err != nil { - return s.handleError(w, r, f, pid, nil, err) + return s.handleError(ctx, w, r, f, pid, nil, err) } if x.IsJSONRequest(r) { s.d.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(codeURL)) @@ -279,7 +280,10 @@ func (s *Strategy) registrationToLogin(w http.ResponseWriter, r *http.Request, r return lf, nil } -func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, rf *registration.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider, container *AuthCodeContainer, idToken string) (*login.Flow, error) { +func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWriter, r *http.Request, rf *registration.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider, container *AuthCodeContainer, idToken string) (*login.Flow, error) { + ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.strategy.processRegistration") + defer otelx.End(span, &err) + if _, _, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject)); err == nil { // If the identity already exists, we should perform the login flow instead. @@ -296,11 +300,11 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r lf, err := s.registrationToLogin(w, r, rf, provider.Config().ID) if err != nil { - return nil, s.handleError(w, r, rf, provider.Config().ID, nil, err) + return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, nil, err) } - if _, err := s.processLogin(w, r, lf, token, claims, provider, container); err != nil { - return lf, s.handleError(w, r, rf, provider.Config().ID, nil, err) + if _, err := s.processLogin(ctx, w, r, lf, token, claims, provider, container); err != nil { + return lf, s.handleError(ctx, w, r, rf, provider.Config().ID, nil, err) } return nil, nil @@ -309,17 +313,17 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(r.Context())), fetcher.WithCache(jsonnetCache, 60*time.Minute)) jsonnetMapperSnippet, err := fetch.FetchContext(r.Context(), provider.Config().Mapper) if err != nil { - return nil, s.handleError(w, r, rf, provider.Config().ID, nil, err) + return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, nil, err) } - i, va, err := s.createIdentity(w, r, rf, claims, provider, container, jsonnetMapperSnippet.Bytes()) + i, va, err := s.createIdentity(ctx, w, r, rf, claims, provider, container, jsonnetMapperSnippet.Bytes()) if err != nil { - return nil, s.handleError(w, r, rf, provider.Config().ID, nil, err) + return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, nil, err) } // Validate the identity itself if err := s.d.IdentityValidator().Validate(r.Context(), i); err != nil { - return nil, s.handleError(w, r, rf, provider.Config().ID, i.Traits, err) + return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, i.Traits, err) } for n := range i.VerifiableAddresses { @@ -336,50 +340,50 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r creds, err := identity.NewCredentialsOIDC(token, provider.Config().ID, claims.Subject, provider.Config().OrganizationID) if err != nil { - return nil, s.handleError(w, r, rf, provider.Config().ID, i.Traits, err) + return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, i.Traits, err) } i.SetCredentials(s.ID(), *creds) if err := s.d.RegistrationExecutor().PostRegistrationHook(w, r, identity.CredentialsTypeOIDC, provider.Config().ID, rf, i); err != nil { - return nil, s.handleError(w, r, rf, provider.Config().ID, i.Traits, err) + return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, i.Traits, err) } return nil, nil } -func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a flow.Flow, claims *claims.Claims, provider Provider, container *AuthCodeContainer, jsonnetSnippet []byte) (*identity.Identity, []VerifiedAddress, error) { +func (s *Strategy) createIdentity(ctx context.Context, w http.ResponseWriter, r *http.Request, a flow.Flow, claims *claims.Claims, provider Provider, container *AuthCodeContainer, jsonnetSnippet []byte) (*identity.Identity, []VerifiedAddress, error) { var jsonClaims bytes.Buffer if err := json.NewEncoder(&jsonClaims).Encode(claims); err != nil { - return nil, nil, s.handleError(w, r, a, provider.Config().ID, nil, err) + return nil, nil, s.handleError(ctx, w, r, a, provider.Config().ID, nil, err) } vm, err := s.d.JsonnetVM(r.Context()) if err != nil { - return nil, nil, s.handleError(w, r, a, provider.Config().ID, nil, err) + return nil, nil, s.handleError(ctx, w, r, a, provider.Config().ID, nil, err) } vm.ExtCode("claims", jsonClaims.String()) evaluated, err := vm.EvaluateAnonymousSnippet(provider.Config().Mapper, string(jsonnetSnippet)) if err != nil { - return nil, nil, s.handleError(w, r, a, provider.Config().ID, nil, err) + return nil, nil, s.handleError(ctx, w, r, a, provider.Config().ID, nil, err) } i := identity.NewIdentity(s.d.Config().DefaultIdentityTraitsSchemaID(r.Context())) - if err := s.setTraits(w, r, a, claims, provider, container, evaluated, i); err != nil { - return nil, nil, s.handleError(w, r, a, provider.Config().ID, i.Traits, err) + if err := s.setTraits(ctx, w, r, a, claims, provider, container, evaluated, i); err != nil { + return nil, nil, s.handleError(ctx, w, r, a, provider.Config().ID, i.Traits, err) } if err := s.setMetadata(evaluated, i, PublicMetadata); err != nil { - return nil, nil, s.handleError(w, r, a, provider.Config().ID, i.Traits, err) + return nil, nil, s.handleError(ctx, w, r, a, provider.Config().ID, i.Traits, err) } if err := s.setMetadata(evaluated, i, AdminMetadata); err != nil { - return nil, nil, s.handleError(w, r, a, provider.Config().ID, i.Traits, err) + return nil, nil, s.handleError(ctx, w, r, a, provider.Config().ID, i.Traits, err) } va, err := s.extractVerifiedAddresses(evaluated) if err != nil { - return nil, nil, s.handleError(w, r, a, provider.Config().ID, i.Traits, err) + return nil, nil, s.handleError(ctx, w, r, a, provider.Config().ID, i.Traits, err) } if orgID := httprouter.ParamsFromContext(r.Context()).ByName("organization"); orgID != "" { @@ -396,7 +400,7 @@ func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a flow return i, va, nil } -func (s *Strategy) setTraits(w http.ResponseWriter, r *http.Request, a flow.Flow, claims *claims.Claims, provider Provider, container *AuthCodeContainer, evaluated string, i *identity.Identity) error { +func (s *Strategy) setTraits(ctx context.Context, w http.ResponseWriter, r *http.Request, a flow.Flow, claims *claims.Claims, provider Provider, container *AuthCodeContainer, evaluated string, i *identity.Identity) error { jsonTraits := gjson.Get(evaluated, "identity.traits") if !jsonTraits.IsObject() { return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("OpenID Connect Jsonnet mapper did not return an object for key identity.traits. Please check your Jsonnet code!")) @@ -405,7 +409,7 @@ func (s *Strategy) setTraits(w http.ResponseWriter, r *http.Request, a flow.Flow if container != nil { traits, err := merge(container.Traits, json.RawMessage(jsonTraits.Raw)) if err != nil { - return s.handleError(w, r, a, provider.Config().ID, nil, err) + return s.handleError(ctx, w, r, a, provider.Config().ID, nil, err) } i.Traits = traits diff --git a/selfservice/strategy/oidc/strategy_settings.go b/selfservice/strategy/oidc/strategy_settings.go index d8b1cfab8695..a98c5a39d324 100644 --- a/selfservice/strategy/oidc/strategy_settings.go +++ b/selfservice/strategy/oidc/strategy_settings.go @@ -11,6 +11,8 @@ import ( "net/http" "time" + "github.com/ory/x/otelx" + "github.com/ory/x/sqlxx" "github.com/ory/x/stringsx" @@ -261,7 +263,10 @@ func (p *updateSettingsFlowWithOidcMethod) SetFlowID(rid uuid.UUID) { p.FlowID = rid.String() } -func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (*settings.UpdateContext, error) { +func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (_ *settings.UpdateContext, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.oidc.strategy.Settings") + defer otelx.End(span, &err) + var p updateSettingsFlowWithOidcMethod if err := s.decoderSettings(&p, r); err != nil { return nil, err @@ -270,7 +275,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { - if !s.d.Config().SelfServiceStrategy(r.Context(), s.SettingsStrategyID()).Enabled { + if !s.d.Config().SelfServiceStrategy(ctx, s.SettingsStrategyID()).Enabled { return nil, errors.WithStack(herodot.ErrNotFound.WithReason(strategy.EndpointDisabledMessage)) } @@ -297,7 +302,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. return nil, errors.WithStack(flow.ErrStrategyNotResponsible) } - if !s.d.Config().SelfServiceStrategy(r.Context(), s.SettingsStrategyID()).Enabled { + if !s.d.Config().SelfServiceStrategy(ctx, s.SettingsStrategyID()).Enabled { return nil, errors.WithStack(herodot.ErrNotFound.WithReason(strategy.EndpointDisabledMessage)) } diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 25d45486391b..244232f638d2 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -569,7 +569,7 @@ func TestStrategy(t *testing.T) { // We essentially run into this bit: // // if authenticated, err := s.alreadyAuthenticated(w, r, req); err != nil { - // s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + // s.forwardError(w, r, req, s.handleError(ctx, w, , r, req, pid, nil, err)) // } else if authenticated { // return <-- we end up here on the second call // } @@ -1543,7 +1543,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { for _, v := range tc.in { in[v.Type] = v } - actual, err := strategy.CountActiveFirstFactorCredentials(in) + actual, err := strategy.CountActiveFirstFactorCredentials(nil, in) require.NoError(t, err) assert.Equal(t, tc.expected, actual) }) diff --git a/selfservice/strategy/passkey/passkey_login.go b/selfservice/strategy/passkey/passkey_login.go index 857d6e824d32..a5235083ac46 100644 --- a/selfservice/strategy/passkey/passkey_login.go +++ b/selfservice/strategy/passkey/passkey_login.go @@ -9,6 +9,8 @@ import ( "net/http" "strings" + "github.com/ory/x/otelx" + "github.com/ory/kratos/selfservice/strategy/idfirst" "github.com/ory/kratos/x/webauthnx/js" @@ -144,12 +146,16 @@ type updateLoginFlowWithPasskeyMethod struct { } func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ *session.Session) (i *identity.Identity, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.passkey.strategy.Login") + defer otelx.End(span, &err) + if f.Type != flow.TypeBrowser { return nil, flow.ErrStrategyNotResponsible } var p updateLoginFlowWithPasskeyMethod if err := s.hd.Decode(r, &p, + decoderx.HTTPKeepRequestBody(true), decoderx.HTTPDecoderSetValidatePayloads(true), decoderx.MustHTTPRawJSONSchemaCompiler(loginSchema), decoderx.HTTPDecoderJSONFollowsFormFormat()); err != nil { @@ -163,11 +169,11 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, flow.ErrStrategyNotResponsible } - if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { return nil, s.handleLoginError(r, f, err) } - if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return nil, s.handleLoginError(r, f, err) } @@ -291,7 +297,7 @@ func (s *Strategy) PopulateLoginMethodFirstFactorRefresh(r *http.Request, f *log return nil } - id, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), id.ID) + id, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, id.ID) if err != nil { return err } @@ -430,6 +436,7 @@ func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request return errors.WithStack(idfirst.ErrNoCredentialsFound) } + ctx := r.Context() o := login.NewFormHydratorOptions(opts) var count int @@ -437,13 +444,13 @@ func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request var err error // If we have an identity hint we can perform identity credentials discovery and // hide this credential if it should not be included. - count, err = s.CountActiveFirstFactorCredentials(o.IdentityHint.Credentials) + count, err = s.CountActiveFirstFactorCredentials(ctx, o.IdentityHint.Credentials) if err != nil { return err } } - if count > 0 || s.d.Config().SecurityAccountEnumerationMitigate(r.Context()) { + if count > 0 || s.d.Config().SecurityAccountEnumerationMitigate(ctx) { sr.UI.Nodes.Append(node.NewInputField( node.PasskeyLoginTrigger, "", diff --git a/selfservice/strategy/passkey/passkey_registration.go b/selfservice/strategy/passkey/passkey_registration.go index 9be753f70c40..4d136893047f 100644 --- a/selfservice/strategy/passkey/passkey_registration.go +++ b/selfservice/strategy/passkey/passkey_registration.go @@ -11,6 +11,8 @@ import ( "net/url" "strings" + "github.com/ory/x/otelx" + "github.com/ory/kratos/x/webauthnx/js" "github.com/go-webauthn/webauthn/protocol" @@ -97,7 +99,8 @@ func (s *Strategy) decode(r *http.Request) (*updateRegistrationFlowWithPasskeyMe } func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, regFlow *registration.Flow, ident *identity.Identity) (err error) { - ctx := r.Context() + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.passkey.strategy.Register") + defer otelx.End(span, &err) if regFlow.Type != flow.TypeBrowser { return flow.ErrStrategyNotResponsible diff --git a/selfservice/strategy/passkey/passkey_registration_test.go b/selfservice/strategy/passkey/passkey_registration_test.go index d7191207cedb..86a7a7992e68 100644 --- a/selfservice/strategy/passkey/passkey_registration_test.go +++ b/selfservice/strategy/passkey/passkey_registration_test.go @@ -299,7 +299,7 @@ func TestRegistration(t *testing.T) { i, _, err := fix.reg.PrivilegedIdentityPool().FindByCredentialsIdentifier(fix.ctx, identity.CredentialsTypePasskey, userID) require.NoError(t, err) - assert.Equal(t, "aal1", i.AvailableAAL.String) + assert.Equal(t, "aal1", i.InternalAvailableAAL.String) assert.Equal(t, email, gjson.GetBytes(i.Traits, "username").String(), "%s", actual) }) } diff --git a/selfservice/strategy/passkey/passkey_settings.go b/selfservice/strategy/passkey/passkey_settings.go index 7214e32660d3..ada1877be662 100644 --- a/selfservice/strategy/passkey/passkey_settings.go +++ b/selfservice/strategy/passkey/passkey_settings.go @@ -4,6 +4,7 @@ package passkey import ( + "context" _ "embed" "encoding/json" "fmt" @@ -11,6 +12,8 @@ import ( "strings" "time" + "github.com/ory/x/otelx" + "github.com/ory/kratos/x/webauthnx/js" "github.com/go-webauthn/webauthn/protocol" @@ -160,14 +163,17 @@ func (s *Strategy) identityListWebAuthn(id *identity.Identity) (*identity.Creden return &cc, nil } -func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (*settings.UpdateContext, error) { +func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (_ *settings.UpdateContext, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.passkey.strategy.Settings") + defer otelx.End(span, &err) + if f.Type != flow.TypeBrowser { return nil, errors.WithStack(flow.ErrStrategyNotResponsible) } var p updateSettingsFlowWithPasskeyMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { - return ctxUpdate, s.continueSettingsFlow(w, r, ctxUpdate, p) + return ctxUpdate, s.continueSettingsFlow(ctx, w, r, ctxUpdate, p) } else if err != nil { return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } @@ -179,7 +185,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. if len(p.Register+p.Remove) > 0 { // This method has only two submit buttons p.Method = s.SettingsStrategyID() - if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { return nil, s.handleSettingsError(w, r, ctxUpdate, p, err) } } else { @@ -188,7 +194,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. // This does not come from the payload! p.Flow = ctxUpdate.Flow.ID.String() - if err := s.continueSettingsFlow(w, r, ctxUpdate, p); err != nil { + if err := s.continueSettingsFlow(ctx, w, r, ctxUpdate, p); err != nil { return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } @@ -235,19 +241,20 @@ func (p *updateSettingsFlowWithPasskeyMethod) SetFlowID(rid uuid.UUID) { } func (s *Strategy) continueSettingsFlow( + ctx context.Context, w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasskeyMethod, ) error { if len(p.Register+p.Remove) > 0 { - if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, flow.SettingsFlow, s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { return err } - if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return err } - if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(r.Context())).Before(time.Now()) { + if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(ctx)).Before(time.Now()) { return errors.WithStack(settings.NewFlowNeedsReAuth()) } } else { @@ -256,16 +263,16 @@ func (s *Strategy) continueSettingsFlow( switch { case len(p.Remove) > 0: - return s.continueSettingsFlowRemove(w, r, ctxUpdate, p) + return s.continueSettingsFlowRemove(ctx, w, r, ctxUpdate, p) case len(p.Register) > 0: - return s.continueSettingsFlowAdd(r, ctxUpdate, p) + return s.continueSettingsFlowAdd(ctx, ctxUpdate, p) default: return errors.New("ended up in unexpected state") } } -func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasskeyMethod) error { - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.IdentityID) +func (s *Strategy) continueSettingsFlowRemove(ctx context.Context, w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasskeyMethod) error { + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, ctxUpdate.Session.IdentityID) if err != nil { return err } @@ -291,7 +298,7 @@ func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Req return errors.WithStack(herodot.ErrBadRequest.WithReasonf("You tried to remove a passkey which does not exist.")) } - count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(r.Context(), i) + count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(ctx, i) if err != nil { return err } @@ -317,7 +324,7 @@ func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Req return nil } -func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasskeyMethod) error { +func (s *Strategy) continueSettingsFlowAdd(ctx context.Context, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasskeyMethod) error { webAuthnSession := gjson.GetBytes(ctxUpdate.Flow.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeySessionData)) if !webAuthnSession.IsObject() { return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected WebAuthN in internal context to be an object.")) @@ -333,7 +340,7 @@ func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings. return errors.WithStack(herodot.ErrBadRequest.WithReasonf("Unable to parse WebAuthn response: %s", err)) } - web, err := webauthn.New(s.d.Config().PasskeyConfig(r.Context())) + web, err := webauthn.New(s.d.Config().PasskeyConfig(ctx)) if err != nil { return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to get webAuthn config.").WithDebug(err.Error())) } @@ -346,7 +353,7 @@ func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings. return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to create WebAuthn credential: %s", err)) } - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.IdentityID) + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, ctxUpdate.Session.IdentityID) if err != nil { return err } @@ -367,7 +374,7 @@ func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings. } i.UpsertCredentialsConfig(s.ID(), credentialsConfig, 1, identity.WithAdditionalIdentifier(string(webAuthnSess.UserID))) - if err := s.validateCredentials(r.Context(), i); err != nil { + if err := s.validateCredentials(ctx, i); err != nil { return err } @@ -377,14 +384,14 @@ func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings. return err } - if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(r.Context(), ctxUpdate.Flow); err != nil { + if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(ctx, ctxUpdate.Flow); err != nil { return err } aal := identity.AuthenticatorAssuranceLevel1 // Since we added the method, it also means that we have authenticated it - if err := s.d.SessionManager().SessionAddAuthenticationMethods(r.Context(), ctxUpdate.Session.ID, session.AuthenticationMethod{ + if err := s.d.SessionManager().SessionAddAuthenticationMethods(ctx, ctxUpdate.Session.ID, session.AuthenticationMethod{ Method: s.ID(), AAL: aal, }); err != nil { @@ -402,6 +409,7 @@ func (s *Strategy) decodeSettingsFlow(r *http.Request, dest interface{}) error { } return decoderx.NewHTTP().Decode(r, dest, compiler, + decoderx.HTTPKeepRequestBody(true), decoderx.HTTPDecoderAllowedMethods("POST", "GET"), decoderx.HTTPDecoderSetValidatePayloads(true), decoderx.HTTPDecoderJSONFollowsFormFormat(), diff --git a/selfservice/strategy/passkey/passkey_settings_test.go b/selfservice/strategy/passkey/passkey_settings_test.go index 606caa838774..781af829be3e 100644 --- a/selfservice/strategy/passkey/passkey_settings_test.go +++ b/selfservice/strategy/passkey/passkey_settings_test.go @@ -78,19 +78,6 @@ func TestCompleteSettings(t *testing.T) { }) }) - t.Run("case=invalid credentials", func(t *testing.T) { - id, _ := fix.createIdentityAndReturnIdentifier(t, []byte(`{invalid}`)) - - apiClient := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, fix.reg, id) - - req, err := http.NewRequest("GET", fix.publicTS.URL+settings.RouteInitBrowserFlow, nil) - require.NoError(t, err) - req.Header.Set("Accept", "application/json") - res, err := apiClient.Do(req) - require.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - }) - t.Run("case=one activation element is shown", func(t *testing.T) { id := fix.createIdentityWithoutPasskey(t) require.NoError(t, fix.reg.PrivilegedIdentityPool().UpdateIdentity(fix.ctx, id)) @@ -233,6 +220,7 @@ func TestCompleteSettings(t *testing.T) { // We load our identity which we will use to replay the webauth session var id identity.Identity require.NoError(t, json.Unmarshal(settingsFixtureSuccessIdentity, &id)) + id.NID = x.NewUUID() _ = fix.reg.PrivilegedIdentityPool().DeleteIdentity(fix.ctx, id.ID) browserClient := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, fix.reg, &id) f := testhelpers.InitializeSettingsFlowViaBrowser(t, browserClient, spa, fix.publicTS) @@ -439,6 +427,7 @@ func TestCompleteSettings(t *testing.T) { var id identity.Identity require.NoError(t, json.Unmarshal(settingsFixtureSuccessIdentity, &id)) _ = fix.reg.PrivilegedIdentityPool().DeleteIdentity(fix.ctx, id.ID) + id.NID = x.NewUUID() browserClient := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, fix.reg, &id) req, err := http.NewRequest("GET", fix.publicTS.URL+settings.RouteInitBrowserFlow, nil) diff --git a/selfservice/strategy/passkey/passkey_strategy.go b/selfservice/strategy/passkey/passkey_strategy.go index b590a7e93b6d..53102ae48982 100644 --- a/selfservice/strategy/passkey/passkey_strategy.go +++ b/selfservice/strategy/passkey/passkey_strategy.go @@ -6,6 +6,7 @@ package passkey import ( "context" "encoding/json" + "strings" "github.com/pkg/errors" @@ -28,6 +29,7 @@ type strategyDependencies interface { x.WriterProvider x.CSRFTokenGeneratorProvider x.CSRFProvider + x.TracingProvider config.Provider @@ -88,24 +90,24 @@ func (*Strategy) NodeGroup() node.UiNodeGroup { return node.PasskeyGroup } -func (s *Strategy) CompletedAuthenticationMethod(context.Context, session.AuthenticationMethods) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(context.Context) session.AuthenticationMethod { return session.AuthenticationMethod{ Method: identity.CredentialsTypePasskey, AAL: identity.AuthenticatorAssuranceLevel1, } } -func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveMultiFactorCredentials(_ context.Context, _ map[identity.CredentialsType]identity.Credentials) (count int, err error) { return 0, nil } -func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveFirstFactorCredentials(_ context.Context, cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { return s.countCredentials(cc) } func (s *Strategy) countCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { for _, c := range cc { - if c.Type == s.ID() && len(c.Config) > 0 && len(c.Identifiers) > 0 { + if c.Type == s.ID() && len(c.Config) > 0 && len(strings.Join(c.Identifiers, "")) > 0 { var conf identity.CredentialsWebAuthnConfig if err = json.Unmarshal(c.Config, &conf); err != nil { return 0, errors.WithStack(err) diff --git a/selfservice/strategy/password/login.go b/selfservice/strategy/password/login.go index 91e59085c8ef..be96b0672409 100644 --- a/selfservice/strategy/password/login.go +++ b/selfservice/strategy/password/login.go @@ -10,6 +10,8 @@ import ( "net/http" "time" + "github.com/ory/x/otelx" + "github.com/gofrs/uuid" "github.com/pkg/errors" @@ -48,6 +50,9 @@ func (s *Strategy) handleLoginError(r *http.Request, f *login.Flow, payload upda } func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ *session.Session) (i *identity.Identity, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.password.strategy.Login") + defer otelx.End(span, &err) + if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel1); err != nil { return nil, err } @@ -65,14 +70,14 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } f.TransientPayload = p.TransientPayload - if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return nil, s.handleLoginError(r, f, p, err) } identifier := stringsx.Coalesce(p.Identifier, p.LegacyIdentifier) - i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), s.ID(), identifier) + i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, s.ID(), identifier) if err != nil { - time.Sleep(x.RandomDelay(s.d.Config().HasherArgon2(r.Context()).ExpectedDuration, s.d.Config().HasherArgon2(r.Context()).ExpectedDeviation)) + time.Sleep(x.RandomDelay(s.d.Config().HasherArgon2(ctx).ExpectedDuration, s.d.Config().HasherArgon2(ctx).ExpectedDeviation)) return nil, s.handleLoginError(r, f, p, errors.WithStack(schema.NewInvalidCredentialsError())) } @@ -83,41 +88,44 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } if o.ShouldUsePasswordMigrationHook() { - pwHook := s.d.Config().PasswordMigrationHook(r.Context()) + pwHook := s.d.Config().PasswordMigrationHook(ctx) if !pwHook.Enabled { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Password migration hook is not enabled but password migration is requested.")) } migrationHook := hook.NewPasswordMigrationHook(s.d, pwHook.Config) - err = migrationHook.Execute(r.Context(), &hook.PasswordMigrationRequest{Identifier: identifier, Password: p.Password}) + err = migrationHook.Execute(ctx, &hook.PasswordMigrationRequest{Identifier: identifier, Password: p.Password}) if err != nil { return nil, s.handleLoginError(r, f, p, err) } - if err := s.migratePasswordHash(r.Context(), i.ID, []byte(p.Password)); err != nil { + if err := s.migratePasswordHash(ctx, i.ID, []byte(p.Password)); err != nil { return nil, s.handleLoginError(r, f, p, err) } } else { - if err := hash.Compare(r.Context(), []byte(p.Password), []byte(o.HashedPassword)); err != nil { + if err := hash.Compare(ctx, []byte(p.Password), []byte(o.HashedPassword)); err != nil { return nil, s.handleLoginError(r, f, p, errors.WithStack(schema.NewInvalidCredentialsError())) } - if !s.d.Hasher(r.Context()).Understands([]byte(o.HashedPassword)) { - if err := s.migratePasswordHash(r.Context(), i.ID, []byte(p.Password)); err != nil { + if !s.d.Hasher(ctx).Understands([]byte(o.HashedPassword)) { + if err := s.migratePasswordHash(ctx, i.ID, []byte(p.Password)); err != nil { return nil, s.handleLoginError(r, f, p, err) } } } f.Active = s.ID() - if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil { + if err = s.d.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { return nil, s.handleLoginError(r, f, p, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error()))) } return i, nil } -func (s *Strategy) migratePasswordHash(ctx context.Context, identifier uuid.UUID, password []byte) error { +func (s *Strategy) migratePasswordHash(ctx context.Context, identifier uuid.UUID, password []byte) (err error) { + ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.password.strategy.migratePasswordHash") + defer otelx.End(span, &err) + hpw, err := s.d.Hasher(ctx).Generate(ctx, password) if err != nil { return err @@ -140,17 +148,21 @@ func (s *Strategy) migratePasswordHash(ctx context.Context, identifier uuid.UUID c.Config = co i.SetCredentials(s.ID(), *c) - return s.d.PrivilegedIdentityPool().UpdateIdentity(ctx, i) + return s.d.IdentityManager().Update(ctx, i, identity.ManagerAllowWriteProtectedTraits) } -func (s *Strategy) PopulateLoginMethodFirstFactorRefresh(r *http.Request, sr *login.Flow) error { +func (s *Strategy) PopulateLoginMethodFirstFactorRefresh(r *http.Request, sr *login.Flow) (err error) { + ctx := r.Context() + ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.password.strategy.PopulateLoginMethodFirstFactorRefresh") + defer otelx.End(span, &err) + identifier, id, _ := flowhelpers.GuessForcedLoginIdentifier(r, s.d, sr, s.ID()) if identifier == "" { return nil } // If we don't have a password set, do not show the password field. - count, err := s.CountActiveFirstFactorCredentials(id.Credentials) + count, err := s.CountActiveFirstFactorCredentials(ctx, id.Credentials) if err != nil { return err } else if count == 0 { @@ -198,7 +210,10 @@ func (s *Strategy) PopulateLoginMethodFirstFactor(r *http.Request, sr *login.Flo return nil } -func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request, sr *login.Flow, opts ...login.FormHydratorModifier) error { +func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request, sr *login.Flow, opts ...login.FormHydratorModifier) (err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.password.strategy.PopulateLoginMethodIdentifierFirstCredentials") + defer otelx.End(span, &err) + o := login.NewFormHydratorOptions(opts) var count int @@ -206,12 +221,12 @@ func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request var err error // If we have an identity hint we can perform identity credentials discovery and // hide this credential if it should not be included. - if count, err = s.CountActiveFirstFactorCredentials(o.IdentityHint.Credentials); err != nil { + if count, err = s.CountActiveFirstFactorCredentials(ctx, o.IdentityHint.Credentials); err != nil { return err } } - if count > 0 || s.d.Config().SecurityAccountEnumerationMitigate(r.Context()) { + if count > 0 || s.d.Config().SecurityAccountEnumerationMitigate(ctx) { sr.UI.SetCSRF(s.d.GenerateCSRFToken(r)) sr.UI.SetNode(NewPasswordNode("password", node.InputAttributeAutocompleteCurrentPassword)) sr.UI.GetNodes().Append(node.NewInputField("method", "password", node.PasswordGroup, node.InputAttributeTypeSubmit).WithMetaLabel(text.NewInfoLoginPassword())) diff --git a/selfservice/strategy/password/login_test.go b/selfservice/strategy/password/login_test.go index 036aa684169b..c955bf7d8a20 100644 --- a/selfservice/strategy/password/login_test.go +++ b/selfservice/strategy/password/login_test.go @@ -96,7 +96,11 @@ func TestCompleteLogin(t *testing.T) { conf.MustSet(ctx, config.ViperKeySelfServiceErrorUI, errTS.URL+"/error-ts") conf.MustSet(ctx, config.ViperKeySelfServiceLoginUI, uiTS.URL+"/login-ts") - testhelpers.SetDefaultIdentitySchemaFromRaw(conf, loginSchema) + testhelpers.SetIdentitySchemas(t, conf, map[string]string{ + "migration": "file://./stub/migration.schema.json", + "default": "file://./stub/login.schema.json", + }) + conf.MustSet(ctx, config.ViperKeySecretsDefault, []string{"not-a-secure-session-key"}) ensureFieldsExist := func(t *testing.T, body []byte) { @@ -519,7 +523,8 @@ func TestCompleteLogin(t *testing.T) { }) t.Run("do not show password method if identity has no password set", func(t *testing.T) { - id := identity.NewIdentity("") + id := identity.NewIdentity("default") + id.NID = x.NewUUID() browserClient := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, reg, id) res, err := browserClient.Get(publicTS.URL + login.RouteInitBrowserFlow + "?refresh=true") @@ -579,7 +584,8 @@ func TestCompleteLogin(t *testing.T) { }) t.Run("do not show password method if identity has no password set", func(t *testing.T) { - id := identity.NewIdentity("") + id := identity.NewIdentity("default") + id.NID = x.NewUUID() hc := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, reg, id) res, err := hc.Do(testhelpers.NewHTTPGetAJAXRequest(t, publicTS.URL+login.RouteInitBrowserFlow+"?refresh=true")) @@ -638,7 +644,8 @@ func TestCompleteLogin(t *testing.T) { }) t.Run("do not show password method if identity has no password set", func(t *testing.T) { - id := identity.NewIdentity("") + id := identity.NewIdentity("default") + id.NID = x.NewUUID() hc := testhelpers.NewHTTPClientWithIdentitySessionToken(t, ctx, reg, id) res, err := hc.Do(testhelpers.NewHTTPGetAJAXRequest(t, publicTS.URL+login.RouteInitAPIFlow+"?refresh=true")) @@ -655,8 +662,6 @@ func TestCompleteLogin(t *testing.T) { }) t.Run("case=should return an error because not passing validation and reset previous errors and values", func(t *testing.T) { - testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/login.schema.json") - check := func(t *testing.T, actual string) { assert.NotEmpty(t, gjson.Get(actual, "id").String(), "%s", actual) assert.Contains(t, gjson.Get(actual, "ui.action").String(), publicTS.URL+login.RouteSubmitFlow, "%s", actual) @@ -839,7 +844,7 @@ func TestCompleteLogin(t *testing.T) { }) t.Run("should upgrade password not primary hashing algorithm", func(t *testing.T) { - identifier, pwd := x.NewUUID().String(), "password" + identifier, pwd := x.NewUUID().String()+"@google.com", "password" h := &hash.Pbkdf2{ Algorithm: "sha256", Iterations: 100000, @@ -850,8 +855,9 @@ func TestCompleteLogin(t *testing.T) { iId := x.NewUUID() require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &identity.Identity{ - ID: iId, - Traits: identity.Traits(fmt.Sprintf(`{"subject":"%s"}`, identifier)), + ID: iId, + SchemaID: "migration", + Traits: identity.Traits(fmt.Sprintf(`{"email":"%s"}`, identifier)), Credentials: map[identity.CredentialsType]identity.Credentials{ identity.CredentialsTypePassword: { Type: identity.CredentialsTypePassword, @@ -881,7 +887,7 @@ func TestCompleteLogin(t *testing.T) { body := testhelpers.SubmitLoginForm(t, false, browserClient, publicTS, values, false, false, http.StatusOK, redirTS.URL) - assert.Equal(t, identifier, gjson.Get(body, "identity.traits.subject").String(), "%s", body) + assert.Equal(t, identifier, gjson.Get(body, "identity.traits.email").String(), "%s", body) // check if password hash algorithm is upgraded _, c, err := reg.PrivilegedIdentityPool().FindByCredentialsIdentifier(context.Background(), identity.CredentialsTypePassword, identifier) @@ -894,7 +900,7 @@ func TestCompleteLogin(t *testing.T) { // retry after upgraded body = testhelpers.SubmitLoginForm(t, false, browserClient, publicTS, values, false, true, http.StatusOK, redirTS.URL) - assert.Equal(t, identifier, gjson.Get(body, "identity.traits.subject").String(), "%s", body) + assert.Equal(t, identifier, gjson.Get(body, "identity.traits.email").String(), "%s", body) }) t.Run("suite=password migration hook", func(t *testing.T) { @@ -1024,12 +1030,13 @@ func TestCompleteLogin(t *testing.T) { t.Cleanup(cleanup) } - identifier := x.NewUUID().String() + identifier := x.NewUUID().String() + "@google.com" password := x.NewUUID().String() iId := x.NewUUID() require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(ctx, &identity.Identity{ - ID: iId, - Traits: identity.Traits(fmt.Sprintf(`{"subject":"%s"}`, identifier)), + ID: iId, + SchemaID: "migration", + Traits: identity.Traits(fmt.Sprintf(`{"email":"%s"}`, identifier)), Credentials: map[identity.CredentialsType]identity.Credentials{ identity.CredentialsTypePassword: { Type: identity.CredentialsTypePassword, @@ -1063,7 +1070,7 @@ func TestCompleteLogin(t *testing.T) { if tc.expectSuccess { body := testhelpers.SubmitLoginForm(t, false, browserClient, publicTS, values, false, false, http.StatusOK, redirTS.URL) - assert.Equal(t, identifier, gjson.Get(body, "identity.traits.subject").String(), "%s", body) + assert.Equal(t, identifier, gjson.Get(body, "identity.traits.email").String(), "%s", body) // check if password hash algorithm is upgraded _, c, err := reg.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, identity.CredentialsTypePassword, identifier) @@ -1076,7 +1083,7 @@ func TestCompleteLogin(t *testing.T) { // retry after upgraded body = testhelpers.SubmitLoginForm(t, false, browserClient, publicTS, values, false, true, http.StatusOK, redirTS.URL) - assert.Equal(t, identifier, gjson.Get(body, "identity.traits.subject").String(), "%s", body) + assert.Equal(t, identifier, gjson.Get(body, "identity.traits.email").String(), "%s", body) } else { body := testhelpers.SubmitLoginForm(t, false, browserClient, publicTS, values, false, false, http.StatusOK, "") diff --git a/selfservice/strategy/password/registration.go b/selfservice/strategy/password/registration.go index ba11733fd0ec..10ee958e7a2c 100644 --- a/selfservice/strategy/password/registration.go +++ b/selfservice/strategy/password/registration.go @@ -8,6 +8,8 @@ import ( "encoding/json" "net/http" + "github.com/ory/x/otelx" + "github.com/ory/kratos/text" "github.com/pkg/errors" @@ -71,11 +73,14 @@ func (s *Strategy) handleRegistrationError(r *http.Request, f *registration.Flow return err } -func (s *Strategy) decode(p *UpdateRegistrationFlowWithPasswordMethod, r *http.Request) error { +func (s *Strategy) decode(p *UpdateRegistrationFlowWithPasswordMethod, r *http.Request) (err error) { return registration.DecodeBody(p, r, s.hd, s.d.Config(), registrationSchema) } func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registration.Flow, i *identity.Identity) (err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.password.strategy.Register") + defer otelx.End(span, &err) + if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.ID().String(), s.d); err != nil { return err } @@ -87,7 +92,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat f.TransientPayload = p.TransientPayload - if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return s.handleRegistrationError(r, f, p, err) } @@ -105,7 +110,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat defer close(hpw) defer close(errC) - h, err := s.d.Hasher(r.Context()).Generate(r.Context(), []byte(p.Password)) + h, err := s.d.Hasher(ctx).Generate(ctx, []byte(p.Password)) if err != nil { errC <- err return @@ -124,7 +129,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat return s.handleRegistrationError(r, f, p, err) } - if err := s.validateCredentials(r.Context(), i, p.Password); err != nil { + if err := s.validateCredentials(ctx, i, p.Password); err != nil { return s.handleRegistrationError(r, f, p, err) } @@ -142,7 +147,10 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat return nil } -func (s *Strategy) validateCredentials(ctx context.Context, i *identity.Identity, pw string) error { +func (s *Strategy) validateCredentials(ctx context.Context, i *identity.Identity, pw string) (err error) { + ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.password.strategy.validateCredentials") + defer otelx.End(span, &err) + if err := s.d.IdentityValidator().Validate(ctx, i); err != nil { return err } diff --git a/selfservice/strategy/password/settings.go b/selfservice/strategy/password/settings.go index 0995a73f5076..33ab114230c3 100644 --- a/selfservice/strategy/password/settings.go +++ b/selfservice/strategy/password/settings.go @@ -8,6 +8,10 @@ import ( "net/http" "time" + "golang.org/x/net/context" + + "github.com/ory/x/otelx" + "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" @@ -71,11 +75,14 @@ func (p *updateSettingsFlowWithPasswordMethod) SetFlowID(rid uuid.UUID) { p.Flow = rid.String() } -func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (*settings.UpdateContext, error) { +func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (_ *settings.UpdateContext, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.password.strategy.Settings") + defer otelx.End(span, &err) + var p updateSettingsFlowWithPasswordMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { - return ctxUpdate, s.continueSettingsFlow(r, ctxUpdate, p) + return ctxUpdate, s.continueSettingsFlow(ctx, r, ctxUpdate, p) } else if err != nil { return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } @@ -90,7 +97,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. // This does not come from the payload! p.Flow = ctxUpdate.Flow.ID.String() - if err := s.continueSettingsFlow(r, ctxUpdate, p); err != nil { + if err := s.continueSettingsFlow(ctx, r, ctxUpdate, p); err != nil { return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } @@ -104,22 +111,23 @@ func (s *Strategy) decodeSettingsFlow(r *http.Request, dest interface{}) error { } return decoderx.NewHTTP().Decode(r, dest, compiler, + decoderx.HTTPKeepRequestBody(true), decoderx.HTTPDecoderAllowedMethods("POST", "GET"), decoderx.HTTPDecoderSetValidatePayloads(true), decoderx.HTTPDecoderJSONFollowsFormFormat(), ) } -func (s *Strategy) continueSettingsFlow(r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasswordMethod) error { - if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), p.Method, s.d); err != nil { +func (s *Strategy) continueSettingsFlow(ctx context.Context, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithPasswordMethod) error { + if err := flow.MethodEnabledAndAllowed(ctx, flow.SettingsFlow, s.SettingsStrategyID(), p.Method, s.d); err != nil { return err } - if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return err } - if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(r.Context())).Before(time.Now()) { + if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(ctx)).Before(time.Now()) { return errors.WithStack(settings.NewFlowNeedsReAuth()) } @@ -131,7 +139,7 @@ func (s *Strategy) continueSettingsFlow(r *http.Request, ctxUpdate *settings.Upd go func() { defer close(hpw) defer close(errC) - h, err := s.d.Hasher(r.Context()).Generate(r.Context(), []byte(p.Password)) + h, err := s.d.Hasher(ctx).Generate(ctx, []byte(p.Password)) if err != nil { errC <- err return @@ -139,13 +147,13 @@ func (s *Strategy) continueSettingsFlow(r *http.Request, ctxUpdate *settings.Upd hpw <- h }() - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.Identity.ID) + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, ctxUpdate.Session.Identity.ID) if err != nil { return err } i.UpsertCredentialsConfig(s.ID(), []byte("{}"), 0) - if err := s.validateCredentials(r.Context(), i, p.Password); err != nil { + if err := s.validateCredentials(ctx, i, p.Password); err != nil { return err } diff --git a/selfservice/strategy/password/settings_test.go b/selfservice/strategy/password/settings_test.go index e5ea909856b0..c670395d5f24 100644 --- a/selfservice/strategy/password/settings_test.go +++ b/selfservice/strategy/password/settings_test.go @@ -44,7 +44,8 @@ func init() { func newIdentityWithPassword(email string) *identity.Identity { return &identity.Identity{ - ID: x.NewUUID(), + ID: x.NewUUID(), + NID: x.NewUUID(), Credentials: map[identity.CredentialsType]identity.Credentials{ "password": { Type: "password", @@ -61,6 +62,7 @@ func newIdentityWithPassword(email string) *identity.Identity { func newEmptyIdentity() *identity.Identity { return &identity.Identity{ ID: x.NewUUID(), + NID: x.NewUUID(), State: identity.StateActive, Traits: identity.Traits(`{}`), SchemaID: config.DefaultIdentityTraitsSchemaID, @@ -70,6 +72,7 @@ func newEmptyIdentity() *identity.Identity { func newIdentityWithoutCredentials(email string) *identity.Identity { return &identity.Identity{ ID: x.NewUUID(), + NID: x.NewUUID(), State: identity.StateActive, Traits: identity.Traits(`{"email":"` + email + `"}`), SchemaID: config.DefaultIdentityTraitsSchemaID, diff --git a/selfservice/strategy/password/strategy.go b/selfservice/strategy/password/strategy.go index ae57982dd89f..9aa36ebeb8a1 100644 --- a/selfservice/strategy/password/strategy.go +++ b/selfservice/strategy/password/strategy.go @@ -6,6 +6,7 @@ package password import ( "context" "encoding/json" + "strings" "github.com/go-playground/validator/v10" "github.com/pkg/errors" @@ -67,6 +68,7 @@ type registrationStrategyDependencies interface { identity.PrivilegedPoolProvider identity.ValidationProvider + identity.ManagementProvider session.HandlerProvider session.ManagementProvider @@ -86,7 +88,7 @@ func NewStrategy(d any) *Strategy { } } -func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveFirstFactorCredentials(ctx context.Context, cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { for _, c := range cc { if c.Type == s.ID() && len(c.Config) > 0 { var conf identity.CredentialsPassword @@ -94,8 +96,9 @@ func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.Credentials return 0, errors.WithStack(err) } - if len(c.Identifiers) > 0 && len(c.Identifiers[0]) > 0 && - (hash.IsBcryptHash([]byte(conf.HashedPassword)) || hash.IsArgon2idHash([]byte(conf.HashedPassword))) { + if len(strings.Join(c.Identifiers, "")) > 0 && + ((s.d.Config().PasswordMigrationHook(ctx).Enabled && conf.UsePasswordMigrationHook) || + len(conf.HashedPassword) > 0) { count++ } } @@ -103,7 +106,7 @@ func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.Credentials return } -func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveMultiFactorCredentials(_ context.Context, _ map[identity.CredentialsType]identity.Credentials) (count int, err error) { return 0, nil } @@ -111,7 +114,7 @@ func (s *Strategy) ID() identity.CredentialsType { return identity.CredentialsTypePassword } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(_ context.Context) session.AuthenticationMethod { return session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel1, diff --git a/selfservice/strategy/password/strategy_test.go b/selfservice/strategy/password/strategy_test.go index e27a0950ca28..dc86e5f1a85e 100644 --- a/selfservice/strategy/password/strategy_test.go +++ b/selfservice/strategy/password/strategy_test.go @@ -4,20 +4,36 @@ package password_test import ( + "cmp" "context" + "encoding/json" "fmt" "testing" + "github.com/ory/kratos/driver/config" + confighelpers "github.com/ory/kratos/driver/config/testhelpers" + hash2 "github.com/ory/kratos/hash" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/go-faker/faker/v4" + "github.com/ory/kratos/identity" "github.com/ory/kratos/internal" "github.com/ory/kratos/selfservice/strategy/password" ) +func generateRandomConfig(t *testing.T) (identity.CredentialsPassword, []byte) { + t.Helper() + var cred identity.CredentialsPassword + require.NoError(t, faker.FakeData(&cred)) + c, err := json.Marshal(cred) + require.NoError(t, err) + return cred, c +} + func TestCountActiveFirstFactorCredentials(t *testing.T) { ctx := context.Background() _, reg := internal.NewFastRegistryWithMocks(t) @@ -28,75 +44,127 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { h2, err := reg.Hasher(ctx).Generate(context.Background(), []byte("a password")) require.NoError(t, err) - for k, tc := range []struct { - in map[identity.CredentialsType]identity.Credentials - expected int - }{ - { - in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { - Type: strategy.ID(), - Config: []byte{}, - }}, - expected: 0, - }, - { - in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { - Type: strategy.ID(), - Config: []byte(`{"hashed_password": "` + string(h1) + `"}`), - }}, - expected: 0, - }, - { - in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { - Type: strategy.ID(), - Identifiers: []string{""}, - Config: []byte(`{"hashed_password": "` + string(h1) + `"}`), - }}, - expected: 0, - }, - { - in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { - Type: strategy.ID(), - Identifiers: []string{"foo"}, - Config: []byte(`{"hashed_password": "` + string(h1) + `"}`), - }}, - expected: 1, - }, - { - in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { - Type: strategy.ID(), - Identifiers: []string{"foo"}, - Config: []byte(`{"hashed_password": "` + string(h2) + `"}`), - }}, - expected: 1, - }, - { - in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { - Type: strategy.ID(), - Config: []byte(`{"hashed_password": "asdf"}`), - }}, - expected: 0, - }, - { - in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { - Type: strategy.ID(), - Config: []byte(`{}`), - }}, - expected: 0, - }, - { - in: nil, - expected: 0, - }, - } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - actual, err := strategy.CountActiveFirstFactorCredentials(tc.in) - assert.NoError(t, err) - assert.Equal(t, tc.expected, actual) + t.Run("test regressions fixtures", func(t *testing.T) { + // This test ensures we do not add regressions to this method by, for example, adding a new field. + for k := 0; k < 100; k++ { + t.Run(fmt.Sprintf("run=%d", k), func(t *testing.T) { + cred, c := generateRandomConfig(t) + actual, err := strategy.CountActiveFirstFactorCredentials(ctx, map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Identifiers: []string{"foo"}, + Config: c, + }}) + assert.NoError(t, err) + + if len(cred.HashedPassword) == 0 && cred.UsePasswordMigrationHook { + // This case is OK + assert.Equal(t, 0, actual) + return + } + + assert.Equal(t, 1, actual) + }) + } + }) + + t.Run("with fixtures", func(t *testing.T) { + for k, tc := range []struct { + in map[identity.CredentialsType]identity.Credentials + expected int + ctx context.Context + }{ + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Config: []byte{}, + }}, + expected: 0, + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Config: []byte(`{"hashed_password": "` + string(h1) + `"}`), + }}, + expected: 0, + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Identifiers: []string{""}, + Config: []byte(`{"hashed_password": "` + string(h1) + `"}`), + }}, + expected: 0, + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Identifiers: []string{"foo"}, + Config: []byte(`{"hashed_password": "` + string(h1) + `"}`), + }}, + expected: 1, + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Identifiers: []string{"foo"}, + Config: []byte(`{"hashed_password": "` + string(h2) + `"}`), + }}, + expected: 1, + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Identifiers: []string{"foo"}, + Config: []byte(`{"use_password_migration_hook":true}`), + }}, + expected: 1, + ctx: confighelpers.WithConfigValue(ctx, config.ViperKeyPasswordMigrationHook+".enabled", true), + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Identifiers: []string{"foo"}, + Config: []byte(`{"use_password_migration_hook":true}`), + }}, + expected: 0, + ctx: confighelpers.WithConfigValue(ctx, config.ViperKeyPasswordMigrationHook+".enabled", false), + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Identifiers: []string{"foo"}, + Config: []byte(`{"use_password_migration_hook":false}`), + }}, + expected: 0, + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Config: []byte(`{"hashed_password": "asdf"}`), + }}, + expected: 0, + }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Config: []byte(`{}`), + }}, + expected: 0, + }, + { + in: nil, + expected: 0, + }, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + actual, err := strategy.CountActiveFirstFactorCredentials(cmp.Or(tc.ctx, ctx), tc.in) + assert.NoError(t, err) + assert.Equal(t, tc.expected, actual) - actual, err = strategy.CountActiveMultiFactorCredentials(tc.in) - assert.NoError(t, err) - assert.Equal(t, 0, actual) - }) - } + actual, err = strategy.CountActiveMultiFactorCredentials(cmp.Or(tc.ctx, ctx), tc.in) + assert.NoError(t, err) + assert.Equal(t, 0, actual) + }) + } + }) } diff --git a/selfservice/strategy/password/stub/migration.schema.json b/selfservice/strategy/password/stub/migration.schema.json new file mode 100644 index 000000000000..5ac7314cd7d9 --- /dev/null +++ b/selfservice/strategy/password/stub/migration.schema.json @@ -0,0 +1,36 @@ +{ + "$id": "https://example.com/person.schema.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Person", + "type": "object", + "properties": { + "traits": { + "type": "object", + "properties": { + "email": { + "type": "string", + "ory.sh/kratos": { + "credentials": { + "password": { + "identifier": true + }, + "webauthn": { + "identifier": true + } + }, + "verification": { + "via": "email" + }, + "recovery": { + "via": "email" + } + } + } + }, + "required": [ + "email" + ] + } + }, + "additionalProperties": false +} diff --git a/selfservice/strategy/profile/strategy.go b/selfservice/strategy/profile/strategy.go index 644f8dd6c263..1f71fdf3f6aa 100644 --- a/selfservice/strategy/profile/strategy.go +++ b/selfservice/strategy/profile/strategy.go @@ -9,6 +9,8 @@ import ( "net/http" "time" + "github.com/ory/x/otelx" + "github.com/ory/jsonschema/v3" "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/text" @@ -40,6 +42,7 @@ type ( x.CSRFTokenGeneratorProvider x.WriterProvider x.LoggingProvider + x.TracingProvider config.Provider @@ -109,11 +112,14 @@ func (s *Strategy) PopulateSettingsMethod(r *http.Request, id *identity.Identity return nil } -func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (*settings.UpdateContext, error) { +func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (_ *settings.UpdateContext, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.profile.strategy.Settings") + defer otelx.End(span, &err) + var p updateSettingsFlowWithProfileMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { - return ctxUpdate, s.continueFlow(r, ctxUpdate, p) + return ctxUpdate, s.continueFlow(ctx, r, ctxUpdate, p) } else if err != nil { return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, p, err) } @@ -122,7 +128,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. return ctxUpdate, err } - option, err := s.newSettingsProfileDecoder(r.Context(), ctxUpdate.GetSessionIdentity()) + option, err := s.newSettingsProfileDecoder(ctx, ctxUpdate.GetSessionIdentity()) if err != nil { return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, p, err) } @@ -138,19 +144,19 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. // Reset after decoding form p.SetFlowID(ctxUpdate.Flow.ID) - if err := s.continueFlow(r, ctxUpdate, p); err != nil { + if err := s.continueFlow(ctx, r, ctxUpdate, p); err != nil { return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, nil, p, err) } return ctxUpdate, nil } -func (s *Strategy) continueFlow(r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithProfileMethod) error { - if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), p.Method, s.d); err != nil { +func (s *Strategy) continueFlow(ctx context.Context, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithProfileMethod) error { + if err := flow.MethodEnabledAndAllowed(ctx, flow.SettingsFlow, s.SettingsStrategyID(), p.Method, s.d); err != nil { return err } - if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return err } @@ -163,12 +169,12 @@ func (s *Strategy) continueFlow(r *http.Request, ctxUpdate *settings.UpdateConte } options := []identity.ManagerOption{identity.ManagerExposeValidationErrorsForInternalTypeAssertion} - ttl := s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(r.Context()) + ttl := s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(ctx) if ctxUpdate.Session.AuthenticatedAt.Add(ttl).After(time.Now()) { options = append(options, identity.ManagerAllowWriteProtectedTraits) } - update, err := s.d.IdentityManager().SetTraits(r.Context(), ctxUpdate.GetSessionIdentity().ID, identity.Traits(p.Traits), options...) + update, err := s.d.IdentityManager().SetTraits(ctx, ctxUpdate.GetSessionIdentity().ID, identity.Traits(p.Traits), options...) if err != nil { if errors.Is(err, identity.ErrProtectedFieldModified) { return settings.NewFlowNeedsReAuth() diff --git a/selfservice/strategy/profile/strategy_test.go b/selfservice/strategy/profile/strategy_test.go index 996ac08796fb..d34c3f9e94f6 100644 --- a/selfservice/strategy/profile/strategy_test.go +++ b/selfservice/strategy/profile/strategy_test.go @@ -89,8 +89,14 @@ func TestStrategyTraits(t *testing.T) { browserIdentity1 := newIdentityWithPassword("john-browser@doe.com") apiIdentity1 := newIdentityWithPassword("john-api@doe.com") - browserIdentity2 := &identity.Identity{ID: x.NewUUID(), Traits: identity.Traits(`{}`), State: identity.StateActive} - apiIdentity2 := &identity.Identity{ID: x.NewUUID(), Traits: identity.Traits(`{}`), State: identity.StateActive} + browserID2 := x.NewUUID() + browserIdentity2 := &identity.Identity{ID: browserID2, Traits: identity.Traits(`{}`), State: identity.StateActive, Credentials: map[identity.CredentialsType]identity.Credentials{ + identity.CredentialsTypePassword: {Type: "password", Identifiers: []string{browserID2.String()}, Config: []byte(`{"hashed_password":"$2a$04$zvZz1zV"}`)}, + }} + apiID2 := x.NewUUID() + apiIdentity2 := &identity.Identity{ID: apiID2, Traits: identity.Traits(`{}`), State: identity.StateActive, Credentials: map[identity.CredentialsType]identity.Credentials{ + identity.CredentialsTypePassword: {Type: "password", Identifiers: []string{apiID2.String()}, Config: []byte(`{"hashed_password":"$2a$04$zvZz1zV"}`)}, + }} browserUser1 := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, reg, browserIdentity1) browserUser2 := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, reg, browserIdentity2) diff --git a/selfservice/strategy/totp/login.go b/selfservice/strategy/totp/login.go index 2aaface8dc5c..2db4cd36038c 100644 --- a/selfservice/strategy/totp/login.go +++ b/selfservice/strategy/totp/login.go @@ -7,6 +7,8 @@ import ( "encoding/json" "net/http" + "github.com/ory/x/otelx" + "github.com/pkg/errors" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" @@ -91,6 +93,9 @@ type updateLoginFlowWithTotpMethod struct { } func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, sess *session.Session) (i *identity.Identity, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.totp.strategy.Login") + defer otelx.End(span, &err) + if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel2); err != nil { return nil, err } @@ -108,11 +113,11 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } f.TransientPayload = p.TransientPayload - if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return nil, s.handleLoginError(r, f, err) } - i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), s.ID(), sess.IdentityID.String()) + i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, s.ID(), sess.IdentityID.String()) if err != nil { return nil, s.handleLoginError(r, f, errors.WithStack(schema.NewNoTOTPDeviceRegistered())) } @@ -132,7 +137,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } f.Active = s.ID() - if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil { + if err = s.d.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { return nil, s.handleLoginError(r, f, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error()))) } diff --git a/selfservice/strategy/totp/settings.go b/selfservice/strategy/totp/settings.go index 0c24d915868b..367b521a1709 100644 --- a/selfservice/strategy/totp/settings.go +++ b/selfservice/strategy/totp/settings.go @@ -9,6 +9,8 @@ import ( "net/http" "time" + "github.com/ory/x/otelx" + "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/tidwall/gjson" @@ -83,7 +85,10 @@ func (p *updateSettingsFlowWithTotpMethod) SetFlowID(rid uuid.UUID) { p.Flow = rid.String() } -func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (*settings.UpdateContext, error) { +func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (_ *settings.UpdateContext, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.oidc.strategy.Settings") + defer otelx.End(span, &err) + var p updateSettingsFlowWithTotpMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { @@ -99,7 +104,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. if p.UnlinkTOTP { // This is a submit so we need to manually set the type to TOTP p.Method = s.SettingsStrategyID() - if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { return nil, s.handleSettingsError(w, r, ctxUpdate, p, err) } } else if err := flow.MethodEnabledAndAllowedFromRequest(r, f.GetFlowName(), s.SettingsStrategyID(), s.d); err != nil { @@ -122,6 +127,7 @@ func (s *Strategy) decodeSettingsFlow(r *http.Request, dest interface{}) error { } return decoderx.NewHTTP().Decode(r, dest, compiler, + decoderx.HTTPKeepRequestBody(true), decoderx.HTTPDecoderAllowedMethods("POST", "GET"), decoderx.HTTPDecoderSetValidatePayloads(true), decoderx.HTTPDecoderJSONFollowsFormFormat(), @@ -245,7 +251,7 @@ func (s *Strategy) identityHasTOTP(ctx context.Context, id uuid.UUID) (bool, err return false, err } - count, err := s.CountActiveMultiFactorCredentials(confidential.Credentials) + count, err := s.CountActiveMultiFactorCredentials(ctx, confidential.Credentials) if err != nil { return false, err } diff --git a/selfservice/strategy/totp/strategy.go b/selfservice/strategy/totp/strategy.go index 6c3205abd9ac..d4f30ba09c67 100644 --- a/selfservice/strategy/totp/strategy.go +++ b/selfservice/strategy/totp/strategy.go @@ -35,6 +35,7 @@ type totpStrategyDependencies interface { x.WriterProvider x.CSRFTokenGeneratorProvider x.CSRFProvider + x.TracingProvider config.Provider @@ -80,11 +81,11 @@ func NewStrategy(d any) *Strategy { } } -func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveFirstFactorCredentials(_ context.Context, _ map[identity.CredentialsType]identity.Credentials) (count int, err error) { return 0, nil } -func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveMultiFactorCredentials(_ context.Context, cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { for _, c := range cc { if c.Type == s.ID() && len(c.Config) > 0 { var conf identity.CredentialsTOTPConfig @@ -93,7 +94,7 @@ func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.Credentials } _, err := otp.NewKeyFromURL(conf.TOTPURL) - if len(c.Identifiers) > 0 && len(c.Identifiers[0]) > 0 && len(conf.TOTPURL) > 0 && err == nil { + if len(conf.TOTPURL) > 0 && err == nil { count++ } } @@ -109,7 +110,7 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.TOTPGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { return session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel2, diff --git a/selfservice/strategy/totp/strategy_test.go b/selfservice/strategy/totp/strategy_test.go index 17c507c9d144..7bce7fe16961 100644 --- a/selfservice/strategy/totp/strategy_test.go +++ b/selfservice/strategy/totp/strategy_test.go @@ -25,7 +25,7 @@ func TestCountActiveCredentials(t *testing.T) { require.NoError(t, err) t.Run("first factor", func(t *testing.T) { - actual, err := strategy.CountActiveFirstFactorCredentials(nil) + actual, err := strategy.CountActiveFirstFactorCredentials(nil, nil) require.NoError(t, err) assert.Equal(t, 0, actual) }) @@ -75,7 +75,7 @@ func TestCountActiveCredentials(t *testing.T) { cc[c.Type] = c } - actual, err := strategy.CountActiveMultiFactorCredentials(cc) + actual, err := strategy.CountActiveMultiFactorCredentials(nil, cc) require.NoError(t, err) assert.Equal(t, tc.expected, actual) }) diff --git a/selfservice/strategy/webauthn/login.go b/selfservice/strategy/webauthn/login.go index fe98d1d88c55..d3a46ddc5085 100644 --- a/selfservice/strategy/webauthn/login.go +++ b/selfservice/strategy/webauthn/login.go @@ -4,11 +4,14 @@ package webauthn import ( + "context" "encoding/json" "net/http" "strings" "time" + "github.com/ory/x/otelx" + "github.com/ory/kratos/selfservice/strategy/idfirst" "github.com/ory/kratos/selfservice/flowhelpers" @@ -145,12 +148,16 @@ type updateLoginFlowWithWebAuthnMethod struct { } func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, sess *session.Session) (i *identity.Identity, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.webauthn.strategy.Login") + defer otelx.End(span, &err) + if f.Type != flow.TypeBrowser { return nil, flow.ErrStrategyNotResponsible } var p updateLoginFlowWithWebAuthnMethod if err := s.hd.Decode(r, &p, + decoderx.HTTPKeepRequestBody(true), decoderx.HTTPDecoderSetValidatePayloads(true), decoderx.MustHTTPRawJSONSchemaCompiler(loginSchema), decoderx.HTTPDecoderJSONFollowsFormFormat()); err != nil { @@ -165,27 +172,30 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, flow.ErrStrategyNotResponsible } - if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { return nil, s.handleLoginError(r, f, err) } - if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return nil, s.handleLoginError(r, f, err) } - if s.d.Config().WebAuthnForPasswordless(r.Context()) || f.IsRefresh() && f.RequestedAAL == identity.AuthenticatorAssuranceLevel1 { - return s.loginPasswordless(w, r, f, &p) + if s.d.Config().WebAuthnForPasswordless(ctx) || f.IsRefresh() && f.RequestedAAL == identity.AuthenticatorAssuranceLevel1 { + return s.loginPasswordless(ctx, w, r, f, &p) } - return s.loginMultiFactor(w, r, f, sess.IdentityID, &p) + return s.loginMultiFactor(ctx, w, r, f, sess.IdentityID, &p) } -func (s *Strategy) loginPasswordless(w http.ResponseWriter, r *http.Request, f *login.Flow, p *updateLoginFlowWithWebAuthnMethod) (i *identity.Identity, err error) { +func (s *Strategy) loginPasswordless(ctx context.Context, w http.ResponseWriter, r *http.Request, f *login.Flow, p *updateLoginFlowWithWebAuthnMethod) (i *identity.Identity, err error) { + ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.webauthn.strategy.loginPasswordless") + defer otelx.End(span, &err) + if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel1); err != nil { return nil, s.handleLoginError(r, f, err) } - if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, f.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return nil, s.handleLoginError(r, f, err) } @@ -193,9 +203,9 @@ func (s *Strategy) loginPasswordless(w http.ResponseWriter, r *http.Request, f * return nil, s.handleLoginError(r, f, errors.WithStack(herodot.ErrBadRequest.WithReason("identifier is required"))) } - i, _, err = s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), s.ID(), p.Identifier) + i, _, err = s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, s.ID(), p.Identifier) if err != nil { - time.Sleep(x.RandomDelay(s.d.Config().HasherArgon2(r.Context()).ExpectedDuration, s.d.Config().HasherArgon2(r.Context()).ExpectedDeviation)) + time.Sleep(x.RandomDelay(s.d.Config().HasherArgon2(ctx).ExpectedDuration, s.d.Config().HasherArgon2(ctx).ExpectedDeviation)) return nil, s.handleLoginError(r, f, errors.WithStack(schema.NewNoWebAuthnCredentials())) } @@ -216,25 +226,28 @@ func (s *Strategy) loginPasswordless(w http.ResponseWriter, r *http.Request, f * f.UI.SetCSRF(s.d.GenerateCSRFToken(r)) f.UI.Messages.Add(text.NewInfoLoginWebAuthnPasswordless()) f.UI.SetNode(node.NewInputField("identifier", p.Identifier, node.DefaultGroup, node.InputAttributeTypeHidden, node.WithRequiredInputAttribute)) - if err := s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil { + if err := s.d.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { return nil, s.handleLoginError(r, f, err) } - redirectTo := f.AppendTo(s.d.Config().SelfServiceFlowLoginUI(r.Context())).String() + redirectTo := f.AppendTo(s.d.Config().SelfServiceFlowLoginUI(ctx)).String() if x.IsJSONRequest(r) { s.d.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(redirectTo)) } else { - http.Redirect(w, r, f.AppendTo(s.d.Config().SelfServiceFlowLoginUI(r.Context())).String(), http.StatusSeeOther) + http.Redirect(w, r, f.AppendTo(s.d.Config().SelfServiceFlowLoginUI(ctx)).String(), http.StatusSeeOther) } return nil, errors.WithStack(flow.ErrCompletedByStrategy) } - return s.loginAuthenticate(w, r, f, i.ID, p, identity.AuthenticatorAssuranceLevel1) + return s.loginAuthenticate(ctx, r, f, i.ID, p, identity.AuthenticatorAssuranceLevel1) } -func (s *Strategy) loginAuthenticate(_ http.ResponseWriter, r *http.Request, f *login.Flow, identityID uuid.UUID, p *updateLoginFlowWithWebAuthnMethod, aal identity.AuthenticatorAssuranceLevel) (*identity.Identity, error) { - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), identityID) +func (s *Strategy) loginAuthenticate(ctx context.Context, r *http.Request, f *login.Flow, identityID uuid.UUID, p *updateLoginFlowWithWebAuthnMethod, aal identity.AuthenticatorAssuranceLevel) (_ *identity.Identity, err error) { + ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.webauthn.strategy.loginAuthenticate") + defer otelx.End(span, &err) + + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, identityID) if err != nil { return nil, s.handleLoginError(r, f, errors.WithStack(schema.NewNoWebAuthnRegistered())) } @@ -249,7 +262,7 @@ func (s *Strategy) loginAuthenticate(_ http.ResponseWriter, r *http.Request, f * return nil, s.handleLoginError(r, f, errors.WithStack(herodot.ErrInternalServerError.WithReason("The WebAuthn credentials could not be decoded properly").WithDebug(err.Error()).WithWrap(err))) } - web, err := webauthn.New(s.d.Config().WebAuthnConfig(r.Context())) + web, err := webauthn.New(s.d.Config().WebAuthnConfig(ctx)) if err != nil { return nil, s.handleLoginError(r, f, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to get webAuthn config.").WithDebug(err.Error()))) } @@ -280,18 +293,18 @@ func (s *Strategy) loginAuthenticate(_ http.ResponseWriter, r *http.Request, f * } f.Active = s.ID() - if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil { + if err = s.d.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil { return nil, s.handleLoginError(r, f, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error()))) } return i, nil } -func (s *Strategy) loginMultiFactor(w http.ResponseWriter, r *http.Request, f *login.Flow, identityID uuid.UUID, p *updateLoginFlowWithWebAuthnMethod) (*identity.Identity, error) { +func (s *Strategy) loginMultiFactor(ctx context.Context, w http.ResponseWriter, r *http.Request, f *login.Flow, identityID uuid.UUID, p *updateLoginFlowWithWebAuthnMethod) (*identity.Identity, error) { if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel2); err != nil { return nil, err } - return s.loginAuthenticate(w, r, f, identityID, p, identity.AuthenticatorAssuranceLevel2) + return s.loginAuthenticate(ctx, r, f, identityID, p, identity.AuthenticatorAssuranceLevel2) } func (s *Strategy) populateLoginMethodRefresh(r *http.Request, sr *login.Flow) error { @@ -388,7 +401,7 @@ func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request var err error // If we have an identity hint we can perform identity credentials discovery and // hide this credential if it should not be included. - if count, err = s.CountActiveFirstFactorCredentials(o.IdentityHint.Credentials); err != nil { + if count, err = s.CountActiveFirstFactorCredentials(r.Context(), o.IdentityHint.Credentials); err != nil { return err } } diff --git a/selfservice/strategy/webauthn/login_test.go b/selfservice/strategy/webauthn/login_test.go index f1323b9b6789..6a98f3b3d383 100644 --- a/selfservice/strategy/webauthn/login_test.go +++ b/selfservice/strategy/webauthn/login_test.go @@ -284,6 +284,7 @@ func TestCompleteLogin(t *testing.T) { for _, f := range []string{"browser", "spa"} { t.Run(f, func(t *testing.T) { id := identity.NewIdentity("") + id.NID = x.NewUUID() client := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, reg, id) f := testhelpers.InitializeLoginFlowViaBrowser(t, client, publicTS, true, f == "spa", false, false) diff --git a/selfservice/strategy/webauthn/registration.go b/selfservice/strategy/webauthn/registration.go index 81d94e0028e7..df0230abdafb 100644 --- a/selfservice/strategy/webauthn/registration.go +++ b/selfservice/strategy/webauthn/registration.go @@ -8,6 +8,8 @@ import ( "net/http" "strings" + "github.com/ory/x/otelx" + "github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/webauthn" "github.com/pkg/errors" @@ -91,7 +93,8 @@ func (s *Strategy) decode(p *updateRegistrationFlowWithWebAuthnMethod, r *http.R } func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, regFlow *registration.Flow, i *identity.Identity) (err error) { - ctx := r.Context() + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.webauthn.strategy.Register") + defer otelx.End(span, &err) if regFlow.Type != flow.TypeBrowser || !s.d.Config().WebAuthnForPasswordless(r.Context()) { return flow.ErrStrategyNotResponsible diff --git a/selfservice/strategy/webauthn/settings.go b/selfservice/strategy/webauthn/settings.go index 6e97c31b54f9..dc2585cf376a 100644 --- a/selfservice/strategy/webauthn/settings.go +++ b/selfservice/strategy/webauthn/settings.go @@ -10,6 +10,10 @@ import ( "strings" "time" + "golang.org/x/net/context" + + "github.com/ory/x/otelx" + "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" "github.com/ory/kratos/x/webauthnx" @@ -98,14 +102,17 @@ func (p *updateSettingsFlowWithWebAuthnMethod) SetFlowID(rid uuid.UUID) { p.Flow = rid.String() } -func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (*settings.UpdateContext, error) { +func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings.Flow, ss *session.Session) (_ *settings.UpdateContext, err error) { + ctx, span := s.d.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.strategy.webauthn.strategy.Settings") + defer otelx.End(span, &err) + if f.Type != flow.TypeBrowser { return nil, flow.ErrStrategyNotResponsible } var p updateSettingsFlowWithWebAuthnMethod ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, f, ss, settings.ContinuityKey(s.SettingsStrategyID()), &p) if errors.Is(err, settings.ErrContinuePreviousAction) { - return ctxUpdate, s.continueSettingsFlow(w, r, ctxUpdate, p) + return ctxUpdate, s.continueSettingsFlow(ctx, w, r, ctxUpdate, p) } else if err != nil { return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } @@ -117,7 +124,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. if len(p.Register+p.Remove) > 0 { // This method has only two submit buttons p.Method = s.SettingsStrategyID() - if err := flow.MethodEnabledAndAllowed(r.Context(), f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, f.GetFlowName(), s.SettingsStrategyID(), p.Method, s.d); err != nil { return nil, s.handleSettingsError(w, r, ctxUpdate, p, err) } } else { @@ -126,7 +133,7 @@ func (s *Strategy) Settings(w http.ResponseWriter, r *http.Request, f *settings. // This does not come from the payload! p.Flow = ctxUpdate.Flow.ID.String() - if err := s.continueSettingsFlow(w, r, ctxUpdate, p); err != nil { + if err := s.continueSettingsFlow(ctx, w, r, ctxUpdate, p); err != nil { return ctxUpdate, s.handleSettingsError(w, r, ctxUpdate, p, err) } @@ -140,6 +147,7 @@ func (s *Strategy) decodeSettingsFlow(r *http.Request, dest interface{}) error { } return decoderx.NewHTTP().Decode(r, dest, compiler, + decoderx.HTTPKeepRequestBody(true), decoderx.HTTPDecoderAllowedMethods("POST", "GET"), decoderx.HTTPDecoderSetValidatePayloads(true), decoderx.HTTPDecoderJSONFollowsFormFormat(), @@ -147,19 +155,20 @@ func (s *Strategy) decodeSettingsFlow(r *http.Request, dest interface{}) error { } func (s *Strategy) continueSettingsFlow( + ctx context.Context, w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithWebAuthnMethod, ) error { if len(p.Register+p.Remove) > 0 { - if err := flow.MethodEnabledAndAllowed(r.Context(), flow.SettingsFlow, s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { + if err := flow.MethodEnabledAndAllowed(ctx, flow.SettingsFlow, s.SettingsStrategyID(), s.SettingsStrategyID(), s.d); err != nil { return err } - if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(r.Context()), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { + if err := flow.EnsureCSRF(s.d, r, ctxUpdate.Flow.Type, s.d.Config().DisableAPIFlowEnforcement(ctx), s.d.GenerateCSRFToken, p.CSRFToken); err != nil { return err } - if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(r.Context())).Before(time.Now()) { + if ctxUpdate.Session.AuthenticatedAt.Add(s.d.Config().SelfServiceFlowSettingsPrivilegedSessionMaxAge(ctx)).Before(time.Now()) { return errors.WithStack(settings.NewFlowNeedsReAuth()) } } else { @@ -167,16 +176,16 @@ func (s *Strategy) continueSettingsFlow( } if len(p.Register) > 0 { - return s.continueSettingsFlowAdd(r, ctxUpdate, p) + return s.continueSettingsFlowAdd(ctx, ctxUpdate, p) } else if len(p.Remove) > 0 { - return s.continueSettingsFlowRemove(w, r, ctxUpdate, p) + return s.continueSettingsFlowRemove(ctx, w, r, ctxUpdate, p) } return errors.New("ended up in unexpected state") } -func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithWebAuthnMethod) error { - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.IdentityID) +func (s *Strategy) continueSettingsFlowRemove(ctx context.Context, w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithWebAuthnMethod) error { + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, ctxUpdate.Session.IdentityID) if err != nil { return err } @@ -205,7 +214,7 @@ func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Req return errors.WithStack(herodot.ErrBadRequest.WithReasonf("You tried to remove a WebAuthn credential which does not exist.")) } - count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(r.Context(), i) + count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(ctx, i) if err != nil { return err } @@ -231,7 +240,7 @@ func (s *Strategy) continueSettingsFlowRemove(w http.ResponseWriter, r *http.Req return nil } -func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithWebAuthnMethod) error { +func (s *Strategy) continueSettingsFlowAdd(ctx context.Context, ctxUpdate *settings.UpdateContext, p updateSettingsFlowWithWebAuthnMethod) error { webAuthnSession := gjson.GetBytes(ctxUpdate.Flow.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeySessionData)) if !webAuthnSession.IsObject() { return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected WebAuthN in internal context to be an object.")) @@ -247,7 +256,7 @@ func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings. return errors.WithStack(herodot.ErrBadRequest.WithReasonf("Unable to parse WebAuthn response: %s", err)) } - web, err := webauthn.New(s.d.Config().WebAuthnConfig(r.Context())) + web, err := webauthn.New(s.d.Config().WebAuthnConfig(ctx)) if err != nil { return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to get webAuthn config.").WithDebug(err.Error())) } @@ -257,7 +266,7 @@ func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings. return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to create WebAuthn credential: %s", err)) } - i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), ctxUpdate.Session.IdentityID) + i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, ctxUpdate.Session.IdentityID) if err != nil { return err } @@ -269,10 +278,10 @@ func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings. return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to decode identity credentials.").WithDebug(err.Error())) } - wc := identity.CredentialFromWebAuthn(credential, s.d.Config().WebAuthnForPasswordless(r.Context())) + wc := identity.CredentialFromWebAuthn(credential, s.d.Config().WebAuthnForPasswordless(ctx)) wc.AddedAt = time.Now().UTC().Round(time.Second) wc.DisplayName = p.RegisterDisplayName - wc.IsPasswordless = s.d.Config().WebAuthnForPasswordless(r.Context()) + wc.IsPasswordless = s.d.Config().WebAuthnForPasswordless(ctx) cc.UserHandle = ctxUpdate.Session.IdentityID[:] cc.Credentials = append(cc.Credentials, *wc) @@ -282,7 +291,7 @@ func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings. } i.UpsertCredentialsConfig(s.ID(), co, 1) - if err := s.validateCredentials(r.Context(), i); err != nil { + if err := s.validateCredentials(ctx, i); err != nil { return err } @@ -292,17 +301,17 @@ func (s *Strategy) continueSettingsFlowAdd(r *http.Request, ctxUpdate *settings. return err } - if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(r.Context(), ctxUpdate.Flow); err != nil { + if err := s.d.SettingsFlowPersister().UpdateSettingsFlow(ctx, ctxUpdate.Flow); err != nil { return err } aal := identity.AuthenticatorAssuranceLevel1 - if !s.d.Config().WebAuthnForPasswordless(r.Context()) { + if !s.d.Config().WebAuthnForPasswordless(ctx) { aal = identity.AuthenticatorAssuranceLevel2 } // Since we added the method, it also means that we have authenticated it - if err := s.d.SessionManager().SessionAddAuthenticationMethods(r.Context(), ctxUpdate.Session.ID, session.AuthenticationMethod{ + if err := s.d.SessionManager().SessionAddAuthenticationMethods(ctx, ctxUpdate.Session.ID, session.AuthenticationMethod{ Method: s.ID(), AAL: aal, }); err != nil { diff --git a/selfservice/strategy/webauthn/settings_test.go b/selfservice/strategy/webauthn/settings_test.go index dd75fc335204..bf37258d5706 100644 --- a/selfservice/strategy/webauthn/settings_test.go +++ b/selfservice/strategy/webauthn/settings_test.go @@ -66,6 +66,7 @@ func createIdentityAndReturnIdentifier(t *testing.T, ctx context.Context, reg dr require.NoError(t, err) i := &identity.Identity{ SchemaID: "default", + NID: uuid.Must(uuid.NewV4()), Traits: identity.Traits(fmt.Sprintf(`{"subject":"%s"}`, identifier)), VerifiableAddresses: []identity.VerifiableAddress{ { @@ -316,6 +317,7 @@ func TestCompleteSettings(t *testing.T) { // We load our identity which we will use to replay the webauth session var id identity.Identity require.NoError(t, json.Unmarshal(settingsFixtureSuccessIdentity, &id)) + id.NID = uuid.Must(uuid.NewV4()) _ = reg.PrivilegedIdentityPool().DeleteIdentity(context.Background(), id.ID) browserClient := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, reg, &id) f := testhelpers.InitializeSettingsFlowViaBrowser(t, browserClient, spa, publicTS) @@ -538,6 +540,7 @@ func TestCompleteSettings(t *testing.T) { isSPA := f == "spa" var id identity.Identity + id.NID = uuid.Must(uuid.NewV4()) require.NoError(t, json.Unmarshal(settingsFixtureSuccessIdentity, &id)) _ = reg.PrivilegedIdentityPool().DeleteIdentity(context.Background(), id.ID) browserClient := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, ctx, reg, &id) diff --git a/selfservice/strategy/webauthn/strategy.go b/selfservice/strategy/webauthn/strategy.go index 998490055996..82a0b7df9b2a 100644 --- a/selfservice/strategy/webauthn/strategy.go +++ b/selfservice/strategy/webauthn/strategy.go @@ -6,6 +6,7 @@ package webauthn import ( "context" "encoding/json" + "strings" "github.com/pkg/errors" @@ -34,6 +35,7 @@ type webauthnStrategyDependencies interface { x.WriterProvider x.CSRFTokenGeneratorProvider x.CSRFProvider + x.TracingProvider config.Provider @@ -80,26 +82,37 @@ func NewStrategy(d any) *Strategy { } } -func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveMultiFactorCredentials(_ context.Context, cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { return s.countCredentials(cc, false) } -func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { +func (s *Strategy) CountActiveFirstFactorCredentials(_ context.Context, cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { return s.countCredentials(cc, true) } -func (s *Strategy) countCredentials(cc map[identity.CredentialsType]identity.Credentials, passwordless bool) (count int, err error) { +func (s *Strategy) countCredentials(cc map[identity.CredentialsType]identity.Credentials, onlyPasswordlessCredentials bool) (count int, err error) { for _, c := range cc { - if c.Type == s.ID() && len(c.Config) > 0 && len(c.Identifiers) > 0 { + if c.Type == s.ID() && len(c.Config) > 0 { var conf identity.CredentialsWebAuthnConfig if err = json.Unmarshal(c.Config, &conf); err != nil { return 0, errors.WithStack(err) } - for _, c := range conf.Credentials { - if c.IsPasswordless == passwordless { - count++ + for _, cred := range conf.Credentials { + if cred.IsPasswordless && len(strings.Join(c.Identifiers, "")) == 0 { + // If this is a passwordless credential, it will only work if the identifier is set, as + // we use the identifier to look up the identity. If the identifier is not set, we can + // assume that the user can't sign in using this method. + continue } + + if cred.IsPasswordless != onlyPasswordlessCredentials { + continue + } + + // If the credential is passwordless and we require passwordless credentials, or if the credential is not + // passwordless and we require non-passwordless credentials, we count it. + count++ } } } @@ -114,7 +127,7 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.WebAuthnGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { aal := identity.AuthenticatorAssuranceLevel1 if !s.d.Config().WebAuthnForPasswordless(ctx) { aal = identity.AuthenticatorAssuranceLevel2 diff --git a/selfservice/strategy/webauthn/strategy_test.go b/selfservice/strategy/webauthn/strategy_test.go index cc5f6fafb475..5ce70310d9e9 100644 --- a/selfservice/strategy/webauthn/strategy_test.go +++ b/selfservice/strategy/webauthn/strategy_test.go @@ -26,13 +26,13 @@ func TestCompletedAuthenticationMethod(t *testing.T) { assert.Equal(t, session.AuthenticationMethod{ Method: strategy.ID(), AAL: identity.AuthenticatorAssuranceLevel2, - }, strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{})) + }, strategy.CompletedAuthenticationMethod(context.Background())) conf.MustSet(ctx, config.ViperKeyWebAuthnPasswordless, true) assert.Equal(t, session.AuthenticationMethod{ Method: strategy.ID(), AAL: identity.AuthenticatorAssuranceLevel1, - }, strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{})) + }, strategy.CompletedAuthenticationMethod(context.Background())) } func TestCountActiveFirstFactorCredentials(t *testing.T) { @@ -64,6 +64,14 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { }}, expectedMulti: 1, }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Identifiers: []string{}, // also works without identifier + Config: []byte(`{"credentials": [{}]}`), + }}, + expectedMulti: 1, + }, { in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), @@ -72,6 +80,13 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { }}, expectedFirst: 1, }, + { + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { + Type: strategy.ID(), + Config: []byte(`{"credentials": [{"is_passwordless": true}]}`), + }}, + expectedFirst: 0, // missing identifier + }, { in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), @@ -105,11 +120,11 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { cc[c.Type] = c } - actual, err := strategy.CountActiveFirstFactorCredentials(cc) + actual, err := strategy.CountActiveFirstFactorCredentials(ctx, cc) require.NoError(t, err) assert.Equal(t, tc.expectedFirst, actual) - actual, err = strategy.CountActiveMultiFactorCredentials(cc) + actual, err = strategy.CountActiveMultiFactorCredentials(ctx, cc) require.NoError(t, err) assert.Equal(t, tc.expectedMulti, actual) }) diff --git a/session/handler_test.go b/session/handler_test.go index c11a8f4a548e..8783a3dcc604 100644 --- a/session/handler_test.go +++ b/session/handler_test.go @@ -253,7 +253,7 @@ func TestSessionWhoAmI(t *testing.T) { i := identity.Identity{Traits: []byte("{}"), State: identity.StateActive} require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i)) - s, err := session.NewActiveSession(&i, conf, time.Now(), identity.CredentialsTypePassword) + s, err := testhelpers.NewActiveSession(&i, conf, time.Now(), identity.CredentialsTypePassword) require.NoError(t, err) require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), s)) require.NotEmpty(t, s.Token) diff --git a/session/manager.go b/session/manager.go index d82cf002c5b4..596d0ab7da97 100644 --- a/session/manager.go +++ b/session/manager.go @@ -7,6 +7,9 @@ import ( "context" "net/http" "net/url" + "time" + + "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/text" @@ -133,8 +136,17 @@ type Manager interface { // PurgeFromRequest removes an HTTP session. PurgeFromRequest(context.Context, http.ResponseWriter, *http.Request) error - // DoesSessionSatisfy answers if a session is satisfying the AAL. - DoesSessionSatisfy(r *http.Request, sess *Session, requestedAAL string, opts ...ManagerOptions) error + // DoesSessionSatisfy answers if a session is satisfying the AAL of a user. + // + // The matcher value can be one of: + // + // - `highest_available`: If set requires the user to upgrade their session to the highest available AAL for that user. + // - `aal1`: Requires the user to have authenticated with at least one authentication factor. + // + // This method is implemented in such a way, that if a second factor is found for the user, it is always assumed + // that the user is able to authenticate with it. This means that if a user has a second factor, the user is always + // asked to authenticate with it if `highest_available` is set and the session's AAL is `aal1`. + DoesSessionSatisfy(r *http.Request, sess *Session, matcher string, opts ...ManagerOptions) error // SessionAddAuthenticationMethods adds one or more authentication method to the session. SessionAddAuthenticationMethods(ctx context.Context, sid uuid.UUID, methods ...AuthenticationMethod) error @@ -142,6 +154,13 @@ type Manager interface { // MaybeRedirectAPICodeFlow for API+Code flows redirects the user to the return_to URL and adds the code query parameter. // `handled` is true if the request a redirect was written, false otherwise. MaybeRedirectAPICodeFlow(w http.ResponseWriter, r *http.Request, f flow.Flow, sessionID uuid.UUID, uiNode node.UiNodeGroup) (handled bool, err error) + + // ActivateSession activates a session. + // + // This method is used to activate a session after a user authenticated with a first or second factor. It sets + // all computed values (e.g. authenticator assurance level) and updates the session object but does not store + // the session in the database or on the client device. + ActivateSession(r *http.Request, session *Session, i *identity.Identity, authenticatedAt time.Time) error } type ManagementProvider interface { diff --git a/session/manager_http.go b/session/manager_http.go index fbd6574e0b2c..18a6134d3949 100644 --- a/session/manager_http.go +++ b/session/manager_http.go @@ -9,6 +9,8 @@ import ( "net/url" "time" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "github.com/ory/kratos/x/events" @@ -38,6 +40,8 @@ import ( "github.com/ory/kratos/x" ) +var ErrNoAALAvailable = herodot.ErrForbidden.WithReasonf("Unable to detect available authentication methods. Perform account recovery or contact support.") + type ( managerHTTPDependencies interface { config.Provider @@ -45,8 +49,10 @@ type ( identity.PrivilegedPoolProvider identity.ManagementProvider x.CookieProvider + x.LoggingProvider x.CSRFProvider x.TracingProvider + x.TransactionPersistenceProvider PersistenceProvider sessiontokenexchange.PersistenceProvider } @@ -283,69 +289,87 @@ func (s *ManagerHTTP) DoesSessionSatisfy(r *http.Request, sess *Session, request defer otelx.End(span, &err) // If we already have AAL2 there is no need to check further because it is the highest AAL. + sess.SetAuthenticatorAssuranceLevel() if sess.AuthenticatorAssuranceLevel > identity.AuthenticatorAssuranceLevel1 { return nil } managerOpts := &options{} - for _, o := range opts { o(managerOpts) } - sess.SetAuthenticatorAssuranceLevel() + loginURL := urlx.CopyWithQuery(urlx.AppendPaths(s.r.Config().SelfPublicURL(ctx), "/self-service/login/browser"), url.Values{"aal": {"aal2"}}) + + // return to the requestURL if it was set + if managerOpts.requestURL != "" { + loginURL = urlx.CopyWithQuery(loginURL, url.Values{"return_to": {managerOpts.requestURL}}) + } + switch requestedAAL { case string(identity.AuthenticatorAssuranceLevel1): if sess.AuthenticatorAssuranceLevel >= identity.AuthenticatorAssuranceLevel1 { return nil } case config.HighestAvailableAAL: + if sess.AuthenticatorAssuranceLevel >= identity.AuthenticatorAssuranceLevel2 { + // The session has AAL2, nothing to check. + return nil + } + + // The session is AAL1, we asked for `highest_available` AAL, so the only thing we can do + // is actually check what authentication methods the identity has. if sess.Identity == nil { + // This is nil if the session did not expand the identity field. sess.Identity, err = s.r.IdentityPool().GetIdentity(ctx, sess.IdentityID, identity.ExpandNothing) if err != nil { return err } } - i := sess.Identity - available, valid := i.AvailableAAL.ToAAL() - if !valid { - // Available is 0 if the identity was created before the AAL feature was introduced, or if the identity - // was directly created in the persister and not the identity manager. - // - // aal0 indicates that the AAL state of the identity is probably unknown. - // - // In either case, we need to fetch the credentials from the database to determine the AAL. - if len(i.Credentials) == 0 { - // The identity was apparently fetched without credentials. Let's hydrate them. - if err := s.r.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, i, identity.ExpandCredentials); err != nil { - return err - } - } + if aal, ok := sess.Identity.InternalAvailableAAL.ToAAL(); ok && aal == identity.AuthenticatorAssuranceLevel2 { + // Identity gives us AAL2, but the session is still AAL1. We need to upgrade the session. + return NewErrAALNotSatisfied(loginURL.String()) + } + + // Identity AAL is not 2, we refresh: - if err := i.SetAvailableAAL(ctx, s.r.IdentityManager()); err != nil { + // The identity was apparently fetched without credentials. Let's hydrate them. + if len(sess.Identity.Credentials) == 0 { + if err := s.r.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, sess.Identity, identity.ExpandCredentials); err != nil { return err } + } - available, _ = i.AvailableAAL.ToAAL() - - // This is the migration strategy for identities that already exist. - if managerOpts.upsertAAL { - if _, err := s.r.SessionPersister().GetConnection(ctx).Where("id = ? AND nid = ?", i.ID, i.NID).UpdateQuery(i, "available_aal"); err != nil { - return err - } - } + // Great, now we determine the identity's available AAL + if err := sess.Identity.SetAvailableAAL(ctx, s.r.IdentityManager()); err != nil { + return err } - if sess.AuthenticatorAssuranceLevel >= available { - return nil + // We override the result with our newly computed values + available, valid := sess.Identity.InternalAvailableAAL.ToAAL() + if !valid { + // Unlikely to happen because SetAvailableAAL will either return an error, or a valid value - but not no error and an invalid value. + return errors.WithStack(x.PseudoPanic.WithReasonf("Unable to determine available authentication methods for session: %s", sess.ID)) } - loginURL := urlx.CopyWithQuery(urlx.AppendPaths(s.r.Config().SelfPublicURL(ctx), "/self-service/login/browser"), url.Values{"aal": {"aal2"}}) + switch available { + case identity.NoAuthenticatorAssuranceLevel: + // The identity has AAL0, the session has AAL1, we're good. + return nil + case identity.AuthenticatorAssuranceLevel1: + // The identity has AAL1, the session has AAL1, we're good. + return nil + case identity.AuthenticatorAssuranceLevel2: + // The identity has AAL2, the session has AAL1, we need to upgrade the session. - // return to the requestURL if it was set - if managerOpts.requestURL != "" { - loginURL = urlx.CopyWithQuery(loginURL, url.Values{"return_to": {managerOpts.requestURL}}) + // Since we ended up here, it also means that `sess.Identity.InternalAvailableAAL` was `aal1` and is now `aal2`. + // Let's update the database. + if managerOpts.upsertAAL { + if err := s.r.PrivilegedIdentityPool().UpdateIdentityColumns(ctx, sess.Identity, "available_aal"); err != nil { + return err + } + } } return NewErrAALNotSatisfied(loginURL.String()) @@ -402,3 +426,41 @@ func (s *ManagerHTTP) MaybeRedirectAPICodeFlow(w http.ResponseWriter, r *http.Re return true, nil } + +func (s *ManagerHTTP) ActivateSession(r *http.Request, session *Session, i *identity.Identity, authenticatedAt time.Time) (err error) { + ctx, span := s.r.Tracer(r.Context()).Tracer().Start(r.Context(), "sessions.ManagerHTTP.ActivateSession", trace.WithAttributes( + attribute.String("session.id", session.ID.String()), + attribute.String("identity.id", session.ID.String()), + attribute.String("authenticated_at", session.ID.String()), + )) + defer otelx.End(span, &err) + + if i == nil { + return errors.WithStack(x.PseudoPanic.WithReasonf("Identity must not be nil when activating a session.")) + } + + if !i.IsActive() { + return errors.WithStack(ErrIdentityDisabled.WithDetail("identity_id", i.ID)) + } + + session.Identity = i + session.IdentityID = i.ID + + session.Active = true + session.IssuedAt = authenticatedAt + session.ExpiresAt = authenticatedAt.Add(s.r.Config().SessionLifespan(ctx)) + session.AuthenticatedAt = authenticatedAt + + session.SetSessionDeviceInformation(r.WithContext(ctx)) + session.SetAuthenticatorAssuranceLevel() + + if err := s.r.IdentityManager().RefreshAvailableAAL(ctx, session.Identity); err != nil { + return err + } + + span.SetAttributes( + attribute.String("identity.available_aal", session.Identity.InternalAvailableAAL.String), + ) + + return nil +} diff --git a/session/manager_http_test.go b/session/manager_http_test.go index 8a1e166da25c..2a6eac16b5b3 100644 --- a/session/manager_http_test.go +++ b/session/manager_http_test.go @@ -13,6 +13,8 @@ import ( "testing" "time" + confighelpers "github.com/ory/kratos/driver/config/testhelpers" + "github.com/ory/nosurf" "github.com/ory/x/urlx" @@ -150,6 +152,40 @@ func TestManagerHTTP(t *testing.T) { }) }) + t.Run("suite=SessionActivate", func(t *testing.T) { + req := testhelpers.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) + + conf, reg := internal.NewFastRegistryWithMocks(t) + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json") + + i := &identity.Identity{ + Traits: []byte("{}"), State: identity.StateActive, + Credentials: map[identity.CredentialsType]identity.Credentials{ + identity.CredentialsTypePassword: {Type: identity.CredentialsTypePassword, Identifiers: []string{x.NewUUID().String()}, Config: []byte(`{"hashed_password":"foo"}`)}, + }, + } + require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), i)) + assert.EqualValues(t, i.InternalAvailableAAL.String, "") + + sess := session.NewInactiveSession() + require.NoError(t, reg.SessionManager().ActivateSession(req, sess, i, time.Now().UTC())) + require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), sess)) + + actual, err := reg.SessionPersister().GetSession(context.Background(), sess.ID, session.ExpandEverything) + require.NoError(t, err) + + assert.EqualValues(t, true, actual.Active) + assert.NotZero(t, actual.IssuedAt) + assert.True(t, time.Now().Before(actual.ExpiresAt)) + require.Len(t, actual.Devices, 1) + assert.EqualValues(t, identity.AuthenticatorAssuranceLevel1, i.InternalAvailableAAL.String) + + actualIdentity, err := reg.IdentityPool().GetIdentity(ctx, i.ID, identity.ExpandNothing) + require.NoError(t, err) + assert.EqualValues(t, identity.AuthenticatorAssuranceLevel1, actualIdentity.InternalAvailableAAL.String) + + }) + t.Run("suite=SessionAddAuthenticationMethod", func(t *testing.T) { req := testhelpers.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) @@ -159,7 +195,7 @@ func TestManagerHTTP(t *testing.T) { i := &identity.Identity{Traits: []byte("{}"), State: identity.StateActive} require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), i)) sess := session.NewInactiveSession() - require.NoError(t, sess.Activate(req, i, conf, time.Now())) + require.NoError(t, reg.SessionManager().ActivateSession(req, sess, i, time.Now().UTC())) require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), sess)) require.NoError(t, reg.SessionManager().SessionAddAuthenticationMethods(context.Background(), sess.ID, session.AuthenticationMethod{ @@ -219,7 +255,7 @@ func TestManagerHTTP(t *testing.T) { i := identity.Identity{Traits: []byte("{}")} require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i)) - s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + s, _ = testhelpers.NewActiveSession(req, reg, &i, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) c := testhelpers.NewClientWithCookies(t) testhelpers.MockHydrateCookieClient(t, c, pts.URL+"/session/set") @@ -240,7 +276,7 @@ func TestManagerHTTP(t *testing.T) { i := identity.Identity{Traits: []byte("{}")} require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i)) - s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + s, _ = testhelpers.NewActiveSession(req, reg, &i, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) c := testhelpers.NewClientWithCookies(t) testhelpers.MockHydrateCookieClient(t, c, pts.URL+"/session/set") @@ -270,7 +306,7 @@ func TestManagerHTTP(t *testing.T) { i := identity.Identity{Traits: []byte("{}")} require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i)) - s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + s, _ = testhelpers.NewActiveSession(req, reg, &i, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) c := testhelpers.NewClientWithCookies(t) res, err := c.Get(pts.URL + "/session/set/invalid") @@ -284,7 +320,7 @@ func TestManagerHTTP(t *testing.T) { i := identity.Identity{Traits: []byte("{}"), State: identity.StateActive} require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i)) - s, err := session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + s, err := testhelpers.NewActiveSession(req, reg, &i, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) require.NoError(t, err) require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), s)) require.NotEmpty(t, s.Token) @@ -305,7 +341,7 @@ func TestManagerHTTP(t *testing.T) { i := identity.Identity{Traits: []byte("{}"), State: identity.StateActive} require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i)) - s, err := session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + s, err := testhelpers.NewActiveSession(req, reg, &i, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) require.NoError(t, err) require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), s)) @@ -329,7 +365,7 @@ func TestManagerHTTP(t *testing.T) { i := identity.Identity{Traits: []byte("{}")} require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i)) - s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + s, _ = testhelpers.NewActiveSession(req, reg, &i, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) c := testhelpers.NewClientWithCookies(t) testhelpers.MockHydrateCookieClient(t, c, pts.URL+"/session/set") @@ -345,9 +381,9 @@ func TestManagerHTTP(t *testing.T) { req := testhelpers.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) i := identity.Identity{Traits: []byte("{}")} require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i)) - s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + s, _ = testhelpers.NewActiveSession(req, reg, &i, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) - s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + s, _ = testhelpers.NewActiveSession(req, reg, &i, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) c := testhelpers.NewClientWithCookies(t) testhelpers.MockHydrateCookieClient(t, c, pts.URL+"/session/set") @@ -371,7 +407,7 @@ func TestManagerHTTP(t *testing.T) { for _, m := range complete { s.CompletedLoginFor(m, "") } - require.NoError(t, s.Activate(req, i, conf, time.Now().UTC())) + require.NoError(t, reg.SessionManager().ActivateSession(req, s, i, time.Now().UTC())) err := reg.SessionManager().DoesSessionSatisfy((&http.Request{}).WithContext(context.Background()), s, requested) if expectedError != nil { require.ErrorAs(t, err, &expectedError) @@ -418,22 +454,22 @@ func TestManagerHTTP(t *testing.T) { s := session.NewInactiveSession() s.CompletedLoginFor(identity.CredentialsTypePassword, "") - require.NoError(t, s.Activate(req, idAAL1, conf, time.Now().UTC())) + require.NoError(t, reg.SessionManager().ActivateSession(req, s, idAAL1, time.Now().UTC())) require.Error(t, reg.SessionManager().DoesSessionSatisfy((&http.Request{}).WithContext(context.Background()), s, config.HighestAvailableAAL, session.UpsertAAL)) result, err := reg.IdentityPool().GetIdentity(context.Background(), idAAL1.ID, identity.ExpandNothing) require.NoError(t, err) - assert.EqualValues(t, identity.AuthenticatorAssuranceLevel2, result.AvailableAAL.String) + assert.EqualValues(t, identity.AuthenticatorAssuranceLevel2, result.InternalAvailableAAL.String) }) t.Run("identity available AAL is hydrated without DB", func(t *testing.T) { // We do not create the identity in the database, proving that we do not need // to do any DB roundtrips in this case. idAAL2 := createAAL2Identity(t, reg) - idAAL2.AvailableAAL = identity.NewNullableAuthenticatorAssuranceLevel(identity.AuthenticatorAssuranceLevel2) + idAAL2.InternalAvailableAAL = identity.NewNullableAuthenticatorAssuranceLevel(identity.AuthenticatorAssuranceLevel2) idAAL1 := createAAL1Identity(t, reg) - idAAL1.AvailableAAL = identity.NewNullableAuthenticatorAssuranceLevel(identity.AuthenticatorAssuranceLevel1) + idAAL1.InternalAvailableAAL = identity.NewNullableAuthenticatorAssuranceLevel(identity.AuthenticatorAssuranceLevel1) test(t, idAAL1, idAAL2) }) @@ -443,19 +479,76 @@ func TestManagerHTTP(t *testing.T) { } func TestDoesSessionSatisfy(t *testing.T) { + ctx := context.Background() conf, reg := internal.NewFastRegistryWithMocks(t) testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json") + ctx = testhelpers.WithDefaultIdentitySchema(ctx, "file://./stub/identity.schema.json") + passwordEmpty := identity.Credentials{Type: identity.CredentialsTypePassword, Config: []byte(`{}`), Identifiers: []string{testhelpers.RandomEmail()}} password := identity.Credentials{ Type: identity.CredentialsTypePassword, Identifiers: []string{testhelpers.RandomEmail()}, Config: []byte(`{"hashed_password": "$argon2id$v=19$m=32,t=2,p=4$cm94YnRVOW5jZzFzcVE4bQ$MNzk5BtR2vUhrp6qQEjRNw"}`), } + passwordMigration := identity.Credentials{ + Type: identity.CredentialsTypePassword, + Identifiers: []string{testhelpers.RandomEmail()}, + Config: []byte(`{"use_password_migration_hook":true}`), + } + + code := identity.Credentials{ + Type: identity.CredentialsTypeCodeAuth, + Identifiers: []string{testhelpers.RandomEmail()}, + Config: []byte(`{"address_type":"email","used_at":{"Time":"0001-01-01T00:00:00Z","Valid":false}}`), + } + //codeEmpty := identity.Credentials{ + // Type: identity.CredentialsTypeCodeAuth, + // Identifiers: []string{testhelpers.RandomEmail()}, + // Config: []byte(`{}`), + //} + oidc := identity.Credentials{ Type: identity.CredentialsTypeOIDC, Config: []byte(`{"providers":[{"subject":"0.fywegkf7hd@ory.sh","provider":"hydra","initial_id_token":"65794a68624763694f694a53557a49314e694973496d74705a434936496e4231596d7870597a706f6557527959533576634756756157517561575174644739725a5734694c434a30655841694f694a4b5631516966512e65794a686446396f59584e6f496a6f6956484650616b6f324e6c397a613046436555643662315679576b466655534973496d46315a43493657794a72636d463062334d74593278705a573530496c3073496d46316447686664476c745a5349364d5459304e6a55314e6a59784e4377695a586877496a6f784e6a51324e5459774d6a45314c434a70595851694f6a45324e4459314e5459324d545573496d6c7a63794936496d6830644841364c79397362324e6862476876633351364e4451304e4338694c434a7164476b694f694a6a596a4d784d6a51794e6930314e7a4d774c5451314d546374596a51335a53316b4d446379596a51334d6a6b344d4759694c434a79595851694f6a45324e4459314e5459324d544d73496e4e705a434936496a677a4e5755344e47526a4c5463344d544d744e4749324f4330354d544a6d4c5446684d7a646d4e444d354d4463304e534973496e4e3159694936496a41755a6e6c335a5764725a6a646f5a454276636e6b75633267694c434a335a574a7a6158526c496a6f696148523063484d364c7939336433637562334a354c6e4e6f4c794a392e506850623770456358544c3456647730427959686f30794a7232714b794b4f7373646c4b6c74716b4953693762414e58776a7635686538506e6d7a586e713538556f5739657754584a485a33425651614d4e79612d755f5933584a4a61665673543347476c52776f376f5261707a6a564836502d72447657385649524d5361356f783242397164416d796659505734376e56782d4e68787247564c56464b526b5866324e4448534e6d435968524963455539724331366235385331344c314367776972624d507662797870644c63764f4a4546554238324c794574525a786f644748354c69394d6b5f4d6137363969583254776758434179306734475a625957337137317466574c37736d5342394669785076434b6a3738433753546b762d764f737a4e6533523864676133775471466e6253797a6a614f4b47626e424a4a77423869306e416c48496d425337587146645f666d556d4e62377a372d63716e593374395069306248466b46596e6746545279664d4c6f466f576956784842704b4d6c6b304d4e7a5155414e5368546e346769544d5547454a4f6372346f6f445f6770344768734c44542d54465f6f73486c304832544237777a6d546d735f3150506547424e716a316b61576a467038567247726e4a6b354f594c643152473152464c794535544c4d47315f62744762447137334450784c334b3657387348507242504b654133344377373371584e5247724e73574e69496e775f4e596a65554d484b6351436c4e51445a49725339794962456a485a78476a34546e4367664f5974694e76527a4c6c36616a73614265464b7a45592d6348416e6e42694c75744439373168697241684f5463544a42783672716f67717764755356726551456f565a5735616e4a7a7575775234685453354d44314d64457045437471526d416c71555459644e5a365778514d","initial_access_token":"52344752743736552d634a2d4a2d424372447159634967464652446c6455455a6a526e534d62336e3242732e47324f444d64303544774b4e67395649476e306e496b3877324e72444f48384a78635042635a4a58336d63","initial_refresh_token":"327872337a4d382d654273674b6d61644a624e5a497572473374545154615070313264514a314476544d632e77326d34747a6e7950584c38324b794563716468685068635156314f77386a535a345355496f3544744a51"}]}`), Identifiers: []string{"hydra:0.fywegkf7hd@ory.sh"}, } + //oidcEmpty := identity.Credentials{ + // Type: identity.CredentialsTypeOIDC, + // Config: []byte(`{}`), + // Identifiers: []string{"hydra:0.fywegkf7hd@ory.sh"}, + //} + + lookupSecrets := identity.Credentials{ + Type: identity.CredentialsTypeLookup, + Config: []byte(`{"recovery_codes": [{"code": "abcde", "used_at": null}]}`), + } + //lookupSecretsEmpty := identity.Credentials{ + // Type: identity.CredentialsTypeLookup, + // Config: []byte(`{}`), + //} + + totp := identity.Credentials{ + Type: identity.CredentialsTypeTOTP, + Config: []byte(`{"totp_url": "otpauth://totp/..."}`), + } + //totpEmpty := identity.Credentials{ + // Type: identity.CredentialsTypeTOTP, + // Config: []byte(`{}`), + //} + + // passkey + passkey := identity.Credentials{ // passkey + Type: identity.CredentialsTypePasskey, + Config: []byte(`{"credentials":[{}]}`), + Identifiers: []string{testhelpers.RandomEmail()}, + } + //passkeyEmpty := identity.Credentials{ // passkey + // Type: identity.CredentialsTypePasskey, + // Config: []byte(`{"credentials":null}`), + // Identifiers: []string{testhelpers.RandomEmail()}, + //} + + // webAuthn mfaWebAuth := identity.Credentials{ Type: identity.CredentialsTypeWebAuthn, Config: []byte(`{"credentials":[{"is_passwordless":false}]}`), @@ -467,106 +560,265 @@ func TestDoesSessionSatisfy(t *testing.T) { Identifiers: []string{testhelpers.RandomEmail()}, } webAuthEmpty := identity.Credentials{Type: identity.CredentialsTypeWebAuthn, Config: []byte(`{}`), Identifiers: []string{testhelpers.RandomEmail()}} - passwordEmpty := identity.Credentials{Type: identity.CredentialsTypePassword, Config: []byte(`{}`), Identifiers: []string{testhelpers.RandomEmail()}} - amrPassword := session.AuthenticationMethod{Method: identity.CredentialsTypePassword, AAL: identity.AuthenticatorAssuranceLevel1} + amrs := map[identity.CredentialsType]session.AuthenticationMethod{} + for _, strat := range reg.AllLoginStrategies() { + amrs[strat.ID()] = strat.CompletedAuthenticationMethod(ctx) + } for k, tc := range []struct { - d string - err error - requested identity.AuthenticatorAssuranceLevel + desc string + withContext func(*testing.T, context.Context) context.Context + errAs error + errIs error + matcher identity.AuthenticatorAssuranceLevel creds []identity.Credentials - amr session.AuthenticationMethods + withAMR session.AuthenticationMethods sessionManagerOptions []session.ManagerOptions expectedFunc func(t *testing.T, err error, tcError error) }{ { - d: "has=aal1, requested=highest, available=aal1, credential=password", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password}, - amr: session.AuthenticationMethods{amrPassword}, + desc: "with highest_available a password user is aal1", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword]}, + // No error }, { - d: "has=aal1, requested=highest, available=aal1, credential=password, legacy=true", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password}, - amr: session.AuthenticationMethods{{Method: identity.CredentialsTypePassword}}, + desc: "with highest_available a password migration user is aal1 if password migration is enabled", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{passwordMigration}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword]}, + withContext: func(t *testing.T, ctx context.Context) context.Context { + return confighelpers.WithConfigValues(ctx, map[string]any{ + "selfservice.methods.password_migration.enabled": true, + }) + }, + // No error }, { - d: "has=aal1, requested=highest, available=aal1, credential=password+webauth_empty", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password, webAuthEmpty}, - amr: session.AuthenticationMethods{amrPassword}, + // This is not an error because DoesSessionSatisfy always assumes at least aal1 + desc: "with highest_available a password migration user is aal1 if password migration is disabled", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{passwordMigration}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword]}, + withContext: func(t *testing.T, ctx context.Context) context.Context { + return confighelpers.WithConfigValues(ctx, map[string]any{ + "selfservice.methods.password_migration.enabled": false, + }) + }, + // No error }, { - d: "has=aal1, requested=highest, available=aal1, credential=password+webauth_empty, legacy=true", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password, webAuthEmpty}, - amr: session.AuthenticationMethods{{Method: identity.CredentialsTypePassword}}, + desc: "with highest_available a otp code user is aal1", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{code}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypeCodeAuth]}, + // No error }, { - d: "has=aal1, requested=highest, available=aal1, credential=password+webauth_passwordless", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password, passwordlessWebAuth}, - amr: session.AuthenticationMethods{amrPassword}, + desc: "with highest_available a oidc user is aal1", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{oidc}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypeOIDC]}, + // No error }, { - d: "has=aal1, requested=highest, available=aal1, credential=password+webauth_passwordless, legacy=true", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password, passwordlessWebAuth}, - amr: session.AuthenticationMethods{{Method: identity.CredentialsTypePassword}}, + desc: "with highest_available a passkey user is aal1", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{passkey}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePasskey]}, + // No error }, { - d: "has=aal1, requested=highest, available=aal2, credential=password+webauth_mfa", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password, mfaWebAuth}, - amr: session.AuthenticationMethods{amrPassword}, - err: new(session.ErrAALNotSatisfied), + desc: "with highest_available a recovery token user is aal1 even if they have no credentials", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypeRecoveryLink]}, + // No error }, { - d: "has=aal1, requested=highest, available=aal2, credential=password+webauth_mfa, legacy=true", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password, mfaWebAuth}, - amr: session.AuthenticationMethods{{Method: identity.CredentialsTypePassword}}, - err: new(session.ErrAALNotSatisfied), + desc: "with highest_available a recovery code user is aal1 even if they have no credentials", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypeRecoveryCode]}, + // No error }, + // Test a recovery method with an identity that has only 2fa methods enabled. { - d: "has=aal1, requested=highest, available=aal2, credential=password+webauth_mfa", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password, mfaWebAuth}, - amr: session.AuthenticationMethods{amrPassword, {Method: identity.CredentialsTypeWebAuthn, AAL: identity.AuthenticatorAssuranceLevel1}}, - err: new(session.ErrAALNotSatisfied), + desc: "with highest_available a recovery link user requires aal2 if they have 2fa totp configured", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{totp}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypeRecoveryLink]}, + errIs: new(session.ErrAALNotSatisfied), }, { - d: "has=aal1, requested=highest, available=aal2, credential=password+webauth_passwordless", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password, passwordlessWebAuth}, - amr: session.AuthenticationMethods{amrPassword, {Method: identity.CredentialsTypeWebAuthn, AAL: identity.AuthenticatorAssuranceLevel1}}, + desc: "with highest_available a recovery code user requires aal2 if they have 2fa lookup configured", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{lookupSecrets}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypeRecoveryCode]}, + errIs: new(session.ErrAALNotSatisfied), }, { - d: "has=aal2, requested=highest, available=aal2, credential=password+webauth_mfa", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password, mfaWebAuth}, - amr: session.AuthenticationMethods{amrPassword, {Method: identity.CredentialsTypeWebAuthn, AAL: identity.AuthenticatorAssuranceLevel2}}, + desc: "with highest_available a recovery code user requires aal2 if they have 2fa lookup configured", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{mfaWebAuth}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypeRecoveryCode]}, + errIs: new(session.ErrAALNotSatisfied), }, { - d: "has=aal2, requested=highest, available=aal2, credential=password+webauth_mfa, legacy=true", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password, mfaWebAuth}, - amr: session.AuthenticationMethods{amrPassword, {Method: identity.CredentialsTypeWebAuthn}}, + desc: "with highest_available a recovery code user requires aal2 if they have many 2fa methods configured", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{lookupSecrets, mfaWebAuth, totp}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypeRecoveryCode]}, + errIs: new(session.ErrAALNotSatisfied), }, { - d: "has=aal1, requested=highest, available=aal1, credential=oidc_and_empties", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{oidc, webAuthEmpty, passwordEmpty}, - amr: session.AuthenticationMethods{{Method: identity.CredentialsTypeOIDC, AAL: identity.AuthenticatorAssuranceLevel1}}, + desc: "with highest_available a recovery link user requires aal2 if they have 2fa code configured", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypeRecoveryLink]}, + withContext: func(t *testing.T, ctx context.Context) context.Context { + return confighelpers.WithConfigValues(ctx, map[string]any{ + "selfservice.methods.code.passwordless_enabled": false, + "selfservice.methods.code.mfa_enabled": true, + }) + }, + errIs: new(session.ErrAALNotSatisfied), }, + + // Legacy tests { - d: "has=aal1, requested=highest, available=aal1, credentials=password+webauthn_mfa, recovery with session manager options", - requested: config.HighestAvailableAAL, + desc: "has=aal1, requested=highest, available=aal0, credential=code", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{totp}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypeRecoveryCode]}, + errIs: session.ErrNoAALAvailable, + }, + + { + desc: "has=aal1, requested=highest, available=aal1, credential=password", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword]}, + }, + { + desc: "has=aal1, requested=highest, available=aal1, credential=password, legacy=true", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password}, + withAMR: session.AuthenticationMethods{{Method: identity.CredentialsTypePassword}}, + }, + { + desc: "has=aal1, requested=highest, available=aal1, credential=password+webauth_empty", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, webAuthEmpty}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword]}, + }, + { + desc: "has=aal1, requested=highest, available=aal1, credential=password+webauth_empty, legacy=true", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, webAuthEmpty}, + withAMR: session.AuthenticationMethods{{Method: identity.CredentialsTypePassword}}, + }, + { + desc: "has=aal1, requested=highest, available=aal1, credential=password+webauth_passwordless", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, passwordlessWebAuth}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword]}, + }, + { + desc: "has=aal1, requested=highest, available=aal1, credential=password+webauth_passwordless, legacy=true", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, passwordlessWebAuth}, + withAMR: session.AuthenticationMethods{{Method: identity.CredentialsTypePassword}}, + }, + { + desc: "has=aal1, requested=highest, available=aal2, credential=password+webauth_mfa", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, mfaWebAuth}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword]}, + errAs: new(session.ErrAALNotSatisfied), + }, + { + desc: "has=aal1, requested=highest, available=aal2, credential=password+totp", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, totp}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword]}, + errAs: new(session.ErrAALNotSatisfied), + }, + { + desc: "has=aal1, requested=highest, available=aal2, credential=password+code-mfa", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, code}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword]}, + errAs: new(session.ErrAALNotSatisfied), + withContext: func(t *testing.T, ctx context.Context) context.Context { + return confighelpers.WithConfigValues(ctx, map[string]any{ + "selfservice.methods.code.passwordless_enabled": false, + "selfservice.methods.code.mfa_enabled": true, + }) + }, + }, + { + desc: "has=aal1, requested=highest, available=aal2, credential=password+lookup_secrets", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, lookupSecrets}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword]}, + errAs: new(session.ErrAALNotSatisfied), + }, + { + desc: "has=aal1, requested=highest, available=aal2, credential=password+webauth_mfa, legacy=true", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, mfaWebAuth}, + withAMR: session.AuthenticationMethods{{Method: identity.CredentialsTypePassword}}, + errAs: new(session.ErrAALNotSatisfied), + }, + { + desc: "has=aal1, requested=highest, available=aal2, credential=password+webauth_mfa", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, mfaWebAuth}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword], {Method: identity.CredentialsTypeWebAuthn, AAL: identity.AuthenticatorAssuranceLevel1}}, + errAs: new(session.ErrAALNotSatisfied), + }, + { + desc: "has=aal1, requested=highest, available=aal2, credential=password+webauth_passwordless", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, passwordlessWebAuth}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword], {Method: identity.CredentialsTypeWebAuthn, AAL: identity.AuthenticatorAssuranceLevel1}}, + }, + { + desc: "has=aal2, requested=highest, available=aal2, credential=password+webauth_mfa", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, mfaWebAuth}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword], {Method: identity.CredentialsTypeWebAuthn, AAL: identity.AuthenticatorAssuranceLevel2}}, + }, + { + desc: "has=aal2, requested=highest, available=aal2, credential=password+webauth_mfa, legacy=true", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, mfaWebAuth}, + withAMR: session.AuthenticationMethods{amrs[identity.CredentialsTypePassword], {Method: identity.CredentialsTypeWebAuthn}}, + }, + + // oidc + { + desc: "has=aal1, requested=highest, available=aal1, credential=oidc_and_empties", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{oidc, webAuthEmpty, passwordEmpty}, + withAMR: session.AuthenticationMethods{{Method: identity.CredentialsTypeOIDC, AAL: identity.AuthenticatorAssuranceLevel1}}, + }, + { + desc: "has=aal1, requested=highest, available=aal1, credential=code and totp", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{code, totp}, + withAMR: session.AuthenticationMethods{{Method: identity.CredentialsTypeCodeAuth, AAL: identity.AuthenticatorAssuranceLevel1}}, + errAs: session.NewErrAALNotSatisfied(urlx.CopyWithQuery(urlx.AppendPaths(conf.SelfPublicURL(ctx), "/self-service/login/browser"), url.Values{"aal": {"aal2"}, "return_to": {"https://myapp.com/settings?id=123"}}).String()), + }, + { + desc: "has=aal1, requested=highest, available=aal1, credentials=password+webauthn_mfa, recovery with session manager options", + matcher: config.HighestAvailableAAL, creds: []identity.Credentials{password, mfaWebAuth}, - amr: session.AuthenticationMethods{{Method: identity.CredentialsTypeRecoveryCode}}, - err: session.NewErrAALNotSatisfied(urlx.CopyWithQuery(urlx.AppendPaths(conf.SelfPublicURL(context.Background()), "/self-service/login/browser"), url.Values{"aal": {"aal2"}, "return_to": {"https://myapp.com/settings?id=123"}}).String()), + withAMR: session.AuthenticationMethods{{Method: identity.CredentialsTypeRecoveryCode}}, + errAs: session.NewErrAALNotSatisfied(urlx.CopyWithQuery(urlx.AppendPaths(conf.SelfPublicURL(ctx), "/self-service/login/browser"), url.Values{"aal": {"aal2"}, "return_to": {"https://myapp.com/settings?id=123"}}).String()), sessionManagerOptions: []session.ManagerOptions{session.WithRequestURL("https://myapp.com/settings?id=123")}, expectedFunc: func(t *testing.T, err error, tcError error) { require.Contains(t, err.(*session.ErrAALNotSatisfied).RedirectTo, "myapp.com") @@ -574,64 +826,74 @@ func TestDoesSessionSatisfy(t *testing.T) { }, }, { - d: "has=aal1, requested=highest, available=aal1, credentials=password+webauthn_mfa, recovery without session manager options", - requested: config.HighestAvailableAAL, - creds: []identity.Credentials{password, mfaWebAuth}, - amr: session.AuthenticationMethods{{Method: identity.CredentialsTypeRecoveryCode}}, - err: session.NewErrAALNotSatisfied(urlx.CopyWithQuery(urlx.AppendPaths(conf.SelfPublicURL(context.Background()), "/self-service/login/browser"), url.Values{"aal": {"aal2"}}).String()), + desc: "has=aal1, requested=highest, available=aal1, credentials=password+webauthn_mfa, recovery without session manager options", + matcher: config.HighestAvailableAAL, + creds: []identity.Credentials{password, mfaWebAuth}, + withAMR: session.AuthenticationMethods{{Method: identity.CredentialsTypeRecoveryCode}}, + errAs: session.NewErrAALNotSatisfied(urlx.CopyWithQuery(urlx.AppendPaths(conf.SelfPublicURL(ctx), "/self-service/login/browser"), url.Values{"aal": {"aal2"}}).String()), expectedFunc: func(t *testing.T, err error, tcError error) { require.Equal(t, tcError.(*session.ErrAALNotSatisfied).RedirectTo, err.(*session.ErrAALNotSatisfied).RedirectTo) }, }, } { - t.Run(fmt.Sprintf("run=%d/desc=%s", k, tc.d), func(t *testing.T) { - id := identity.NewIdentity("") + t.Run(fmt.Sprintf("run=%d/desc=%s", k, tc.desc), func(t *testing.T) { + ctx := ctx + if tc.withContext != nil { + ctx = tc.withContext(t, ctx) + } + + id := identity.NewIdentity("default") for _, c := range tc.creds { id.SetCredentials(c.Type, c) } - require.NoError(t, reg.IdentityManager().Create(context.Background(), id, identity.ManagerAllowWriteProtectedTraits)) + require.NoError(t, reg.IdentityManager().Create(ctx, id, identity.ManagerAllowWriteProtectedTraits)) t.Cleanup(func() { - require.NoError(t, reg.PrivilegedIdentityPool().DeleteIdentity(context.Background(), id.ID)) + require.NoError(t, reg.PrivilegedIdentityPool().DeleteIdentity(ctx, id.ID)) }) req := testhelpers.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) s := session.NewInactiveSession() - for _, m := range tc.amr { + for _, m := range tc.withAMR { s.CompletedLoginFor(m.Method, m.AAL) } - require.NoError(t, s.Activate(req, id, conf, time.Now().UTC())) + require.NoError(t, reg.SessionManager().ActivateSession(req, s, id, time.Now().UTC())) - err := reg.SessionManager().DoesSessionSatisfy((&http.Request{}).WithContext(context.Background()), s, string(tc.requested), tc.sessionManagerOptions...) - if tc.err != nil { + err := reg.SessionManager().DoesSessionSatisfy((&http.Request{}).WithContext(ctx), s, string(tc.matcher), tc.sessionManagerOptions...) + if tc.errAs != nil || tc.errIs != nil { + if tc.expectedFunc != nil { + tc.expectedFunc(t, err, tc.errAs) + } + require.ErrorAs(t, err, &tc.errAs) + } else if tc.errIs != nil { if tc.expectedFunc != nil { - tc.expectedFunc(t, err, tc.err) + tc.expectedFunc(t, err, tc.errIs) } - require.ErrorAs(t, err, &tc.err) + require.ErrorIs(t, err, tc.errIs) } else { require.NoError(t, err) } - // This should still work even if the session does not have identity data attached yet... + // This should still work even if the session does not have identity data attached yet ... s.Identity = nil - err = reg.SessionManager().DoesSessionSatisfy((&http.Request{}).WithContext(context.Background()), s, string(tc.requested), tc.sessionManagerOptions...) - if tc.err != nil { + err = reg.SessionManager().DoesSessionSatisfy((&http.Request{}).WithContext(ctx), s, string(tc.matcher), tc.sessionManagerOptions...) + if tc.errAs != nil { if tc.expectedFunc != nil { - tc.expectedFunc(t, err, tc.err) + tc.expectedFunc(t, err, tc.errAs) } - require.ErrorAs(t, err, &tc.err) + require.ErrorAs(t, err, &tc.errAs) } else { require.NoError(t, err) } - // ..or no credentials attached. + // ... or no credentials attached. s.Identity = id s.Identity.Credentials = nil - err = reg.SessionManager().DoesSessionSatisfy((&http.Request{}).WithContext(context.Background()), s, string(tc.requested), tc.sessionManagerOptions...) - if tc.err != nil { + err = reg.SessionManager().DoesSessionSatisfy((&http.Request{}).WithContext(ctx), s, string(tc.matcher), tc.sessionManagerOptions...) + if tc.errAs != nil { if tc.expectedFunc != nil { - tc.expectedFunc(t, err, tc.err) + tc.expectedFunc(t, err, tc.errAs) } - require.ErrorAs(t, err, &tc.err) + require.ErrorAs(t, err, &tc.errAs) } else { require.NoError(t, err) } diff --git a/session/session.go b/session/session.go index e5b826b88f2f..63261dce807f 100644 --- a/session/session.go +++ b/session/session.go @@ -210,6 +210,9 @@ func (s *Session) SetAuthenticatorAssuranceLevel() { isAAL1 = true case identity.AuthenticatorAssuranceLevel2: isAAL2 = true + // The following section is a graceful migration from Ory Kratos v0.9. + // + // TODO remove this section, it is already over 2 years old. case "": // Sessions before Ory Kratos 0.9 did not have the AAL // be part of the AMR. @@ -238,20 +241,11 @@ func (s *Session) SetAuthenticatorAssuranceLevel() { } else if isAAL1 { s.AuthenticatorAssuranceLevel = identity.AuthenticatorAssuranceLevel1 } else if len(s.AMR) > 0 { - // A fallback. If an AMR is set but we did not satisfy the above, gracefully fall back to level 1. + // A fallback. If an AMR is set, but we did not satisfy the above, gracefully fall back to level 1. s.AuthenticatorAssuranceLevel = identity.AuthenticatorAssuranceLevel1 } } -func NewActiveSession(r *http.Request, i *identity.Identity, c lifespanProvider, authenticatedAt time.Time, completedLoginFor identity.CredentialsType, completedLoginAAL identity.AuthenticatorAssuranceLevel) (*Session, error) { - s := NewInactiveSession() - s.CompletedLoginFor(completedLoginFor, completedLoginAAL) - if err := s.Activate(r, i, c, authenticatedAt); err != nil { - return nil, err - } - return s, nil -} - func NewInactiveSession() *Session { return &Session{ ID: uuid.Nil, @@ -262,23 +256,6 @@ func NewInactiveSession() *Session { } } -func (s *Session) Activate(r *http.Request, i *identity.Identity, c lifespanProvider, authenticatedAt time.Time) error { - if i != nil && !i.IsActive() { - return ErrIdentityDisabled.WithDetail("identity_id", i.ID) - } - - s.Active = true - s.ExpiresAt = authenticatedAt.Add(c.SessionLifespan(r.Context())) - s.AuthenticatedAt = authenticatedAt - s.IssuedAt = authenticatedAt - s.Identity = i - s.IdentityID = i.ID - - s.SetSessionDeviceInformation(r) - s.SetAuthenticatorAssuranceLevel() - return nil -} - func (s *Session) SetSessionDeviceInformation(r *http.Request) { device := Device{ SessionID: s.ID, diff --git a/session/session_test.go b/session/session_test.go index 4e1efe3b647a..75fc61ea300a 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/ory/kratos/x" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/assert" @@ -22,7 +24,8 @@ import ( func TestSession(t *testing.T) { ctx := context.Background() - conf, _ := internal.NewFastRegistryWithMocks(t) + conf, reg := internal.NewFastRegistryWithMocks(t) + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json") authAt := time.Now() t.Run("case=active session", func(t *testing.T) { @@ -30,14 +33,17 @@ func TestSession(t *testing.T) { i := new(identity.Identity) i.State = identity.StateActive - s, _ := session.NewActiveSession(req, i, conf, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + i.NID = x.NewUUID() + s, err := testhelpers.NewActiveSession(req, reg, i, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + require.NoError(t, err) assert.True(t, s.IsActive()) require.NotEmpty(t, s.Token) require.NotEmpty(t, s.LogoutToken) assert.EqualValues(t, identity.CredentialsTypePassword, s.AMR[0].Method) i = new(identity.Identity) - s, err := session.NewActiveSession(req, i, conf, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + i.NID = x.NewUUID() + s, err = testhelpers.NewActiveSession(req, reg, i, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) assert.Nil(t, s) assert.ErrorIs(t, err, session.ErrIdentityDisabled) }) @@ -62,13 +68,13 @@ func TestSession(t *testing.T) { req := testhelpers.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil) s := session.NewInactiveSession() - require.NoError(t, s.Activate(req, &identity.Identity{State: identity.StateActive}, conf, authAt)) + require.NoError(t, reg.SessionManager().ActivateSession(req, s, &identity.Identity{NID: x.NewUUID(), State: identity.StateActive}, authAt)) assert.True(t, s.Active) assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, s.AuthenticatorAssuranceLevel) assert.Equal(t, authAt, s.AuthenticatedAt) s = session.NewInactiveSession() - require.ErrorIs(t, s.Activate(req, &identity.Identity{State: identity.StateInactive}, conf, authAt), session.ErrIdentityDisabled) + require.ErrorIs(t, reg.SessionManager().ActivateSession(req, s, &identity.Identity{NID: x.NewUUID(), State: identity.StateInactive}, authAt), session.ErrIdentityDisabled) assert.False(t, s.Active) assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, s.AuthenticatorAssuranceLevel) assert.Empty(t, s.AuthenticatedAt) @@ -98,7 +104,7 @@ func TestSession(t *testing.T) { req.Header.Set("X-Forwarded-For", tc.input) s := session.NewInactiveSession() - require.NoError(t, s.Activate(req, &identity.Identity{State: identity.StateActive}, conf, authAt)) + require.NoError(t, reg.SessionManager().ActivateSession(req, s, &identity.Identity{NID: x.NewUUID(), State: identity.StateActive}, authAt)) assert.True(t, s.Active) assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, s.AuthenticatorAssuranceLevel) assert.Equal(t, authAt, s.AuthenticatedAt) @@ -118,7 +124,7 @@ func TestSession(t *testing.T) { req.Header["X-Forwarded-For"] = []string{"54.155.246.232", "10.145.1.10"} s := session.NewInactiveSession() - require.NoError(t, s.Activate(req, &identity.Identity{State: identity.StateActive}, conf, authAt)) + require.NoError(t, reg.SessionManager().ActivateSession(req, s, &identity.Identity{NID: x.NewUUID(), State: identity.StateActive}, authAt)) assert.True(t, s.Active) assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, s.AuthenticatorAssuranceLevel) assert.Equal(t, authAt, s.AuthenticatedAt) @@ -138,7 +144,7 @@ func TestSession(t *testing.T) { req.Header.Set("X-Forwarded-For", "217.73.188.139,162.158.203.149, 172.19.2.7") s := session.NewInactiveSession() - require.NoError(t, s.Activate(req, &identity.Identity{State: identity.StateActive}, conf, authAt)) + require.NoError(t, reg.SessionManager().ActivateSession(req, s, &identity.Identity{State: identity.StateActive, NID: x.NewUUID()}, authAt)) assert.True(t, s.Active) assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, s.AuthenticatorAssuranceLevel) assert.Equal(t, authAt, s.AuthenticatedAt) @@ -159,7 +165,7 @@ func TestSession(t *testing.T) { req.Header.Set("Cf-Ipcountry", "Germany") s := session.NewInactiveSession() - require.NoError(t, s.Activate(req, &identity.Identity{State: identity.StateActive}, conf, authAt)) + require.NoError(t, reg.SessionManager().ActivateSession(req, s, &identity.Identity{NID: x.NewUUID(), State: identity.StateActive}, authAt)) assert.True(t, s.Active) assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, s.AuthenticatorAssuranceLevel) assert.Equal(t, authAt, s.AuthenticatedAt) @@ -350,7 +356,9 @@ func TestSession(t *testing.T) { }) i := new(identity.Identity) i.State = identity.StateActive - s, _ := session.NewActiveSession(req, i, conf, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + i.NID = x.NewUUID() + s, err := testhelpers.NewActiveSession(req, reg, i, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + require.NoError(t, err) assert.False(t, s.CanBeRefreshed(ctx, conf), "fresh session is not refreshable") s.ExpiresAt = s.ExpiresAt.Add(-12 * time.Hour) diff --git a/session/tokenizer_test.go b/session/tokenizer_test.go index fab54f09d99e..29a213f03142 100644 --- a/session/tokenizer_test.go +++ b/session/tokenizer_test.go @@ -10,10 +10,13 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" + + "github.com/ory/kratos/internal/testhelpers" + "github.com/ory/herodot" "github.com/gofrs/uuid" - "github.com/golang-jwt/jwt/v5" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -68,6 +71,7 @@ func TestTokenizer(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) conf.MustSet(ctx, config.ViperKeyPublicBaseURL, "http://localhost/") + testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json") tkn := session.NewTokenizer(reg) nowDate := time.Date(2023, 02, 01, 00, 00, 00, 0, time.UTC) tkn.SetNowFunc(func() time.Time { @@ -77,8 +81,9 @@ func TestTokenizer(t *testing.T) { r := httptest.NewRequest("GET", "/sessions/whoami", nil) i := identity.NewIdentity("default") i.ID = uuid.FromStringOrNil("7458af86-c1d8-401c-978a-8da89133f78b") + i.NID = uuid.Must(uuid.NewV4()) - s, err := session.NewActiveSession(r, i, conf, now, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) + s, err := testhelpers.NewActiveSession(r, reg, i, now, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1) require.NoError(t, err) s.ID = uuid.FromStringOrNil("432caf86-c1d8-401c-978a-8da89133f78b") diff --git a/test/e2e/cypress/integration/profiles/code/login/error.spec.ts b/test/e2e/cypress/integration/profiles/code/login/error.spec.ts index 3310a966627f..244d40ab8132 100644 --- a/test/e2e/cypress/integration/profiles/code/login/error.spec.ts +++ b/test/e2e/cypress/integration/profiles/code/login/error.spec.ts @@ -160,6 +160,12 @@ context("Login error messages with code method", () => { "contain", "Property identifier is missing", ) + } else if (app === "react") { + // The backspace trick is not working in React. + cy.get('[data-testid="ui/message/4010008"]').should( + "contain", + "code is invalid", + ) } else { cy.get('[data-testid="ui/message/4000002"]').should( "contain", diff --git a/test/e2e/cypress/integration/profiles/mfa/code.spec.ts b/test/e2e/cypress/integration/profiles/mfa/code.spec.ts index c7d29c0561c2..313e85e5382c 100644 --- a/test/e2e/cypress/integration/profiles/mfa/code.spec.ts +++ b/test/e2e/cypress/integration/profiles/mfa/code.spec.ts @@ -3,7 +3,6 @@ import { gen, website } from "../../../helpers" import { routes as express } from "../../../helpers/express" -import { routes as react } from "../../../helpers/react" context("2FA code", () => { ;[ @@ -31,72 +30,110 @@ context("2FA code", () => { let email: string let password: string - beforeEach(() => { - email = gen.email() - password = gen.password() - cy.useConfig((builder) => - builder - .longPrivilegedSessionTime() - .useLaxAal() - .enableCode() - .enableCodeMFA(), - ) - - cy.register({ - email, - password, - fields: { "traits.website": website }, + describe("when using highest_available aal", () => { + beforeEach(() => { + cy.useConfig((builder) => + builder + .longPrivilegedSessionTime() + .useHighestAvailable() + .enableCodeMFA(), + ) + }) + + it("should show second factor screen on whoami call", () => { + email = gen.email() + password = gen.password() + cy.register({ + email, + password, + fields: { "traits.website": website }, + }) + cy.deleteMail() + + cy.visit(settings) + cy.location("pathname").should("contain", "/login") // we get redirected to login + + cy.get("[type='submit'][name='address']").should("be.visible").click() + + cy.getLoginCodeFromEmail(email).then((code) => { + cy.get("input[name='code']").type(code) + cy.contains("Continue").click() + }) + + cy.getSession({ + expectAal: "aal2", + expectMethods: ["password", "code"], + }) }) - cy.deleteMail() - cy.visit(login + "?aal=aal2&via=email") }) - it("should be asked to sign in with 2fa if set up", () => { - cy.get("input[name='identifier']").type(email) - cy.contains("Continue with code").click() + describe("when using aal1 required aal", () => { + beforeEach(() => { + email = gen.email() + password = gen.password() + cy.useConfig((builder) => + builder + .longPrivilegedSessionTime() + .useLaxAal() + .enableCode() + .enableCodeMFA(), + ) - cy.get("input[name='code']").should("be.visible") - cy.getLoginCodeFromEmail(email).then((code) => { - cy.get("input[name='code']").type(code) - cy.contains("Continue").click() + cy.register({ + email, + password, + fields: { "traits.website": website }, + }) + cy.deleteMail() + cy.visit(login + "?aal=aal2&via=email") }) - cy.getSession({ - expectAal: "aal2", - expectMethods: ["password", "code"], + it("should be asked to sign in with 2fa if set up", () => { + cy.get("*[name='address']").click() + + cy.get("input[name='code']").should("be.visible") + cy.getLoginCodeFromEmail(email).then((code) => { + cy.get("input[name='code']").type(code) + cy.contains("Continue").click() + }) + + cy.getSession({ + expectAal: "aal2", + expectMethods: ["password", "code"], + }) }) - }) - it("can't use different email in 2fa request", () => { - cy.get("input[name='identifier']").type(gen.email()) - cy.contains("Continue with code").click() + it("can't use different email in 2fa request", () => { + cy.get('[name="address"]').invoke("attr", "value", gen.email()) + + cy.get('[name="address"]').click() - cy.get("*[data-testid='ui/message/4010010']").should("be.visible") - cy.get("input[name='code']").should("not.exist") - cy.get("input[name='identifier']").should("be.visible") + cy.get("*[data-testid='ui/message/4000035']").should("be.visible") + cy.get("input[name='code']").should("not.exist") + cy.get("[name='address']").should("be.visible") - // The current session should be unchanged - cy.getSession({ - expectAal: "aal1", - expectMethods: ["password"], + // The current session should be unchanged + cy.getSession({ + expectAal: "aal1", + expectMethods: ["password"], + }) }) - }) - it("entering wrong code should not invalidate correct codes", () => { - cy.get("input[name='identifier']").type(email) - cy.contains("Continue with code").click() + it("entering wrong code should not invalidate correct codes", () => { + cy.get("*[name='address']").click() + + cy.get("input[name='code']").should("be.visible").type("123456") - cy.get("input[name='code']").should("be.visible") - cy.get("input[name='code']").type("123456") - cy.contains("Continue").click() - cy.getLoginCodeFromEmail(email).then((code) => { - cy.get("input[name='code']").type(code) cy.contains("Continue").click() - }) + cy.getLoginCodeFromEmail(email).then((code) => { + cy.get("input[name='code']").type(code) + cy.contains("Continue").click() + }) - cy.getSession({ - expectAal: "aal2", - expectMethods: ["password", "code"], + cy.getSession({ + expectAal: "aal2", + expectMethods: ["password", "code"], + }) }) }) }) diff --git a/test/e2e/cypress/support/configHelpers.ts b/test/e2e/cypress/support/configHelpers.ts index 0fc72864294a..bce083d63d77 100644 --- a/test/e2e/cypress/support/configHelpers.ts +++ b/test/e2e/cypress/support/configHelpers.ts @@ -132,6 +132,13 @@ export class ConfigBuilder { this.config.session.whoami.required_aal = "aal1" return this } + + public useHighestAvailable() { + this.config.selfservice.flows.settings.required_aal = "highest_available" + this.config.session.whoami.required_aal = "highest_available" + return this + } + public enableCode() { this.config.selfservice.methods.code.enabled = true return this diff --git a/test/e2e/playwright/tests/desktop/identifier_first/code.login.spec.ts b/test/e2e/playwright/tests/desktop/identifier_first/code.login.spec.ts index aed035d4c2ae..ddb0d06c4be9 100644 --- a/test/e2e/playwright/tests/desktop/identifier_first/code.login.spec.ts +++ b/test/e2e/playwright/tests/desktop/identifier_first/code.login.spec.ts @@ -88,6 +88,9 @@ test.describe("account enumeration protection on", () => { mitigateEnumeration: true, selfservice: { methods: { + password: { + enabled: false, + }, code: { passwordless_enabled: true, }, diff --git a/test/e2e/playwright/tests/desktop/identifier_first/oidc.login.spec.ts b/test/e2e/playwright/tests/desktop/identifier_first/oidc.login.spec.ts index cf0a19d23a87..5568b88a1aab 100644 --- a/test/e2e/playwright/tests/desktop/identifier_first/oidc.login.spec.ts +++ b/test/e2e/playwright/tests/desktop/identifier_first/oidc.login.spec.ts @@ -63,7 +63,16 @@ for (const mitigateEnumeration of [true, false]) { mitigateEnumeration ? "on" : "off" }`, () => { test.use({ - configOverride: toConfig({ mitigateEnumeration }), + configOverride: toConfig({ + mitigateEnumeration, + selfservice: { + methods: { + password: { + enabled: true, + }, + }, + }, + }), }) test("login", async ({ page, config, kratosPublicURL }) => { diff --git a/test/e2e/profiles/email/identity.traits.schema.json b/test/e2e/profiles/email/identity.traits.schema.json index 59a799a43d20..e5d40a431bbd 100644 --- a/test/e2e/profiles/email/identity.traits.schema.json +++ b/test/e2e/profiles/email/identity.traits.schema.json @@ -17,6 +17,10 @@ "password": { "identifier": true }, + "code": { + "identifier": true, + "via": "email" + }, "webauthn": { "identifier": true } diff --git a/test/e2e/profiles/mfa/.kratos.yml b/test/e2e/profiles/mfa/.kratos.yml index 99becd59a868..91f0bfb71042 100644 --- a/test/e2e/profiles/mfa/.kratos.yml +++ b/test/e2e/profiles/mfa/.kratos.yml @@ -44,7 +44,7 @@ selfservice: identity: schemas: - id: default - url: file://test/e2e/profiles/email/identity.traits.schema.json + url: file://test/e2e/profiles/mfa/identity.traits.schema.json session: whoami: diff --git a/test/e2e/profiles/mfa/identity.traits.schema.json b/test/e2e/profiles/mfa/identity.traits.schema.json index 87db142ee8d2..36aef3680f4d 100644 --- a/test/e2e/profiles/mfa/identity.traits.schema.json +++ b/test/e2e/profiles/mfa/identity.traits.schema.json @@ -19,6 +19,10 @@ }, "webauthn": { "identifier": true + }, + "code": { + "identifier": true, + "via": "email" } } } diff --git a/test/e2e/profiles/oidc/identity-required.traits.schema.json b/test/e2e/profiles/oidc/identity-required.traits.schema.json index 19241c7f2760..20b2b50e9995 100644 --- a/test/e2e/profiles/oidc/identity-required.traits.schema.json +++ b/test/e2e/profiles/oidc/identity-required.traits.schema.json @@ -19,6 +19,10 @@ }, "webauthn": { "identifier": true + }, + "code": { + "identifier": true, + "via": "email" } } } diff --git a/test/e2e/profiles/oidc/identity.traits.schema.json b/test/e2e/profiles/oidc/identity.traits.schema.json index 10fb49486aed..67857c8ab987 100644 --- a/test/e2e/profiles/oidc/identity.traits.schema.json +++ b/test/e2e/profiles/oidc/identity.traits.schema.json @@ -19,6 +19,10 @@ }, "webauthn": { "identifier": true + }, + "code": { + "identifier": true, + "via": "email" } }, "verification": { diff --git a/text/id.go b/text/id.go index edff417a2738..fafe8e450344 100644 --- a/text/id.go +++ b/text/id.go @@ -32,6 +32,7 @@ const ( InfoSelfServiceLoginCodeMFAHint // 1010020 InfoSelfServiceLoginPasskey // 1010021 InfoSelfServiceLoginPassword // 1010022 + InfoSelfServiceLoginAAL2CodeAddress // 1010023 ) const ( diff --git a/text/message_login.go b/text/message_login.go index 9312a21e97a2..0a6ca5684181 100644 --- a/text/message_login.go +++ b/text/message_login.go @@ -271,17 +271,18 @@ func NewInfoSelfServiceLoginCodeMFA() *Message { return &Message{ ID: InfoSelfServiceLoginCodeMFA, Type: Info, - Text: "Continue with code", + Text: "Request code to continue", } } -func NewInfoSelfServiceLoginCodeMFAHint(maskedTo string) *Message { +func NewInfoSelfServiceLoginAAL2CodeAddress(channel string, to string) *Message { return &Message{ - ID: InfoSelfServiceLoginCodeMFAHint, + ID: InfoSelfServiceLoginAAL2CodeAddress, Type: Info, - Text: fmt.Sprintf("We will send a code to %s. To verify that this is your address please enter it here.", maskedTo), + Text: fmt.Sprintf("Send code to %s", to), Context: context(map[string]any{ - "masked_to": maskedTo, + "address": to, + "channel": channel, }), } } diff --git a/text/message_registration.go b/text/message_registration.go index 7779143a1cc0..e0d62bde2610 100644 --- a/text/message_registration.go +++ b/text/message_registration.go @@ -106,7 +106,7 @@ func NewErrorValidationRegistrationRetrySuccessful() *Message { func NewInfoSelfServiceRegistrationRegisterCode() *Message { return &Message{ ID: InfoSelfServiceRegistrationRegisterCode, - Text: "Sign up with code", + Text: "Send sign up code", Type: Info, } } diff --git a/x/normalize.go b/x/normalize.go new file mode 100644 index 000000000000..9429fa12bd92 --- /dev/null +++ b/x/normalize.go @@ -0,0 +1,77 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package x + +import ( + "strings" + + "github.com/nyaruka/phonenumbers" + "github.com/pkg/errors" +) + +// NormalizeEmailIdentifier normalizes an email address. +func NormalizeEmailIdentifier(value string) string { + if strings.Contains(value, "@") { + value = strings.TrimSpace(strings.ToLower(value)) + } + return value +} + +// NormalizePhoneIdentifier normalizes a phone number. +func NormalizePhoneIdentifier(value string) string { + if number, err := phonenumbers.Parse(value, ""); err == nil && phonenumbers.IsValidNumber(number) { + value = phonenumbers.Format(number, phonenumbers.E164) + } + return value +} + +// NormalizeOtherIdentifier normalizes an identifier that is not an email or phone number. +func NormalizeOtherIdentifier(value string) string { + return strings.TrimSpace(value) +} + +// GracefulNormalization normalizes an identifier based on the format. +// +// Supported formats are: +// +// - email +// - phone +// - username +func GracefulNormalization(value string) string { + if number, err := phonenumbers.Parse(value, ""); err == nil && phonenumbers.IsValidNumber(number) { + return phonenumbers.Format(number, phonenumbers.E164) + } else if strings.Contains(value, "@") { + return NormalizeEmailIdentifier(value) + } + return NormalizeOtherIdentifier(value) +} + +// NormalizeIdentifier normalizes an identifier based on the format. +// +// Supported formats are: +// +// - email +// - phone +// - username +func NormalizeIdentifier(value, format string) (string, error) { + switch format { + case "email": + return NormalizeEmailIdentifier(value), nil + case "sms": + number, err := phonenumbers.Parse(value, "") + if err != nil { + return "", err + } + + if !phonenumbers.IsValidNumber(number) { + return "", errors.New("the provided number is not a valid phone number") + } + + return phonenumbers.Format(number, phonenumbers.E164), nil + case "username": + fallthrough + default: + return NormalizeOtherIdentifier(value), nil + } +} diff --git a/x/normalize_test.go b/x/normalize_test.go new file mode 100644 index 000000000000..3b7993527c28 --- /dev/null +++ b/x/normalize_test.go @@ -0,0 +1,95 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package x + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNormalizeEmailIdentifier(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {" EXAMPLE@DOMAIN.COM ", "example@domain.com"}, + {"user@domain.com", "user@domain.com"}, + {"invalid-email", "invalid-email"}, + } + + for _, test := range tests { + assert.Equal(t, test.expected, NormalizeEmailIdentifier(test.input)) + } +} + +func TestNormalizePhoneIdentifier(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"+1 650-253-0000", "+16502530000"}, + {"+1 (650) 253-0000", "+16502530000"}, + {"invalid-phone", "invalid-phone"}, + } + + for _, test := range tests { + assert.Equal(t, test.expected, NormalizePhoneIdentifier(test.input)) + } +} + +func TestNormalizeOtherIdentifier(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {" username ", "username"}, + {"user123", "user123"}, + {" ", ""}, + } + + for _, test := range tests { + assert.Equal(t, test.expected, NormalizeOtherIdentifier(test.input)) + } +} + +func TestGracefulNormalization(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"+1 650-253-0000", "+16502530000"}, + {" EXAMPLE@DOMAIN.COM ", "example@domain.com"}, + {" username ", "username"}, + {"invalid-phone", "invalid-phone"}, + } + + for _, test := range tests { + assert.Equal(t, test.expected, GracefulNormalization(test.input)) + } +} + +func TestNormalizeIdentifier(t *testing.T) { + tests := []struct { + input string + format string + expected string + err bool + }{ + {" EXAMPLE@DOMAIN.COM ", "email", "example@domain.com", false}, + {"+1 650-253-0000", "sms", "+16502530000", false}, + {" username ", "username", "username", false}, + {"invalid-phone", "sms", "", true}, + } + + for _, test := range tests { + result, err := NormalizeIdentifier(test.input, test.format) + if test.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, test.expected, result) + } + } +} diff --git a/x/transaction.go b/x/transaction.go new file mode 100644 index 000000000000..117a1bf5cb6e --- /dev/null +++ b/x/transaction.go @@ -0,0 +1,20 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package x + +import ( + "context" + + "github.com/gobuffalo/pop/v6" +) + +type ( + TransactionPersistenceProvider interface { + TransactionalPersisterProvider() TransactionalPersister + } + + TransactionalPersister interface { + Transaction(ctx context.Context, callback func(ctx context.Context, connection *pop.Connection) error) error + } +)