Skip to content

Commit 9511c8a

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
Support BF16 in TorchTNT logger
Reviewed By: galrotem Differential Revision: D70537893
1 parent fb2f350 commit 9511c8a

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

tests/utils/loggers/test_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,8 @@ def test_scalar_to_float(self) -> None:
2929

3030
valid_ndarray = np.array([[[float_x]]])
3131
self.assertAlmostEqual(scalar_to_float(valid_ndarray), float_x)
32+
33+
def test_scalar_to_float_bf16(self) -> None:
34+
float_x = 3.45
35+
valid_tensor = torch.Tensor([float_x]).to(torch.bfloat16)
36+
self.assertAlmostEqual(scalar_to_float(valid_tensor), float_x, delta=0.01)

torchtnt/utils/loggers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def scalar_to_float(scalar: Scalar) -> float:
2020
f"Scalar tensor must contain a single item, {numel} given."
2121
)
2222

23-
return float(scalar.cpu().detach().numpy().item())
23+
return float(scalar.cpu().detach().float().numpy().item())
2424
elif isinstance(scalar, ndarray):
2525
numel = scalar.size
2626
if numel != 1:

0 commit comments

Comments
 (0)