Skip to content

Commit d8e1f9a

Browse files
_fill_dot support general vectors (#229)
* Update fillalgebra.jl * promote_op * add breaking test * add breaking test * fix * accept round-off errors * Update test/runtests.jl Co-authored-by: Sheehan Olver <[email protected]> * update * support inf and nan * fix 1.6 * Update fillalgebra.jl * Update fillalgebra.jl * trying to fix Julia 1.6 * comments * Update runtests.jl * add @inferred --------- Co-authored-by: Sheehan Olver <[email protected]>
1 parent fea49f6 commit d8e1f9a

File tree

2 files changed

+45
-47
lines changed

2 files changed

+45
-47
lines changed

src/fillalgebra.jl

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -159,38 +159,22 @@ function *(a::Transpose{T, <:AbstractVector{T}}, b::ZerosVector{T}) where T<:Rea
159159
end
160160
*(a::Transpose{T, <:AbstractMatrix{T}}, b::ZerosVector{T}) where T<:Real = mult_zeros(a, b)
161161

162-
# treat zero separately to support ∞-vectors
163-
function _zero_dot(a, b)
164-
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
165-
zero(promote_type(eltype(a),eltype(b)))
166-
end
167-
168-
_fill_dot(a::Zeros, b::Zeros) = _zero_dot(a, b)
169-
_fill_dot(a::Zeros, b) = _zero_dot(a, b)
170-
_fill_dot(a, b::Zeros) = _zero_dot(a, b)
171-
_fill_dot(a::Zeros, b::AbstractFill) = _zero_dot(a, b)
172-
_fill_dot(a::AbstractFill, b::Zeros) = _zero_dot(a, b)
173-
174-
function _fill_dot(a::AbstractFill, b::AbstractFill)
175-
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
176-
getindex_value(a)getindex_value(b)*length(b)
177-
end
178-
179162
# support types with fast sum
180-
function _fill_dot(a::AbstractFill, b)
163+
# infinite cases should be supported in InfiniteArrays.jl
164+
# type issues of Bool dot are ignored at present.
165+
function _fill_dot(a::AbstractFillVector{T}, b::AbstractVector{V}) where {T,V}
181166
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
182-
getindex_value(a)sum(b)
167+
dot(getindex_value(a), sum(b))
183168
end
184169

185-
function _fill_dot(a, b::AbstractFill)
170+
function _fill_dot_rev(a::AbstractVector{T}, b::AbstractFillVector{V}) where {T,V}
186171
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
187-
sum(a)getindex_value(b)
172+
dot(sum(a), getindex_value(b))
188173
end
189174

190-
191175
dot(a::AbstractFillVector, b::AbstractFillVector) = _fill_dot(a, b)
192176
dot(a::AbstractFillVector, b::AbstractVector) = _fill_dot(a, b)
193-
dot(a::AbstractVector, b::AbstractFillVector) = _fill_dot(a, b)
177+
dot(a::AbstractVector, b::AbstractFillVector) = _fill_dot_rev(a, b)
194178

195179
function dot(u::AbstractVector, E::Eye, v::AbstractVector)
196180
length(u) == size(E,1) && length(v) == size(E,2) ||

test/runtests.jl

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -329,20 +329,32 @@ end
329329
# type, and produce numerically correct results.
330330
as_array(x::AbstractArray) = Array(x)
331331
as_array(x::UniformScaling) = x
332-
function test_addition_and_subtraction(As, Bs, Tout::Type)
332+
equal_or_undef(a::Number, b::Number) = (a == b) || isequal(a, b)
333+
equal_or_undef(a, b) = all(equal_or_undef.(a, b))
334+
function test_addition_subtraction_dot(As, Bs, Tout::Type)
333335
for A in As, B in Bs
334-
@testset "$(typeof(A)) ± $(typeof(B))" begin
336+
@testset "$(typeof(A)) and $(typeof(B))" begin
335337
@test A + B isa Tout{promote_type(eltype(A), eltype(B))}
336-
@test as_array(A + B) == as_array(A) + as_array(B)
338+
@test equal_or_undef(as_array(A + B), as_array(A) + as_array(B))
337339

338340
@test A - B isa Tout{promote_type(eltype(A), eltype(B))}
339-
@test as_array(A - B) == as_array(A) - as_array(B)
341+
@test equal_or_undef(as_array(A - B), as_array(A) - as_array(B))
340342

341343
@test B + A isa Tout{promote_type(eltype(B), eltype(A))}
342-
@test as_array(B + A) == as_array(B) + as_array(A)
344+
@test equal_or_undef(as_array(B + A), as_array(B) + as_array(A))
343345

344346
@test B - A isa Tout{promote_type(eltype(B), eltype(A))}
345-
@test as_array(B - A) == as_array(B) - as_array(A)
347+
@test equal_or_undef(as_array(B - A), as_array(B) - as_array(A))
348+
349+
# Julia 1.6 doesn't support dot(UniformScaling)
350+
if VERSION < v"1.6.0" || VERSION >= v"1.8.0"
351+
d1 = dot(A, B)
352+
d2 = dot(as_array(A), as_array(B))
353+
d3 = dot(B, A)
354+
d4 = dot(as_array(B), as_array(A))
355+
@test d1 d2 || d1 d2
356+
@test d3 d4 || d3 d4
357+
end
346358
end
347359
end
348360
end
@@ -372,37 +384,37 @@ end
372384
@test -A_fill === Fill(-A_fill.value, 5)
373385

374386
# FillArray +/- FillArray should construct a new FillArray.
375-
test_addition_and_subtraction((A_fill, B_fill), (A_fill, B_fill), Fill)
387+
test_addition_subtraction_dot((A_fill, B_fill), (A_fill, B_fill), Fill)
376388
test_addition_and_subtraction_dim_mismatch(A_fill, Fill(randn(rng), 5, 2))
377389

378390
# FillArray + Array (etc) should construct a new Array using `getindex`.
379-
A_dense, B_dense = randn(rng, 5), [5, 4, 3, 2, 1]
380-
test_addition_and_subtraction((A_fill, B_fill), (A_dense, B_dense), Array)
391+
B_dense = (randn(rng, 5), [5, 4, 3, 2, 1], fill(Inf, 5), fill(NaN, 5))
392+
test_addition_subtraction_dot((A_fill, B_fill), B_dense, Array)
381393
test_addition_and_subtraction_dim_mismatch(A_fill, randn(rng, 5, 2))
382394

383395
# FillArray + StepLenRange / UnitRange (etc) should yield an AbstractRange.
384396
A_ur, B_ur = 1.0:5.0, 6:10
385-
test_addition_and_subtraction((A_fill, B_fill), (A_ur, B_ur), AbstractRange)
397+
test_addition_subtraction_dot((A_fill, B_fill), (A_ur, B_ur), AbstractRange)
386398
test_addition_and_subtraction_dim_mismatch(A_fill, 1.0:6.0)
387399
test_addition_and_subtraction_dim_mismatch(A_fill, 5:10)
388400

389401
# FillArray + UniformScaling should yield a Matrix in general
390402
As_fill_square = (Fill(randn(rng, Float64), 3, 3), Fill(5, 4, 4))
391403
Bs_us = (UniformScaling(2.3), UniformScaling(3))
392-
test_addition_and_subtraction(As_fill_square, Bs_us, Matrix)
404+
test_addition_subtraction_dot(As_fill_square, Bs_us, Matrix)
393405
As_fill_nonsquare = (Fill(randn(rng, Float64), 3, 2), Fill(5, 3, 4))
394406
for A in As_fill_nonsquare, B in Bs_us
395407
test_addition_and_subtraction_dim_mismatch(A, B)
396408
end
397409

398410
# FillArray + StaticArray should not have ambiguities
399411
A_svec, B_svec = SVector{5}(rand(5)), SVector(1, 2, 3, 4, 5)
400-
test_addition_and_subtraction((A_fill, B_fill, Zeros(5)), (A_svec, B_svec), SVector{5})
412+
test_addition_subtraction_dot((A_fill, B_fill, Zeros(5)), (A_svec, B_svec), SVector{5})
401413

402414
# Issue #224
403415
A_matmat, B_matmat = Fill(rand(3,3),5), [rand(3,3) for n=1:5]
404-
test_addition_and_subtraction((A_matmat,), (A_matmat,), Fill)
405-
test_addition_and_subtraction((B_matmat,), (A_matmat,), Vector)
416+
test_addition_subtraction_dot((A_matmat,), (A_matmat,), Fill)
417+
test_addition_subtraction_dot((B_matmat,), (A_matmat,), Vector)
406418

407419
# Optimizations for Zeros and RectOrDiagonal{<:Any, <:AbstractFill}
408420
As_special_square = (
@@ -412,7 +424,7 @@ end
412424
RectDiagonal(Fill(randn(rng, Float64), 3), 3, 3), RectDiagonal(Fill(3, 4), 4, 4)
413425
)
414426
DiagonalAbstractFill{T} = Diagonal{T, <:AbstractFill{T, 1}}
415-
test_addition_and_subtraction(As_special_square, Bs_us, DiagonalAbstractFill)
427+
test_addition_subtraction_dot(As_special_square, Bs_us, DiagonalAbstractFill)
416428
As_special_nonsquare = (
417429
Zeros(3, 2), Zeros{Int}(3, 4),
418430
Eye(3, 2), Eye{Int}(3, 4),
@@ -537,7 +549,7 @@ end
537549
@test [SVector(1,2)', SVector(2,3)', SVector(3,4)']' * Zeros{Int}(3) === SVector(0,0)
538550
@test_throws DimensionMismatch randn(4)' * Zeros(3)
539551
@test Zeros(5)' * randn(5,3) Zeros(5)'*Zeros(5,3) Zeros(5)'*Ones(5,3) Zeros(3)'
540-
@test Zeros(5)' * randn(5) Zeros(5)' * Zeros(5) Zeros(5)' * Ones(5) 0.0
552+
@test abs(Zeros(5)' * randn(5)) abs(Zeros(5)' * Zeros(5)) abs(Zeros(5)' * Ones(5)) 0.0
541553
@test Zeros(5) * Zeros(6)' Zeros(5,1) * Zeros(6)' Zeros(5,6)
542554
@test randn(5) * Zeros(6)' randn(5,1) * Zeros(6)' Zeros(5,6)
543555
@test Zeros(5) * randn(6)' Zeros(5,6)
@@ -552,7 +564,7 @@ end
552564
@test transpose([1, 2, 3]) * Zeros{Int}(3) === zero(Int)
553565
@test_throws DimensionMismatch transpose(randn(4)) * Zeros(3)
554566
@test transpose(Zeros(5)) * randn(5,3) transpose(Zeros(5))*Zeros(5,3) transpose(Zeros(5))*Ones(5,3) transpose(Zeros(3))
555-
@test transpose(Zeros(5)) * randn(5) transpose(Zeros(5)) * Zeros(5) transpose(Zeros(5)) * Ones(5) 0.0
567+
@test abs(transpose(Zeros(5)) * randn(5)) abs(transpose(Zeros(5)) * Zeros(5)) abs(transpose(Zeros(5)) * Ones(5)) 0.0
556568
@test randn(5) * transpose(Zeros(6)) randn(5,1) * transpose(Zeros(6)) Zeros(5,6)
557569
@test Zeros(5) * transpose(randn(6)) Zeros(5,6)
558570
@test transpose(randn(5)) * Zeros(5) 0.0
@@ -571,13 +583,13 @@ end
571583
@test +(z1) === z1
572584
@test -(z1) === z1
573585

574-
test_addition_and_subtraction((z1, z2), (z1, z2), Zeros)
586+
test_addition_subtraction_dot((z1, z2), (z1, z2), Zeros)
575587
test_addition_and_subtraction_dim_mismatch(z1, Zeros{Float64}(4, 2))
576588
end
577589

578590
# `Zeros` +/- `Fill`s should yield `Fills`.
579591
fill1, fill2 = Fill(5.0, 4), Fill(5, 4)
580-
test_addition_and_subtraction((z1, z2), (fill1, fill2), Fill)
592+
test_addition_subtraction_dot((z1, z2), (fill1, fill2), Fill)
581593
test_addition_and_subtraction_dim_mismatch(z1, Fill(5, 5))
582594

583595
X = randn(3, 5)
@@ -1326,17 +1338,19 @@ end
13261338
Random.seed!(5)
13271339
u = rand(n)
13281340
v = rand(n)
1341+
c = rand(ComplexF16, n)
13291342

13301343
@test dot(u, D, v) == dot(u, v)
13311344
@test dot(u, 2D, v) == 2dot(u, v)
13321345
@test dot(u, Z, v) == 0
13331346

1334-
@test dot(Zeros(5), Zeros{ComplexF16}(5)) zero(ComplexF64)
1335-
@test dot(Zeros(5), Ones{ComplexF16}(5)) zero(ComplexF64)
1336-
@test dot(Ones{ComplexF16}(5), Zeros(5)) zero(ComplexF64)
1337-
@test dot(randn(5), Zeros{ComplexF16}(5)) dot(Zeros{ComplexF16}(5), randn(5)) zero(ComplexF64)
1347+
@test @inferred(dot(Zeros(5), Zeros{ComplexF16}(5))) zero(ComplexF64)
1348+
@test @inferred(dot(Zeros(5), Ones{ComplexF16}(5))) zero(ComplexF64)
1349+
@test abs(@inferred(dot(Ones{ComplexF16}(5), Zeros(5)))) abs(@inferred(dot(randn(5), Zeros{ComplexF16}(5)))) abs(@inferred(dot(Zeros{ComplexF16}(5), randn(5)))) zero(Float64) # 0.0 !≡ -0.0
1350+
@test @inferred(dot(c, Fill(1 + im, 15))) (@inferred(dot(Fill(1 + im, 15), c)))' @inferred(dot(c, fill(1 + im, 15)))
13381351

1339-
@test dot(Fill(1,5), Fill(2.0,5)) 10.0
1352+
@test @inferred(dot(Fill(1,5), Fill(2.0,5))) 10.0
1353+
@test_skip dot(Fill(true,5), Fill(Int8(1),5)) isa Int8 # not working at present
13401354

13411355
let N = 2^big(1000) # fast dot for fast sum
13421356
@test dot(Fill(2,N),1:N) == dot(Fill(2,N),1:N) == dot(1:N,Fill(2,N)) == 2*sum(1:N)

0 commit comments

Comments
 (0)