Skip to content

Commit 33349ec

Browse files
monofbfacebook-github-bot
authored andcommitted
Support Hindsight PR in TorchRec (#2627)
Summary: Pull Request resolved: #2627 ### Overview This diff implements HindsightTargetPR metric into TorchRec. This will also include a bucketized version. Thrift changes submitted ahead in D66216486. ### Implementation 1) Create X-wide granular array to store metric states where each index represents the threshold. For bucketization, each bucket will be stacked in the next dimension within the state tensor. 2) Calculate minimum threshold that meets target_precision. 3) Calculate precision and recall points with target threshold. ### Metrics This metric will return the following curves: * hindsight_target_pr: this is the calculated threshold for the window state to maximize recall while achieving the target precision. * hindsight_target_precision: this is the achieved precision with hindsight_target_pr. * hindsight_target_recall: this is the achieved recall with hindsight_target_pr. ### Usage Hindsight PR metrics are primarily useful to mimic the calibration system within identity team. Please adjust the bucketization and window size accordingly to best approximate this. Note: since the states are stored as a dimensional tensor, multiple tasks will not be supported for this metric. Reviewed By: iamzainhuda Differential Revision: D65867461 fbshipit-source-id: 5aa813f7d1ac1beb7e3fff9d840a2968e33f2541
1 parent b34ccb9 commit 33349ec

File tree

5 files changed

+399
-0
lines changed

5 files changed

+399
-0
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
from typing import Any, cast, Dict, List, Optional, Type
11+
12+
import torch
13+
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
14+
from torchrec.metrics.rec_metric import (
15+
MetricComputationReport,
16+
RecMetric,
17+
RecMetricComputation,
18+
RecMetricException,
19+
)
20+
21+
22+
TARGET_PRECISION = "target_precision"
23+
THRESHOLD_GRANULARITY = 1000
24+
25+
26+
def compute_precision(
27+
num_true_positives: torch.Tensor, num_false_positives: torch.Tensor
28+
) -> torch.Tensor:
29+
return torch.where(
30+
num_true_positives + num_false_positives == 0.0,
31+
0.0,
32+
num_true_positives / (num_true_positives + num_false_positives).double(),
33+
)
34+
35+
36+
def compute_recall(
37+
num_true_positives: torch.Tensor, num_false_negitives: torch.Tensor
38+
) -> torch.Tensor:
39+
return torch.where(
40+
num_true_positives + num_false_negitives == 0.0,
41+
0.0,
42+
num_true_positives / (num_true_positives + num_false_negitives),
43+
)
44+
45+
46+
def compute_threshold_idx(
47+
num_true_positives: torch.Tensor,
48+
num_false_positives: torch.Tensor,
49+
target_precision: float,
50+
) -> int:
51+
for i in range(THRESHOLD_GRANULARITY):
52+
if (
53+
compute_precision(num_true_positives[i], num_false_positives[i])
54+
>= target_precision
55+
):
56+
return i
57+
58+
return THRESHOLD_GRANULARITY - 1
59+
60+
61+
def compute_true_pos_sum(
62+
labels: torch.Tensor,
63+
predictions: torch.Tensor,
64+
weights: torch.Tensor,
65+
) -> torch.Tensor:
66+
predictions = predictions.double()
67+
tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double)
68+
thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY)
69+
for i, threshold in enumerate(thresholds):
70+
tp_sum[i] = torch.sum(weights * ((predictions >= threshold) * labels), -1)
71+
return tp_sum
72+
73+
74+
def compute_false_pos_sum(
75+
labels: torch.Tensor,
76+
predictions: torch.Tensor,
77+
weights: torch.Tensor,
78+
) -> torch.Tensor:
79+
predictions = predictions.double()
80+
fp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double)
81+
thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY)
82+
for i, threshold in enumerate(thresholds):
83+
fp_sum[i] = torch.sum(weights * ((predictions >= threshold) * (1 - labels)), -1)
84+
return fp_sum
85+
86+
87+
def compute_false_neg_sum(
88+
labels: torch.Tensor,
89+
predictions: torch.Tensor,
90+
weights: torch.Tensor,
91+
) -> torch.Tensor:
92+
predictions = predictions.double()
93+
fn_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double)
94+
thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY)
95+
for i, threshold in enumerate(thresholds):
96+
fn_sum[i] = torch.sum(weights * ((predictions <= threshold) * labels), -1)
97+
return fn_sum
98+
99+
100+
def get_pr_states(
101+
labels: torch.Tensor,
102+
predictions: torch.Tensor,
103+
weights: Optional[torch.Tensor],
104+
) -> Dict[str, torch.Tensor]:
105+
if weights is None:
106+
weights = torch.ones_like(predictions)
107+
return {
108+
"true_pos_sum": compute_true_pos_sum(labels, predictions, weights),
109+
"false_pos_sum": compute_false_pos_sum(labels, predictions, weights),
110+
"false_neg_sum": compute_false_neg_sum(labels, predictions, weights),
111+
}
112+
113+
114+
class HindsightTargetPRMetricComputation(RecMetricComputation):
115+
r"""
116+
This class implements the RecMetricComputation for Hingsight Target PR.
117+
118+
The constructor arguments are defined in RecMetricComputation.
119+
See the docstring of RecMetricComputation for more detail.
120+
121+
Args:
122+
target_precision (float): If provided, computes the minimum threshold to achieve the target precision.
123+
"""
124+
125+
def __init__(
126+
self, *args: Any, target_precision: float = 0.5, **kwargs: Any
127+
) -> None:
128+
super().__init__(*args, **kwargs)
129+
self._add_state(
130+
"true_pos_sum",
131+
torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double),
132+
add_window_state=True,
133+
dist_reduce_fx="sum",
134+
persistent=True,
135+
)
136+
self._add_state(
137+
"false_pos_sum",
138+
torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double),
139+
add_window_state=True,
140+
dist_reduce_fx="sum",
141+
persistent=True,
142+
)
143+
self._add_state(
144+
"false_neg_sum",
145+
torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double),
146+
add_window_state=True,
147+
dist_reduce_fx="sum",
148+
persistent=True,
149+
)
150+
self._target_precision: float = target_precision
151+
152+
def update(
153+
self,
154+
*,
155+
predictions: Optional[torch.Tensor],
156+
labels: torch.Tensor,
157+
weights: Optional[torch.Tensor],
158+
**kwargs: Dict[str, Any],
159+
) -> None:
160+
if predictions is None:
161+
raise RecMetricException(
162+
"Inputs 'predictions' should not be None for HindsightTargetPRMetricComputation update"
163+
)
164+
states = get_pr_states(labels, predictions, weights)
165+
num_samples = predictions.shape[-1]
166+
167+
for state_name, state_value in states.items():
168+
state = getattr(self, state_name)
169+
state += state_value
170+
self._aggregate_window_state(state_name, state_value, num_samples)
171+
172+
def _compute(self) -> List[MetricComputationReport]:
173+
true_pos_sum = cast(torch.Tensor, self.true_pos_sum)
174+
false_pos_sum = cast(torch.Tensor, self.false_pos_sum)
175+
false_neg_sum = cast(torch.Tensor, self.false_neg_sum)
176+
threshold_idx = compute_threshold_idx(
177+
true_pos_sum,
178+
false_pos_sum,
179+
self._target_precision,
180+
)
181+
window_threshold_idx = compute_threshold_idx(
182+
self.get_window_state("true_pos_sum"),
183+
self.get_window_state("false_pos_sum"),
184+
self._target_precision,
185+
)
186+
reports = [
187+
MetricComputationReport(
188+
name=MetricName.HINDSIGHT_TARGET_PR,
189+
metric_prefix=MetricPrefix.LIFETIME,
190+
value=torch.Tensor(threshold_idx),
191+
),
192+
MetricComputationReport(
193+
name=MetricName.HINDSIGHT_TARGET_PR,
194+
metric_prefix=MetricPrefix.WINDOW,
195+
value=torch.Tensor(window_threshold_idx),
196+
),
197+
MetricComputationReport(
198+
name=MetricName.HINDSIGHT_TARGET_PRECISION,
199+
metric_prefix=MetricPrefix.LIFETIME,
200+
value=compute_precision(
201+
true_pos_sum[threshold_idx],
202+
false_pos_sum[threshold_idx],
203+
),
204+
),
205+
MetricComputationReport(
206+
name=MetricName.HINDSIGHT_TARGET_PRECISION,
207+
metric_prefix=MetricPrefix.WINDOW,
208+
value=compute_precision(
209+
self.get_window_state("true_pos_sum")[window_threshold_idx],
210+
self.get_window_state("false_pos_sum")[window_threshold_idx],
211+
),
212+
),
213+
MetricComputationReport(
214+
name=MetricName.HINDSIGHT_TARGET_RECALL,
215+
metric_prefix=MetricPrefix.LIFETIME,
216+
value=compute_recall(
217+
true_pos_sum[threshold_idx],
218+
false_neg_sum[threshold_idx],
219+
),
220+
),
221+
MetricComputationReport(
222+
name=MetricName.HINDSIGHT_TARGET_RECALL,
223+
metric_prefix=MetricPrefix.WINDOW,
224+
value=compute_recall(
225+
self.get_window_state("true_pos_sum")[window_threshold_idx],
226+
self.get_window_state("false_neg_sum")[window_threshold_idx],
227+
),
228+
),
229+
]
230+
return reports
231+
232+
233+
class HindsightTargetPRMetric(RecMetric):
234+
_namespace: MetricNamespace = MetricNamespace.HINDSIGHT_TARGET_PR
235+
_computation_class: Type[RecMetricComputation] = HindsightTargetPRMetricComputation

