From 999abefa42cc0a09552c45c0f34c8371fc743509 Mon Sep 17 00:00:00 2001 From: Keith Smiley Date: Wed, 23 Apr 2025 17:44:03 +0000 Subject: [PATCH] feat(pypi): allow scoping pip.override to single hubs If you use multiple pip.parse hubs for different platforms, it's possible you want to override wheels for a single platform, where previously pip.override would apply to all hubs. This is useful for GPU vs CPU dependencies on the same OS / arch combinations, where the wheel name isn't enough to differentiate. --- python/private/pypi/extension.bzl | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/python/private/pypi/extension.bzl b/python/private/pypi/extension.bzl index d1895ca211..a76024e46f 100644 --- a/python/private/pypi/extension.bzl +++ b/python/private/pypi/extension.bzl @@ -193,7 +193,10 @@ def _create_whl_repos( python_interpreter_target = python_interpreter_target, whl_patches = { p: json.encode(args) - for p, args in whl_overrides.get(whl_name, {}).items() + for p, args in ( + whl_overrides.get(hub_name, {}).get(whl_name, {}).items() + + whl_overrides.get("", {}).get(whl_name, {}).items() # Overrides without a hub name apply to all hubs + ) }, ) whl_library_args.update({k: v for k, v in maybe_args.items() if v}) @@ -384,16 +387,18 @@ You cannot use both the additive_build_content and additive_build_content_file a _overriden_whl_set[attr.file] = None for patch in attr.patches: - if whl_name not in whl_overrides: - whl_overrides[whl_name] = {} + if attr.hub_name not in whl_overrides: + whl_overrides[attr.hub_name] = {} + if whl_name not in whl_overrides[attr.hub_name]: + whl_overrides[attr.hub_name][whl_name] = {} - if patch not in whl_overrides[whl_name]: - whl_overrides[whl_name][patch] = struct( + if patch not in whl_overrides[attr.hub_name][whl_name]: + whl_overrides[attr.hub_name][whl_name][patch] = struct( patch_strip = attr.patch_strip, whls = [], ) - whl_overrides[whl_name][patch].whls.append(attr.file) + whl_overrides[attr.hub_name][whl_name][patch].whls.append(attr.file) # Used to track all the different pip hubs and the spoke pip Python # versions. @@ -845,6 +850,12 @@ applied to all repositories that setup this distribution via the pip.parse tag class.""", mandatory = True, ), + "hub_name": attr.string( + doc = """ +The name of the pip repo to override the wheel in. If this is not provided the +wheel is overridden in all pip hubs. +""", + ), "patch_strip": attr.int( default = 0, doc = """\