Implemented OAuth2 refresh_token process

This commit is contained in:
Ziver Koc 2020-11-25 00:51:26 +01:00
parent 7519763b62
commit 2786e93df2
5 changed files with 208 additions and 86 deletions

View file

@ -24,6 +24,7 @@
package zutil.net.http.page.oauth; package zutil.net.http.page.oauth;
import zutil.Hasher;
import zutil.log.LogUtil; import zutil.log.LogUtil;
import zutil.net.http.HttpHeader; import zutil.net.http.HttpHeader;
import zutil.net.http.HttpPage; import zutil.net.http.HttpPage;
@ -145,7 +146,7 @@ public class OAuth2AuthorizationPage implements HttpPage {
switch (request.get("response_type")) { switch (request.get("response_type")) {
case RESPONSE_TYPE_CODE: case RESPONSE_TYPE_CODE:
String code = generateCode(); String code = registry.generateCode();
registry.registerAuthorizationCode(clientId, code); registry.registerAuthorizationCode(clientId, code);
url.setParameter("code", code); url.setParameter("code", code);
@ -165,10 +166,6 @@ public class OAuth2AuthorizationPage implements HttpPage {
redirect(out, url); redirect(out, url);
} }
private String generateCode() {
return String.valueOf(Math.abs(random.nextLong()));
}
// ------------------------------------------------------ // ------------------------------------------------------
// Error handling // Error handling
// ------------------------------------------------------ // ------------------------------------------------------

View file

@ -24,21 +24,27 @@
package zutil.net.http.page.oauth; package zutil.net.http.page.oauth;
import zutil.Hasher;
import zutil.Timer; import zutil.Timer;
import java.io.Serializable;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Random;
/** /**
* A data class containing authentication information for individual * A data class containing authentication information for individual
* clients going through the OAuth 2 process. * clients going through the OAuth 2 process.
*/ */
public class OAuth2Registry { public class OAuth2Registry implements Serializable {
private static final long DEFAULT_TIMEOUT = 24 * 60 * 60 * 1000; // 24h private static final long DEFAULT_CODE_TIMEOUT = 10 * 60 * 1000; // 10min
private static final long DEFAULT_TOKEN_TIMEOUT = 24 * 60 * 60 * 1000; // 24h
private Map<String, ClientRegister> clientRegistry = new HashMap<>(); private Map<String, ClientRegister> clientRegisters = new HashMap<>();
private boolean requireWhitelist = true; private boolean requireWhitelist = true;
transient private Random random = new Random();
// ------------------------------------------------------ // ------------------------------------------------------
// Whitelist methods // Whitelist methods
@ -62,8 +68,8 @@ public class OAuth2Registry {
* @param clientId A String ID that should be whitelisted * @param clientId A String ID that should be whitelisted
*/ */
public void addWhitelist(String clientId) { public void addWhitelist(String clientId) {
if (!clientRegistry.containsKey(clientId)) { if (!clientRegisters.containsKey(clientId)) {
clientRegistry.put(clientId, new ClientRegister()); clientRegisters.put(clientId, new ClientRegister());
} }
} }
@ -83,21 +89,17 @@ public class OAuth2Registry {
if (!requireWhitelist) if (!requireWhitelist)
return true; return true;
return clientRegistry.containsKey(clientId); return clientRegisters.containsKey(clientId);
} }
/** /**
* Validates that a authorization code has valid format and has been authorized and not elapsed. * Validates that a authorization code has valid format and has been authorized and not elapsed.
* *
* @param clientId the id of the requesting client
* @param code the code that should be validated * @param code the code that should be validated
* @return true if the given code is valid otherwise false. * @return true if the given code is valid otherwise false.
*/ */
public boolean isAuthorizationCodeValid(String clientId, String code) { public boolean isAuthorizationCodeValid(String code) {
if (clientId == null || code == null) ClientRegister reg = getClientRegisterForAuthCode(code);
return false;
ClientRegister reg = getClientRegistry(clientId);
if (reg != null) { if (reg != null) {
return reg.authCodes.containsKey(code) && return reg.authCodes.containsKey(code) &&
@ -109,34 +111,40 @@ public class OAuth2Registry {
/** /**
* Validates that a access token has valid format and has been authorized and not elapsed. * Validates that a access token has valid format and has been authorized and not elapsed.
* *
* @param clientId the id of the requesting client
* @param token the token that should be validated * @param token the token that should be validated
* @return true if the given token is valid otherwise false. * @return true if the given token is valid otherwise false.
*/ */
public boolean isAccessTokenValid(String clientId, String token) { public boolean isAccessTokenValid(String token) {
if (clientId == null || token == null) ClientRegister reg = getClientRegisterForToken(token);
return false;
ClientRegister reg = getClientRegistry(clientId);
if (reg != null) { if (reg != null) {
boolean b1 = reg.accessTokens.containsKey(token);
boolean b2 = reg.accessTokens.get(token).hasTimedOut();
return reg.accessTokens.containsKey(token) && return reg.accessTokens.containsKey(token) &&
!reg.accessTokens.get(token).hasTimedOut(); !reg.accessTokens.get(token).hasTimedOut();
} }
return false; return false;
} }
// ------------------------------------------------------
// Revocation
// ------------------------------------------------------
public void revokeAuthorizationCode(String code) {
ClientRegister reg = getClientRegisterForAuthCode(code);
if (reg != null) {
reg.authCodes.remove(code);
}
}
// ------------------------------------------------------ // ------------------------------------------------------
// OAuth2 process methods // OAuth2 process methods
// ------------------------------------------------------ // ------------------------------------------------------
protected long registerAuthorizationCode(String clientId, String code) { protected long registerAuthorizationCode(String clientId, String code) {
return registerAuthorizationCode(clientId, code, DEFAULT_TIMEOUT); return registerAuthorizationCode(clientId, code, DEFAULT_CODE_TIMEOUT);
} }
protected long registerAuthorizationCode(String clientId, String code, long timeoutMillis) { protected long registerAuthorizationCode(String clientId, String code, long timeoutMillis) {
ClientRegister reg = getClientRegistry(clientId); ClientRegister reg = getClientRegister(clientId);
if (reg != null) { if (reg != null) {
reg.authCodes.put(code, new Timer(timeoutMillis).start()); reg.authCodes.put(code, new Timer(timeoutMillis).start());
@ -146,10 +154,10 @@ public class OAuth2Registry {
} }
protected long registerAccessToken(String clientId, String token) { protected long registerAccessToken(String clientId, String token) {
return registerAccessToken(clientId, token, DEFAULT_TIMEOUT); return registerAccessToken(clientId, token, DEFAULT_TOKEN_TIMEOUT);
} }
protected long registerAccessToken(String clientId, String token, long timeoutMillis) { protected long registerAccessToken(String clientId, String token, long timeoutMillis) {
ClientRegister reg = getClientRegistry(clientId); ClientRegister reg = getClientRegister(clientId);
if (reg != null) { if (reg != null) {
reg.accessTokens.put(token, new Timer(timeoutMillis).start()); reg.accessTokens.put(token, new Timer(timeoutMillis).start());
@ -158,16 +166,67 @@ public class OAuth2Registry {
return -1; return -1;
} }
// -------------------------------------------------------------------- protected String generateCode() {
return generateToken();
private ClientRegister getClientRegistry(String clientId) {
if (!requireWhitelist && !clientRegistry.containsKey(clientId))
clientRegistry.put(clientId, new ClientRegister());
return clientRegistry.get(clientId);
} }
private static class ClientRegister { protected String generateToken() {
return Hasher.SHA1(Math.abs(random.nextLong()));
}
// ------------------------------------------------------
// Data methods
// ------------------------------------------------------
/**
* @param code is the authentication code given to the client.
* @return The client_id registered for the given code
*/
public String getClientIdForAuthenticationCode(String code) {
for (String clientId : clientRegisters.keySet()) {
if (clientRegisters.get(clientId).authCodes.containsKey(code))
return clientId;
}
return null;
}
/**
* @param token is the access token given to the client.
* @return The client_id registered for the given token
*/
public String getClientIdForAccessToken(String token) {
for (String clientId : clientRegisters.keySet()) {
if (clientRegisters.get(clientId).accessTokens.containsKey(token))
return clientId;
}
return null;
}
// ------------------------------------------------------
private ClientRegister getClientRegister(String clientId) {
if (!requireWhitelist && !clientRegisters.containsKey(clientId))
clientRegisters.put(clientId, new ClientRegister());
return clientRegisters.get(clientId);
}
private ClientRegister getClientRegisterForAuthCode(String code) {
String clientId = getClientIdForAuthenticationCode(code);
return (clientId == null ? null : clientRegisters.get(clientId));
}
private ClientRegister getClientRegisterForToken(String token) {
String clientId = getClientIdForAccessToken(token);
return (clientId == null ? null : clientRegisters.get(clientId));
}
private static class ClientRegister implements Serializable {
Map<String, Timer> authCodes = new HashMap<>(); Map<String, Timer> authCodes = new HashMap<>();
Map<String, Timer> accessTokens = new HashMap<>(); Map<String, Timer> accessTokens = new HashMap<>();
} }

View file

@ -27,6 +27,7 @@ import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.logging.Logger; import java.util.logging.Logger;
import zutil.Hasher;
import zutil.log.LogUtil; import zutil.log.LogUtil;
import zutil.net.http.HttpHeader; import zutil.net.http.HttpHeader;
import zutil.net.http.HttpPrintStream; import zutil.net.http.HttpPrintStream;
@ -54,6 +55,8 @@ import zutil.parser.DataNode;
public class OAuth2TokenPage extends HttpJsonPage { public class OAuth2TokenPage extends HttpJsonPage {
private static final Logger logger = LogUtil.getLogger(); private static final Logger logger = LogUtil.getLogger();
private static final long REFRESH_TOKEN_TIMEOUT = 60 * 24 * 60 * 60 * 1000L; // 60 days
/** The request is missing a required parameter, includes an unsupported parameter value (other than grant type), /** The request is missing a required parameter, includes an unsupported parameter value (other than grant type),
repeats a parameter, includes multiple credentials, utilizes more than one mechanism for authenticating the repeats a parameter, includes multiple credentials, utilizes more than one mechanism for authenticating the
client, or is otherwise malformed. **/ client, or is otherwise malformed. **/
@ -72,7 +75,6 @@ public class OAuth2TokenPage extends HttpJsonPage {
/** The requested scope is invalid, unknown, malformed, or exceeds the scope granted by the resource owner. **/ /** The requested scope is invalid, unknown, malformed, or exceeds the scope granted by the resource owner. **/
private static final String ERROR_INVALID_SCOPE = "invalid_scope"; private static final String ERROR_INVALID_SCOPE = "invalid_scope";
private Random random = new Random();
private OAuth2Registry registry; private OAuth2Registry registry;
@ -101,68 +103,85 @@ public class OAuth2TokenPage extends HttpJsonPage {
DataNode jsonRes = new DataNode(DataNode.DataType.Map); DataNode jsonRes = new DataNode(DataNode.DataType.Map);
// Validate client_id
if (!request.containsKey("client_id"))
return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: client_id");
String clientId = request.get("client_id");
if (!registry.isClientIdValid(clientId))
return errorResponse(out, ERROR_INVALID_CLIENT , request.get("state"), "Invalid client_id value.");
// Validate code
if (!request.containsKey("code"))
return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: code");
if (!registry.isAuthorizationCodeValid(clientId, request.get("code")))
return errorResponse(out, ERROR_INVALID_GRANT, request.get("state"), "Invalid authorization code value.");
// Validate redirect_uri
if (!request.containsKey("redirect_uri"))
return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: redirect_uri");
// TODO: ensure that the "redirect_uri" parameter is present if the
// "redirect_uri" parameter was included in the initial authorization
// request as described in Section 4.1.1, and if included ensure that
// their values are identical.
// Validate grant_type // Validate grant_type
if (!request.containsKey("grant_type")) String grantType = request.get("grant_type");
String codeKey;
String clientId = null;
if (grantType == null)
return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter grant_type."); return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter grant_type.");
switch (grantType) {
case "authorization_code":
codeKey = "code";
// Validate client_id
clientId = request.get("client_id");
if (clientId == null)
return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: client_id");
if (!registry.isClientIdValid(clientId))
return errorResponse(out, ERROR_INVALID_CLIENT , request.get("state"), "Invalid client_id value.");
// Validate redirect_uri
if (!request.containsKey("redirect_uri"))
return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: redirect_uri");
// TODO: ensure that the "redirect_uri" parameter is present if the
// "redirect_uri" parameter was included in the initial authorization
// request as described in Section 4.1.1, and if included ensure that
// their values are identical.
break;
case "refresh_token":
codeKey = "refresh_token";
break;
default:
return errorResponse(out, ERROR_UNSUPPORTED_GRANT_TYPE, request.get("state"), "Unsupported grant_type: " + request.containsKey("grant_type"));
}
// Validate code and refresh_token
String authorizationCode = request.get(codeKey);
if (authorizationCode == null)
return errorResponse(out, ERROR_INVALID_REQUEST , request.get("state"), "Missing mandatory parameter: " + codeKey);
if (!registry.isAuthorizationCodeValid(authorizationCode))
return errorResponse(out, ERROR_INVALID_GRANT, request.get("state"), "Invalid " + codeKey + " value.");
// ----------------------------------------------- // -----------------------------------------------
// Handle request // Handle request
// ----------------------------------------------- // -----------------------------------------------
String grantType = request.get("grant_type"); if (clientId == null)
clientId = registry.getClientIdForAuthenticationCode(authorizationCode);
switch (grantType) { String token = registry.generateToken();
case "authorization_code":
jsonRes.set("refresh_token", "TODO"); // TODO: implement refresh logic
break;
default:
return errorResponse(out, ERROR_UNSUPPORTED_GRANT_TYPE, request.get("state"), "Unsupported grant_type: " + request.containsKey("grant_type"));
}
String token = generateToken();
long timeoutMillis = registry.registerAccessToken(clientId, token); long timeoutMillis = registry.registerAccessToken(clientId, token);
String refreshToken = registry.generateToken();
registry.registerAuthorizationCode(clientId, refreshToken, REFRESH_TOKEN_TIMEOUT);
jsonRes.set("access_token", token); jsonRes.set("access_token", token);
jsonRes.set("token_type", "bearer"); jsonRes.set("token_type", "bearer");
jsonRes.set("expires_in", timeoutMillis/1000); jsonRes.set("expires_in", timeoutMillis/1000);
jsonRes.set("refresh_token", refreshToken);
//jsonRes.set("scope", ?); //jsonRes.set("scope", ?);
if (request.containsKey("state")) jsonRes.set("state", request.get("state")); if (request.containsKey("state")) jsonRes.set("state", request.get("state"));
registry.revokeAuthorizationCode(authorizationCode);
return jsonRes; return jsonRes;
} }
private String generateToken() {
return String.valueOf(Math.abs(random.nextLong()));
}
// ------------------------------------------------------ // ------------------------------------------------------
// Error handling // Error handling
@ -178,7 +197,7 @@ public class OAuth2TokenPage extends HttpJsonPage {
* @return A DataNode containing the error response * @return A DataNode containing the error response
*/ */
private static DataNode errorResponse(HttpPrintStream out, String error, String state, String description) { private static DataNode errorResponse(HttpPrintStream out, String error, String state, String description) {
logger.warning("OAuth2 Token Error(" + error + "): " + description); logger.warning("OAuth2 Token Error(" + error + ") for client: " + description);
out.setResponseStatusCode(400); out.setResponseStatusCode(400);

View file

@ -153,7 +153,7 @@ public class OAuth2AuthorizationPageTest {
assertNotNull(url.getParameter("code")); assertNotNull(url.getParameter("code"));
assertNull(url.getParameter("state")); assertNull(url.getParameter("state"));
assertTrue(registry.isAuthorizationCodeValid("12345", url.getParameter("code"))); assertTrue(registry.isAuthorizationCodeValid(url.getParameter("code")));
} }
@Test @Test
@ -169,6 +169,6 @@ public class OAuth2AuthorizationPageTest {
HttpURL url = new HttpURL(rspHeader.getHeader("Location")); HttpURL url = new HttpURL(rspHeader.getHeader("Location"));
assertEquals("app_state", url.getParameter("state")); assertEquals("app_state", url.getParameter("state"));
assertTrue(registry.isAuthorizationCodeValid("12345", url.getParameter("code"))); assertTrue(registry.isAuthorizationCodeValid(url.getParameter("code")));
} }
} }

View file

@ -157,8 +157,7 @@ public class OAuth2TokenPageTest {
} }
@Test private HttpHeader doBasicRequest() throws IOException {
public void requestBasic() throws IOException {
HttpHeader reqHeader = new HttpHeader(); HttpHeader reqHeader = new HttpHeader();
reqHeader.setURLAttribute("client_id", VALID_CLIENT_ID); reqHeader.setURLAttribute("client_id", VALID_CLIENT_ID);
reqHeader.setURLAttribute("redirect_uri", VALID_REDIRECT_URI); reqHeader.setURLAttribute("redirect_uri", VALID_REDIRECT_URI);
@ -166,6 +165,13 @@ public class OAuth2TokenPageTest {
reqHeader.setURLAttribute("code", VALID_AUTH_CODE); reqHeader.setURLAttribute("code", VALID_AUTH_CODE);
HttpHeader rspHeader = HttpTestUtil.makeRequest(tokenPage, reqHeader); HttpHeader rspHeader = HttpTestUtil.makeRequest(tokenPage, reqHeader);
return rspHeader;
}
@Test
public void requestBasic() throws IOException {
HttpHeader rspHeader = doBasicRequest();
assertEquals(200, rspHeader.getResponseStatusCode()); assertEquals(200, rspHeader.getResponseStatusCode());
assertEquals("application/json", rspHeader.getHeader("Content-Type")); assertEquals("application/json", rspHeader.getHeader("Content-Type"));
DataNode json = JSONParser.read(IOUtil.readContentAsString(rspHeader.getInputStream())); DataNode json = JSONParser.read(IOUtil.readContentAsString(rspHeader.getInputStream()));
@ -174,6 +180,47 @@ public class OAuth2TokenPageTest {
assertNotNull(json.getString("expires_in")); assertNotNull(json.getString("expires_in"));
assertEquals("bearer", json.getString("token_type")); assertEquals("bearer", json.getString("token_type"));
assertTrue(registry.isAccessTokenValid(VALID_CLIENT_ID, json.getString("access_token"))); assertTrue(registry.isAccessTokenValid(json.getString("access_token")));
}
@Test
public void revocationCode() throws IOException {
requestBasic();
HttpHeader reqHeader = new HttpHeader();
reqHeader.setURLAttribute("client_id", VALID_CLIENT_ID);
reqHeader.setURLAttribute("redirect_uri", VALID_REDIRECT_URI);
reqHeader.setURLAttribute("grant_type", VALID_GRANT_TYPE);
reqHeader.setURLAttribute("code", VALID_AUTH_CODE);
HttpHeader rspHeader = HttpTestUtil.makeRequest(tokenPage, reqHeader);
assertEquals(400, rspHeader.getResponseStatusCode());
DataNode json = JSONParser.read(IOUtil.readContentAsString(rspHeader.getInputStream()));
assertEquals("invalid_grant", json.getString("error"));
}
@Test
public void requestRefreshToken() throws IOException {
HttpHeader rspHeader = doBasicRequest();
DataNode json = JSONParser.read(IOUtil.readContentAsString(rspHeader.getInputStream()));
String refreshToken = json.getString("refresh_token");
assertTrue(registry.isAuthorizationCodeValid(refreshToken));
HttpHeader reqHeader = new HttpHeader();
reqHeader.setURLAttribute("grant_type", "refresh_token");
reqHeader.setURLAttribute("refresh_token", refreshToken);
rspHeader = HttpTestUtil.makeRequest(tokenPage, reqHeader);
assertEquals(200, rspHeader.getResponseStatusCode());
json = JSONParser.read(IOUtil.readContentAsString(rspHeader.getInputStream()));
assertNotNull(json.getString("refresh_token"));
assertNotNull(json.getString("access_token"));
assertNotNull(json.getString("expires_in"));
assertEquals("bearer", json.getString("token_type"));
assertTrue(registry.isAccessTokenValid(json.getString("access_token")));
assertTrue(registry.isAuthorizationCodeValid(json.getString("refresh_token")));
assertFalse(registry.isAuthorizationCodeValid(refreshToken));
} }
} }