Skip to content

Commit 97bc100

Browse files
colesburyezyang
authored andcommitted
Fix handling of inf and nan (#153)
1 parent f1c5d8c commit 97bc100

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

aten/src/ATen/Half.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <limits>
44
#include <string>
55
#include <stdint.h>
6+
#include <cmath>
67
#ifdef AT_CUDA_ENABLED
78
#include <cuda.h>
89
#include <cuda_runtime.h>
@@ -17,6 +18,12 @@ template<typename To, typename From> To convert(From f) {
1718

1819
template<typename To, typename From> bool overflows(From f) {
1920
using limit = std::numeric_limits<To>;
21+
if (limit::has_infinity && std::isinf(f)) {
22+
return false;
23+
}
24+
if (!limit::has_quiet_NaN && std::isnan(f)) {
25+
return true;
26+
}
2027
return f < limit::lowest() || f > limit::max();
2128
}
2229

aten/src/ATen/test/scalar_test.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,26 @@ void test_overflow() {
5757
threw = true;
5858
}
5959
ASSERT(threw);
60+
61+
s1 = Scalar(NAN);
62+
ASSERT(std::isnan(s1.toFloat()));
63+
threw = false;
64+
try {
65+
s1.toInt();
66+
} catch (std::domain_error& e) {
67+
threw = true;
68+
}
69+
ASSERT(threw);
70+
71+
s1 = Scalar(INFINITY);
72+
ASSERT(std::isinf(s1.toFloat()));
73+
threw = false;
74+
try {
75+
s1.toInt();
76+
} catch (std::domain_error& e) {
77+
threw = true;
78+
}
79+
ASSERT(threw);
6080
}
6181

6282
int main() {

0 commit comments

Comments
 (0)