Skip to content

Commit 5be668f

Browse files
authored
Move OneElement from Zygote and overload setindex (#161) (#235)
* Add Zeros(T, n...) and Ones(T, n...) constructors (#94( (#233) * Add Zeros(T, n...) and Ones(T, n...) constructors (#94( * increase coverage * Update README.md * Move over OneElement from Zygote * Add tests * Update oneelement.jl * add tests * Update runtests.jl * add docs
1 parent 4498570 commit 5be668f

File tree

5 files changed

+83
-11
lines changed

5 files changed

+83
-11
lines changed

README.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,8 @@ as well as identity matrices. This package exports the following types:
1414

1515

1616
The primary purpose of this package is to present a unified way of constructing
17-
matrices. For example, to construct a 5-by-5 `CLArray` of all zeros, one would use
18-
```julia
19-
julia> CLArray(Zeros(5,5))
20-
```
21-
Because `Zeros` is lazy, this can be accomplished on the GPU with no memory transfer.
22-
Similarly, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use
17+
matrices.
18+
For example, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use
2319
```julia
2420
julia> BandedMatrix(Zeros(5,5), (1, 2))
2521
```

src/FillArrays.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape
1818
import Statistics: mean, std, var, cov, cor
1919

2020

21-
export Zeros, Ones, Fill, Eye, Trues, Falses
21+
export Zeros, Ones, Fill, Eye, Trues, Falses, OneElement
2222

2323
import Base: oneto
2424

@@ -263,6 +263,7 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
263263
@inline $Typ{T,N}(A::AbstractArray{V,N}) where{T,V,N} = $Typ{T,N}(size(A))
264264
@inline $Typ{T}(A::AbstractArray) where{T} = $Typ{T}(size(A))
265265
@inline $Typ(A::AbstractArray) = $Typ{eltype(A)}(A)
266+
@inline $Typ(::Type{T}, m...) where T = $Typ{T}(m...)
266267

267268
@inline axes(Z::$Typ) = Z.axes
268269
@inline size(Z::$Typ) = length.(Z.axes)
@@ -728,4 +729,6 @@ Base.@propagate_inbounds function view(A::AbstractFill{<:Any,N}, I::Vararg{Real,
728729
fillsimilar(A)
729730
end
730731

732+
include("oneelement.jl")
733+
731734
end # module

src/fillalgebra.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ end
8686
*(a::ZerosMatrix, b::AbstractMatrix) = mult_zeros(a, b)
8787
*(a::AbstractMatrix, b::ZerosVector) = mult_zeros(a, b)
8888
*(a::AbstractMatrix, b::ZerosMatrix) = mult_zeros(a, b)
89-
*(a::ZerosVector, b::AbstractVector) = mult_zeros(a, b)
9089
*(a::ZerosMatrix, b::AbstractVector) = mult_zeros(a, b)
9190
*(a::AbstractVector, b::ZerosMatrix) = mult_zeros(a, b)
9291

src/oneelement.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
OneElement(val, ind, axesorsize) <: AbstractArray
3+
4+
Represents an array with the specified axes (if its a tuple of `AbstractUnitRange`s)
5+
or size (if its a tuple of `Integer`s), with a single entry set to `val` and all others equal to zero,
6+
specified by `ind``.
7+
"""
8+
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
9+
val::T
10+
ind::I
11+
axes::A
12+
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
13+
end
14+
15+
OneElement(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where N = OneElement(val, inds, oneto.(sz))
16+
"""
17+
OneElement(val, ind::Int, n::Int)
18+
19+
Creates a length `n` vector where the `ind` entry is equal to `val`, and all other entries are zero.
20+
"""
21+
OneElement(val, ind::Int, len::Int) = OneElement(val, (ind,), (len,))
22+
"""
23+
OneElement(ind::Int, n::Int)
24+
25+
Creates a length `n` vector where the `ind` entry is equal to `1`, and all other entries are zero.
26+
"""
27+
OneElement(inds::Int, sz::Int) = OneElement(1, inds, sz)
28+
OneElement{T}(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where {T,N} = OneElement(convert(T,val), inds, oneto.(sz))
29+
OneElement{T}(val, inds::Int, sz::Int) where T = OneElement{T}(val, (inds,), (sz,))
30+
31+
"""
32+
OneElement{T}(val, ind::Int, n::Int)
33+
34+
Creates a length `n` vector where the `ind` entry is equal to `one(T)`, and all other entries are zero.
35+
"""
36+
OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz)
37+
38+
Base.size(A::OneElement) = map(length, A.axes)
39+
Base.axes(A::OneElement) = A.axes
40+
function Base.getindex(A::OneElement{T,N}, kj::Vararg{Int,N}) where {T,N}
41+
@boundscheck checkbounds(A, kj...)
42+
ifelse(kj == A.ind, A.val, zero(T))
43+
end
44+
45+
Base.replace_in_print_matrix(o::OneElement{<:Any,2}, k::Integer, j::Integer, s::AbstractString) =
46+
o.ind == (k,j) ? s : Base.replace_with_centered_mark(s)
47+
48+
function Base.setindex(A::Zeros{T,N}, v, kj::Vararg{Int,N}) where {T,N}
49+
@boundscheck checkbounds(A, kj...)
50+
OneElement(convert(T, v), kj, axes(A))
51+
end

test/runtests.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ include("infinitearrays.jl")
2020

2121
for T in (Int, Float64)
2222
Z = $Typ{T}(5)
23+
@test $Typ(T, 5) Z
2324
@test eltype(Z) == T
2425
@test Array(Z) == $funcs(T,5)
2526
@test Array{T}(Z) == $funcs(T,5)
@@ -34,6 +35,7 @@ include("infinitearrays.jl")
3435
@test $Typ(2ones(T,5)) == Z
3536

3637
Z = $Typ{T}(5, 5)
38+
@test $Typ(T, 5, 5) Z
3739
@test eltype(Z) == T
3840
@test Array(Z) == $funcs(T,5,5)
3941
@test Array{T}(Z) == $funcs(T,5,5)
@@ -525,9 +527,9 @@ end
525527
@test_throws MethodError [1,2,3]*Zeros(3) # Not defined for [1,2,3]*[0,0,0] either
526528

527529
@testset "Check multiplication by Adjoint vectors works as expected." begin
528-
@test randn(4, 3)' * Zeros(4) === Zeros(3)
529-
@test randn(4)' * Zeros(4) === zero(Float64)
530-
@test [1, 2, 3]' * Zeros{Int}(3) === zero(Int)
530+
@test randn(4, 3)' * Zeros(4) Zeros(3)
531+
@test randn(4)' * Zeros(4) transpose(randn(4)) * Zeros(4) zero(Float64)
532+
@test [1, 2, 3]' * Zeros{Int}(3) zero(Int)
531533
@test [SVector(1,2)', SVector(2,3)', SVector(3,4)']' * Zeros{Int}(3) === SVector(0,0)
532534
@test_throws DimensionMismatch randn(4)' * Zeros(3)
533535
@test Zeros(5)' * randn(5,3) Zeros(5)'*Zeros(5,3) Zeros(5)'*Ones(5,3) Zeros(3)'
@@ -1503,4 +1505,25 @@ end
15031505
@test Zeros(5,5) .+ D isa Diagonal
15041506
f = (x,y) -> x+1
15051507
@test f.(D, Zeros(5,5)) isa Matrix
1508+
end
1509+
1510+
@testset "OneElement" begin
1511+
e₁ = OneElement(2, 5)
1512+
@test e₁ == [0,1,0,0,0]
1513+
@test_throws BoundsError e₁[6]
1514+
1515+
e₁ = OneElement{Float64}(2, 5)
1516+
@test e₁ == [0,1,0,0,0]
1517+
1518+
v = OneElement{Float64}(2, 3, 4)
1519+
@test v == [0,0,2,0]
1520+
1521+
V = OneElement(2, (2,3), (3,4))
1522+
@test V == [0 0 0 0; 0 0 2 0; 0 0 0 0]
1523+
1524+
@test stringmime("text/plain", V) == "3×4 OneElement{$Int, 2, Tuple{$Int, $Int}, Tuple{Base.OneTo{$Int}, Base.OneTo{$Int}}}:\n ⋅ ⋅ ⋅ ⋅\n ⋅ ⋅ 2 ⋅\n ⋅ ⋅ ⋅ ⋅"
1525+
1526+
@test Base.setindex(Zeros(5), 2, 2) OneElement(2.0, 2, 5)
1527+
@test Base.setindex(Zeros(5,3), 2, 2, 3) OneElement(2.0, (2,3), (5,3))
1528+
@test_throws BoundsError Base.setindex(Zeros(5), 2, 6)
15061529
end

0 commit comments

Comments
 (0)