Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 36a3cb8

Browse files
authored
Julia: split ndarray.jl into several snippets (#14001)
- `ndarray/type.jl` - `ndarray/context.jl` - `ndarray/show.jl` - `ndarray/remap.jl` - `ndarray/array.jl` - `ndarray/arithmetic.jl` - `ndarray/comparison.jl` - `ndarray/io.jl` - `ndarray/reduction.jl` - `ndarray/statistic.jl` - `ndarray/linalg.jl` - `ndarray/trig.jl` - `ndarray/activation.jl` - `ndarray/autoimport.jl`
1 parent e37ff53 commit 36a3cb8

File tree

15 files changed

+2061
-1813
lines changed

15 files changed

+2061
-1813
lines changed

julia/src/ndarray.jl

Lines changed: 14 additions & 1813 deletions
Large diffs are not rendered by default.

julia/src/ndarray/activation.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# activation functions
19+
20+
@doc doc"""
21+
σ.(x::NDArray)
22+
sigmoid.(x::NDArray)
23+
24+
Computes sigmoid of x element-wise.
25+
26+
```math
27+
σ(x) = \frac{1}{(1 + exp(-x))}
28+
```
29+
30+
The storage type of `sigmoid` output is always dense.
31+
"""
32+
function σ end
33+
const sigmoid = σ
34+
_nddoc[] = false
35+
@_remap broadcasted(::typeof(σ), x::NDArray) sigmoid(x)
36+
37+
@doc doc"""
38+
relu.(x::NDArray)
39+
40+
Computes rectified linear.
41+
42+
```math
43+
\max(x, 0)
44+
```
45+
"""
46+
function relu end
47+
_nddoc[:relu] = false
48+
@_remap broadcasted(::typeof(relu), x::NDArray) relu(x)
49+
50+
@doc doc"""
51+
softmax.(x::NDArray, [dim = ndims(x)])
52+
53+
Applies the softmax function.
54+
55+
The resulting array contains elements in the range `(0, 1)`
56+
and the elements along the given axis sum up to 1.
57+
58+
```math
59+
softmax(\mathbf{z})_j = \frac{e^{z_j}}{\sum_{k=1}^K e^{z_k}}
60+
```
61+
"""
62+
function softmax end
63+
_nddoc[:softmax] = false
64+
@_remap broadcasted(::typeof(softmax), x::NDArray) softmax(x; axis = -ndims(x))
65+
@_remap broadcasted(::typeof(softmax), x::NDArray, dim::Int) softmax(x; axis = -dim)
66+
67+
"""
68+
log_softmax.(x::NDArray, [dim = ndims(x)])
69+
70+
Computes the log softmax of the input.
71+
This is equivalent to computing softmax followed by log.
72+
73+
julia> x
74+
2×3 mx.NDArray{Float64,2} @ CPU0:
75+
1.0 2.0 0.1
76+
0.1 2.0 1.0
77+
78+
julia> mx.log_softmax.(x)
79+
2×3 mx.NDArray{Float64,2} @ CPU0:
80+
-1.41703 -0.41703 -2.31703
81+
-2.31703 -0.41703 -1.41703
82+
"""
83+
function log_softmax end
84+
_nddoc[:log_softmax] = false
85+
@_remap broadcasted(::typeof(log_softmax), x::NDArray) log_softmax(x; axis = -ndims(x))
86+
@_remap broadcasted(::typeof(log_softmax), x::NDArray, dim::Int) log_softmax(x; axis = -dim)
87+

julia/src/ndarray/arithmetic.jl

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import Base: +
19+
20+
"""
21+
+(args...)
22+
.+(args...)
23+
24+
Summation. Multiple arguments of either scalar or `NDArray` could be
25+
added together. Note at least the first or second argument needs to be an
26+
`NDArray` to avoid ambiguity of built-in summation.
27+
"""
28+
+(x::NDArray) = x
29+
+(x::NDArray, y::NDArray) = _plus(x, y)
30+
+(x::NDArray, y::Real) = _plus_scalar(x, scalar = y)
31+
+(y::Real, x::NDArray) = _plus_scalar(x, scalar = y)
32+
33+
broadcasted(::typeof(+), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
34+
_broadcast_add(x, y)
35+
36+
"""
37+
sub_from!(dst::NDArray, args::NDArrayOrReal...)
38+
39+
Subtract a bunch of arguments from `dst`. Inplace updating.
40+
"""
41+
function sub_from!(dst::NDArray, arg::NDArrayOrReal)
42+
@assert dst.writable
43+
if isa(arg, Real)
44+
_minus_scalar(dst, scalar = arg, out = dst)
45+
else
46+
_minus!(dst, arg)
47+
end
48+
dst
49+
end
50+
51+
import Base: -
52+
53+
"""
54+
-(x::NDArray)
55+
-(x, y)
56+
.-(x, y)
57+
58+
Subtraction `x - y`, of scalar types or `NDArray`.
59+
Or create the negative of `x`.
60+
"""
61+
-(x::NDArray) = _mul_scalar(x, scalar = -one(eltype(x)))
62+
-(x::NDArray, y::NDArray) = _minus(x, y)
63+
-(x::NDArray, y::Real) = _minus_scalar(x, scalar = y)
64+
-(y::Real, x::NDArray) = _rminus_scalar(x, scalar = y)
65+
66+
broadcasted(::typeof(-), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
67+
_broadcast_minus(x, y)
68+
69+
"""
70+
mul_to!(dst::NDArray, arg::NDArrayOrReal)
71+
72+
Elementwise multiplication into `dst` of either a scalar or an `NDArray` of the same shape.
73+
Inplace updating.
74+
"""
75+
function mul_to!(dst::NDArray, arg::NDArrayOrReal)
76+
@assert dst.writable
77+
if isa(arg, Real)
78+
_mul_scalar(dst, scalar = arg, out = dst)
79+
else
80+
_mul(dst, arg, out = dst)
81+
end
82+
dst
83+
end
84+
85+
import Base: *
86+
87+
"""
88+
.*(x, y)
89+
90+
Elementwise multiplication for `NDArray`.
91+
"""
92+
*(x::NDArray, y::Real) = _mul_scalar(x, scalar = y)
93+
*(y::Real, x::NDArray) = _mul_scalar(x, scalar = y)
94+
95+
broadcasted(::typeof(*), x::NDArray{T,N}, y::NDArray{T,N}) where {T,N} =
96+
_mul(x, y)
97+
broadcasted(::typeof(*), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
98+
_broadcast_mul(x, y)
99+
100+
"""
101+
*(A::NDArray, B::NDArray)
102+
103+
Matrix/tensor multiplication.
104+
"""
105+
*(x::NDArray{T}, y::NDArray{T}) where T = x y
106+
107+
LinearAlgebra.adjoint(x::NDArray{T,1}) where T = transpose(x)
108+
LinearAlgebra.adjoint(x::NDArray{T,2}) where T = transpose(x)
109+
110+
"""
111+
div_from!(dst::NDArray, arg::NDArrayOrReal)
112+
113+
Elementwise divide a scalar or an `NDArray` of the same shape from `dst`. Inplace updating.
114+
"""
115+
function div_from!(dst::NDArray, arg::NDArrayOrReal)
116+
@assert dst.writable
117+
if isa(arg, Real)
118+
_div_scalar(dst, scalar = arg, out = dst)
119+
else
120+
_div(dst, arg, out = dst)
121+
end
122+
dst
123+
end
124+
125+
function div_from!(dst::NDArray{T}, arg::Real) where {T<:Integer}
126+
@assert dst.writable
127+
@assert(round(T, arg) != zero(T), "Integer divided by zero")
128+
_div_scalar(dst, scalar = arg, out = dst)
129+
dst
130+
end
131+
132+
"""
133+
rdiv_from!(x:: Real, y::NDArray)
134+
135+
Elementwise divide a scalar by an `NDArray`. Inplace updating.
136+
"""
137+
function rdiv_from!(x::Real, y::NDArray)
138+
@assert y.writable
139+
_rdiv_scalar(y, scalar = x, out = y)
140+
y
141+
end
142+
143+
import Base: /
144+
145+
"""
146+
./(x::NDArray, y::NDArray)
147+
./(x::NDArray, y::Real)
148+
./(x::Real, y::NDArray)
149+
150+
* Elementwise dividing an `NDArray` by a scalar or another `NDArray`
151+
of the same shape.
152+
153+
* Elementwise divide a scalar by an `NDArray`.
154+
155+
* Matrix division (solving linear systems) is not implemented yet.
156+
"""
157+
/(x::NDArray, y::Real) = _div_scalar(x, scalar = y)
158+
159+
broadcasted(::typeof(/), y::Real, x::NDArray) = _rdiv_scalar(x, scalar = y)
160+
broadcasted(::typeof(/), x::NDArray{T,N}, y::NDArray{T,N}) where {T,N} =
161+
_div(x, y)
162+
broadcasted(::typeof(/), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
163+
_broadcast_div(x, y)
164+
165+
function broadcasted(::typeof(/), x::NDArray{T}, y::Real) where {T<:Integer}
166+
@assert(round(T, y) != zero(T), "Integer divided by zero")
167+
_div_scalar(x, scalar = y)
168+
end
169+
170+
"""
171+
mod_from!(x::NDArray, y::NDArray)
172+
mod_from!(x::NDArray, y::Real)
173+
174+
Elementwise modulo for `NDArray`.
175+
Inplace updating.
176+
"""
177+
mod_from!(x::NDArray, y::NDArray) = _mod!(x, y)
178+
mod_from!(x::NDArray, y::Real) = _mod_scalar!(x, y)
179+
180+
"""
181+
rmod_from!(y::Real, x::NDArray)
182+
183+
Elementwise modulo for `NDArray`.
184+
Inplace updating.
185+
"""
186+
rmod_from!(y::Real, x::NDArray) = _rmod_scalar!(x, y)
187+
188+
import Base: %
189+
190+
"""
191+
.%(x::NDArray, y::NDArray)
192+
.%(x::NDArray, y::Real)
193+
.%(x::Real, y::NDArray)
194+
195+
Elementwise modulo for `NDArray`.
196+
"""
197+
%(x::NDArray, y::Real) = _mod_scalar(x, y)
198+
199+
broadcasted(::typeof(%), y::Real, x::NDArray) = _rmod_scalar(x, y)
200+
broadcasted(::typeof(%), x::NDArray{T,N}, y::NDArray{T,N}) where {T,N} =
201+
_mod(x, y)
202+
broadcasted(::typeof(%), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
203+
_broadcast_mod(x, y)
204+
205+
# document of `.^` is merged into SymbolicNode's
206+
207+
broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::NDArray, ::Val{s}) where {s} =
208+
_power_scalar(x, scalar = s)
209+
broadcasted(::typeof(^), x::NDArray, s::Real) = _power_scalar(x, scalar = s)
210+
broadcasted(::typeof(^), s::Real, x::NDArray) = _rpower_scalar(x, scalar = s)
211+
212+
broadcasted(::typeof(^), ::Irrational{:ℯ}, x::NDArray) = exp(x)
213+
broadcasted(::typeof(^), x::NDArray, s::Irrational) = _power_scalar(x, scalar = s)
214+
broadcasted(::typeof(^), s::Irrational, x::NDArray) = _rpower_scalar(x, scalar = s)
215+
216+
broadcasted(::typeof(^), x::NDArray{T,N}, y::NDArray{T,N}) where {T,N} =
217+
_power(x, y)
218+
broadcasted(::typeof(^), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
219+
_broadcast_power(x, y)
220+
221+
_nddoc[:clip] = _nddoc[:clip!] =
222+
"""
223+
clip(x::NDArray, min, max)
224+
clip!(x::NDArray, min, max)
225+
226+
Clips (limits) the values in `NDArray`.
227+
Given an interval, values outside the interval are clipped to the interval edges.
228+
Clipping `x` between `min` and `x` would be:
229+
230+
```julia
231+
clip(x, min_, max_) = max(min(x, max_), min_))
232+
```
233+
234+
```jldoctest
235+
julia> x = NDArray(1:9);
236+
237+
julia> mx.clip(x, 2, 8)'
238+
1×9 mx.NDArray{Int64,2} @ CPU0:
239+
2 2 3 4 5 6 7 8 8
240+
```
241+
242+
The storage type of clip output depends on storage types of inputs and the
243+
`min`, `max` parameter values:
244+
245+
- clip(default) = default
246+
- clip(row_sparse, min <= 0, max >= 0) = row_sparse
247+
- clip(csr, min <= 0, max >= 0) = csr
248+
- clip(row_sparse, min < 0, max < 0) = default
249+
- clip(row_sparse, min > 0, max > 0) = default
250+
- clip(csr, min < 0, max < 0) = csr
251+
- clip(csr, min > 0, max > 0) = csr
252+
"""
253+
@_remap clip(x::NDArray, min::Real, max::Real) clip(x; a_min = min, a_max = max)
254+
@_remap clip!(x::NDArray, min::Real, max::Real) clip(x; a_min = min, a_max = max)
255+
256+
################################################################################
257+
# remapping to solving type unstablility
258+
################################################################################
259+
260+
@_remap _plus(x::NDArray, y::NDArray) _plus(x, y)
261+
@_remap _plus!(x::NDArray, y::NDArray) _plus(x, y)
262+
263+
@_remap _minus(x::NDArray, y::NDArray) _minus(x, y)
264+
@_remap _minus!(x::NDArray, y::NDArray) _minus(x, y)
265+
266+
@_remap _mod(x::NDArray, y::NDArray) _mod(x, y)
267+
@_remap _mod!(x::NDArray, y::NDArray) _mod(x, y)
268+
269+
@_remap _mod_scalar(x::NDArray, y::Real) _mod_scalar(x; scalar = y)
270+
@_remap _mod_scalar!(x::NDArray, y::Real) _mod_scalar(x; scalar = y)
271+
272+
@_remap _rmod_scalar(x::NDArray, y::Real) _rmod_scalar(x; scalar = y)
273+
@_remap _rmod_scalar!(x::NDArray, y::Real) _rmod_scalar(x; scalar = y)
274+
275+
@_remap _broadcast_add(x::NDArray, y::NDArray) broadcast_add(x, y)
276+
@_remap _broadcast_add!(x::NDArray, y::NDArray) broadcast_add(x, y)
277+
278+
@_remap _broadcast_minus(x::NDArray, y::NDArray) broadcast_minus(x, y)
279+
@_remap _broadcast_minus!(x::NDArray, y::NDArray) broadcast_minus(x, y)
280+
281+
@_remap _broadcast_mul(x::NDArray, y::NDArray) broadcast_mul(x, y)
282+
@_remap _broadcast_mul!(x::NDArray, y::NDArray) broadcast_mul(x, y)
283+
284+
@_remap _broadcast_div(x::NDArray, y::NDArray) broadcast_div(x, y)
285+
@_remap _broadcast_div!(x::NDArray, y::NDArray) broadcast_div(x, y)
286+
287+
@_remap _broadcast_mod(x::NDArray, y::NDArray) broadcast_mod(x, y)
288+
@_remap _broadcast_mod!(x::NDArray, y::NDArray) broadcast_mod(x, y)
289+
290+
@_remap _broadcast_power(x::NDArray, y::NDArray) broadcast_power(x, y)
291+
@_remap _broadcast_power!(x::NDArray, y::NDArray) broadcast_power(x, y)

0 commit comments

Comments
 (0)