diff --git a/.github/workflows/unittest_ci_cpu.yml b/.github/workflows/unittest_ci_cpu.yml index f95a06bf3..1cd74f820 100644 --- a/.github/workflows/unittest_ci_cpu.yml +++ b/.github/workflows/unittest_ci_cpu.yml @@ -20,24 +20,24 @@ on: jobs: build_test: strategy: - fail-fast: false - matrix: - include: - - os: linux.2xlarge - python-version: 3.9 - python-tag: "py39" - - os: linux.2xlarge - python-version: '3.10' - python-tag: "py310" - - os: linux.2xlarge - python-version: '3.11' - python-tag: "py311" - - os: linux.2xlarge - python-version: '3.12' - python-tag: "py312" - - os: linux.2xlarge - python-version: '3.13' - python-tag: "py313" + fail-fast: false + matrix: + include: + - os: linux.2xlarge + python-version: 3.9 + python-tag: "py39" + - os: linux.2xlarge + python-version: '3.10' + python-tag: "py310" + - os: linux.2xlarge + python-version: '3.11' + python-tag: "py311" + - os: linux.2xlarge + python-version: '3.12' + python-tag: "py312" + - os: linux.2xlarge + python-version: '3.13' + python-tag: "py313" uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: id-token: write diff --git a/torchrec/distributed/hash_mc_embedding.py b/torchrec/distributed/hash_mc_embedding.py new file mode 100644 index 000000000..4171e1092 --- /dev/null +++ b/torchrec/distributed/hash_mc_embedding.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import logging as logger +from collections import defaultdict +from typing import Dict, List + +import torch +from torchrec.distributed.quant_state import WeightSpec +from torchrec.distributed.types import ShardingType +from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule + + +def sharded_zchs_buffers_spec( + sharded_model: torch.nn.Module, +) -> Dict[str, WeightSpec]: + # OUTPUT: + # Example: + # "main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.0._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities", [0, 0], [500, 1]) + # "main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.1._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities", [500, 0], [1000, 1]) + + # 'main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.0._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities': WeightSpec(fqn='main_module.module.ec_in_task_arch_hash._ d_embedding_collection._managed_collision_collection.viewer_rid_duplicate._hash_zch_identities' + def _get_table_names( + sharded_module: torch.nn.Module, + ) -> List[str]: + table_names: List[str] = [] + for _, module in sharded_module.named_modules(): + type_name: str = type(module).__name__ + if "ShardedMCCRemapper" in type_name: + for table_name in module._tables: + if table_name not in table_names: + table_names.append(table_name) + return table_names + + def _get_unsharded_fqn_identities( + sharded_module: torch.nn.Module, + fqn: str, + table_name: str, + ) -> str: + for module_fqn, module in sharded_module.named_modules(): + type_name: str = type(module).__name__ + if "ManagedCollisionCollection" in type_name: + if table_name in module._table_to_features: + return f"{fqn}.{module_fqn}._managed_collision_modules.{table_name}.{HashZchManagedCollisionModule.IDENTITY_BUFFER}" + logger.info(f"did not find table {table_name} in module {fqn}") + return "" + + ret: Dict[str, WeightSpec] = defaultdict() + for module_fqn, module in sharded_model.named_modules(): + type_name: str = type(module).__name__ + if "ShardedQuantManagedCollisionEmbeddingCollection" in type_name: + sharding_type = ShardingType.ROW_WISE.value + table_name_to_unsharded_fqn_identities: Dict[str, str] = {} + for subfqn, submodule in module.named_modules(): + type_name: str = type(submodule).__name__ + if "ShardedMCCRemapper" in type_name: + for table_name in submodule.zchs.keys(): + # identities tensor has only one column + shard_offsets: List[int] = [ + submodule._shard_metadata[table_name][0], + 0, + ] + shard_sizes: List[int] = [ + submodule._shard_metadata[table_name][1], + 1, + ] + if table_name not in table_name_to_unsharded_fqn_identities: + table_name_to_unsharded_fqn_identities[table_name] = ( + _get_unsharded_fqn_identities( + module, module_fqn, table_name + ) + ) + unsharded_fqn_identities: str = ( + table_name_to_unsharded_fqn_identities[table_name] + ) + # subfqn contains the index of sharding, so no need to add it specifically here + sharded_fqn_identities: str = ( + f"{module_fqn}.{subfqn}.zchs.{table_name}.{HashZchManagedCollisionModule.IDENTITY_BUFFER}" + ) + ret[sharded_fqn_identities] = WeightSpec( + fqn=unsharded_fqn_identities, + shard_offsets=shard_offsets, + shard_sizes=shard_sizes, + sharding_type=sharding_type, + ) + return ret diff --git a/torchrec/distributed/tests/test_hash_zch_mc.py b/torchrec/distributed/tests/test_hash_zch_mc.py new file mode 100644 index 000000000..7cf9906d1 --- /dev/null +++ b/torchrec/distributed/tests/test_hash_zch_mc.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#!/usr/bin/env python3 + +# pyre-strict + +import copy +import multiprocessing +import unittest +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from pyre_extensions import none_throws +from torch import nn +from torchrec import ( + EmbeddingCollection, + EmbeddingConfig, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) +from torchrec.distributed import ModuleSharder, ShardingEnv +from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder + +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + EmbeddingCollectionSharder, + ManagedCollisionEmbeddingCollectionSharder, + row_wise, +) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ShardingPlan +from torchrec.modules.hash_mc_evictions import ( + HashZchEvictionConfig, + HashZchEvictionPolicyName, +) +from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection +from torchrec.modules.mc_modules import ManagedCollisionCollection + +BASE_LEAF_MODULES = [ + "IntNBitTableBatchedEmbeddingBagsCodegen", + "HashZchManagedCollisionModule", +] + + +class SparseArch(nn.Module): + def __init__( + self, + tables: List[EmbeddingConfig], + device: torch.device, + buckets: int, + return_remapped: bool = False, + input_hash_size: int = 4000, + is_inference: bool = False, + ) -> None: + super().__init__() + self._return_remapped = return_remapped + + mc_modules = {} + mc_modules["table_0"] = HashZchManagedCollisionModule( + is_inference=is_inference, + zch_size=(tables[0].num_embeddings), + input_hash_size=input_hash_size, + device=device, + total_num_buckets=buckets, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=["feature_0"], + single_ttl=1, + ), + ) + + mc_modules["table_1"] = HashZchManagedCollisionModule( + is_inference=is_inference, + zch_size=(tables[1].num_embeddings), + device=device, + input_hash_size=input_hash_size, + total_num_buckets=buckets, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=["feature_1"], + single_ttl=1, + ), + ) + + self._mc_ec: ManagedCollisionEmbeddingCollection = ( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection( + tables=tables, + device=device, + ), + ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=tables, + ), + return_remapped_features=self._return_remapped, + ) + ) + + def forward( + self, kjt: KeyedJaggedTensor + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + return self._mc_ec(kjt) + + +class TestHashZchMcEmbedding(MultiProcessTestBase): + # pyre-ignore + @unittest.skipIf(torch.cuda.device_count() <= 1, "Not enough GPUs, skipping") + def test_hash_zch_mc_ec(self) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + train_input_per_rank = [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + list(range(1000, 1025)), + ), + lengths=torch.LongTensor([1] * 8 + [2] * 8), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + list(range(25000, 25025)), + ), + lengths=torch.LongTensor([1] * 8 + [2] * 8), + weights=None, + ), + ] + train_state_dict = multiprocessing.Manager().dict() + + # Train Model with ZCH on GPU + self._run_multi_process_test( + callable=_train_model, + world_size=WORLD_SIZE, + tables=embedding_config, + num_buckets=2, + kjt_input_per_rank=train_input_per_rank, + sharder=ManagedCollisionEmbeddingCollectionSharder( + EmbeddingCollectionSharder(), + ManagedCollisionCollectionSharder(), + ), + return_dict=train_state_dict, + backend="nccl", + ) + + +def _train_model( + tables: List[EmbeddingConfig], + num_buckets: int, + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + sharder: ModuleSharder[nn.Module], + backend: str, + return_dict: Dict[str, Any], + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + + train_model = SparseArch( + tables=tables, + device=torch.device("cuda"), + input_hash_size=0, + return_remapped=True, + buckets=num_buckets, + ) + train_sharding_plan = construct_module_sharding_plan( + train_model._mc_ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda", + sharder=sharder, + ) + print(f"train_sharding_plan: {train_sharding_plan}") + sharded_train_model = _shard_modules( + module=copy.deepcopy(train_model), + plan=ShardingPlan({"_mc_ec": train_sharding_plan}), + env=ShardingEnv.from_process_group(none_throws(ctx.pg)), + sharders=[sharder], + device=ctx.device, + ) + # train + sharded_train_model(kjt_input.to(ctx.device)) + + for ( + key, + value, + ) in ( + # pyre-ignore + sharded_train_model._mc_ec._managed_collision_collection._managed_collision_modules.state_dict().items() + ): + return_dict[f"mc_{key}_{rank}"] = value.cpu() + for ( + key, + value, + # pyre-ignore + ) in sharded_train_model._mc_ec._embedding_collection.state_dict().items(): + tensors = [] + for i in range(len(value.local_shards())): + tensors.append(value.local_shards()[i].tensor.cpu()) + return_dict[f"ec_{key}_{rank}"] = torch.cat(tensors, dim=0) diff --git a/torchrec/modules/hash_mc_evictions.py b/torchrec/modules/hash_mc_evictions.py new file mode 100644 index 000000000..c5875bf8d --- /dev/null +++ b/torchrec/modules/hash_mc_evictions.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +import time +from dataclasses import dataclass +from enum import Enum, unique +from typing import List, Optional, Tuple + +import torch +from pyre_extensions import none_throws + +from torchrec.sparse.jagged_tensor import JaggedTensor + +logger: logging.Logger = logging.getLogger(__name__) + + +@unique +class HashZchEvictionPolicyName(Enum): + # eviction based on the time the ID is last seen during training, + # and a single TTL + SINGLE_TTL_EVICTION = "SINGLE_TTL_EVICTION" + # eviction based on the time the ID is last seen during training, + # and per-feature TTLs + PER_FEATURE_TTL_EVICTION = "PER_FEATURE_TTL_EVICTION" + # eviction based on least recently seen ID within the probe range + LRU_EVICTION = "LRU_EVICTION" + + +@torch.jit.script +@dataclass +class HashZchEvictionConfig: + features: List[str] + single_ttl: Optional[int] = None + per_feature_ttl: Optional[List[int]] = None + + +@torch.fx.wrap +def get_kernel_from_policy( + policy_name: Optional[HashZchEvictionPolicyName], +) -> int: + return ( + 1 + if policy_name is not None + and policy_name == HashZchEvictionPolicyName.LRU_EVICTION + else 0 + ) + + +class HashZchEvictionScorer: + def __init__(self, config: HashZchEvictionConfig) -> None: + self._config: HashZchEvictionConfig = config + + def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor: + return torch.empty(0, device=device) + + def gen_threshold(self) -> int: + return -1 + + +class HashZchSingleTtlScorer(HashZchEvictionScorer): + def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor: + assert ( + self._config.single_ttl is not None and self._config.single_ttl > 0 + ), "To use scorer HashZchSingleTtlScorer, a positive single_ttl is required." + + return torch.full_like( + feature.values(), + # pyre-ignore [58] + self._config.single_ttl + int(time.time() / 3600), + dtype=torch.int32, + device=device, + ) + + def gen_threshold(self) -> int: + return int(time.time() / 3600) + + +class HashZchPerFeatureTtlScorer(HashZchEvictionScorer): + def __init__(self, config: HashZchEvictionConfig) -> None: + super().__init__(config) + + assert self._config.per_feature_ttl is not None and len( + self._config.features + ) == len( + # pyre-ignore [6] + self._config.per_feature_ttl + ), "To use scorer HashZchPerFeatureTtlScorer, a 1:1 mapping between features and per_feature_ttl is required." + + self._per_feature_ttl = torch.IntTensor(self._config.per_feature_ttl) + + def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor: + feature_split = feature.weights() + assert feature_split.size(0) == self._per_feature_ttl.size(0) + + scores = self._per_feature_ttl.repeat_interleave(feature_split) + int( + time.time() / 3600 + ) + + return scores.to(device=device) + + def gen_threshold(self) -> int: + return int(time.time() / 3600) + + +@torch.fx.wrap +def get_eviction_scorer( + policy_name: str, config: HashZchEvictionConfig +) -> HashZchEvictionScorer: + if policy_name == HashZchEvictionPolicyName.SINGLE_TTL_EVICTION: + return HashZchSingleTtlScorer(config) + elif policy_name == HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION: + return HashZchPerFeatureTtlScorer(config) + elif policy_name == HashZchEvictionPolicyName.LRU_EVICTION: + return HashZchSingleTtlScorer(config) + else: + return HashZchEvictionScorer(config) + + +class HashZchThresholdEvictionModule(torch.nn.Module): + """ + This module manages the computation of eviction score for input IDs. Based on the selected + eviction policy, a scorer is initiated to generate a score for each ID. The kernel + will use this score to make eviction decisions. + + Args: + policy_name: an enum value that indicates the eviction policy to use. + config: a config that contains information needed to run the eviction policy. + + Example:: + module = HashZchThresholdEvictionModule(...) + score = module(feature) + """ + + _eviction_scorer: HashZchEvictionScorer + + def __init__( + self, + policy_name: HashZchEvictionPolicyName, + config: HashZchEvictionConfig, + ) -> None: + super().__init__() + + self._policy_name: HashZchEvictionPolicyName = policy_name + self._config: HashZchEvictionConfig = config + self._eviction_scorer = get_eviction_scorer( + policy_name=self._policy_name, + config=self._config, + ) + + logger.info( + f"HashZchThresholdEvictionModule: {self._policy_name=}, {self._config=}" + ) + + def forward( + self, feature: JaggedTensor, device: torch.device + ) -> Tuple[torch.Tensor, int]: + """ + Args: + feature: a jagged tensor that contains the input IDs, and their lengths and + weights (feature split). + device: device of the tensor. + + Returns: + a tensor that contains the eviction score for each ID, plus an eviction threshold. + """ + return ( + self._eviction_scorer.gen_score(feature, device), + self._eviction_scorer.gen_threshold(), + ) + + +class HashZchOptEvictionModule(torch.nn.Module): + """ + This module manages the eviction of IDs from the ZCH table based on the selected eviction policy. + Args: + policy_name: an enum value that indicates the eviction policy to use. + Example: + module = HashZchOptEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION) + """ + + def __init__( + self, + policy_name: HashZchEvictionPolicyName, + ) -> None: + super().__init__() + + self._policy_name: HashZchEvictionPolicyName = policy_name + + def forward(self, feature: JaggedTensor, device: torch.device) -> Tuple[None, int]: + """ + Does not apply to this Eviction Policy. Returns None and -1. + Args: + feature: No op + Returns: + None, -1 + """ + return None, -1 + + +@torch.fx.wrap +def get_eviction_module( + policy_name: HashZchEvictionPolicyName, config: Optional[HashZchEvictionConfig] +) -> torch.nn.Module: + if policy_name in ( + HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION, + HashZchEvictionPolicyName.LRU_EVICTION, + ): + return HashZchThresholdEvictionModule(policy_name, none_throws(config)) + else: + return HashZchOptEvictionModule(policy_name) + + +class HashZchEvictionModule(torch.nn.Module): + """ + This module manages the eviction of IDs from the ZCH table based on the selected eviction policy. + Args: + policy_name: an enum value that indicates the eviction policy to use. + device: device of the tensor. + config: an optional config required if threshold based eviction is selected. + Example: + module = HashZchEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION) + """ + + def __init__( + self, + policy_name: HashZchEvictionPolicyName, + device: torch.device, + config: Optional[HashZchEvictionConfig], + ) -> None: + super().__init__() + + self._policy_name: HashZchEvictionPolicyName = policy_name + self._device: torch.device = device + self._eviction_module: torch.nn.Module = get_eviction_module( + self._policy_name, config + ) + + logger.info(f"HashZchEvictionModule: {self._policy_name=}, {self._device=}") + + def forward(self, feature: JaggedTensor) -> Tuple[Optional[torch.Tensor], int]: + """ + Args: + feature: a jagged tensor that contains the input IDs, and their lengths and + weights (feature split). + + Returns: + For threshold eviction, a tensor that contains the eviction score for each ID, plus an eviction threshold. Otherwise None and -1. + """ + return self._eviction_module(feature, self._device) diff --git a/torchrec/modules/hash_mc_metrics.py b/torchrec/modules/hash_mc_metrics.py new file mode 100644 index 000000000..0f5054727 --- /dev/null +++ b/torchrec/modules/hash_mc_metrics.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +import time +from typing import Optional + +import torch + +from torchrec.modules.hash_mc_evictions import HashZchEvictionConfig + + +class ScalarLogger(torch.nn.Module): + """ + A logger to report various metrics related to multi-probe ZCH. + + Args: + name: name of the embedding table. + zch_size: size of the sharded embedding table. + frequency: frequency of reporting metrics. + start_bucket: start bucket of the rank. + + + Example:: + logger = ScalarLogger(...) + logger(run_type, identities) + """ + + STEPS_BUFFER: str = "_scalar_logger_steps" + SECONDS_IN_HOUR: int = 3600 + MAX_HOURS: int = 2**31 - 1 + + def __init__( + self, + name: str, + zch_size: int, + frequency: int, + start_bucket: int, + log_file_path: str = "", + ) -> None: + super().__init__() + + # persist scalar logger steps in checkpoint to make sure it is not reset after training job restarted + self.register_buffer( + ScalarLogger.STEPS_BUFFER, + torch.tensor(1, dtype=torch.int64), + persistent=True, + ) + + self._name: str = name + self._zch_size: int = zch_size + self._frequency: int = frequency + self._start_bucket: int = start_bucket + + self._dtype_checked: bool = False + self._total_cnt: int = 0 + self._hit_cnt: int = 0 + self._insert_cnt: int = 0 + self._collision_cnt: int = 0 + self._eviction_cnt: int = 0 + self._opt_in_cnt: int = 0 + self._sum_eviction_age: float = 0.0 + + self.logger: logging.Logger = logging.getLogger() + if ( + log_file_path != "" + ): # if a log file path is provided, create a file handler to output logs to the file + file_handler = logging.FileHandler( + log_file_path, mode="w" + ) # initialize file handler + self.logger.addHandler(file_handler) # add file handler to logger + + def should_report(self) -> bool: + # We only need to report metrics from rank0 (start_bucket = 0) + + return ( + self._start_bucket == 0 + and self._total_cnt > 0 + and + # pyre-fixme[29]: `Union[(self: TensorBase, other: Any) -> Tensor, Tensor, + # Module]` is not a function. + self._scalar_logger_steps % self._frequency == 0 + ) + + def update( + self, + identities_0: torch.Tensor, + identities_1: torch.Tensor, + values: torch.Tensor, + remapped_ids: torch.Tensor, + evicted_emb_indices: Optional[torch.Tensor], + metadata: Optional[torch.Tensor], + num_reserved_slots: int, + eviction_config: Optional[HashZchEvictionConfig] = None, + ) -> None: + if not self._dtype_checked: + assert ( + identities_0.dtype == values.dtype + ), "identity type and feature type must match for meaningful metrics collection." + self._dtype_checked = True + + remapped_identities_0 = torch.index_select(identities_0, 0, remapped_ids)[:, 0] + remapped_identities_1 = torch.index_select(identities_1, 0, remapped_ids)[:, 0] + empty_slot_cnt_before_process = remapped_identities_0 == -1 + empty_slot_cnt_after_process = remapped_identities_1 == -1 + insert_cnt = int(torch.sum(empty_slot_cnt_before_process).item()) - int( + torch.sum(empty_slot_cnt_after_process).item() + ) + + self._insert_cnt += insert_cnt + self._total_cnt += values.numel() + hits = torch.eq(remapped_identities_0, values) + hit_cnt = int(torch.sum(hits).item()) + self._hit_cnt += hit_cnt + self._collision_cnt += values.numel() - hit_cnt - insert_cnt + + opt_in_range = self._zch_size - num_reserved_slots + opt_in_ids = torch.lt(remapped_ids, opt_in_range) + self._opt_in_cnt += int(torch.sum(opt_in_ids).item()) + + if evicted_emb_indices is not None and evicted_emb_indices.numel() > 0: + deduped_evicted_indices = torch.unique(evicted_emb_indices) + self._eviction_cnt += deduped_evicted_indices.numel() + + assert ( + metadata is not None + ), "metadata cannot be None when evicted_emb_indices has values" + now_c = int(time.time()) + cur_hour = now_c / ScalarLogger.SECONDS_IN_HOUR % ScalarLogger.MAX_HOURS + if eviction_config is not None and eviction_config.single_ttl is not None: + self._sum_eviction_age += int( + torch.sum( + cur_hour + + eviction_config.single_ttl + - metadata[deduped_evicted_indices, 0] + ).item() + ) + + def forward( + self, + run_type: str, + identities: torch.Tensor, + ) -> None: + """ + Args: + run_type: type of the run (train, eval, etc). + identities: the identities tensor for metrics computation. + + Returns: + None + """ + + if self.should_report(): + hit_rate = round(self._hit_cnt / self._total_cnt, 3) + insert_rate = round(self._insert_cnt / self._total_cnt, 3) + collision_rate = round(self._collision_cnt / self._total_cnt, 3) + eviction_rate = round(self._eviction_cnt / self._total_cnt, 3) + total_unused_slots = int(torch.sum(identities[:, 0] == -1).item()) + table_usage_ratio = round( + (self._zch_size - total_unused_slots) / self._zch_size, 3 + ) + opt_in_rate = ( + round(self._opt_in_cnt / self._total_cnt, 3) + if self._total_cnt > 0 + else 0 + ) + avg_eviction_age = ( + round(self._sum_eviction_age / self._eviction_cnt, 3) + if self._eviction_cnt > 0 + else 0 + ) + + # log the metrics to console (if no log file path is provided) or to the file (if a log file path is provided) + self.logger.info( + f"{self._name=}, {run_type=}, " + f"{self._total_cnt=}, {self._hit_cnt=}, {hit_rate=}, " + f"{self._insert_cnt=}, {insert_rate=}, " + f"{self._collision_cnt=}, {collision_rate=}, " + f"{self._eviction_cnt=}, {eviction_rate=}, {avg_eviction_age=}, " + f"{self._opt_in_cnt=}, {opt_in_rate=}, " + f"{total_unused_slots=}, {table_usage_ratio=}" + ) + + # reset the counter after reporting + self._total_cnt = 0 + self._hit_cnt = 0 + self._insert_cnt = 0 + self._collision_cnt = 0 + self._eviction_cnt = 0 + self._opt_in_cnt = 0 + self._sum_eviction_age = 0.0 + + # pyre-ignore[16]: `ScalarLogger` has no attribute `_scalar_logger_steps`. + # pyre-ignore[29]: `Union[(self: TensorBase, other: Any) -> Tensor, Tensor, Module]` is not a function. + self._scalar_logger_steps += 1 diff --git a/torchrec/modules/hash_mc_modules.py b/torchrec/modules/hash_mc_modules.py new file mode 100644 index 000000000..fe5a0ce19 --- /dev/null +++ b/torchrec/modules/hash_mc_modules.py @@ -0,0 +1,580 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +import math +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import fbgemm_gpu # @manual=//deeplearning/fbgemm/fbgemm_gpu:fbgemm_gpu + +import torch + +from torchrec.modules.hash_mc_evictions import ( + get_kernel_from_policy, + HashZchEvictionConfig, + HashZchEvictionModule, + HashZchEvictionPolicyName, +) +from torchrec.modules.hash_mc_metrics import ScalarLogger +from torchrec.modules.mc_modules import ManagedCollisionModule +from torchrec.sparse.jagged_tensor import JaggedTensor + +logger: logging.Logger = logging.getLogger(__name__) + + +@torch.fx.wrap +def _tensor_may_to_device( + src: torch.Tensor, + device: Optional[torch.device] = None, +) -> Tuple[torch.Tensor, torch.device]: + src_device: torch.device = src.device + if device is None: + return (src, src_device) + + if device.type != "meta" and src_device != device: + return (src.to(device), src_device) + return (src, src_device) + + +class TrainInputMapper(torch.nn.Module): + """ + Module used to generate sizes and offsets information corresponding to + the train ranks for inference inputs. This is due to we currently merge + all identity tensors that are row-wise sharded across training ranks at + inference time. So we need to map the inputs to the chunk of identities + that the input would go at training time to generate appropriate indices. + + Args: + input_hash_size: the max size of input IDs + total_num_buckets: the total number of buckets across all ranks at training time + size_per_rank: the size of the identity tensor/embedding size per rank + train_rank_offsets: the offset of the embedding table indices per rank + inference_dispatch_div_train_world_size: the flag to control whether to divide input by + world_size https://fburl.com/code/c9x98073 + name: the name of the embedding table + + Example:: + mapper = TrainInputMapper(...) + mapper(values, output_offset) + """ + + def __init__( + self, + input_hash_size: int, + total_num_buckets: int, + size_per_rank: torch.Tensor, + train_rank_offsets: torch.Tensor, + inference_dispatch_div_train_world_size: bool = False, + name: Optional[str] = None, + ) -> None: + super().__init__() + + self._input_hash_size = input_hash_size + assert total_num_buckets > 0, f"{total_num_buckets=} must be positive" + self._buckets = total_num_buckets + self._inference_dispatch_div_train_world_size = ( + inference_dispatch_div_train_world_size + ) + self._name = name + self.register_buffer( + "_zch_size_per_training_rank", size_per_rank, persistent=False + ) + self.register_buffer( + "_train_rank_offsets", train_rank_offsets, persistent=False + ) + logger.info( + f"TrainInputMapper: {self._name=}, {self._input_hash_size=}, {self._zch_size_per_training_rank=}, " + f"{self._train_rank_offsets=}, {self._inference_dispatch_div_train_world_size=}" + ) + + # TODO: make a kernel + def _get_values_sizes_offsets( + self, x: torch.Tensor, output_offset: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + zch_size_per_training_rank, _ = _tensor_may_to_device( + self._zch_size_per_training_rank, x.device + ) + train_rank_offsets, _ = _tensor_may_to_device( + self._train_rank_offsets, x.device + ) + + # NOTE: This assumption has to be the same as TorchRec input_dist logic + # https://fburl.com/code/c9x98073. Do not use torch.where() for performance. + if self._input_hash_size == 0: + train_ranks = x % self._buckets + if self._inference_dispatch_div_train_world_size: + x = x // self._buckets + else: + blk_size = (self._input_hash_size // self._buckets) + ( + 0 if self._input_hash_size % self._buckets == 0 else 1 + ) + train_ranks = x // blk_size + if self._inference_dispatch_div_train_world_size: + x = x % blk_size + + local_sizes = zch_size_per_training_rank.index_select( + dim=0, index=train_ranks + ) # This line causes error where zch_size_per_training_rank = tensor([25000, 25000, 25000, 25000], device='cuda:1') and train_ranks = tensor([291, 34, 15], device='cuda:1'), leading to index error: index out of range + offsets = train_rank_offsets.index_select(dim=0, index=train_ranks) + if output_offset is not None: + offsets -= output_offset + + return (x, local_sizes, offsets) + + def forward( + self, + values: torch.Tensor, + output_offset: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Args: + values: ID values to compute bucket assignment and offset. + output_offset: global offset of the start bucket per rank, used to compute bucket offset within the rank. + + Returns: + A tuple of three tensors: + - values: transformed ID values, different from input value only if inference_dispatch_div_train_world_size is True. + - local_sizes: bucket sizes of the input values. + - offsets: in-rank bucket offsets of the input values. + """ + + values, local_sizes, offsets = self._get_values_sizes_offsets( + values, output_offset + ) + return (values, local_sizes, offsets) + + +@torch.fx.wrap +def _get_device(hash_zch_identities: torch.Tensor) -> torch.device: + return hash_zch_identities.device + + +class HashZchManagedCollisionModule(ManagedCollisionModule): + """ + Module to manage multi-probe ZCH (MPZCH), including lookup (remapping), eviction, metrics collection, and required auxiliary tensors. + + Args: + zch_size: local size of the embedding table + device: the compute device + total_num_buckets: logical shard within each rank for resharding purpose, note that + 1) zch_size must be a multiple of total_num_buckets, and 2) total_num_buckets must be a multiple of world size + max_probe: the number of times MPZCH kernel attempts to run linear search for lookup or insertion + input_hash_size: the max size of input IDs (default to 0) + output_segments: the index range of each bucket, which is computed before sharding and typically not provided by user + is_inference: the flag to indicate if the module is used in inference as opposed to train or eval + name: the name of the embedding table + tb_logging_frequency: the frequency of emitting metrics to TensorBoard, measured by the number of batches + eviction_policy_name: the specific policy to be used for eviction operations + eviction_config: the config associated with the selected eviction policy + inference_dispatch_div_train_world_size: the flag to control whether to divide input by + world_size https://fburl.com/code/c9x98073 + start_bucket: start bucket of the current rank, typically not provided by user + end_bucket: end bucket of the current rank, typically not provided by user + opt_in_prob: the probability of an ID to be opted in from a statistical aspect + percent_reserved_slots: percentage of slots to be reserved when opt-in is enabled, the value must be in [0, 100) + + Example:: + module = HashZchManagedCollisionModule(...) + module(features) + """ + + _evicted_indices: List[torch.Tensor] + + IDENTITY_BUFFER: str = "_hash_zch_identities" + METADATA_BUFFER: str = "_hash_zch_metadata" + + def __init__( + self, + zch_size: int, + device: torch.device, + total_num_buckets: int, + max_probe: int = 128, + input_hash_size: int = 0, + output_segments: Optional[List[int]] = None, + is_inference: bool = False, + name: Optional[str] = None, + tb_logging_frequency: int = 0, + eviction_policy_name: Optional[HashZchEvictionPolicyName] = None, + eviction_config: Optional[HashZchEvictionConfig] = None, + inference_dispatch_div_train_world_size: bool = False, + start_bucket: int = 0, + end_bucket: Optional[int] = None, + opt_in_prob: int = -1, + percent_reserved_slots: float = 0, + ) -> None: + if output_segments is None: + assert ( + zch_size % total_num_buckets == 0 + ), f"please pass output segments if not uniform buckets {zch_size=}, {total_num_buckets=}" + output_segments = [ + (zch_size // total_num_buckets) * bucket + for bucket in range(total_num_buckets + 1) + ] + + super().__init__( + device=device, + output_segments=output_segments, + skip_state_validation=True, # avoid peristent buffers for TGIF Puslishing + ) + + self._zch_size: int = zch_size + self._output_segments: List[int] = output_segments + self._start_bucket: int = start_bucket + self._end_bucket: int = ( + end_bucket if end_bucket is not None else total_num_buckets + ) + self._output_global_offset_tensor: Optional[torch.Tensor] = None + if output_segments[start_bucket] > 0: + self._output_global_offset_tensor = torch.tensor( + [output_segments[start_bucket]], + dtype=torch.int64, + device=device if device.type != "meta" else torch.device("cpu"), + ) + + self._device: torch.device = device + self._input_hash_size: int = input_hash_size + self._is_inference: bool = is_inference + self._name: Optional[str] = name + self._tb_logging_frequency: int = tb_logging_frequency + self._scalar_logger: Optional[ScalarLogger] = None + self._eviction_policy_name: Optional[HashZchEvictionPolicyName] = ( + eviction_policy_name + ) + self._eviction_config: Optional[HashZchEvictionConfig] = eviction_config + self._eviction_module: Optional[HashZchEvictionModule] = ( + HashZchEvictionModule( + policy_name=self._eviction_policy_name, + device=self._device, + config=self._eviction_config, + ) + if self._eviction_policy_name is not None and self.training + else None + ) + self._opt_in_prob: int = opt_in_prob + assert ( + percent_reserved_slots >= 0 and percent_reserved_slots < 100 + ), "percent_reserved_slots must be in [0, 100)" + self._percent_reserved_slots: float = percent_reserved_slots + if self._opt_in_prob > 0: + assert ( + self._percent_reserved_slots > 0 + ), "percent_reserved_slots must be positive when opt_in_prob is positive" + assert ( + self._eviction_policy_name is None + or self._eviction_policy_name != HashZchEvictionPolicyName.LRU_EVICTION + ), "LRU eviction is not compatible with opt-in at this time" + + if torch.jit.is_scripting() or self._is_inference or self._name is None: + self._tb_logging_frequency = 0 + + if self._tb_logging_frequency > 0 and self._device.type != "meta": + assert self._name is not None + self._scalar_logger = ScalarLogger( + name=self._name, + zch_size=self._zch_size, + frequency=self._tb_logging_frequency, + start_bucket=self._start_bucket, + ) + else: + logger.info( + f"ScalarLogger is disabled because {self._tb_logging_frequency=} and {self._device.type=}" + ) + + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + size=self._zch_size, + support_evict=self._eviction_module is not None, + device=self._device, + long_type=True, # deprecated, always True + ) + + self._hash_zch_identities = torch.nn.Parameter(identities, requires_grad=False) + self.register_buffer(HashZchManagedCollisionModule.METADATA_BUFFER, metadata) + + self._max_probe = max_probe + self._buckets = total_num_buckets + # Do not need to store in buffer since this is created and consumed + # at each step https://fburl.com/code/axzimmbx + self._evicted_indices = [] + + # do not pass device, so its initialized on default physical device ('meta' will result in silent failure) + size_per_rank = torch.diff( + torch.tensor(self._output_segments, dtype=torch.int64) + ) + + self.input_mapper: torch.nn.Module = TrainInputMapper( + input_hash_size=self._input_hash_size, + total_num_buckets=total_num_buckets, + size_per_rank=size_per_rank, + train_rank_offsets=torch.tensor( + torch.ops.fbgemm.asynchronous_exclusive_cumsum(size_per_rank) + ), + # be consistent with https://fburl.com/code/p4mj4mc1 + inference_dispatch_div_train_world_size=inference_dispatch_div_train_world_size, + name=self._name, + ) + + if self._is_inference is True: + self.reset_inference_mode() + + logger.info( + f"HashZchManagedCollisionModule: {self._name=}, {self.device=}, " + f"{self._zch_size=}, {self._input_hash_size=}, {self._max_probe=}, " + f"{self._is_inference=}, {self._tb_logging_frequency=}, " + f"{self._eviction_policy_name=}, {self._eviction_config=}, " + f"{self._buckets=}, {self._start_bucket=}, {self._end_bucket=}, " + f"{self._output_global_offset_tensor=}, {self._output_segments=}, " + f"{inference_dispatch_div_train_world_size=}, " + f"{self._opt_in_prob=}, {self._percent_reserved_slots=}" + ) + + @property + def device(self) -> torch.device: + return _get_device(self._hash_zch_identities) + + def buckets(self) -> int: + return self._buckets + + # TODO: This is hacky as we are using parameters to go through publishing. + # Can remove once working out buffer solution. + def named_buffers( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + yield from super().named_buffers(prefix, recurse, remove_duplicate) + key: str = HashZchManagedCollisionModule.IDENTITY_BUFFER + if prefix: + key = f"{prefix}.{key}" + yield (key, self._hash_zch_identities.data) + + def validate_state(self) -> None: + raise NotImplementedError() + + def reset_inference_mode( + self, + ) -> None: + logger.info("HashZchManagedCollisionModule resetting inference mode") + # not revertable + self.eval() + self._is_inference = True + self._hash_zch_metadata = None + self._evicted_indices = [] + self._eviction_policy_name = None + self._eviction_module = None + + def _load_state_dict_pre_hook( + module: "HashZchManagedCollisionModule", + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + logger.info("HashZchManagedCollisionModule loading state dict") + # We store the full identity in checkpoint and predictor, cut it at inference loading + if not self._is_inference: + return + if "_hash_zch_metadata" in state_dict: + del state_dict["_hash_zch_metadata"] + + self._register_load_state_dict_pre_hook( + _load_state_dict_pre_hook, with_module=True + ) + + def preprocess( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + return features + + def evict(self) -> Optional[torch.Tensor]: + if len(self._evicted_indices) == 0: + return None + out = torch.unique(torch.cat(self._evicted_indices)) + self._evicted_indices = [] + return ( + out + self._output_global_offset_tensor + if self._output_global_offset_tensor + else out + ) + + def profile( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + return features + + def get_reserved_slots_per_bucket(self) -> int: + if self._opt_in_prob == -1: + return -1 + + return math.floor( + self._zch_size + * self._percent_reserved_slots + / 100 + / (self._end_bucket - self._start_bucket) + ) + + def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: + metadata: Optional[torch.Tensor] = self._hash_zch_metadata + readonly: bool = False + if self._output_global_offset_tensor is not None: + self._output_global_offset_tensor, _ = _tensor_may_to_device( + self._output_global_offset_tensor, self.device + ) + if not self.training: + readonly = True + metadata = None + + # _evicted_indices will be reset in evict(): https://fburl.com/code/r3fxcs1y + assert len(self._evicted_indices) == 0 + + # `torch.no_grad()` Annotatin prevents torchscripting `JaggedTensor` for some reason... + with torch.no_grad(): + remapped_features: Dict[str, JaggedTensor] = {} + identities_0 = ( + self._hash_zch_identities.data.clone() + if self._tb_logging_frequency > 0 + else None + ) + + for name, feature in features.items(): + values = feature.values() + input_metadata, eviction_threshold = ( + self._eviction_module(feature) + if self._eviction_module is not None + else (None, -1) + ) + + opt_in_rands = ( + (torch.rand_like(values, dtype=torch.float) * 100).to(torch.int32) + if self._opt_in_prob != -1 and self.training + else None + ) + + values, orig_device = _tensor_may_to_device(values, self.device) + values, local_sizes, offsets = self.input_mapper( + values=values, + output_offset=self._output_global_offset_tensor, + ) + num_reserved_slots = self.get_reserved_slots_per_bucket() + remapped_ids, evictions = torch.ops.fbgemm.zero_collision_hash( + input=values, + identities=self._hash_zch_identities, + max_probe=self._max_probe, + circular_probe=True, + exp_hours=-1, # deprecated, always -1 + readonly=readonly, + local_sizes=local_sizes, + offsets=offsets, + metadata=metadata, + # Use self._is_inference to turn on writing to pinned + # CPU memory directly. But may not have perf benefit. + output_on_uvm=False, # self._is_inference, + disable_fallback=False, + _modulo_identity_DPRECATED=False, # deprecated, always False + input_metadata=input_metadata, + eviction_threshold=eviction_threshold, + eviction_policy=get_kernel_from_policy(self._eviction_policy_name), + opt_in_prob=self._opt_in_prob, + num_reserved_slots=num_reserved_slots, + opt_in_rands=opt_in_rands, + ) + + if self._scalar_logger is not None: + assert identities_0 is not None + self._scalar_logger.update( + identities_0=identities_0, + identities_1=self._hash_zch_identities, + values=values, + remapped_ids=remapped_ids, + evicted_emb_indices=evictions, + metadata=metadata, + num_reserved_slots=num_reserved_slots, + eviction_config=self._eviction_config, + ) + + output_global_offset_tensor = self._output_global_offset_tensor + if output_global_offset_tensor is not None: + remapped_ids = remapped_ids + output_global_offset_tensor + + _append_eviction_indice(self._evicted_indices, evictions) + remapped_ids, _ = _tensor_may_to_device(remapped_ids, orig_device) + + remapped_features[name] = JaggedTensor( + values=remapped_ids, + lengths=feature.lengths(), + offsets=feature.offsets(), + weights=feature.weights_or_none(), + ) + + if self._scalar_logger is not None: + self._scalar_logger( + run_type="train" if self.training else "eval", + identities=self._hash_zch_identities.data, + ) + + return remapped_features + + def forward( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + return self.remap(features) + + def output_size(self) -> int: + return self._zch_size + + def input_size(self) -> int: + return self._input_hash_size + + def open_slots(self) -> torch.Tensor: + return torch.tensor([0]) + + def rebuild_with_output_id_range( + self, + output_id_range: Tuple[int, int], + output_segments: Optional[List[int]] = None, + device: Optional[torch.device] = None, + ) -> "HashZchManagedCollisionModule": + # rebuild should use existing output_segments instead of the input one and should not + # recalculate since the output segments are calculated based on the original embedding + # table size, total bucket number, which might not be available for the rebuild caller + try: + start_idx = self._output_segments.index(output_id_range[0]) + end_idx = self._output_segments.index(output_id_range[1]) + except ValueError: + raise RuntimeError( + f"Attempting to shard HashZchManagedCollisionModule, but rank {device} does not align with bucket boundaries;" + + f" please check kwarg total_num_buckets={self._buckets} is a multiple of world size." + ) + new_zch_size = output_id_range[1] - output_id_range[0] + + return self.__class__( + zch_size=new_zch_size, + device=device or self.device, + max_probe=self._max_probe, + total_num_buckets=self._buckets, + input_hash_size=self._input_hash_size, + is_inference=self._is_inference, + start_bucket=start_idx, + end_bucket=end_idx, + output_segments=self._output_segments, + name=self._name, + tb_logging_frequency=self._tb_logging_frequency, + eviction_policy_name=self._eviction_policy_name, + eviction_config=self._eviction_config, + opt_in_prob=self._opt_in_prob, + percent_reserved_slots=self._percent_reserved_slots, + ) + + +@torch.fx.wrap +def _append_eviction_indice( + evicted_indices: List[torch.Tensor], + evictions: Optional[torch.Tensor], +) -> None: + if evictions is not None and evictions.numel() > 0: + evicted_indices.append(evictions) diff --git a/torchrec/modules/tests/test_hash_mc_evictions.py b/torchrec/modules/tests/test_hash_mc_evictions.py new file mode 100644 index 000000000..e62b0d819 --- /dev/null +++ b/torchrec/modules/tests/test_hash_mc_evictions.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import unittest +from unittest.mock import patch + +import torch +from torchrec.modules.hash_mc_evictions import ( + HashZchEvictionConfig, + HashZchPerFeatureTtlScorer, + HashZchSingleTtlScorer, +) +from torchrec.sparse.jagged_tensor import JaggedTensor + + +class TestEvictionScorer(unittest.TestCase): + # pyre-ignore [56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "This test requires CUDA device", + ) + def test_single_ttl_scorer(self) -> None: + scorer = HashZchSingleTtlScorer( + config=HashZchEvictionConfig(features=["f1"], single_ttl=24) + ) + + jt = JaggedTensor( + values=torch.arange(0, 5, dtype=torch.int64), + lengths=torch.tensor([2, 2, 1], dtype=torch.int64), + ) + + with patch("time.time") as mock_time: + mock_time.return_value = 36000000 # hour 10000 + score = scorer.gen_score(jt, device=torch.device("cuda")) + self.assertTrue( + torch.equal( + score, + torch.tensor([10024, 10024, 10024, 10024, 10024], device="cuda"), + ), + f"{torch.unique(score)=}", + ) + + # pyre-ignore [56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "This test requires CUDA device", + ) + def test_per_feature_ttl_scorer(self) -> None: + scorer = HashZchPerFeatureTtlScorer( + config=HashZchEvictionConfig( + features=["f1", "f2"], per_feature_ttl=[24, 48] + ) + ) + + jt = JaggedTensor( + values=torch.arange(0, 5, dtype=torch.int64), + lengths=torch.tensor([2, 2, 1], dtype=torch.int64), + weights=torch.tensor([4, 1], dtype=torch.int64), + ) + + with patch("time.time") as mock_time: + mock_time.return_value = 36000000 # hour 10000 + score = scorer.gen_score(jt, device=torch.device("cuda")) + self.assertTrue( + torch.equal( + score, + torch.tensor([10024, 10024, 10024, 10024, 10048], device="cuda"), + ), + f"{torch.unique(score)=}", + ) diff --git a/torchrec/modules/tests/test_hash_mc_modules.py b/torchrec/modules/tests/test_hash_mc_modules.py new file mode 100644 index 000000000..113c05b5b --- /dev/null +++ b/torchrec/modules/tests/test_hash_mc_modules.py @@ -0,0 +1,650 @@ +#!/usr/bin/env python3 +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import unittest +from typing import cast + +import torch +from hypothesis import given, settings, strategies as st +from pyre_extensions import none_throws +from torchrec.distributed.embedding_sharding import bucketize_kjt_before_all2all +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.hash_mc_evictions import ( + HashZchEvictionConfig, + HashZchEvictionPolicyName, +) +from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule +from torchrec.modules.mc_modules import ( + ManagedCollisionCollection, + ManagedCollisionModule, +) +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + + +class TestMCH(unittest.TestCase): + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_zch_hash_inference(self) -> None: + # prepare + m1 = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("cuda"), + total_num_buckets=2, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ) + self.assertEqual(m1._hash_zch_identities.dtype, torch.int64) + in1 = { + "f": JaggedTensor( + values=torch.arange(0, 20, 2, dtype=torch.int64, device="cuda"), + lengths=torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + ), + } + o1 = m1(in1)["f"].values() + self.assertTrue( + torch.equal(torch.unique(o1), torch.arange(0, 10, device="cuda")), + f"{torch.unique(o1)=}", + ) + + in2 = { + "f": JaggedTensor( + values=torch.arange(1, 20, 2, dtype=torch.int64, device="cuda"), + lengths=torch.tensor([8, 2], dtype=torch.int64, device="cuda"), + ), + } + o2 = m1(in2)["f"].values() + self.assertTrue( + torch.equal(torch.unique(o2), torch.arange(10, 20, device="cuda")), + f"{torch.unique(o2)=}", + ) + + for device_str in ["cpu", "cuda"]: + # Inference + m_infer = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device(device_str), + total_num_buckets=2, + ) + + m_infer.reset_inference_mode() + m_infer.to(device_str) + + self.assertTrue( + torch.equal( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. + none_throws(m_infer.input_mapper._zch_size_per_training_rank), + torch.tensor([10, 10], dtype=torch.int64, device=device_str), + ) + ) + self.assertTrue( + torch.equal( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. + none_throws(m_infer.input_mapper._train_rank_offsets), + torch.tensor([0, 10], dtype=torch.int64, device=device_str), + ) + ) + + m_infer._hash_zch_identities = torch.nn.Parameter( + m1._hash_zch_identities[:, :1], + requires_grad=False, + ) + in12 = { + "f": JaggedTensor( + values=torch.arange(0, 20, dtype=torch.int64, device=device_str), + lengths=torch.tensor( + [4, 6, 8, 2], dtype=torch.int64, device=device_str + ), + ), + } + m_infer = torch.jit.script(m_infer) + o_infer = m_infer(in12)["f"].values() + o12 = torch.stack([o1, o2], dim=1).view(-1).to(device_str) + self.assertTrue(torch.equal(o_infer, o12), f"{o_infer=} vs {o12=}") + + m3 = HashZchManagedCollisionModule( + zch_size=10, + device=torch.device("cuda"), + total_num_buckets=2, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ) + self.assertEqual(m3._hash_zch_identities.dtype, torch.int64) + in3 = { + "f": JaggedTensor( + values=torch.arange(10, 20, dtype=torch.int64, device="cuda"), + lengths=torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + ), + } + o3 = m3(in3)["f"].values() + self.assertTrue( + torch.equal(torch.unique(o3), torch.arange(0, 10, device="cuda")), + f"{torch.unique(o3)=}", + ) + # validate that original ids are assigned to identities + self.assertTrue( + torch.equal( + torch.unique(m3._hash_zch_identities), + torch.arange(10, 20, device="cuda"), + ), + f"{torch.unique(m3._hash_zch_identities)=}", + ) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "This test requires CUDA device", + ) + def test_scriptability(self) -> None: + zch_size = 10 + mc_modules = { + "t1": cast( + ManagedCollisionModule, + HashZchManagedCollisionModule( + zch_size=zch_size, + device=torch.device("cpu"), + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=["feature"], + ), + total_num_buckets=2, + ), + ) + } + + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] + + mcc_ec = ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=embedding_configs, + ) + torch.jit.script(mcc_ec) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "This test requires CUDA device", + ) + def test_scriptability_lru(self) -> None: + zch_size = 10 + mc_modules = { + "t1": cast( + ManagedCollisionModule, + HashZchManagedCollisionModule( + zch_size=zch_size, + device=torch.device("cpu"), + total_num_buckets=2, + eviction_policy_name=HashZchEvictionPolicyName.LRU_EVICTION, + eviction_config=HashZchEvictionConfig( + features=["feature"], + single_ttl=12, + ), + ), + ) + } + + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] + + mcc_ec = ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=embedding_configs, + ) + torch.jit.script(mcc_ec) + + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires at least one GPUs", + ) + # pyre-ignore [56] + @given(hash_size=st.sampled_from([0, 80]), keep_original_indices=st.booleans()) + @settings(max_examples=6, deadline=None) + def test_zch_hash_train_to_inf_block_bucketize( + self, hash_size: int, keep_original_indices: bool + ) -> None: + # rank 0 + world_size = 2 + kjt = KeyedJaggedTensor( + keys=["f"], + values=torch.cat( + [ + torch.arange(0, 20, 2, dtype=torch.int64, device="cuda"), + torch.arange(30, 60, 3, dtype=torch.int64, device="cuda"), + ] + ), + lengths=torch.cat( + [ + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + ] + ), + ) + block_sizes = torch.tensor( + [(size + world_size - 1) // world_size for size in [hash_size]], + dtype=torch.int64, + device="cuda", + ) + + bucketized_kjt, _ = bucketize_kjt_before_all2all( + kjt, + num_buckets=world_size, + block_sizes=block_sizes, + keep_original_indices=keep_original_indices, + ) + in1, in2 = bucketized_kjt.split([len(kjt.keys())] * world_size) + in1 = in1.to_dict() + in2 = in2.to_dict() + m0 = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("cuda"), + input_hash_size=hash_size, + total_num_buckets=2, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ) + m1 = m0.rebuild_with_output_id_range((0, 10)) + m2 = m0.rebuild_with_output_id_range((10, 20)) + + # simulate calls to each rank + o1 = m1(in1) + o2 = m2(in2) + + m0.reset_inference_mode() + full_zch_identities = torch.cat( + [ + m1.state_dict()["_hash_zch_identities"], + m2.state_dict()["_hash_zch_identities"], + ] + ) + state_dict = m0.state_dict() + state_dict["_hash_zch_identities"] = full_zch_identities + m0.load_state_dict(state_dict) + + # now pass in original kjt + inf_input = kjt.to_dict() + inf_output = m0(inf_input) + + torch.allclose( + inf_output["f"].values(), torch.cat([o1["f"].values(), o2["f"].values()]) + ) + + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires at least one GPUs", + ) + # pyre-ignore [56] + @given(hash_size=st.sampled_from([0, 80])) + @settings(max_examples=5, deadline=None) + def test_zch_hash_train_rescales_two(self, hash_size: int) -> None: + keep_original_indices = False + # rank 0 + world_size = 2 + kjt = KeyedJaggedTensor( + keys=["f"], + values=torch.cat( + [ + torch.randint( + 0, + hash_size if hash_size > 0 else 1000, + (20,), + dtype=torch.int64, + device="cuda", + ), + ] + ), + lengths=torch.cat( + [ + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + ] + ), + ) + block_sizes = torch.tensor( + [(size + world_size - 1) // world_size for size in [hash_size]], + dtype=torch.int64, + device="cuda", + ) + sub_block_sizes = torch.tensor( + [(size + 2 - 1) // 2 for size in [block_sizes[0]]], + dtype=torch.int64, + device="cuda", + ) + bucketized_kjt, _ = bucketize_kjt_before_all2all( + kjt, + num_buckets=world_size, + block_sizes=block_sizes, + keep_original_indices=keep_original_indices, + ) + in1, in2 = bucketized_kjt.split([len(kjt.keys())] * world_size) + + bucketized_in1, _ = bucketize_kjt_before_all2all( + in1, + num_buckets=2, + block_sizes=sub_block_sizes, + keep_original_indices=keep_original_indices, + ) + bucketized_in2, _ = bucketize_kjt_before_all2all( + in2, + num_buckets=2, + block_sizes=sub_block_sizes, + keep_original_indices=keep_original_indices, + ) + in1_1, in1_2 = bucketized_in1.split([len(kjt.keys())] * 2) + in2_1, in2_2 = bucketized_in2.split([len(kjt.keys())] * 2) + + in1_1, in1_2 = in1_1.to_dict(), in1_2.to_dict() + in2_1, in2_2 = in2_1.to_dict(), in2_2.to_dict() + + m0 = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("cuda"), + input_hash_size=hash_size, + total_num_buckets=4, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ) + + m1_1 = m0.rebuild_with_output_id_range((0, 5)) + m1_2 = m0.rebuild_with_output_id_range((5, 10)) + m2_1 = m0.rebuild_with_output_id_range((10, 15)) + m2_2 = m0.rebuild_with_output_id_range((15, 20)) + + # simulate calls to each rank + o1_1 = m1_1(in1_1) + o1_2 = m1_2(in1_2) + o2_1 = m2_1(in2_1) + o2_2 = m2_2(in2_2) + + m0.reset_inference_mode() + + full_zch_identities = torch.cat( + [ + m1_1.state_dict()["_hash_zch_identities"], + m1_2.state_dict()["_hash_zch_identities"], + m2_1.state_dict()["_hash_zch_identities"], + m2_2.state_dict()["_hash_zch_identities"], + ] + ) + state_dict = m0.state_dict() + state_dict["_hash_zch_identities"] = full_zch_identities + m0.load_state_dict(state_dict) + + # now pass in original kjt + inf_input = kjt.to_dict() + inf_output = m0(inf_input) + torch.allclose( + inf_output["f"].values(), + torch.cat([x["f"].values() for x in [o1_1, o1_2, o2_1, o2_2]]), + ) + + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires at least one GPUs", + ) + # pyre-ignore [56] + @given(hash_size=st.sampled_from([0, 80])) + @settings(max_examples=5, deadline=None) + def test_zch_hash_train_rescales_four(self, hash_size: int) -> None: + keep_original_indices = True + kjt = KeyedJaggedTensor( + keys=["f"], + values=torch.cat( + [ + torch.randint( + 0, + hash_size if hash_size > 0 else 1000, + (20,), + dtype=torch.int64, + device="cuda", + ), + ] + ), + lengths=torch.cat( + [ + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + ] + ), + ) + + # initialize mch with 8 buckets + m0 = HashZchManagedCollisionModule( + zch_size=40, + device=torch.device("cuda"), + input_hash_size=hash_size, + total_num_buckets=4, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ) + + # start with world_size = 4 + world_size = 4 + block_sizes = torch.tensor( + [(size + world_size - 1) // world_size for size in [hash_size]], + dtype=torch.int64, + device="cuda", + ) + + m1_1 = m0.rebuild_with_output_id_range((0, 10)) + m2_1 = m0.rebuild_with_output_id_range((10, 20)) + m3_1 = m0.rebuild_with_output_id_range((20, 30)) + m4_1 = m0.rebuild_with_output_id_range((30, 40)) + + # shard, now world size 2! + # start with world_size = 4 + if hash_size > 0: + world_size = 2 + block_sizes = torch.tensor( + [(size + world_size - 1) // world_size for size in [hash_size]], + dtype=torch.int64, + device="cuda", + ) + # simulate kjt call + bucketized_kjt, permute = bucketize_kjt_before_all2all( + kjt, + num_buckets=world_size, + block_sizes=block_sizes, + keep_original_indices=keep_original_indices, + output_permute=True, + ) + in1_2, in2_2 = bucketized_kjt.split([len(kjt.keys())] * world_size) + else: + bucketized_kjt, permute = bucketize_kjt_before_all2all( + kjt, + num_buckets=world_size, + block_sizes=block_sizes, + keep_original_indices=keep_original_indices, + output_permute=True, + ) + kjts = bucketized_kjt.split([len(kjt.keys())] * world_size) + # rebuild kjt + in1_2 = KeyedJaggedTensor( + keys=kjts[0].keys(), + values=torch.cat([kjts[0].values(), kjts[1].values()], dim=0), + lengths=torch.cat([kjts[0].lengths(), kjts[1].lengths()], dim=0), + ) + in2_2 = KeyedJaggedTensor( + keys=kjts[2].keys(), + values=torch.cat([kjts[2].values(), kjts[3].values()], dim=0), + lengths=torch.cat([kjts[2].lengths(), kjts[3].lengths()], dim=0), + ) + + m1_2 = m0.rebuild_with_output_id_range((0, 20)) + m2_2 = m0.rebuild_with_output_id_range((20, 40)) + m1_zch_identities = torch.cat( + [ + m1_1.state_dict()["_hash_zch_identities"], + m2_1.state_dict()["_hash_zch_identities"], + ] + ) + m1_zch_metadata = torch.cat( + [ + m1_1.state_dict()["_hash_zch_metadata"], + m2_1.state_dict()["_hash_zch_metadata"], + ] + ) + state_dict = m1_2.state_dict() + state_dict["_hash_zch_identities"] = m1_zch_identities + state_dict["_hash_zch_metadata"] = m1_zch_metadata + m1_2.load_state_dict(state_dict) + + m2_zch_identities = torch.cat( + [ + m3_1.state_dict()["_hash_zch_identities"], + m4_1.state_dict()["_hash_zch_identities"], + ] + ) + m2_zch_metadata = torch.cat( + [ + m3_1.state_dict()["_hash_zch_metadata"], + m4_1.state_dict()["_hash_zch_metadata"], + ] + ) + state_dict = m2_2.state_dict() + state_dict["_hash_zch_identities"] = m2_zch_identities + state_dict["_hash_zch_metadata"] = m2_zch_metadata + m2_2.load_state_dict(state_dict) + + _ = m1_2(in1_2.to_dict()) + _ = m2_2(in2_2.to_dict()) + + m0.reset_inference_mode() # just clears out training state + full_zch_identities = torch.cat( + [ + m1_2.state_dict()["_hash_zch_identities"], + m2_2.state_dict()["_hash_zch_identities"], + ] + ) + state_dict = m0.state_dict() + state_dict["_hash_zch_identities"] = full_zch_identities + m0.load_state_dict(state_dict) + + # now set all models to eval, and run kjt + m1_2.eval() + m2_2.eval() + assert m0.training is False + + inf_input = kjt.to_dict() + inf_output = m0(inf_input) + + o1_2 = m1_2(in1_2.to_dict()) + o2_2 = m2_2(in2_2.to_dict()) + self.assertTrue( + torch.allclose( + inf_output["f"].values(), + torch.index_select( + torch.cat([x["f"].values() for x in [o1_2, o2_2]]), + dim=0, + index=cast(torch.Tensor, permute), + ), + ) + ) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "This test requires CUDA device", + ) + def test_output_global_offset_tensor(self) -> None: + m = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("cpu"), + total_num_buckets=4, + ) + self.assertIsNone(m._output_global_offset_tensor) + + bucket2 = m.rebuild_with_output_id_range((5, 10)) + self.assertIsNotNone(bucket2._output_global_offset_tensor) + self.assertTrue( + # pyre-ignore [6] + torch.equal(bucket2._output_global_offset_tensor, torch.tensor([5])) + ) + self.assertEqual(bucket2._start_bucket, 1) + + m.reset_inference_mode() + bucket3 = m.rebuild_with_output_id_range((10, 15)) + self.assertIsNotNone(bucket3._output_global_offset_tensor) + self.assertTrue( + # pyre-ignore [6] + torch.equal(bucket3._output_global_offset_tensor, torch.tensor([10])) + ) + self.assertEqual(bucket3._start_bucket, 2) + self.assertEqual( + # pyre-ignore [16] + bucket3._output_global_offset_tensor.device.type, + "cpu", + ) + + remapped_indices = bucket3.remap( + { + "test": JaggedTensor( + values=torch.tensor( + [6, 10, 14, 18, 22], dtype=torch.int64, device="cpu" + ), + lengths=torch.tensor([5], dtype=torch.int64, device="cpu"), + ) + } + ) + self.assertTrue( + torch.allclose( + remapped_indices["test"].values(), torch.tensor([14, 10, 10, 11, 10]) + ) + ) + + gpu_zch = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("cuda"), + total_num_buckets=4, + ) + bucket4 = gpu_zch.rebuild_with_output_id_range((15, 20)) + self.assertIsNotNone(bucket4._output_global_offset_tensor) + self.assertTrue(bucket4._output_global_offset_tensor.device.type == "cuda") + self.assertEqual( + bucket4._output_global_offset_tensor, torch.tensor([15], device="cuda") + ) + + meta_zch = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("meta"), + total_num_buckets=4, + ) + meta_zch.reset_inference_mode() + self.assertIsNone(meta_zch._output_global_offset_tensor) + bucket5 = meta_zch.rebuild_with_output_id_range((15, 20)) + self.assertIsNotNone(bucket5._output_global_offset_tensor) + self.assertTrue(bucket5._output_global_offset_tensor.device.type == "cpu") + self.assertEqual(bucket5._output_global_offset_tensor, torch.tensor([15]))