Skip to content

Commit 9c3fc56

Browse files
authored
Fix the GCP Workload Identity Federation support in the GCP Service Connector (#2914)
* Fix the GCP Workload Identity Federation support in the GCP Service Connector Starting from google-auth version 2.29.0, the AWS logic has been moved to a separate `google.auth.aws._DefaultAwsSecurityCredentialsSupplier` class from the original `google.auth.aws.Credentials` class. This breaks the ZenML hack that allowed `sts.AssumeRoleWithWebIdentity` authentication to be supported for IAM roles attacked to EKS service accounts (see: https://medium.com/@derek10cloud/gcp-workload-identity-federation-doesnt-yet-support-eks-irsa-in-aws-a3c71877671a). * Be more specific about boto3 errors
1 parent 6730a16 commit 9c3fc56

File tree

1 file changed

+123
-6
lines changed

1 file changed

+123
-6
lines changed

src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py

Lines changed: 123 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,108 @@ class GCPAuthenticationMethods(StrEnum):
496496
IMPERSONATION = "impersonation"
497497

498498

499+
try:
500+
from google.auth.aws import _DefaultAwsSecurityCredentialsSupplier
501+
502+
class ZenMLAwsSecurityCredentialsSupplier(
503+
_DefaultAwsSecurityCredentialsSupplier # type: ignore[misc]
504+
):
505+
"""An improved version of the GCP external account credential supplier for AWS.
506+
507+
The original GCP external account credential supplier only provides
508+
rudimentary support for extracting AWS credentials from environment
509+
variables or the AWS metadata service. This version improves on that by
510+
using the boto3 library itself (if available), which uses the entire range
511+
of implicit authentication features packed into it.
512+
513+
Without this improvement, `sts.AssumeRoleWithWebIdentity` authentication is
514+
not supported for EKS pods and the EC2 attached role credentials are
515+
used instead (see: https://medium.com/@derek10cloud/gcp-workload-identity-federation-doesnt-yet-support-eks-irsa-in-aws-a3c71877671a).
516+
"""
517+
518+
def get_aws_security_credentials(
519+
self, context: Any, request: Any
520+
) -> gcp_aws.AwsSecurityCredentials:
521+
"""Get the security credentials from the local environment.
522+
523+
This method is a copy of the original method from the
524+
`google.auth.aws._DefaultAwsSecurityCredentialsSupplier` class. It has
525+
been modified to use the boto3 library to extract the AWS credentials
526+
from the local environment.
527+
528+
Args:
529+
context: The context to use to get the security credentials.
530+
request: The request to use to get the security credentials.
531+
532+
Returns:
533+
The AWS temporary security credentials.
534+
"""
535+
try:
536+
import boto3
537+
538+
session = boto3.Session()
539+
credentials = session.get_credentials()
540+
if credentials is not None:
541+
creds = credentials.get_frozen_credentials()
542+
return gcp_aws.AwsSecurityCredentials(
543+
creds.access_key,
544+
creds.secret_key,
545+
creds.token,
546+
)
547+
except ImportError:
548+
pass
549+
550+
logger.debug(
551+
"Failed to extract AWS credentials from the local environment "
552+
"using the boto3 library. Falling back to the original "
553+
"implementation."
554+
)
555+
556+
return super().get_aws_security_credentials(context, request)
557+
558+
def get_aws_region(self, context: Any, request: Any) -> str:
559+
"""Get the AWS region from the local environment.
560+
561+
This method is a copy of the original method from the
562+
`google.auth.aws._DefaultAwsSecurityCredentialsSupplier` class. It has
563+
been modified to use the boto3 library to extract the AWS
564+
region from the local environment.
565+
566+
Args:
567+
context: The context to use to get the security credentials.
568+
request: The request to use to get the security credentials.
569+
570+
Returns:
571+
The AWS region.
572+
"""
573+
try:
574+
import boto3
575+
576+
session = boto3.Session()
577+
if session.region_name:
578+
return session.region_name # type: ignore[no-any-return]
579+
except ImportError:
580+
pass
581+
582+
logger.debug(
583+
"Failed to extract AWS region from the local environment "
584+
"using the boto3 library. Falling back to the original "
585+
"implementation."
586+
)
587+
588+
return super().get_aws_region( # type: ignore[no-any-return]
589+
context, request
590+
)
591+
592+
except ImportError:
593+
# The `google.auth.aws._DefaultAwsSecurityCredentialsSupplier`
594+
# class has been introduced in the `google-auth` library version 2.29.0.
595+
# Before that, the AWS logic was part of the `google.auth.awsCredentials`
596+
# class itself.
597+
ZenMLAwsSecurityCredentialsSupplier = None # type: ignore[assignment,misc]
598+
pass
599+
600+
499601
class ZenMLGCPAWSExternalAccountCredentials(gcp_aws.Credentials): # type: ignore[misc]
500602
"""An improved version of the GCP external account credential for AWS.
501603
@@ -508,6 +610,13 @@ class ZenMLGCPAWSExternalAccountCredentials(gcp_aws.Credentials): # type: ignor
508610
Without this improvement, `sts.AssumeRoleWithWebIdentity` authentication is
509611
not supported for EKS pods and the EC2 attached role credentials are
510612
used instead (see: https://medium.com/@derek10cloud/gcp-workload-identity-federation-doesnt-yet-support-eks-irsa-in-aws-a3c71877671a).
613+
614+
IMPORTANT: subclassing this class only works with the `google-auth` library
615+
version lower than 2.29.0. Starting from version 2.29.0, the AWS logic
616+
has been moved to a separate `google.auth.aws._DefaultAwsSecurityCredentialsSupplier`
617+
class that can be subclassed instead and supplied as the
618+
`aws_security_credentials_supplier` parameter to the
619+
`google.auth.aws.Credentials` class.
511620
"""
512621

513622
def _get_security_credentials(
@@ -539,12 +648,14 @@ def _get_security_credentials(
539648
"secret_access_key": creds.secret_key,
540649
"security_token": creds.token,
541650
}
542-
except Exception:
543-
logger.debug(
544-
"Failed to extract AWS credentials from the local environment "
545-
"using the boto3 library. Falling back to the original "
546-
"implementation."
547-
)
651+
except ImportError:
652+
pass
653+
654+
logger.debug(
655+
"Failed to extract AWS credentials from the local environment "
656+
"using the boto3 library. Falling back to the original "
657+
"implementation."
658+
)
548659

549660
return super()._get_security_credentials( # type: ignore[no-any-return]
550661
request, imdsv2_session_token
@@ -1126,6 +1237,12 @@ def _authenticate(
11261237
account_info.get("subject_token_type")
11271238
== _AWS_SUBJECT_TOKEN_TYPE
11281239
):
1240+
if ZenMLAwsSecurityCredentialsSupplier is not None:
1241+
account_info["aws_security_credentials_supplier"] = (
1242+
ZenMLAwsSecurityCredentialsSupplier(
1243+
account_info.pop("credential_source"),
1244+
)
1245+
)
11291246
credentials = (
11301247
ZenMLGCPAWSExternalAccountCredentials.from_info(
11311248
account_info,

0 commit comments

Comments
 (0)