From 5b7141b7efbb8d258deb02932874382109bae0ff Mon Sep 17 00:00:00 2001 From: noppoman Date: Wed, 19 Jul 2017 17:33:24 +0900 Subject: [PATCH] add blockForCallbackURLQueryParams label to Oauth1 to extend callback url depends on the request --- Sources/HexavilleAuth/HexavilleAuth+Router.swift | 3 ++- Sources/HexavilleAuth/OAuth/OAuth1.swift | 8 +++++--- Sources/HexavilleAuth/Providers/CallbackURL.swift | 11 +++++++++++ .../OAuth1/TwitterAuthorizationProvider.swift | 3 ++- .../Providers/OAuth1AuthorizationProvidable.swift | 8 ++++---- 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/Sources/HexavilleAuth/HexavilleAuth+Router.swift b/Sources/HexavilleAuth/HexavilleAuth+Router.swift index 20a53f4..16e86d4 100644 --- a/Sources/HexavilleAuth/HexavilleAuth+Router.swift +++ b/Sources/HexavilleAuth/HexavilleAuth+Router.swift @@ -16,7 +16,8 @@ extension HexavilleAuth { switch type { case .oauth1(let provider): router.use(.get, provider.path) { request, context in - let requestToken = try provider.getRequestToken() + let queryItems: [URLQueryItem] = provider.oauth.blockForCallbackURLQueryParams?(request) ?? [] + let requestToken = try provider.getRequestToken(withCallbackURLQueryItems: queryItems) context.session?["hexaville.oauth_token_secret"] = requestToken.oauthTokenSecret context.session?["hexaville.oauth_token"] = requestToken.oauthToken let location = try provider.createAuthorizeURL(requestToken: requestToken).absoluteString diff --git a/Sources/HexavilleAuth/OAuth/OAuth1.swift b/Sources/HexavilleAuth/OAuth/OAuth1.swift index 92d2498..e631230 100644 --- a/Sources/HexavilleAuth/OAuth/OAuth1.swift +++ b/Sources/HexavilleAuth/OAuth/OAuth1.swift @@ -100,15 +100,17 @@ public class OAuth1 { let authorizeUrl: String let accessTokenUrl: String let callbackURL: CallbackURL + let blockForCallbackURLQueryParams: ((Request) -> [URLQueryItem])? let withAllowedCharacters: CharacterSet - public init(consumerKey: String, consumerSecret: String, requestTokenUrl: String, authorizeUrl: String, accessTokenUrl: String, callbackURL: CallbackURL, withAllowedCharacters: CharacterSet = CharacterSet.alphanumerics) { + public init(consumerKey: String, consumerSecret: String, requestTokenUrl: String, authorizeUrl: String, accessTokenUrl: String, callbackURL: CallbackURL, blockForCallbackURLQueryParams: ((Request) -> [URLQueryItem])? = nil, withAllowedCharacters: CharacterSet = CharacterSet.alphanumerics) { self.consumerKey = consumerKey self.consumerSecret = consumerSecret self.requestTokenUrl = requestTokenUrl self.authorizeUrl = authorizeUrl self.accessTokenUrl = accessTokenUrl self.callbackURL = callbackURL + self.blockForCallbackURLQueryParams = blockForCallbackURLQueryParams self.withAllowedCharacters = withAllowedCharacters } @@ -116,9 +118,9 @@ public class OAuth1 { return dict.map({ "\($0.key)=\($0.value)" }).joined(separator: "&") } - public func getRequestToken() throws -> RequestToken { + public func getRequestToken(withCallbackURLQueryItems queryItems: [URLQueryItem]) throws -> RequestToken { var params = [ - "oauth_callback": callbackURL.absoluteURL()!.absoluteString, + "oauth_callback": callbackURL.absoluteURL(withQueryItems: queryItems)!.absoluteString, "oauth_consumer_key": consumerKey, "oauth_nonce": OAuth1.generateNonce(), "oauth_signature_method": "HMAC-SHA1", diff --git a/Sources/HexavilleAuth/Providers/CallbackURL.swift b/Sources/HexavilleAuth/Providers/CallbackURL.swift index 6b89172..c1a70d8 100644 --- a/Sources/HexavilleAuth/Providers/CallbackURL.swift +++ b/Sources/HexavilleAuth/Providers/CallbackURL.swift @@ -20,4 +20,15 @@ public struct CallbackURL { public func absoluteURL() -> URL? { return URL(string: "\(baseURL)\(path)") } + + public func absoluteURL(withQueryItems queryItems: [URLQueryItem]) -> URL? { + guard let url = absoluteURL() else { return nil } + if queryItems.count > 0 { + let additionalQuery = queryItems.filter({ $0.value != nil }).map({ "\($0.name)=\($0.value!)" }).joined(separator: "&") + let separator = url.queryItems.count == 0 ? "?" : "&" + return URL(string: url.absoluteString+separator+additionalQuery) + } + + return url + } } diff --git a/Sources/HexavilleAuth/Providers/OAuth1/TwitterAuthorizationProvider.swift b/Sources/HexavilleAuth/Providers/OAuth1/TwitterAuthorizationProvider.swift index 89c413c..6e8a8ca 100644 --- a/Sources/HexavilleAuth/Providers/OAuth1/TwitterAuthorizationProvider.swift +++ b/Sources/HexavilleAuth/Providers/OAuth1/TwitterAuthorizationProvider.swift @@ -25,7 +25,7 @@ public struct TwitterAuthorizationProvider: OAuth1AuthorizationProvidable { public let callback: RespodWithCredential - public init(path: String, consumerKey: String, consumerSecret: String, callbackURL: CallbackURL, scope: String, callback: @escaping RespodWithCredential) { + public init(path: String, consumerKey: String, consumerSecret: String, callbackURL: CallbackURL, blockForCallbackURLQueryParams: ((Request) -> [URLQueryItem])? = nil, scope: String, callback: @escaping RespodWithCredential) { self.path = path self.oauth = OAuth1( @@ -35,6 +35,7 @@ public struct TwitterAuthorizationProvider: OAuth1AuthorizationProvidable { authorizeUrl: "https://api.twitter.com/oauth/authenticate", accessTokenUrl: "https://api.twitter.com/oauth/access_token", callbackURL: callbackURL, + blockForCallbackURLQueryParams: blockForCallbackURLQueryParams, withAllowedCharacters: .twitterQueryAllowed ) diff --git a/Sources/HexavilleAuth/Providers/OAuth1AuthorizationProvidable.swift b/Sources/HexavilleAuth/Providers/OAuth1AuthorizationProvidable.swift index 1ed83e2..7c5075c 100644 --- a/Sources/HexavilleAuth/Providers/OAuth1AuthorizationProvidable.swift +++ b/Sources/HexavilleAuth/Providers/OAuth1AuthorizationProvidable.swift @@ -13,16 +13,16 @@ public protocol OAuth1AuthorizationProvidable { var path: String { get } var oauth: OAuth1 { get } var callback: RespodWithCredential { get } - init(path: String, consumerKey: String, consumerSecret: String, callbackURL: CallbackURL, scope: String, callback: @escaping RespodWithCredential) - func getRequestToken() throws -> RequestToken + init(path: String, consumerKey: String, consumerSecret: String, callbackURL: CallbackURL, blockForCallbackURLQueryParams: ((Request) -> [URLQueryItem])?, scope: String, callback: @escaping RespodWithCredential) + func getRequestToken(withCallbackURLQueryItems queryItems: [URLQueryItem]) throws -> RequestToken func createAuthorizeURL(requestToken: RequestToken) throws -> URL func getAccessToken(request: Request, requestToken: RequestToken) throws -> Credential func authorize(request: Request, requestToken: RequestToken) throws -> (Credential, LoginUser) } extension OAuth1AuthorizationProvidable { - public func getRequestToken() throws -> RequestToken { - return try oauth.getRequestToken() + public func getRequestToken(withCallbackURLQueryItems queryItems: [URLQueryItem]) throws -> RequestToken { + return try oauth.getRequestToken(withCallbackURLQueryItems: queryItems) } public func createAuthorizeURL(requestToken: RequestToken) throws -> URL {