torchrec/metrics/metric_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torchrec.metrics.cali_free_ne import CaliFreeNEMetric
2525
from torchrec.metrics.calibration import CalibrationMetric
2626
from torchrec.metrics.ctr import CTRMetric
27+
from torchrec.metrics.hindsight_target_pr import HindsightTargetPRMetric
2728
from torchrec.metrics.mae import MAEMetric
2829
from torchrec.metrics.metrics_config import (
2930
BatchSizeStage,
@@ -94,6 +95,7 @@
9495
RecMetricEnum.TENSOR_WEIGHTED_AVG: TensorWeightedAvgMetric,
9596
RecMetricEnum.CALI_FREE_NE: CaliFreeNEMetric,
9697
RecMetricEnum.UNWEIGHTED_NE: UnweightedNEMetric,
98+
RecMetricEnum.HINDSIGHT_TARGET_PR: HindsightTargetPRMetric,
9799
}
98100

99101

torchrec/metrics/metrics_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class RecMetricEnum(RecMetricEnumBase):
4747
TENSOR_WEIGHTED_AVG = "tensor_weighted_avg"
4848
CALI_FREE_NE = "cali_free_ne"
4949
UNWEIGHTED_NE = "unweighted_ne"
50+
HINDSIGHT_TARGET_PR = "hindsight_target_pr"
5051

5152

5253
@dataclass(unsafe_hash=True, eq=True)

torchrec/metrics/metrics_namespace.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ class MetricName(MetricNameBase):
8282
CALI_FREE_NE = "cali_free_ne"
8383
UNWEIGHTED_NE = "unweighted_ne"
8484

85+
HINDSIGHT_TARGET_PR = "hindsight_target_pr"
86+
HINDSIGHT_TARGET_PRECISION = "hindsight_target_precision"
87+
HINDSIGHT_TARGET_RECALL = "hindsight_target_recall"
88+
8589

8690
class MetricNamespaceBase(StrValueMixin, Enum):
8791
pass
@@ -131,6 +135,8 @@ class MetricNamespace(MetricNamespaceBase):
131135
CALI_FREE_NE = "cali_free_ne"
132136
UNWEIGHTED_NE = "unweighted_ne"
133137

138+
HINDSIGHT_TARGET_PR = "hindsight_target_pr"
139+
134140

135141
class MetricPrefix(StrValueMixin, Enum):
136142
DEFAULT = ""

0 commit comments

Comments
 (0)