File tree Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -29,3 +29,8 @@ def test_scalar_to_float(self) -> None:
29
29
30
30
valid_ndarray = np .array ([[[float_x ]]])
31
31
self .assertAlmostEqual (scalar_to_float (valid_ndarray ), float_x )
32
+
33
+ def test_scalar_to_float_bf16 (self ) -> None :
34
+ float_x = 3.45
35
+ valid_tensor = torch .Tensor ([float_x ]).to (torch .bfloat16 )
36
+ self .assertAlmostEqual (scalar_to_float (valid_tensor ), float_x , delta = 0.01 )
Original file line number Diff line number Diff line change @@ -20,7 +20,7 @@ def scalar_to_float(scalar: Scalar) -> float:
20
20
f"Scalar tensor must contain a single item, { numel } given."
21
21
)
22
22
23
- return float (scalar .cpu ().detach ().numpy ().item ())
23
+ return float (scalar .cpu ().detach ().float (). numpy ().item ())
24
24
elif isinstance (scalar , ndarray ):
25
25
numel = scalar .size
26
26
if numel != 1 :
You can’t perform that action at this time.
0 commit comments