Skip to content

Commit 4fe75bc

Browse files
committed
JDBC implementation of RegisteredClientRepository
1 parent 8e9563a commit 4fe75bc

File tree

3 files changed

+581
-0
lines changed

3 files changed

+581
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
/*
2+
* Copyright 2020-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.server.authorization.client;
17+
18+
import org.springframework.jdbc.core.*;
19+
import org.springframework.jdbc.support.lob.DefaultLobHandler;
20+
import org.springframework.jdbc.support.lob.LobCreator;
21+
import org.springframework.jdbc.support.lob.LobHandler;
22+
import org.springframework.security.oauth2.core.AuthorizationGrantType;
23+
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
24+
import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
25+
import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
26+
import org.springframework.util.Assert;
27+
28+
import java.nio.charset.StandardCharsets;
29+
import java.sql.*;
30+
import java.time.Duration;
31+
import java.time.Instant;
32+
import java.util.*;
33+
import java.util.function.Function;
34+
import java.util.stream.Collectors;
35+
36+
/**
37+
* JDBC-backed registered client repository
38+
*
39+
* @author Rafal Lewczuk
40+
* @since 0.1.2
41+
*/
42+
public class JdbcRegisteredClientRepository implements RegisteredClientRepository {
43+
44+
private static final Map<String, AuthorizationGrantType> AUTHORIZATION_GRANT_TYPE_MAP;
45+
private static final Map<String, ClientAuthenticationMethod> CLIENT_AUTHENTICATION_METHOD_MAP;
46+
47+
private static final String COLUMN_NAMES = "id, "
48+
+ "client_id, "
49+
+ "client_id_issued_at, "
50+
+ "client_secret, "
51+
+ "client_secret_expires_at, "
52+
+ "client_name, "
53+
+ "client_authentication_methods, "
54+
+ "authorization_grant_types, "
55+
+ "redirect_uris, "
56+
+ "scopes, "
57+
+ "require_proof_key, "
58+
+ "require_user_consent, "
59+
+ "access_token_ttl, "
60+
+ "reuse_refresh_tokens, "
61+
+ "refresh_token_ttl";
62+
63+
private static final String TABLE_NAME = "oauth2_registered_client";
64+
65+
private static final String LOAD_REGISTERED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME + " WHERE ";
66+
67+
private static final String INSERT_REGISTERED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME
68+
+ "(" + COLUMN_NAMES + ") values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
69+
70+
private RowMapper<RegisteredClient> registeredClientRowMapper;
71+
72+
private Function<RegisteredClient, List<SqlParameterValue>> registeredClientParametersMapper;
73+
74+
private final JdbcOperations jdbcOperations;
75+
76+
private final LobHandler lobHandler = new DefaultLobHandler();
77+
78+
public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations) {
79+
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
80+
this.jdbcOperations = jdbcOperations;
81+
this.registeredClientRowMapper = new DefaultRegisteredClientRowMapper();
82+
this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper();
83+
}
84+
85+
/**
86+
* Allows changing of {@link RegisteredClient} row mapper implementation
87+
*
88+
* @param registeredClientRowMapper mapper implementation
89+
*/
90+
public void setRegisteredClientRowMapper(RowMapper<RegisteredClient> registeredClientRowMapper) {
91+
Assert.notNull(registeredClientRowMapper, "registeredClientRowMapper cannot be null");
92+
this.registeredClientRowMapper = registeredClientRowMapper;
93+
}
94+
95+
/**
96+
* Allows changing of SQL parameter mapper for {@link RegisteredClient}
97+
*
98+
* @param registeredClientParametersMapper mapper implementation
99+
*/
100+
public void setRegisteredClientParametersMapper(Function<RegisteredClient, List<SqlParameterValue>> registeredClientParametersMapper) {
101+
Assert.notNull(registeredClientParametersMapper, "registeredClientParameterMapper cannot be null");
102+
this.registeredClientParametersMapper = registeredClientParametersMapper;
103+
}
104+
105+
@Override
106+
public void save(RegisteredClient registeredClient) {
107+
Assert.notNull(registeredClient, "registeredClient cannot be null");
108+
RegisteredClient foundClient = this.findBy("id = ? OR client_id = ? OR client_secret = ?",
109+
registeredClient.getId(), registeredClient.getClientId(),
110+
registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8));
111+
112+
if (null != foundClient) {
113+
Assert.isTrue(!foundClient.getId().equals(registeredClient.getId()),
114+
"Registered client must be unique. Found duplicate identifier: " + registeredClient.getId());
115+
Assert.isTrue(!foundClient.getClientId().equals(registeredClient.getClientId()),
116+
"Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId());
117+
Assert.isTrue(!foundClient.getClientSecret().equals(registeredClient.getClientSecret()),
118+
"Registered client must be unique. Found duplicate client secret for identifier: " + registeredClient.getId());
119+
}
120+
121+
List<SqlParameterValue> parameters = this.registeredClientParametersMapper.apply(registeredClient);
122+
123+
try (LobCreator lobCreator = this.lobHandler.getLobCreator()) {
124+
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator, parameters.toArray());
125+
jdbcOperations.update(INSERT_REGISTERED_CLIENT_SQL, pss);
126+
}
127+
}
128+
129+
@Override
130+
public RegisteredClient findById(String id) {
131+
Assert.hasText(id, "id cannot be empty");
132+
return findBy("id = ?", id);
133+
}
134+
135+
@Override
136+
public RegisteredClient findByClientId(String clientId) {
137+
Assert.hasText(clientId, "clientId cannot be empty");
138+
return findBy("client_id = ?", clientId);
139+
}
140+
141+
private RegisteredClient findBy(String condStr, Object...args) {
142+
List<RegisteredClient> lst = jdbcOperations.query(
143+
LOAD_REGISTERED_CLIENT_SQL + condStr,
144+
registeredClientRowMapper, args);
145+
return !lst.isEmpty() ? lst.get(0) : null;
146+
}
147+
148+
private static class DefaultRegisteredClientRowMapper implements RowMapper<RegisteredClient> {
149+
150+
private final LobHandler lobHandler = new DefaultLobHandler();
151+
152+
private Collection<String> parseList(String s) {
153+
return s != null ? Arrays.asList(s.split("\\|")) : Collections.emptyList();
154+
}
155+
156+
@Override
157+
public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException {
158+
Collection<String> scopes = parseList(rs.getString("scopes"));
159+
List<AuthorizationGrantType> authGrantTypes = parseList(rs.getString("authorization_grant_types"))
160+
.stream().map(AUTHORIZATION_GRANT_TYPE_MAP::get).collect(Collectors.toList());
161+
List<ClientAuthenticationMethod> clientAuthMethods = parseList(rs.getString("client_authentication_methods"))
162+
.stream().map(CLIENT_AUTHENTICATION_METHOD_MAP::get).collect(Collectors.toList());
163+
Collection<String> redirectUris = parseList(rs.getString("redirect_uris"));
164+
Timestamp clientIssuedAt = rs.getTimestamp("client_id_issued_at");
165+
Timestamp clientSecretExpiresAt = rs.getTimestamp("client_secret_expires_at");
166+
byte[] clientSecretBytes = this.lobHandler.getBlobAsBytes(rs, "client_secret");
167+
String clientSecret = clientSecretBytes != null ? new String(clientSecretBytes, StandardCharsets.UTF_8) : null;
168+
RegisteredClient.Builder builder = RegisteredClient
169+
.withId(rs.getString("id"))
170+
.clientId(rs.getString("client_id"))
171+
.clientIdIssuedAt(clientIssuedAt != null ? clientIssuedAt.toInstant() : null)
172+
.clientSecret(clientSecret)
173+
.clientSecretExpiresAt(clientSecretExpiresAt != null ? clientSecretExpiresAt.toInstant() : null)
174+
.clientName(rs.getString("client_name"))
175+
.clientAuthenticationMethods(coll -> coll.addAll(clientAuthMethods))
176+
.authorizationGrantTypes(coll -> coll.addAll(authGrantTypes))
177+
.redirectUris(coll -> coll.addAll(redirectUris))
178+
.scopes(coll -> coll.addAll(scopes));
179+
180+
RegisteredClient rc = builder.build();
181+
182+
TokenSettings ts = rc.getTokenSettings();
183+
ts.accessTokenTimeToLive(Duration.ofMillis(rs.getLong("access_token_ttl")));
184+
ts.refreshTokenTimeToLive(Duration.ofMillis(rs.getLong("refresh_token_ttl")));
185+
ts.reuseRefreshTokens(rs.getBoolean("reuse_refresh_tokens"));
186+
187+
ClientSettings cs = rc.getClientSettings();
188+
cs.requireProofKey(rs.getBoolean("require_proof_key"));
189+
cs.requireUserConsent(rs.getBoolean("require_user_consent"));
190+
191+
return rc;
192+
}
193+
}
194+
195+
private static class DefaultRegisteredClientParametersMapper implements Function<RegisteredClient, List<SqlParameterValue>> {
196+
@Override
197+
public List<SqlParameterValue> apply(RegisteredClient registeredClient) {
198+
199+
List<String> clientAuthenticationMethodNames = new ArrayList<>(registeredClient.getClientAuthenticationMethods().size());
200+
for (ClientAuthenticationMethod clientAuthenticationMethod : registeredClient.getClientAuthenticationMethods()) {
201+
clientAuthenticationMethodNames.add(clientAuthenticationMethod.getValue());
202+
}
203+
204+
List<String> authorizationGrantTypeNames = new ArrayList<>(registeredClient.getAuthorizationGrantTypes().size());
205+
for (AuthorizationGrantType authorizationGrantType : registeredClient.getAuthorizationGrantTypes()) {
206+
authorizationGrantTypeNames.add(authorizationGrantType.getValue());
207+
}
208+
209+
Instant issuedAt = registeredClient.getClientIdIssuedAt() != null ?
210+
registeredClient.getClientIdIssuedAt() : Instant.now();
211+
212+
Timestamp clientSecretExpiresAt = registeredClient.getClientSecretExpiresAt() != null ?
213+
Timestamp.from(registeredClient.getClientSecretExpiresAt()) : null;
214+
215+
return Arrays.asList(
216+
new SqlParameterValue(Types.VARCHAR, registeredClient.getId()),
217+
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientId()),
218+
new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(issuedAt)),
219+
new SqlParameterValue(Types.BLOB, registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8)),
220+
new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt),
221+
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()),
222+
new SqlParameterValue(Types.VARCHAR, String.join("|", clientAuthenticationMethodNames)),
223+
new SqlParameterValue(Types.VARCHAR, String.join("|", authorizationGrantTypeNames)),
224+
new SqlParameterValue(Types.VARCHAR, String.join("|", registeredClient.getRedirectUris())),
225+
new SqlParameterValue(Types.VARCHAR, String.join("|", registeredClient.getScopes())),
226+
new SqlParameterValue(Types.BOOLEAN, registeredClient.getClientSettings().requireProofKey()),
227+
new SqlParameterValue(Types.BOOLEAN, registeredClient.getClientSettings().requireUserConsent()),
228+
new SqlParameterValue(Types.NUMERIC, registeredClient.getTokenSettings().accessTokenTimeToLive().toMillis()),
229+
new SqlParameterValue(Types.BOOLEAN, registeredClient.getTokenSettings().reuseRefreshTokens()),
230+
new SqlParameterValue(Types.NUMERIC, registeredClient.getTokenSettings().refreshTokenTimeToLive().toMillis()));
231+
}
232+
}
233+
234+
private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
235+
236+
protected final LobCreator lobCreator;
237+
238+
private LobCreatorArgumentPreparedStatementSetter(LobCreator lobCreator, Object[] args) {
239+
super(args);
240+
this.lobCreator = lobCreator;
241+
}
242+
243+
@Override
244+
protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException {
245+
if (argValue instanceof SqlParameterValue) {
246+
SqlParameterValue paramValue = (SqlParameterValue) argValue;
247+
if (paramValue.getSqlType() == Types.BLOB) {
248+
if (paramValue.getValue() != null) {
249+
Assert.isInstanceOf(byte[].class, paramValue.getValue(),
250+
"Value of blob parameter must be byte[]");
251+
}
252+
byte[] valueBytes = (byte[]) paramValue.getValue();
253+
this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes);
254+
return;
255+
}
256+
}
257+
super.doSetValue(ps, parameterPosition, argValue);
258+
}
259+
260+
}
261+
262+
static {
263+
Map<String, AuthorizationGrantType> am = new HashMap<>();
264+
for (AuthorizationGrantType a : Arrays.asList(
265+
AuthorizationGrantType.AUTHORIZATION_CODE,
266+
AuthorizationGrantType.REFRESH_TOKEN,
267+
AuthorizationGrantType.CLIENT_CREDENTIALS,
268+
AuthorizationGrantType.PASSWORD,
269+
AuthorizationGrantType.IMPLICIT)) {
270+
am.put(a.getValue(), a);
271+
}
272+
AUTHORIZATION_GRANT_TYPE_MAP = Collections.unmodifiableMap(am);
273+
274+
Map<String, ClientAuthenticationMethod> cm = new HashMap<>();
275+
for (ClientAuthenticationMethod c : Arrays.asList(
276+
ClientAuthenticationMethod.NONE,
277+
ClientAuthenticationMethod.BASIC,
278+
ClientAuthenticationMethod.POST)) {
279+
cm.put(c.getValue(), c);
280+
}
281+
CLIENT_AUTHENTICATION_METHOD_MAP = Collections.unmodifiableMap(cm);
282+
}
283+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
CREATE TABLE oauth2_registered_client (
2+
id varchar(100) NOT NULL,
3+
client_id varchar(100) NOT NULL,
4+
client_id_issued_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
5+
client_secret blob NOT NULL,
6+
client_secret_expires_at timestamp DEFAULT NULL,
7+
client_name varchar(200),
8+
client_authentication_methods varchar(1000) NOT NULL,
9+
authorization_grant_types varchar(1000) NOT NULL,
10+
redirect_uris varchar(1000) NOT NULL,
11+
scopes varchar(1000) NOT NULL,
12+
require_proof_key boolean NOT NULL,
13+
require_user_consent boolean NOT NULL,
14+
access_token_ttl integer DEFAULT 300000 NOT NULL,
15+
reuse_refresh_tokens boolean DEFAULT true NOT NULL,
16+
refresh_token_ttl integer DEFAULT 600000 NOT NULL,
17+
PRIMARY KEY (id));

0 commit comments

Comments
 (0)