Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 9900f18

Browse files
committed
NumPy-compatible std and var
1 parent 3baa6eb commit 9900f18

File tree

7 files changed

+585
-6
lines changed

7 files changed

+585
-6
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3434
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3535
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
36-
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax']
36+
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
3737

3838

3939
@set_module('mxnet.ndarray.numpy')
@@ -2201,3 +2201,140 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable
22012201
array(0.55)
22022202
"""
22032203
return _npi.mean(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
2204+
2205+
2206+
@set_module('mxnet.ndarray.numpy')
2207+
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
2208+
"""
2209+
Compute the standard deviation along the specified axis.
2210+
Returns the standard deviation, a measure of the spread of a distribution,
2211+
of the array elements. The standard deviation is computed for the
2212+
flattened array by default, otherwise over the specified axis.
2213+
2214+
Parameters
2215+
----------
2216+
a : array_like
2217+
Calculate the standard deviation of these values.
2218+
axis : None or int or tuple of ints, optional
2219+
Axis or axes along which the standard deviation is computed. The
2220+
default is to compute the standard deviation of the flattened array.
2221+
.. versionadded:: 1.7.0
2222+
If this is a tuple of ints, a standard deviation is performed over
2223+
multiple axes, instead of a single axis or all the axes as before.
2224+
dtype : dtype, optional
2225+
Type to use in computing the standard deviation. For arrays of
2226+
integer type the default is float64, for arrays of float types it is
2227+
the same as the array type.
2228+
out : ndarray, optional
2229+
Alternative output array in which to place the result. It must have
2230+
the same shape as the expected output but the type (of the calculated
2231+
values) will be cast if necessary.
2232+
ddof : int, optional
2233+
Means Delta Degrees of Freedom. The divisor used in calculations
2234+
is ``N - ddof``, where ``N`` represents the number of elements.
2235+
By default `ddof` is zero.
2236+
keepdims : bool, optional
2237+
If this is set to True, the axes which are reduced are left
2238+
in the result as dimensions with size one. With this option,
2239+
the result will broadcast correctly against the input array.
2240+
If the default value is passed, then `keepdims` will not be
2241+
passed through to the `std` method of sub-classes of
2242+
`ndarray`, however any non-default value will be. If the
2243+
sub-class' method does not implement `keepdims` any
2244+
exceptions will be raised.
2245+
2246+
Returns
2247+
-------
2248+
standard_deviation : ndarray, see dtype parameter above.
2249+
If `out` is None, return a new array containing the standard deviation,
2250+
otherwise return a reference to the output array.
2251+
2252+
Examples
2253+
--------
2254+
>>> a = np.array([[1, 2], [3, 4]])
2255+
>>> np.std(a)
2256+
1.1180339887498949 # may vary
2257+
>>> np.std(a, axis=0)
2258+
array([1., 1.])
2259+
>>> np.std(a, axis=1)
2260+
array([0.5, 0.5])
2261+
In single precision, std() can be inaccurate:
2262+
>>> a = np.zeros((2, 512*512), dtype=np.float32)
2263+
>>> a[0, :] = 1.0
2264+
>>> a[1, :] = 0.1
2265+
>>> np.std(a)
2266+
array(0.45)
2267+
>>> np.std(a, dtype=np.float64)
2268+
array(0.45, dtype=float64)
2269+
"""
2270+
return _npi.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
2271+
2272+
2273+
@set_module('mxnet.ndarray.numpy')
2274+
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
2275+
"""
2276+
Compute the variance along the specified axis.
2277+
Returns the variance of the array elements, a measure of the spread of a
2278+
distribution. The variance is computed for the flattened array by
2279+
default, otherwise over the specified axis.
2280+
2281+
Parameters
2282+
----------
2283+
a : array_like
2284+
Array containing numbers whose variance is desired. If `a` is not an
2285+
array, a conversion is attempted.
2286+
axis : None or int or tuple of ints, optional
2287+
Axis or axes along which the variance is computed. The default is to
2288+
compute the variance of the flattened array.
2289+
.. versionadded:: 1.7.0
2290+
If this is a tuple of ints, a variance is performed over multiple axes,
2291+
instead of a single axis or all the axes as before.
2292+
dtype : data-type, optional
2293+
Type to use in computing the variance. For arrays of integer type
2294+
the default is `float32`; for arrays of float types it is the same as
2295+
the array type.
2296+
out : ndarray, optional
2297+
Alternate output array in which to place the result. It must have
2298+
the same shape as the expected output, but the type is cast if
2299+
necessary.
2300+
ddof : int, optional
2301+
"Delta Degrees of Freedom": the divisor used in the calculation is
2302+
``N - ddof``, where ``N`` represents the number of elements. By
2303+
default `ddof` is zero.
2304+
keepdims : bool, optional
2305+
If this is set to True, the axes which are reduced are left
2306+
in the result as dimensions with size one. With this option,
2307+
the result will broadcast correctly against the input array.
2308+
If the default value is passed, then `keepdims` will not be
2309+
passed through to the `var` method of sub-classes of
2310+
`ndarray`, however any non-default value will be. If the
2311+
sub-class' method does not implement `keepdims` any
2312+
exceptions will be raised.
2313+
2314+
Returns
2315+
-------
2316+
variance : ndarray, see dtype parameter above
2317+
If ``out=None``, returns a new array containing the variance;
2318+
otherwise, a reference to the output array is returned.
2319+
2320+
Examples
2321+
--------
2322+
>>> a = np.array([[1, 2], [3, 4]])
2323+
>>> np.var(a)
2324+
array(1.25)
2325+
>>> np.var(a, axis=0)
2326+
array([1., 1.])
2327+
>>> np.var(a, axis=1)
2328+
array([0.25, 0.25])
2329+
2330+
>>> a = np.zeros((2, 512*512), dtype=np.float32)
2331+
>>> a[0, :] = 1.0
2332+
>>> a[1, :] = 0.1
2333+
>>> np.var(a)
2334+
array(0.2025)
2335+
>>> np.var(a, dtype=np.float64)
2336+
array(0.2025, dtype=float64)
2337+
>>> ((1-0.55)**2 + (0.1-0.55)**2)/2
2338+
0.2025
2339+
"""
2340+
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)

python/mxnet/numpy/multiarray.py

Lines changed: 139 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
5353
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
5454
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
55-
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax']
55+
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
5656

5757
# Return code for dispatching indexing function call
5858
_NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -1172,11 +1172,9 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disa
11721172
"""Returns the average of the array elements along given axis."""
11731173
raise NotImplementedError
11741174

1175-
# TODO(junwu): Use mxnet std op instead of onp.std
11761175
def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=arguments-differ
11771176
"""Returns the standard deviation of the array elements along given axis."""
1178-
ret_np = self.asnumpy().std(axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims)
1179-
return array(ret_np, dtype=ret_np.dtype, ctx=self.context)
1177+
return _mx_np_op.std(self, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
11801178

11811179
def cumsum(self, axis=None, dtype=None, out=None):
11821180
"""Return the cumulative sum of the elements along the given axis."""
@@ -3644,3 +3642,140 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable
36443642
array(0.55)
36453643
"""
36463644
return _npi.mean(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
3645+
3646+
3647+
@set_module('mxnet.numpy')
3648+
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):
3649+
"""
3650+
Compute the standard deviation along the specified axis.
3651+
Returns the standard deviation, a measure of the spread of a distribution,
3652+
of the array elements. The standard deviation is computed for the
3653+
flattened array by default, otherwise over the specified axis.
3654+
3655+
Parameters
3656+
----------
3657+
a : array_like
3658+
Calculate the standard deviation of these values.
3659+
axis : None or int or tuple of ints, optional
3660+
Axis or axes along which the standard deviation is computed. The
3661+
default is to compute the standard deviation of the flattened array.
3662+
.. versionadded:: 1.7.0
3663+
If this is a tuple of ints, a standard deviation is performed over
3664+
multiple axes, instead of a single axis or all the axes as before.
3665+
dtype : dtype, optional
3666+
Type to use in computing the standard deviation. For arrays of
3667+
integer type the default is float64, for arrays of float types it is
3668+
the same as the array type.
3669+
out : ndarray, optional
3670+
Alternative output array in which to place the result. It must have
3671+
the same shape as the expected output but the type (of the calculated
3672+
values) will be cast if necessary.
3673+
ddof : int, optional
3674+
Means Delta Degrees of Freedom. The divisor used in calculations
3675+
is ``N - ddof``, where ``N`` represents the number of elements.
3676+
By default `ddof` is zero.
3677+
keepdims : bool, optional
3678+
If this is set to True, the axes which are reduced are left
3679+
in the result as dimensions with size one. With this option,
3680+
the result will broadcast correctly against the input array.
3681+
If the default value is passed, then `keepdims` will not be
3682+
passed through to the `std` method of sub-classes of
3683+
`ndarray`, however any non-default value will be. If the
3684+
sub-class' method does not implement `keepdims` any
3685+
exceptions will be raised.
3686+
3687+
Returns
3688+
-------
3689+
standard_deviation : ndarray, see dtype parameter above.
3690+
If `out` is None, return a new array containing the standard deviation,
3691+
otherwise return a reference to the output array.
3692+
3693+
Examples
3694+
--------
3695+
>>> a = np.array([[1, 2], [3, 4]])
3696+
>>> np.std(a)
3697+
1.1180339887498949 # may vary
3698+
>>> np.std(a, axis=0)
3699+
array([1., 1.])
3700+
>>> np.std(a, axis=1)
3701+
array([0.5, 0.5])
3702+
In single precision, std() can be inaccurate:
3703+
>>> a = np.zeros((2, 512*512), dtype=np.float32)
3704+
>>> a[0, :] = 1.0
3705+
>>> a[1, :] = 0.1
3706+
>>> np.std(a)
3707+
array(0.45)
3708+
>>> np.std(a, dtype=np.float64)
3709+
array(0.45, dtype=float64)
3710+
"""
3711+
return _npi.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
3712+
3713+
3714+
@set_module('mxnet.numpy')
3715+
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):
3716+
"""
3717+
Compute the variance along the specified axis.
3718+
Returns the variance of the array elements, a measure of the spread of a
3719+
distribution. The variance is computed for the flattened array by
3720+
default, otherwise over the specified axis.
3721+
3722+
Parameters
3723+
----------
3724+
a : array_like
3725+
Array containing numbers whose variance is desired. If `a` is not an
3726+
array, a conversion is attempted.
3727+
axis : None or int or tuple of ints, optional
3728+
Axis or axes along which the variance is computed. The default is to
3729+
compute the variance of the flattened array.
3730+
.. versionadded:: 1.7.0
3731+
If this is a tuple of ints, a variance is performed over multiple axes,
3732+
instead of a single axis or all the axes as before.
3733+
dtype : data-type, optional
3734+
Type to use in computing the variance. For arrays of integer type
3735+
the default is `float32`; for arrays of float types it is the same as
3736+
the array type.
3737+
out : ndarray, optional
3738+
Alternate output array in which to place the result. It must have
3739+
the same shape as the expected output, but the type is cast if
3740+
necessary.
3741+
ddof : int, optional
3742+
"Delta Degrees of Freedom": the divisor used in the calculation is
3743+
``N - ddof``, where ``N`` represents the number of elements. By
3744+
default `ddof` is zero.
3745+
keepdims : bool, optional
3746+
If this is set to True, the axes which are reduced are left
3747+
in the result as dimensions with size one. With this option,
3748+
the result will broadcast correctly against the input array.
3749+
If the default value is passed, then `keepdims` will not be
3750+
passed through to the `var` method of sub-classes of
3751+
`ndarray`, however any non-default value will be. If the
3752+
sub-class' method does not implement `keepdims` any
3753+
exceptions will be raised.
3754+
3755+
Returns
3756+
-------
3757+
variance : ndarray, see dtype parameter above
3758+
If ``out=None``, returns a new array containing the variance;
3759+
otherwise, a reference to the output array is returned.
3760+
3761+
Examples
3762+
--------
3763+
>>> a = np.array([[1, 2], [3, 4]])
3764+
>>> np.var(a)
3765+
array(1.25)
3766+
>>> np.var(a, axis=0)
3767+
array([1., 1.])
3768+
>>> np.var(a, axis=1)
3769+
array([0.25, 0.25])
3770+
3771+
>>> a = np.zeros((2, 512*512), dtype=np.float32)
3772+
>>> a[0, :] = 1.0
3773+
>>> a[1, :] = 0.1
3774+
>>> np.var(a)
3775+
array(0.2025)
3776+
>>> np.var(a, dtype=np.float64)
3777+
array(0.2025, dtype=float64)
3778+
>>> ((1-0.55)**2 + (0.1-0.55)**2)/2
3779+
0.2025
3780+
"""
3781+
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3636
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3737
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
38-
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax']
38+
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
3939

4040

4141
def _num_outputs(sym):
@@ -2569,4 +2569,18 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable
25692569
return _npi.mean(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
25702570

25712571

2572+
@set_module('mxnet.symbol.numpy')
2573+
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
2574+
"""
2575+
"""
2576+
return _npi.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
2577+
2578+
2579+
@set_module('mxnet.symbol.numpy')
2580+
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
2581+
"""
2582+
"""
2583+
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
2584+
2585+
25722586
_set_np_symbol_class(_Symbol)

0 commit comments

Comments
 (0)