Skip to content

Commit 5580b47

Browse files
authored
iinfo and scalar overflow detection (#2009)
1 parent bc62932 commit 5580b47

File tree

6 files changed

+112
-0
lines changed

6 files changed

+112
-0
lines changed

mlx/utils.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,4 +380,43 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
380380
}
381381
}
382382

383+
template <typename T>
384+
void set_iinfo_limits(int64_t& min, uint64_t& max) {
385+
min = std::numeric_limits<T>::min();
386+
max = std::numeric_limits<T>::max();
387+
}
388+
389+
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
390+
switch (dtype) {
391+
case int8:
392+
set_iinfo_limits<int8_t>(min, max);
393+
break;
394+
case uint8:
395+
set_iinfo_limits<uint8_t>(min, max);
396+
break;
397+
case int16:
398+
set_iinfo_limits<int16_t>(min, max);
399+
break;
400+
case uint16:
401+
set_iinfo_limits<uint16_t>(min, max);
402+
break;
403+
case int32:
404+
set_iinfo_limits<int32_t>(min, max);
405+
break;
406+
case uint32:
407+
set_iinfo_limits<uint32_t>(min, max);
408+
break;
409+
case int64:
410+
set_iinfo_limits<int64_t>(min, max);
411+
break;
412+
case uint64:
413+
set_iinfo_limits<uint64_t>(min, max);
414+
break;
415+
default:
416+
std::ostringstream msg;
417+
msg << "[iinfo] dtype " << dtype << " is not integral.";
418+
throw std::invalid_argument(msg.str());
419+
}
420+
}
421+
383422
} // namespace mlx::core

mlx/utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ struct finfo {
6767
double max;
6868
};
6969

70+
/** Holds information about integral types. */
71+
struct iinfo {
72+
explicit iinfo(Dtype dtype);
73+
Dtype dtype;
74+
int64_t min;
75+
uint64_t max;
76+
};
77+
7078
/** The type from promoting the arrays' types with one another. */
7179
inline Dtype result_type(const array& a, const array& b) {
7280
return promote_types(a.dtype(), b.dtype());

python/src/array.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,30 @@ void init_array(nb::module_& m) {
206206
return os.str();
207207
});
208208

209+
nb::class_<mx::iinfo>(
210+
m,
211+
"iinfo",
212+
R"pbdoc(
213+
Get information on integer types.
214+
)pbdoc")
215+
.def(nb::init<mx::Dtype>())
216+
.def_ro(
217+
"min",
218+
&mx::iinfo::min,
219+
R"pbdoc(The smallest representable number.)pbdoc")
220+
.def_ro(
221+
"max",
222+
&mx::iinfo::max,
223+
R"pbdoc(The largest representable number.)pbdoc")
224+
.def_ro("dtype", &mx::iinfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc")
225+
.def("__repr__", [](const mx::iinfo& i) {
226+
std::ostringstream os;
227+
os << "iinfo("
228+
<< "min=" << i.min << ", max=" << i.max << ", dtype=" << i.dtype
229+
<< ")";
230+
return os.str();
231+
});
232+
209233
nb::class_<ArrayAt>(
210234
m,
211235
"ArrayAt",

python/src/utils.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "python/src/utils.h"
44
#include "mlx/ops.h"
5+
#include "mlx/utils.h"
56
#include "python/src/convert.h"
67

78
mx::array to_array(
@@ -16,6 +17,16 @@ mx::array to_array(
1617
? mx::int64
1718
: mx::int32;
1819
auto out_t = dtype.value_or(default_type);
20+
if (mx::issubdtype(out_t, mx::integer) && out_t.size() < 8) {
21+
auto info = mx::iinfo(out_t);
22+
if (val < info.min || val > static_cast<int64_t>(info.max)) {
23+
std::ostringstream msg;
24+
msg << "Converting " << val << " to " << out_t
25+
<< " would result in overflow.";
26+
throw std::invalid_argument(msg.str());
27+
}
28+
}
29+
1930
// bool_ is an exception and is always promoted
2031
return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t);
2132
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {

python/tests/test_array.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ def test_finfo(self):
109109
self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max)
110110
self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16)
111111

112+
def test_iinfo(self):
113+
with self.assertRaises(ValueError):
114+
mx.iinfo(mx.float32)
115+
116+
self.assertEqual(mx.iinfo(mx.int32).min, np.iinfo(np.int32).min)
117+
self.assertEqual(mx.iinfo(mx.int32).max, np.iinfo(np.int32).max)
118+
self.assertEqual(mx.iinfo(mx.int32).dtype, mx.int32)
119+
120+
self.assertEqual(mx.iinfo(mx.uint32).min, np.iinfo(np.uint32).min)
121+
self.assertEqual(mx.iinfo(mx.uint32).max, np.iinfo(np.uint32).max)
122+
self.assertEqual(mx.iinfo(mx.int8).dtype, mx.int8)
123+
112124

113125
class TestEquality(mlx_tests.MLXTestCase):
114126
def test_array_eq_array(self):
@@ -1999,6 +2011,14 @@ def t():
19992011
used = get_mem()
20002012
self.assertEqual(expected, used)
20012013

2014+
def test_scalar_integer_conversion_overflow(self):
2015+
y = mx.array(2000000000, dtype=mx.int32)
2016+
x = 3000000000
2017+
with self.assertRaises(ValueError):
2018+
y + x
2019+
with self.assertRaises(ValueError):
2020+
mx.add(y, x)
2021+
20022022

20032023
if __name__ == "__main__":
20042024
unittest.main()

tests/utils_tests.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,13 @@ TEST_CASE("test finfo") {
5555
CHECK_EQ(finfo(float16).min, -65504);
5656
CHECK_EQ(finfo(float16).max, 65504);
5757
}
58+
59+
TEST_CASE("test iinfo") {
60+
CHECK_EQ(iinfo(int8).dtype, int8);
61+
CHECK_EQ(iinfo(int64).dtype, int64);
62+
CHECK_EQ(iinfo(int64).max, std::numeric_limits<int64_t>::max());
63+
CHECK_EQ(iinfo(uint64).max, std::numeric_limits<uint64_t>::max());
64+
CHECK_EQ(iinfo(uint64).max, std::numeric_limits<uint64_t>::max());
65+
CHECK_EQ(iinfo(uint64).min, 0);
66+
CHECK_EQ(iinfo(int64).min, std::numeric_limits<int64_t>::min());
67+
}

0 commit comments

Comments
 (0)