Skip to content

Commit 2bd0512

Browse files
jzheauxchao.wang
and
chao.wang
committed
Add JdbcAssertingPartyMetadataRepository#save
Issue gh-16012 Co-Authored-By: chao.wang <[email protected]>
1 parent e2e42a5 commit 2bd0512

File tree

2 files changed

+117
-81
lines changed

2 files changed

+117
-81
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@
1919
import java.sql.ResultSet;
2020
import java.sql.SQLException;
2121
import java.sql.Types;
22+
import java.util.ArrayList;
2223
import java.util.Collection;
2324
import java.util.Iterator;
2425
import java.util.List;
26+
import java.util.function.Function;
2527

2628
import org.apache.commons.logging.Log;
2729
import org.apache.commons.logging.LogFactory;
2830

2931
import org.springframework.core.log.LogMessage;
3032
import org.springframework.core.serializer.DefaultDeserializer;
33+
import org.springframework.core.serializer.DefaultSerializer;
3134
import org.springframework.core.serializer.Deserializer;
35+
import org.springframework.core.serializer.Serializer;
3236
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
3337
import org.springframework.jdbc.core.JdbcOperations;
3438
import org.springframework.jdbc.core.PreparedStatementSetter;
@@ -37,6 +41,7 @@
3741
import org.springframework.security.saml2.core.Saml2X509Credential;
3842
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails;
3943
import org.springframework.util.Assert;
44+
import org.springframework.util.function.ThrowingFunction;
4045

