File tree Expand file tree Collapse file tree 6 files changed +112
-0
lines changed Expand file tree Collapse file tree 6 files changed +112
-0
lines changed Original file line number Diff line number Diff line change @@ -380,4 +380,43 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
380
380
}
381
381
}
382
382
383
+ template <typename T>
384
+ void set_iinfo_limits (int64_t & min, uint64_t & max) {
385
+ min = std::numeric_limits<T>::min ();
386
+ max = std::numeric_limits<T>::max ();
387
+ }
388
+
389
+ iinfo::iinfo (Dtype dtype) : dtype(dtype) {
390
+ switch (dtype) {
391
+ case int8:
392
+ set_iinfo_limits<int8_t >(min, max);
393
+ break ;
394
+ case uint8:
395
+ set_iinfo_limits<uint8_t >(min, max);
396
+ break ;
397
+ case int16:
398
+ set_iinfo_limits<int16_t >(min, max);
399
+ break ;
400
+ case uint16:
401
+ set_iinfo_limits<uint16_t >(min, max);
402
+ break ;
403
+ case int32:
404
+ set_iinfo_limits<int32_t >(min, max);
405
+ break ;
406
+ case uint32:
407
+ set_iinfo_limits<uint32_t >(min, max);
408
+ break ;
409
+ case int64:
410
+ set_iinfo_limits<int64_t >(min, max);
411
+ break ;
412
+ case uint64:
413
+ set_iinfo_limits<uint64_t >(min, max);
414
+ break ;
415
+ default :
416
+ std::ostringstream msg;
417
+ msg << " [iinfo] dtype " << dtype << " is not integral." ;
418
+ throw std::invalid_argument (msg.str ());
419
+ }
420
+ }
421
+
383
422
} // namespace mlx::core
Original file line number Diff line number Diff line change @@ -67,6 +67,14 @@ struct finfo {
67
67
double max;
68
68
};
69
69
70
+ /* * Holds information about integral types. */
71
+ struct iinfo {
72
+ explicit iinfo (Dtype dtype);
73
+ Dtype dtype;
74
+ int64_t min;
75
+ uint64_t max;
76
+ };
77
+
70
78
/* * The type from promoting the arrays' types with one another. */
71
79
inline Dtype result_type (const array& a, const array& b) {
72
80
return promote_types (a.dtype (), b.dtype ());
Original file line number Diff line number Diff line change @@ -206,6 +206,30 @@ void init_array(nb::module_& m) {
206
206
return os.str ();
207
207
});
208
208
209
+ nb::class_<mx::iinfo>(
210
+ m,
211
+ " iinfo" ,
212
+ R"pbdoc(
213
+ Get information on integer types.
214
+ )pbdoc" )
215
+ .def (nb::init<mx::Dtype>())
216
+ .def_ro (
217
+ " min" ,
218
+ &mx::iinfo::min,
219
+ R"pbdoc( The smallest representable number.)pbdoc" )
220
+ .def_ro (
221
+ " max" ,
222
+ &mx::iinfo::max,
223
+ R"pbdoc( The largest representable number.)pbdoc" )
224
+ .def_ro (" dtype" , &mx::iinfo::dtype, R"pbdoc( The :obj:`Dtype`.)pbdoc" )
225
+ .def (" __repr__" , [](const mx::iinfo& i) {
226
+ std::ostringstream os;
227
+ os << " iinfo("
228
+ << " min=" << i.min << " , max=" << i.max << " , dtype=" << i.dtype
229
+ << " )" ;
230
+ return os.str ();
231
+ });
232
+
209
233
nb::class_<ArrayAt>(
210
234
m,
211
235
" ArrayAt" ,
Original file line number Diff line number Diff line change 2
2
3
3
#include " python/src/utils.h"
4
4
#include " mlx/ops.h"
5
+ #include " mlx/utils.h"
5
6
#include " python/src/convert.h"
6
7
7
8
mx::array to_array (
@@ -16,6 +17,16 @@ mx::array to_array(
16
17
? mx::int64
17
18
: mx::int32;
18
19
auto out_t = dtype.value_or (default_type);
20
+ if (mx::issubdtype (out_t , mx::integer) && out_t .size () < 8 ) {
21
+ auto info = mx::iinfo (out_t );
22
+ if (val < info.min || val > static_cast <int64_t >(info.max )) {
23
+ std::ostringstream msg;
24
+ msg << " Converting " << val << " to " << out_t
25
+ << " would result in overflow." ;
26
+ throw std::invalid_argument (msg.str ());
27
+ }
28
+ }
29
+
19
30
// bool_ is an exception and is always promoted
20
31
return mx::array (val, (out_t == mx::bool_) ? mx::int32 : out_t );
21
32
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
Original file line number Diff line number Diff line change @@ -109,6 +109,18 @@ def test_finfo(self):
109
109
self .assertEqual (mx .finfo (mx .float16 ).max , np .finfo (np .float16 ).max )
110
110
self .assertEqual (mx .finfo (mx .float16 ).dtype , mx .float16 )
111
111
112
+ def test_iinfo (self ):
113
+ with self .assertRaises (ValueError ):
114
+ mx .iinfo (mx .float32 )
115
+
116
+ self .assertEqual (mx .iinfo (mx .int32 ).min , np .iinfo (np .int32 ).min )
117
+ self .assertEqual (mx .iinfo (mx .int32 ).max , np .iinfo (np .int32 ).max )
118
+ self .assertEqual (mx .iinfo (mx .int32 ).dtype , mx .int32 )
119
+
120
+ self .assertEqual (mx .iinfo (mx .uint32 ).min , np .iinfo (np .uint32 ).min )
121
+ self .assertEqual (mx .iinfo (mx .uint32 ).max , np .iinfo (np .uint32 ).max )
122
+ self .assertEqual (mx .iinfo (mx .int8 ).dtype , mx .int8 )
123
+
112
124
113
125
class TestEquality (mlx_tests .MLXTestCase ):
114
126
def test_array_eq_array (self ):
@@ -1999,6 +2011,14 @@ def t():
1999
2011
used = get_mem ()
2000
2012
self .assertEqual (expected , used )
2001
2013
2014
+ def test_scalar_integer_conversion_overflow (self ):
2015
+ y = mx .array (2000000000 , dtype = mx .int32 )
2016
+ x = 3000000000
2017
+ with self .assertRaises (ValueError ):
2018
+ y + x
2019
+ with self .assertRaises (ValueError ):
2020
+ mx .add (y , x )
2021
+
2002
2022
2003
2023
if __name__ == "__main__" :
2004
2024
unittest .main ()
Original file line number Diff line number Diff line change @@ -55,3 +55,13 @@ TEST_CASE("test finfo") {
55
55
CHECK_EQ (finfo (float16).min , -65504 );
56
56
CHECK_EQ (finfo (float16).max , 65504 );
57
57
}
58
+
59
+ TEST_CASE (" test iinfo" ) {
60
+ CHECK_EQ (iinfo (int8).dtype , int8);
61
+ CHECK_EQ (iinfo (int64).dtype , int64);
62
+ CHECK_EQ (iinfo (int64).max , std::numeric_limits<int64_t >::max ());
63
+ CHECK_EQ (iinfo (uint64).max , std::numeric_limits<uint64_t >::max ());
64
+ CHECK_EQ (iinfo (uint64).max , std::numeric_limits<uint64_t >::max ());
65
+ CHECK_EQ (iinfo (uint64).min , 0 );
66
+ CHECK_EQ (iinfo (int64).min , std::numeric_limits<int64_t >::min ());
67
+ }
You can’t perform that action at this time.
0 commit comments