From 2786e93df2e8ddcb8c822d3801fbbebb36c4d900 Mon Sep 17 00:00:00 2001 From: Ziver Koc Date: Wed, 25 Nov 2020 00:51:26 +0100 Subject: [PATCH] Implemented OAuth2 refresh_token process --- .../page/oauth/OAuth2AuthorizationPage.java | 7 +- .../net/http/page/oauth/OAuth2Registry.java | 123 +++++++++++++----- .../net/http/page/oauth/OAuth2TokenPage.java | 107 ++++++++------- .../oauth/OAuth2AuthorizationPageTest.java | 4 +- .../http/page/oauth/OAuth2TokenPageTest.java | 53 +++++++- 5 files changed, 208 insertions(+), 86 deletions(-) diff --git a/src/zutil/net/http/page/oauth/OAuth2AuthorizationPage.java b/src/zutil/net/http/page/oauth/OAuth2AuthorizationPage.java index 43cbc6d..deceee0 100644 --- a/src/zutil/net/http/page/oauth/OAuth2AuthorizationPage.java +++ b/src/zutil/net/http/page/oauth/OAuth2AuthorizationPage.java @@ -24,6 +24,7 @@ package zutil.net.http.page.oauth; +import zutil.Hasher; import zutil.log.LogUtil; import zutil.net.http.HttpHeader; import zutil.net.http.HttpPage; @@ -145,7 +146,7 @@ public class OAuth2AuthorizationPage implements HttpPage { switch (request.get("response_type")) { case RESPONSE_TYPE_CODE: - String code = generateCode(); + String code = registry.generateCode(); registry.registerAuthorizationCode(clientId, code); url.setParameter("code", code); @@ -165,10 +166,6 @@ public class OAuth2AuthorizationPage implements HttpPage { redirect(out, url); } - private String generateCode() { - return String.valueOf(Math.abs(random.nextLong())); - } - // ------------------------------------------------------ // Error handling // ------------------------------------------------------ diff --git a/src/zutil/net/http/page/oauth/OAuth2Registry.java b/src/zutil/net/http/page/oauth/OAuth2Registry.java index d5abcb8..e7bfd66 100644 --- a/src/zutil/net/http/page/oauth/OAuth2Registry.java +++ b/src/zutil/net/http/page/oauth/OAuth2Registry.java @@ -24,21 +24,27 @@ package zutil.net.http.page.oauth; +import zutil.Hasher; import zutil.Timer; +import java.io.Serializable; import java.util.HashMap; import java.util.Map; +import java.util.Random; /** * A data class containing authentication information for individual * clients going through the OAuth 2 process. */ -public class OAuth2Registry { - private static final long DEFAULT_TIMEOUT = 24 * 60 * 60 * 1000; // 24h +public class OAuth2Registry implements Serializable { + private static final long DEFAULT_CODE_TIMEOUT = 10 * 60 * 1000; // 10min + private static final long DEFAULT_TOKEN_TIMEOUT = 24 * 60 * 60 * 1000; // 24h - private Map clientRegistry = new HashMap<>(); + private Map clientRegisters = new HashMap<>(); private boolean requireWhitelist = true; + transient private Random random = new Random(); + // ------------------------------------------------------ // Whitelist methods @@ -62,8 +68,8 @@ public class OAuth2Registry { * @param clientId A String ID that should be whitelisted */ public void addWhitelist(String clientId) { - if (!clientRegistry.containsKey(clientId)) { - clientRegistry.put(clientId, new ClientRegister()); + if (!clientRegisters.containsKey(clientId)) { + clientRegisters.put(clientId, new ClientRegister()); } } @@ -83,21 +89,17 @@ public class OAuth2Registry { if (!requireWhitelist) return true; - return clientRegistry.containsKey(clientId); + return clientRegisters.containsKey(clientId); } /** * Validates that a authorization code has valid format and has been authorized and not elapsed. * - * @param clientId the id of the requesting client * @param code the code that should be validated * @return true if the given code is valid otherwise false. */ - public boolean isAuthorizationCodeValid(String clientId, String code) { - if (clientId == null || code == null) - return false; - - ClientRegister reg = getClientRegistry(clientId); + public boolean isAuthorizationCodeValid(String code) { + ClientRegister reg = getClientRegisterForAuthCode(code); if (reg != null) { return reg.authCodes.containsKey(code) && @@ -109,34 +111,40 @@ public class OAuth2Registry { /** * Validates that a access token has valid format and has been authorized and not elapsed. * - * @param clientId the id of the requesting client * @param token the token that should be validated * @return true if the given token is valid otherwise false. */ - public boolean isAccessTokenValid(String clientId, String token) { - if (clientId == null || token == null) - return false; - - ClientRegister reg = getClientRegistry(clientId); + public boolean isAccessTokenValid(String token) { + ClientRegister reg = getClientRegisterForToken(token); if (reg != null) { - boolean b1 = reg.accessTokens.containsKey(token); - boolean b2 = reg.accessTokens.get(token).hasTimedOut(); return reg.accessTokens.containsKey(token) && !reg.accessTokens.get(token).hasTimedOut(); } return false; } + // ------------------------------------------------------ + // Revocation + // ------------------------------------------------------ + + public void revokeAuthorizationCode(String code) { + ClientRegister reg = getClientRegisterForAuthCode(code); + + if (reg != null) { + reg.authCodes.remove(code); + } + } + // ------------------------------------------------------ // OAuth2 process methods // ------------------------------------------------------ protected long registerAuthorizationCode(String clientId, String code) { - return registerAuthorizationCode(clientId, code, DEFAULT_TIMEOUT); + return registerAuthorizationCode(clientId, code, DEFAULT_CODE_TIMEOUT); } protected long registerAuthorizationCode(String clientId, String code, long timeoutMillis) { - ClientRegister reg = getClientRegistry(clientId); + ClientRegister reg = getClientRegister(clientId); if (reg != null) { reg.authCodes.put(code, new Timer(timeoutMillis).start()); @@ -146,10 +154,10 @@ public class OAuth2Registry { } protected long registerAccessToken(String clientId, String token) { - return registerAccessToken(clientId, token, DEFAULT_TIMEOUT); + return registerAccessToken(clientId, token, DEFAULT_TOKEN_TIMEOUT); } protected long registerAccessToken(String clientId, String token, long timeoutMillis) { - ClientRegister reg = getClientRegistry(clientId); + ClientRegister reg = getClientRegister(clientId); if (reg != null) { reg.accessTokens.put(token, new Timer(timeoutMillis).start()); @@ -158,16 +166,67 @@ public class OAuth2Registry { return -1; } - // -------------------------------------------------------------------- - - private ClientRegister getClientRegistry(String clientId) { - if (!requireWhitelist && !clientRegistry.containsKey(clientId)) - clientRegistry.put(clientId, new ClientRegister()); - - return clientRegistry.get(clientId); + protected String generateCode() { + return generateToken(); } - private static class ClientRegister { + protected String generateToken() { + return Hasher.SHA1(Math.abs(random.nextLong())); + } + + // ------------------------------------------------------ + // Data methods + // ------------------------------------------------------ + + /** + * @param code is the authentication code given to the client. + * @return The client_id registered for the given code + */ + public String getClientIdForAuthenticationCode(String code) { + for (String clientId : clientRegisters.keySet()) { + if (clientRegisters.get(clientId).authCodes.containsKey(code)) + return clientId; + } + + return null; + } + + /** + * @param token is the access token given to the client. + * @return The client_id registered for the given token + */ + public String getClientIdForAccessToken(String token) { + for (String clientId : clientRegisters.keySet()) { + if (clientRegisters.get(clientId).accessTokens.containsKey(token)) + return clientId; + } + + return null; + } + + // ------------------------------------------------------ + + private ClientRegister getClientRegister(String clientId) { + if (!requireWhitelist && !clientRegisters.containsKey(clientId)) + clientRegisters.put(clientId, new ClientRegister()); + + return clientRegisters.get(clientId); + } + + private ClientRegister getClientRegisterForAuthCode(String code) { + String clientId = getClientIdForAuthenticationCode(code); + + return (clientId == null ? null : clientRegisters.get(clientId)); + } + + private ClientRegister getClientRegisterForToken(String token) { + String clientId = getClientIdForAccessToken(token); + + return (clientId == null ? null : clientRegisters.get(clientId)); + } + + + private static class ClientRegister implements Serializable { Map authCodes = new HashMap<>(); Map accessTokens = new HashMap<>(); } diff --git a/src/zutil/net/http/page/oauth/OAuth2TokenPage.java b/src/zutil/net/http/page/oauth/OAuth2TokenPage.java index 3a66cae..8eaba61 100644 --- a/src/zutil/net/http/page/oauth/OAuth2TokenPage.java +++ b/src/zutil/net/http/page/oauth/OAuth2TokenPage.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Random; import java.util.logging.Logger; +import zutil.Hasher; import zutil.log.LogUtil; import zutil.net.http.HttpHeader; import zutil.net.http.HttpPrintStream; @@ -54,6 +55,8 @@ import zutil.parser.DataNode; public class OAuth2TokenPage extends HttpJsonPage { private static final Logger logger = LogUtil.getLogger(); + private static final long REFRESH_TOKEN_TIMEOUT = 60 * 24 * 60 * 60 * 1000L; // 60 days + /** The request is missing a required parameter, includes an unsupported parameter value (other than grant type), repeats a parameter, includes multiple credentials, utilizes more than one mechanism for authenticating the client, or is otherwise malformed. **/ @@ -72,7 +75,6 @@ public class OAuth2TokenPage extends HttpJsonPage { /** The requested scope is invalid, unknown, malformed, or exceeds the scope granted by the resource owner. **/ private static final String ERROR_INVALID_SCOPE = "invalid_scope"; - private Random random = new Random(); private OAuth2Registry registry; @@ -101,68 +103,85 @@ public class OAuth2TokenPage extends HttpJsonPage { DataNode jsonRes = new DataNode(DataNode.DataType.Map); - // Validate client_id - - if (!request.containsKey("client_id")) - return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: client_id"); - - String clientId = request.get("client_id"); - - if (!registry.isClientIdValid(clientId)) - return errorResponse(out, ERROR_INVALID_CLIENT , request.get("state"), "Invalid client_id value."); - - // Validate code - - if (!request.containsKey("code")) - return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: code"); - - if (!registry.isAuthorizationCodeValid(clientId, request.get("code"))) - return errorResponse(out, ERROR_INVALID_GRANT, request.get("state"), "Invalid authorization code value."); - - // Validate redirect_uri - - if (!request.containsKey("redirect_uri")) - return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: redirect_uri"); - - // TODO: ensure that the "redirect_uri" parameter is present if the - // "redirect_uri" parameter was included in the initial authorization - // request as described in Section 4.1.1, and if included ensure that - // their values are identical. - // Validate grant_type - if (!request.containsKey("grant_type")) + String grantType = request.get("grant_type"); + String codeKey; + String clientId = null; + + if (grantType == null) return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter grant_type."); + switch (grantType) { + case "authorization_code": + codeKey = "code"; + + // Validate client_id + + clientId = request.get("client_id"); + + if (clientId == null) + return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: client_id"); + + if (!registry.isClientIdValid(clientId)) + return errorResponse(out, ERROR_INVALID_CLIENT , request.get("state"), "Invalid client_id value."); + + // Validate redirect_uri + + if (!request.containsKey("redirect_uri")) + return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: redirect_uri"); + + // TODO: ensure that the "redirect_uri" parameter is present if the + // "redirect_uri" parameter was included in the initial authorization + // request as described in Section 4.1.1, and if included ensure that + // their values are identical. + + break; + + case "refresh_token": + codeKey = "refresh_token"; + break; + + default: + return errorResponse(out, ERROR_UNSUPPORTED_GRANT_TYPE, request.get("state"), "Unsupported grant_type: " + request.containsKey("grant_type")); + } + + // Validate code and refresh_token + + String authorizationCode = request.get(codeKey); + + if (authorizationCode == null) + return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: " + codeKey); + + if (!registry.isAuthorizationCodeValid(authorizationCode)) + return errorResponse(out, ERROR_INVALID_GRANT, request.get("state"), "Invalid " + codeKey + " value."); + // ----------------------------------------------- // Handle request // ----------------------------------------------- - String grantType = request.get("grant_type"); + if (clientId == null) + clientId = registry.getClientIdForAuthenticationCode(authorizationCode); - switch (grantType) { - case "authorization_code": - jsonRes.set("refresh_token", "TODO"); // TODO: implement refresh logic - break; - default: - return errorResponse(out, ERROR_UNSUPPORTED_GRANT_TYPE, request.get("state"), "Unsupported grant_type: " + request.containsKey("grant_type")); - } - - String token = generateToken(); + String token = registry.generateToken(); long timeoutMillis = registry.registerAccessToken(clientId, token); + String refreshToken = registry.generateToken(); + registry.registerAuthorizationCode(clientId, refreshToken, REFRESH_TOKEN_TIMEOUT); + jsonRes.set("access_token", token); jsonRes.set("token_type", "bearer"); jsonRes.set("expires_in", timeoutMillis/1000); + jsonRes.set("refresh_token", refreshToken); //jsonRes.set("scope", ?); if (request.containsKey("state")) jsonRes.set("state", request.get("state")); + registry.revokeAuthorizationCode(authorizationCode); + return jsonRes; } - private String generateToken() { - return String.valueOf(Math.abs(random.nextLong())); - } + // ------------------------------------------------------ // Error handling @@ -178,7 +197,7 @@ public class OAuth2TokenPage extends HttpJsonPage { * @return A DataNode containing the error response */ private static DataNode errorResponse(HttpPrintStream out, String error, String state, String description) { - logger.warning("OAuth2 Token Error(" + error + "): " + description); + logger.warning("OAuth2 Token Error(" + error + ") for client: " + description); out.setResponseStatusCode(400); diff --git a/test/zutil/net/http/page/oauth/OAuth2AuthorizationPageTest.java b/test/zutil/net/http/page/oauth/OAuth2AuthorizationPageTest.java index e73417b..a7d6bf3 100644 --- a/test/zutil/net/http/page/oauth/OAuth2AuthorizationPageTest.java +++ b/test/zutil/net/http/page/oauth/OAuth2AuthorizationPageTest.java @@ -153,7 +153,7 @@ public class OAuth2AuthorizationPageTest { assertNotNull(url.getParameter("code")); assertNull(url.getParameter("state")); - assertTrue(registry.isAuthorizationCodeValid("12345", url.getParameter("code"))); + assertTrue(registry.isAuthorizationCodeValid(url.getParameter("code"))); } @Test @@ -169,6 +169,6 @@ public class OAuth2AuthorizationPageTest { HttpURL url = new HttpURL(rspHeader.getHeader("Location")); assertEquals("app_state", url.getParameter("state")); - assertTrue(registry.isAuthorizationCodeValid("12345", url.getParameter("code"))); + assertTrue(registry.isAuthorizationCodeValid(url.getParameter("code"))); } } \ No newline at end of file diff --git a/test/zutil/net/http/page/oauth/OAuth2TokenPageTest.java b/test/zutil/net/http/page/oauth/OAuth2TokenPageTest.java index c60792f..3b95475 100644 --- a/test/zutil/net/http/page/oauth/OAuth2TokenPageTest.java +++ b/test/zutil/net/http/page/oauth/OAuth2TokenPageTest.java @@ -157,8 +157,7 @@ public class OAuth2TokenPageTest { } - @Test - public void requestBasic() throws IOException { + private HttpHeader doBasicRequest() throws IOException { HttpHeader reqHeader = new HttpHeader(); reqHeader.setURLAttribute("client_id", VALID_CLIENT_ID); reqHeader.setURLAttribute("redirect_uri", VALID_REDIRECT_URI); @@ -166,6 +165,13 @@ public class OAuth2TokenPageTest { reqHeader.setURLAttribute("code", VALID_AUTH_CODE); HttpHeader rspHeader = HttpTestUtil.makeRequest(tokenPage, reqHeader); + return rspHeader; + } + + @Test + public void requestBasic() throws IOException { + HttpHeader rspHeader = doBasicRequest(); + assertEquals(200, rspHeader.getResponseStatusCode()); assertEquals("application/json", rspHeader.getHeader("Content-Type")); DataNode json = JSONParser.read(IOUtil.readContentAsString(rspHeader.getInputStream())); @@ -174,6 +180,47 @@ public class OAuth2TokenPageTest { assertNotNull(json.getString("expires_in")); assertEquals("bearer", json.getString("token_type")); - assertTrue(registry.isAccessTokenValid(VALID_CLIENT_ID, json.getString("access_token"))); + assertTrue(registry.isAccessTokenValid(json.getString("access_token"))); + } + + @Test + public void revocationCode() throws IOException { + requestBasic(); + + HttpHeader reqHeader = new HttpHeader(); + reqHeader.setURLAttribute("client_id", VALID_CLIENT_ID); + reqHeader.setURLAttribute("redirect_uri", VALID_REDIRECT_URI); + reqHeader.setURLAttribute("grant_type", VALID_GRANT_TYPE); + reqHeader.setURLAttribute("code", VALID_AUTH_CODE); + HttpHeader rspHeader = HttpTestUtil.makeRequest(tokenPage, reqHeader); + + assertEquals(400, rspHeader.getResponseStatusCode()); + DataNode json = JSONParser.read(IOUtil.readContentAsString(rspHeader.getInputStream())); + assertEquals("invalid_grant", json.getString("error")); + } + + @Test + public void requestRefreshToken() throws IOException { + HttpHeader rspHeader = doBasicRequest(); + DataNode json = JSONParser.read(IOUtil.readContentAsString(rspHeader.getInputStream())); + String refreshToken = json.getString("refresh_token"); + + assertTrue(registry.isAuthorizationCodeValid(refreshToken)); + + HttpHeader reqHeader = new HttpHeader(); + reqHeader.setURLAttribute("grant_type", "refresh_token"); + reqHeader.setURLAttribute("refresh_token", refreshToken); + rspHeader = HttpTestUtil.makeRequest(tokenPage, reqHeader); + + assertEquals(200, rspHeader.getResponseStatusCode()); + json = JSONParser.read(IOUtil.readContentAsString(rspHeader.getInputStream())); + assertNotNull(json.getString("refresh_token")); + assertNotNull(json.getString("access_token")); + assertNotNull(json.getString("expires_in")); + assertEquals("bearer", json.getString("token_type")); + + assertTrue(registry.isAccessTokenValid(json.getString("access_token"))); + assertTrue(registry.isAuthorizationCodeValid(json.getString("refresh_token"))); + assertFalse(registry.isAuthorizationCodeValid(refreshToken)); } } \ No newline at end of file