Skip to content

JWK endpoint as filter #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ group = 'org.springframework.security.experimental'
description = 'Spring Authorization Server'
version = '0.0.1-SNAPSHOT'

ext['junit-jupiter.version'] = '5.4.0'

repositories {
mavenCentral()
}
Expand Down
12 changes: 12 additions & 0 deletions samples/boot/minimal/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## Minimal Authorization Server Sample

#### How to run

```
./gradlew spring-authorization-server-samples-boot-minimal:bootRun
```

```
curl http://localhost:8080/.well-known/jwk_uris
```

Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
apply plugin: 'io.spring.convention.spring-sample-boot'

dependencies {
implementation 'org.springframework.boot:spring-boot-starter'
implementation 'org.springframework.boot:spring-boot-starter-web'
implementation 'org.springframework.boot:spring-boot-starter-security'

implementation 'com.nimbusds:oauth2-oidc-sdk'

testImplementation('org.springframework.boot:spring-boot-starter-test') {
exclude group: 'org.junit.vintage', module: 'junit-vintage-engine'
}

testImplementation 'org.springframework.security:spring-security-test'

testRuntime("org.junit.platform:junit-platform-runner")
testRuntime("org.junit.jupiter:junit-jupiter-engine")
}

test {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sample;

import static org.springframework.http.HttpMethod.GET;
import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE;

import java.io.IOException;
import java.io.Writer;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.UrlPathHelper;

import com.nimbusds.jose.jwk.JWKSet;

public class JwkSetEndpointFilter extends OncePerRequestFilter {

static final String WELL_KNOWN_JWK_URIS = "/.well-known/jwk_uris";

private final RequestMatcher requestMatcher = new AntPathRequestMatcher(WELL_KNOWN_JWK_URIS, GET.name(), true,
new UrlPathHelper());

private final JWKSet jwkSet;

public JwkSetEndpointFilter(JWKSet jwkSet) {
Assert.notNull(jwkSet, "jwkSet cannot be null");
this.jwkSet = jwkSet;
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {

if (ifRequestMatches(request)) {
respond(response);
} else {
filterChain.doFilter(request, response);
}
}

private void respond(HttpServletResponse response) throws IOException {
response.setContentType(APPLICATION_JSON_VALUE);
try (Writer writer = response.getWriter()) {
writer.write(jwkSet.toPublicJWKSet().toJSONObject().toJSONString());
}
}

private boolean ifRequestMatches(HttpServletRequest request) {
return this.requestMatcher.matches(request);
}

}
43 changes: 43 additions & 0 deletions samples/boot/minimal/src/main/java/sample/SecurityConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sample;

import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.web.access.channel.ChannelProcessingFilter;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.gen.RSAKeyGenerator;

@EnableWebSecurity
public class SecurityConfig extends WebSecurityConfigurerAdapter {

@Override
protected void configure(HttpSecurity http) throws Exception {
http.addFilterBefore(new JwkSetEndpointFilter(generateJwkSet()), ChannelProcessingFilter.class);
}

protected JWKSet generateJwkSet() throws JOSEException {
JWK jwk = new RSAKeyGenerator(2048).keyID("minimal-ASA").keyUse(KeyUse.SIGNATURE).generate();
return new JWKSet(jwk);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sample;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.only;
import static org.mockito.Mockito.verify;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import static sample.JwkSetEndpointFilter.WELL_KNOWN_JWK_URIS;

import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import org.mockito.Mockito;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.gen.RSAKeyGenerator;

@TestInstance(Lifecycle.PER_CLASS)
public class JwkSetEndpointFilterTest {

private MockMvc mvc;
private JWKSet jwkSet;
private JWK jwk;
private JwkSetEndpointFilter filter;

@BeforeAll
void setup() throws JOSEException {
this.jwk = new RSAKeyGenerator(2048).keyID("endpoint-test").keyUse(KeyUse.SIGNATURE).generate();
this.jwkSet = new JWKSet(jwk);
this.filter = new JwkSetEndpointFilter(jwkSet);
this.mvc = MockMvcBuilders.standaloneSetup(new FakeController()).addFilters(filter).alwaysDo(print()).build();
}

@Test
void constructorWhenJsonWebKeySetIsNullThrowIllegalArgumentException() {
assertThatThrownBy(() -> new JwkSetEndpointFilter(null)).isInstanceOf(IllegalArgumentException.class);
}

@Test
void doFilterWhenPathMatches() throws Exception {
String requestUri = WELL_KNOWN_JWK_URIS;
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);

MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);

this.filter.doFilter(request, response, filterChain);

verify(filterChain, never()).doFilter(Mockito.any(HttpServletRequest.class),
Mockito.any(HttpServletResponse.class));
}

@Test
void doFilterWhenPathDoesNotMatch() throws Exception {
String requestUri = "/stuff/" + WELL_KNOWN_JWK_URIS;
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);

MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);

this.filter.doFilter(request, response, filterChain);

verify(filterChain, only()).doFilter(Mockito.any(HttpServletRequest.class),
Mockito.any(HttpServletResponse.class));
}

@Test
void testResponseIfRequestMatches() throws Exception {
mvc.perform(get(WELL_KNOWN_JWK_URIS)).andDo(print()).andExpect(status().isOk())
.andExpect(jsonPath("$.keys").isArray()).andExpect(jsonPath("$.keys").isNotEmpty())
.andExpect(jsonPath("$.keys[0].kid").value(jwk.getKeyID()))
.andExpect(jsonPath("$.keys[0].kty").value(jwk.getKeyType().toString()));
}

@Test
void testResponseIfNotRequestMatches() throws Exception {
mvc.perform(get("/fake")).andDo(print()).andExpect(status().isOk())
.andExpect(content().string(is("fake")));
}

@RestController
class FakeController {

@RequestMapping("/fake")
public String hello() {
return "fake";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,29 @@
*/
package sample;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.http.HttpStatus.OK;

import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.ApplicationContext;

import static org.assertj.core.api.Assertions.assertThat;
import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
import org.springframework.boot.web.server.LocalServerPort;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;

@SpringBootTest
@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT)
public class MinimalAuthorizationServerApplicationTests {

private RestTemplate rest = new RestTemplate();

@LocalServerPort
private int serverPort;

@Test
public void loadContext(ApplicationContext context) {
assertThat(context).isNotNull();
void verifyJwkSetEndpointFilterAccessibleWithoutAuthentication() {
ResponseEntity<String> responseEntity = rest.getForEntity(
"http://localhost:" + serverPort + JwkSetEndpointFilter.WELL_KNOWN_JWK_URIS, String.class);
assertThat(responseEntity.getStatusCode()).isEqualTo(OK);
}

}