4146
/**
4247
* A JDBC implementation of {@link AssertingPartyMetadataRepository}.
@@ -51,6 +56,8 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
5156
private RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper(
5257
ResultSet::getBytes);
5358

59+
private final AssertingPartyMetadataParametersMapper assertingPartyMetadataParametersMapper = new AssertingPartyMetadataParametersMapper();
60+
5461
// @formatter:off
5562
static final String COLUMN_NAMES = "entity_id, "
5663
+ "singlesignon_url, "
@@ -77,6 +84,25 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
7784
+ " FROM " + TABLE_NAME;
7885
// @formatter:on
7986

87+
// @formatter:off
88+
private static final String SAVE_CREDENTIAL_RECORD_SQL = "INSERT INTO " + TABLE_NAME
89+
+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
90+
// @formatter:on
91+
92+
// @formatter:off
93+
private static final String UPDATE_CREDENTIAL_RECORD_SQL = "UPDATE " + TABLE_NAME
94+
+ " SET singlesignon_url = ?, "
95+
+ "singlesignon_binding = ?, "
96+
+ "singlesignon_sign_request = ?, "
97+
+ "signing_algorithms = ?, "
98+
+ "verification_credentials = ?, "
99+
+ "encryption_credentials = ?, "
100+
+ "singlelogout_url = ?, "
101+
+ "singlelogout_response_url = ?, "
102+
+ "singlelogout_binding = ?"
103+
+ " WHERE " + ENTITY_ID_FILTER;
104+
// @formatter:on
105+
80106
/**
81107
* Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided
82108
* parameters.
@@ -116,6 +142,30 @@ public Iterator<AssertingPartyMetadata> iterator() {
116142
return result.iterator();
117143
}
118144

145+
/**
146+
* Persist this {@link AssertingPartyMetadata}
147+
* @param metadata the metadata to persist
148+
*/
149+
public void save(AssertingPartyMetadata metadata) {
150+
Assert.notNull(metadata, "metadata cannot be null");
151+
int rows = updateCredentialRecord(metadata);
152+
if (rows == 0) {
153+
insertCredentialRecord(metadata);
154+
}
155+
}
156+
157+
private void insertCredentialRecord(AssertingPartyMetadata metadata) {
158+
List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
159+
this.jdbcOperations.update(SAVE_CREDENTIAL_RECORD_SQL, parameters.toArray());
160+
}
161+
162+
private int updateCredentialRecord(AssertingPartyMetadata metadata) {
163+
List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
164+
SqlParameterValue credentialId = parameters.remove(0);
165+
parameters.add(credentialId);
166+
return this.jdbcOperations.update(UPDATE_CREDENTIAL_RECORD_SQL, parameters.toArray());
167+
}
168+
119169
/**
120170
* The default {@link RowMapper} that maps the current row in
121171
* {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}.
@@ -181,6 +231,34 @@ public AssertingPartyMetadata mapRow(ResultSet rs, int rowNum) throws SQLExcepti
181231

182232
}
183233

234+
private static class AssertingPartyMetadataParametersMapper
235+
implements Function<AssertingPartyMetadata, List<SqlParameterValue>> {
236+
237+
private final Serializer<Object> serializer = new DefaultSerializer();
238+
239+
@Override
240+
public List<SqlParameterValue> apply(AssertingPartyMetadata record) {
241+
List<SqlParameterValue> parameters = new ArrayList<>();
242+
243+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getEntityId()));
244+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceLocation()));
245+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceBinding().getUrn()));
246+
parameters.add(new SqlParameterValue(Types.BOOLEAN, record.getWantAuthnRequestsSigned()));
247+
ThrowingFunction<List<String>, byte[]> algorithms = this.serializer::serializeToByteArray;
248+
parameters.add(new SqlParameterValue(Types.BLOB, algorithms.apply(record.getSigningAlgorithms())));
249+
ThrowingFunction<Collection<Saml2X509Credential>, byte[]> credentials = this.serializer::serializeToByteArray;
250+
parameters
251+
.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getVerificationX509Credentials())));
252+
parameters.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getEncryptionX509Credentials())));
253+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceLocation()));
254+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceResponseLocation()));
255+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceBinding().getUrn()));
256+
257+
return parameters;
258+
}
259+
260+
}
261+
184262
private interface GetBytes {
185263

186264
byte[] getBytes(ResultSet rs, String columnName) throws SQLException;

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java

Lines changed: 39 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,17 @@
1616

1717
package org.springframework.security.saml2.provider.service.registration;
1818

19-
import java.io.IOException;
20-
import java.io.InputStream;
21-
import java.security.cert.CertificateFactory;
22-
import java.security.cert.X509Certificate;
23-
import java.util.Collection;
2419
import java.util.Iterator;
25-
import java.util.List;
2620

2721
import org.junit.jupiter.api.AfterEach;
2822
import org.junit.jupiter.api.BeforeEach;
2923
import org.junit.jupiter.api.Test;
3024

31-
import org.springframework.core.io.ClassPathResource;
32-
import org.springframework.core.serializer.DefaultSerializer;
33-
import org.springframework.core.serializer.Serializer;
3425
import org.springframework.jdbc.core.JdbcOperations;
3526
import org.springframework.jdbc.core.JdbcTemplate;
3627
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
3728
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
3829
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
39-
import org.springframework.security.saml2.core.Saml2X509Credential;
4030

4131
import static org.assertj.core.api.Assertions.assertThat;
4232
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -48,41 +38,21 @@ class JdbcAssertingPartyMetadataRepositoryTests {
4838

4939
private static final String SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql";
5040

51-
private static final String SAVE_SQL = "INSERT INTO saml2_asserting_party_metadata ("
52-
+ JdbcAssertingPartyMetadataRepository.COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
53-
54-
private static final String ENTITY_ID = "https://localhost/simplesaml/saml2/idp/metadata.php";
55-
56-
private static final String SINGLE_SIGNON_URL = "https://localhost/SSO";
57-
58-
private static final String SINGLE_SIGNON_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
59-
60-
private static final boolean SINGLE_SIGNON_SIGN_REQUEST = false;
61-
62-
private static final String SINGLE_LOGOUT_URL = "https://localhost/SLO";
63-
64-
private static final String SINGLE_LOGOUT_RESPONSE_URL = "https://localhost/SLO/response";
65-
66-
private static final String SINGLE_LOGOUT_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
67-
68-
private static final List<String> SIGNING_ALGORITHMS = List.of("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512");
69-
70-
private X509Certificate certificate;
71-
7241
private EmbeddedDatabase db;
7342

7443
private JdbcAssertingPartyMetadataRepository repository;
7544

7645
private JdbcOperations jdbcOperations;
7746

78-
private final Serializer<Object> serializer = new DefaultSerializer();
47+
private final AssertingPartyMetadata metadata = TestRelyingPartyRegistrations.full()
48+
.build()
49+
.getAssertingPartyMetadata();
7950

8051
@BeforeEach
8152
void setUp() {
8253
this.db = createDb();
8354
this.jdbcOperations = new JdbcTemplate(this.db);
8455
this.repository = new JdbcAssertingPartyMetadataRepository(this.jdbcOperations);
85-
this.certificate = loadCertificate("rsa.crt");
8656
}
8757

8858
@AfterEach
@@ -109,26 +79,12 @@ void findByEntityIdWhenEntityIdIsNullThenThrowIllegalArgumentException() {
10979
}
11080

11181
@Test
112-
void findByEntityId() throws IOException {
113-
this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
114-
SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
115-
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
116-
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
117-
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
82+
void findByEntityId() {
83+
this.repository.save(this.metadata);
11884

119-
AssertingPartyMetadata found = this.repository.findByEntityId(ENTITY_ID);
85+
AssertingPartyMetadata found = this.repository.findByEntityId(this.metadata.getEntityId());
12086

121-
assertThat(found).isNotNull();
122-
assertThat(found.getEntityId()).isEqualTo(ENTITY_ID);
123-
assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(SINGLE_SIGNON_URL);
124-
assertThat(found.getSingleSignOnServiceBinding().getUrn()).isEqualTo(SINGLE_SIGNON_BINDING);
125-
assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(SINGLE_SIGNON_SIGN_REQUEST);
126-
assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL);
127-
assertThat(found.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL);
128-
assertThat(found.getSingleLogoutServiceBinding().getUrn()).isEqualTo(SINGLE_LOGOUT_BINDING);
129-
assertThat(found.getSigningAlgorithms()).contains(SIGNING_ALGORITHMS.get(0));
130-
assertThat(found.getVerificationX509Credentials()).hasSize(1);
131-
assertThat(found.getEncryptionX509Credentials()).hasSize(1);
87+
assertAssertingPartyEquals(found, this.metadata);
13288
}
13389

13490
@Test
@@ -138,28 +94,30 @@ void findByEntityIdWhenNotExists() {
13894
}
13995

14096
@Test
141-
void iterator() throws IOException {
142-
this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
143-
SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
144-
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
145-
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
146-
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
147-
148-
this.jdbcOperations.update(SAVE_SQL, "https://localhost/simplesaml2/saml2/idp/metadata.php", SINGLE_SIGNON_URL,
149-
SINGLE_SIGNON_BINDING, SINGLE_SIGNON_SIGN_REQUEST,
150-
this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
151-
this.serializer.serializeToByteArray(asCredentials(this.certificate)),
152-
this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
153-
SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
97+
void iterator() {
98+
AssertingPartyMetadata second = RelyingPartyRegistration.withAssertingPartyMetadata(this.metadata)
99+
.assertingPartyMetadata((a) -> a.entityId("https://example.org/idp"))
100+
.build()
101+
.getAssertingPartyMetadata();
102+
this.repository.save(this.metadata);
103+
this.repository.save(second);
154104

155105
Iterator<AssertingPartyMetadata> iterator = this.repository.iterator();
156-
AssertingPartyMetadata first = iterator.next();
157-
assertThat(first).isNotNull();
158-
AssertingPartyMetadata second = iterator.next();
159-
assertThat(second).isNotNull();
106+
107+
assertAssertingPartyEquals(iterator.next(), this.metadata);
108+
assertAssertingPartyEquals(iterator.next(), second);
160109
assertThat(iterator.hasNext()).isFalse();
161110
}
162111

112+
@Test
113+
void saveWhenExistingThenUpdates() {
114+
this.repository.save(this.metadata);
115+
boolean existing = this.metadata.getWantAuthnRequestsSigned();
116+
this.repository.save(this.metadata.mutate().wantAuthnRequestsSigned(!existing).build());
117+
boolean updated = this.repository.findByEntityId(this.metadata.getEntityId()).getWantAuthnRequestsSigned();
118+
assertThat(existing).isNotEqualTo(updated);
119+
}
120+
163121
private static EmbeddedDatabase createDb() {
164122
return createDb(SCHEMA_SQL_RESOURCE);
165123
}
@@ -175,19 +133,19 @@ private static EmbeddedDatabase createDb(String schema) {
175133
// @formatter:on
176134
}
177135

178-
private X509Certificate loadCertificate(String path) {
179-
try (InputStream is = new ClassPathResource(path).getInputStream()) {
180-
CertificateFactory factory = CertificateFactory.getInstance("X.509");
181-
return (X509Certificate) factory.generateCertificate(is);
182-
}
183-
catch (Exception ex) {
184-
throw new RuntimeException("Error loading certificate from " + path, ex);
185-
}
186-
}
187-
188-
private Collection<Saml2X509Credential> asCredentials(X509Certificate certificate) {
189-
return List.of(new Saml2X509Credential(certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION,
190-
Saml2X509Credential.Saml2X509CredentialType.VERIFICATION));
136+
private void assertAssertingPartyEquals(AssertingPartyMetadata found, AssertingPartyMetadata expected) {
137+
assertThat(found).isNotNull();
138+
assertThat(found.getEntityId()).isEqualTo(expected.getEntityId());
139+
assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(expected.getSingleSignOnServiceLocation());
140+
assertThat(found.getSingleSignOnServiceBinding()).isEqualTo(expected.getSingleSignOnServiceBinding());
141+
assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(expected.getWantAuthnRequestsSigned());
142+
assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(expected.getSingleLogoutServiceLocation());
143+
assertThat(found.getSingleLogoutServiceResponseLocation())
144+
.isEqualTo(expected.getSingleLogoutServiceResponseLocation());
145+
assertThat(found.getSingleLogoutServiceBinding()).isEqualTo(expected.getSingleLogoutServiceBinding());
146+
assertThat(found.getSigningAlgorithms()).containsAll(expected.getSigningAlgorithms());
147+
assertThat(found.getVerificationX509Credentials()).containsAll(expected.getVerificationX509Credentials());
148+
assertThat(found.getEncryptionX509Credentials()).containsAll(expected.getEncryptionX509Credentials());
191149
}
192150

193151
}

0 commit comments

Comments
 (0)