diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java index 41687817c47..b5568efe400 100644 --- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -16,8 +16,13 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.XdsNameResolver.CLUSTER_SELECTION_KEY; +import static io.grpc.xds.XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY; + import com.google.auth.oauth2.ComputeEngineCredentials; import com.google.auth.oauth2.IdTokenCredentials; +import com.google.common.annotations.VisibleForTesting; import com.google.common.primitives.UnsignedLongs; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; @@ -34,8 +39,11 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.auth.MoreCallCredentials; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; import io.grpc.xds.MetadataRegistry.MetadataValueParser; +import io.grpc.xds.XdsConfig.XdsClusterConfig; import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; import java.util.LinkedHashMap; import java.util.Map; @@ -52,6 +60,13 @@ final class GcpAuthenticationFilter implements Filter { static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig"; + final String filterInstanceName; + + GcpAuthenticationFilter(String name) { + filterInstanceName = checkNotNull(name, "name"); + } + + static final class Provider implements Filter.Provider { @Override public String[] typeUrls() { @@ -64,8 +79,8 @@ public boolean isClientFilter() { } @Override - public GcpAuthenticationFilter newInstance() { - return new GcpAuthenticationFilter(); + public GcpAuthenticationFilter newInstance(String name) { + return new GcpAuthenticationFilter(name); } @Override @@ -119,34 +134,57 @@ public ClientInterceptor buildClientInterceptor(FilterConfig config, public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { - /*String clusterName = callOptions.getOption(XdsAttributes.ATTR_CLUSTER_NAME); + String clusterName = callOptions.getOption(CLUSTER_SELECTION_KEY); if (clusterName == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s does not contain cluster resource", filterInstanceName))); + } + + if (!clusterName.startsWith("cluster:")) { return next.newCall(method, callOptions); - }*/ - - // TODO: Fetch the CDS resource for the cluster. - // If the CDS resource is not available, fail the RPC with Status.UNAVAILABLE. - - // TODO: Extract the audience from the CDS resource metadata. - // If the audience is not found or is in the wrong format, fail the RPC. - String audience = "TEST_AUDIENCE"; - - try { - CallCredentials existingCallCredentials = callOptions.getCredentials(); - CallCredentials newCallCredentials = - getCallCredentials(callCredentialsCache, audience, credentials); - if (existingCallCredentials != null) { - callOptions = callOptions.withCallCredentials( - new CompositeCallCredentials(existingCallCredentials, newCallCredentials)); - } else { - callOptions = callOptions.withCallCredentials(newCallCredentials); - } } - catch (Exception e) { - // If we fail to attach CallCredentials due to any reason, return a FailingClientCall - return new FailingClientCall<>(Status.UNAUTHENTICATED - .withDescription("Failed to attach CallCredentials.") - .withCause(e)); + XdsConfig xdsConfig = callOptions.getOption(XDS_CONFIG_CALL_OPTION_KEY); + if (xdsConfig == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s with %s does not contain xds configuration", + filterInstanceName, clusterName))); + } + StatusOr xdsCluster = + xdsConfig.getClusters().get(clusterName.substring("cluster:".length())); + if (xdsCluster == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s with %s - xds cluster config does not contain xds cluster", + filterInstanceName, clusterName))); + } + if (!xdsCluster.hasValue()) { + return new FailingClientCall<>(xdsCluster.getStatus()); + } + Object audienceObj = + xdsCluster.getValue().getClusterResource().parsedMetadata().get(filterInstanceName); + if (audienceObj == null) { + return next.newCall(method, callOptions); + } + if (!(audienceObj instanceof AudienceWrapper)) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format("GCP Authn found wrong type in %s metadata: %s=%s", + clusterName, filterInstanceName, audienceObj.getClass()))); + } + AudienceWrapper audience = (AudienceWrapper) audienceObj; + CallCredentials existingCallCredentials = callOptions.getCredentials(); + CallCredentials newCallCredentials = + getCallCredentials(callCredentialsCache, audience.audience, credentials); + if (existingCallCredentials != null) { + callOptions = callOptions.withCallCredentials( + new CompositeCallCredentials(existingCallCredentials, newCallCredentials)); + } else { + callOptions = callOptions.withCallCredentials(newCallCredentials); } return next.newCall(method, callOptions); } @@ -186,9 +224,11 @@ public String typeUrl() { } /** An implementation of {@link ClientCall} that fails when started. */ - private static final class FailingClientCall extends ClientCall { + @VisibleForTesting + static final class FailingClientCall extends ClientCall { - private final Status error; + @VisibleForTesting + final Status error; public FailingClientCall(Status error) { this.error = error; @@ -235,13 +275,21 @@ V getOrInsert(K key, Function create) { static class AudienceMetadataParser implements MetadataValueParser { + static final class AudienceWrapper { + final String audience; + + AudienceWrapper(String audience) { + this.audience = checkNotNull(audience); + } + } + @Override public String getTypeUrl() { return "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.Audience"; } @Override - public String parse(Any any) throws ResourceInvalidException { + public AudienceWrapper parse(Any any) throws ResourceInvalidException { Audience audience; try { audience = any.unpack(Audience.class); @@ -253,7 +301,7 @@ public String parse(Any any) throws ResourceInvalidException { throw new ResourceInvalidException( "Audience URL is empty. Metadata value must contain a valid URL."); } - return url; + return new AudienceWrapper(url); } } } diff --git a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java index 52efaf9bd7b..a5e142b4094 100644 --- a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java @@ -17,25 +17,60 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsNameResolver.CLUSTER_SELECTION_KEY; +import static io.grpc.xds.XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY; +import static io.grpc.xds.XdsTestUtils.CLUSTER_NAME; +import static io.grpc.xds.XdsTestUtils.EDS_NAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_HOSTNAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_PORT; +import static io.grpc.xds.XdsTestUtils.RDS_NAME; +import static io.grpc.xds.XdsTestUtils.buildRouteConfiguration; +import static io.grpc.xds.XdsTestUtils.getWrrLbConfigAsMap; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; import com.google.protobuf.Empty; import com.google.protobuf.Message; import com.google.protobuf.UInt64Value; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig; import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.TokenCacheConfig; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.Endpoints.LbEndpoint; +import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; +import io.grpc.xds.GcpAuthenticationFilter.FailingClientCall; import io.grpc.xds.GcpAuthenticationFilter.GcpAuthenticationConfig; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsConfig.XdsClusterConfig; +import io.grpc.xds.XdsConfig.XdsClusterConfig.EndpointConfig; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.XdsListenerResource.LdsUpdate; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsResourceType; +import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -46,6 +81,17 @@ public class GcpAuthenticationFilterTest { private static final GcpAuthenticationFilter.Provider FILTER_PROVIDER = new GcpAuthenticationFilter.Provider(); + private static final String serverName = InProcessServerBuilder.generateName(); + private static final LdsUpdate ldsUpdate = getLdsUpdate(); + private static final EdsUpdate edsUpdate = getEdsUpdate(); + private static final RdsUpdate rdsUpdate = getRdsUpdate(); + private static final CdsUpdate cdsUpdate = getCdsUpdate(); + + @Test + public void testNewFilterInstancesPerFilterName() { + assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1")) + .isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1")); + } @Test public void filterType_clientOnly() { @@ -92,35 +138,258 @@ public void testParseFilterConfig_withInvalidMessageType() { } @Test - public void testClientInterceptor_createsAndReusesCachedCredentials() { + public void testClientInterceptor_success() throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter(); - - // Create interceptor + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); - // Mock channel and capture CallOptions + verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions capturedOptions = callOptionsCaptor.getAllValues().get(0); + assertNotNull(capturedOptions.getCredentials()); + } + + @Test + public void testClientInterceptor_createsAndReusesCachedCredentials() + throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = Mockito.mock(Channel.class); ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); - // Execute interception twice to check caching - interceptor.interceptCall(methodDescriptor, CallOptions.DEFAULT, mockChannel); - interceptor.interceptCall(methodDescriptor, CallOptions.DEFAULT, mockChannel); + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); - // Capture and verify CallOptions for CallCredentials presence - Mockito.verify(mockChannel, Mockito.times(2)) + verify(mockChannel, Mockito.times(2)) .newCall(eq(methodDescriptor), callOptionsCaptor.capture()); - - // Retrieve the CallOptions captured from both calls CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0); CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1); - - // Ensure that CallCredentials was added assertNotNull(firstCapturedOptions.getCredentials()); assertNotNull(secondCapturedOptions.getCredentials()); - - // Ensure that the CallCredentials from both calls are the same, indicating caching assertSame(firstCapturedOptions.getCredentials(), secondCapturedOptions.getCredentials()); } + + @Test + public void testClientInterceptor_withoutClusterSelectionKey() throws Exception { + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + CallOptions callOptionsWithXds = CallOptions.DEFAULT; + + ClientCall call = interceptor.interceptCall( + methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain cluster resource"); + } + + @Test + public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exception { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + Channel mockChannel = mock(Channel.class); + + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + verify(mockChannel).newCall(methodDescriptor, callOptionsWithXds); + } + + @Test + public void testClientInterceptor_xdsConfigDoesNotExist() throws Exception { + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0"); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain xds configuration"); + } + + @Test + public void testClientInterceptor_incorrectClusterName() throws Exception { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster("custer0", StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain xds cluster"); + } + + @Test + public void testClientInterceptor_statusOrError() throws Exception { + StatusOr errorCluster = + StatusOr.fromStatus(Status.NOT_FOUND.withDescription("Cluster resource not found")); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, errorCluster).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("Cluster resource not found"); + } + + @Test + public void testClientInterceptor_notAudienceWrapper() + throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + getCdsUpdateWithIncorrectAudienceWrapper(), + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("GCP Authn found wrong type"); + } + + private static LdsUpdate getLdsUpdate() { + Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig( + serverName, RouterFilter.ROUTER_CONFIG); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forRdsName( + 0L, RDS_NAME, Collections.singletonList(routerFilterConfig)); + return XdsListenerResource.LdsUpdate.forApiListener(httpConnectionManager); + } + + private static RdsUpdate getRdsUpdate() { + RouteConfiguration routeConfiguration = + buildRouteConfiguration(serverName, RDS_NAME, CLUSTER_NAME); + XdsResourceType.Args args = new XdsResourceType.Args(null, "0", "0", null, null, null); + try { + return XdsRouteConfigureResource.getInstance().doParse(args, routeConfiguration); + } catch (ResourceInvalidException ex) { + return null; + } + } + + private static EdsUpdate getEdsUpdate() { + Map lbEndpointsMap = new HashMap<>(); + LbEndpoint lbEndpoint = LbEndpoint.create( + serverName, ENDPOINT_PORT, 0, true, ENDPOINT_HOSTNAME, ImmutableMap.of()); + lbEndpointsMap.put( + Locality.create("", "", ""), + LocalityLbEndpoints.create(ImmutableList.of(lbEndpoint), 10, 0, ImmutableMap.of())); + return new XdsEndpointResource.EdsUpdate(EDS_NAME, lbEndpointsMap, Collections.emptyList()); + } + + private static CdsUpdate getCdsUpdate() { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("TEST_AUDIENCE")); + try { + CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false) + .lbPolicyConfig(getWrrLbConfigAsMap()); + return cdsUpdate.parsedMetadata(parsedMetadata.build()).build(); + } catch (IOException ex) { + return null; + } + } + + private static CdsUpdate getCdsUpdateWithIncorrectAudienceWrapper() throws IOException { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + parsedMetadata.put("FILTER_INSTANCE_NAME", "TEST_AUDIENCE"); + CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false) + .lbPolicyConfig(getWrrLbConfigAsMap()); + return cdsUpdate.parsedMetadata(parsedMetadata.build()).build(); + } } diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java index 7fac666f983..588f058998e 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java @@ -129,6 +129,7 @@ import io.grpc.xds.Endpoints.LbEndpoint; import io.grpc.xds.Endpoints.LocalityLbEndpoints; import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; import io.grpc.xds.MetadataRegistry.MetadataValueParser; import io.grpc.xds.RouteLookupServiceClusterSpecifierPlugin.RlsPluginConfig; import io.grpc.xds.VirtualHost.Route; @@ -2417,8 +2418,7 @@ public Object parse(Any value) { } @Test - public void processCluster_parsesAudienceMetadata() - throws ResourceInvalidException, InvalidProtocolBufferException { + public void processCluster_parsesAudienceMetadata() throws Exception { MetadataRegistry.getInstance(); Audience audience = Audience.newBuilder() @@ -2462,7 +2462,10 @@ public void processCluster_parsesAudienceMetadata() "FILTER_METADATA", ImmutableMap.of( "key1", "value1", "key2", 42.0)); - assertThat(update.parsedMetadata()).isEqualTo(expectedParsedMetadata); + assertThat(update.parsedMetadata().get("FILTER_METADATA")) + .isEqualTo(expectedParsedMetadata.get("FILTER_METADATA")); + assertThat(update.parsedMetadata().get("AUDIENCE_METADATA")) + .isInstanceOf(AudienceWrapper.class); } @Test @@ -2519,8 +2522,7 @@ public void processCluster_parsesAddressMetadata() throws Exception { } @Test - public void processCluster_metadataKeyCollision_resolvesToTypedMetadata() - throws ResourceInvalidException, InvalidProtocolBufferException { + public void processCluster_metadataKeyCollision_resolvesToTypedMetadata() throws Exception { MetadataRegistry metadataRegistry = MetadataRegistry.getInstance(); MetadataValueParser testParser = @@ -2575,8 +2577,7 @@ public Object parse(Any value) { } @Test - public void parseNonAggregateCluster_withHttp11ProxyTransportSocket() - throws ResourceInvalidException, InvalidProtocolBufferException { + public void parseNonAggregateCluster_withHttp11ProxyTransportSocket() throws Exception { XdsClusterResource.isEnabledXdsHttpConnect = true; Http11ProxyUpstreamTransport http11ProxyUpstreamTransport = diff --git a/xds/src/test/java/io/grpc/xds/XdsTestUtils.java b/xds/src/test/java/io/grpc/xds/XdsTestUtils.java index d0580ae2667..33109ac012b 100644 --- a/xds/src/test/java/io/grpc/xds/XdsTestUtils.java +++ b/xds/src/test/java/io/grpc/xds/XdsTestUtils.java @@ -282,7 +282,7 @@ static XdsConfig getDefaultXdsConfig(String serverHostName) } @SuppressWarnings("unchecked") - private static ImmutableMap getWrrLbConfigAsMap() throws IOException { + static ImmutableMap getWrrLbConfigAsMap() throws IOException { String lbConfigStr = "{\"wrr_locality_experimental\" : " + "{ \"childPolicy\" : [{\"round_robin\" : {}}]}}";