Skip to content

Commit 30b8e0c

Browse files
mcabbottdkarrasch
authored andcommitted
Add 3-arg * methods (JuliaLang#37898)
This addresses the simplest part of #12065 (optimizing * for optimal matrix order), by adding some methods for * with 3 arguments, where this can be done more efficiently than working left-to-right. Co-authored-by: Daniel Karrasch <[email protected]>
1 parent 203f7ec commit 30b8e0c

File tree

2 files changed

+237
-0
lines changed

2 files changed

+237
-0
lines changed

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,3 +1081,141 @@ function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMat
10811081
end # inbounds
10821082
C
10831083
end
1084+
1085+
const RealOrComplex = Union{Real,Complex}
1086+
1087+
# Three-argument *
1088+
"""
1089+
*(A, B::AbstractMatrix, C)
1090+
A * B * C * D
1091+
1092+
Chained multiplication of 3 or 4 matrices is done in the most efficient sequence,
1093+
based on the sizes of the arrays. That is, the number of scalar multiplications needed
1094+
for `(A * B) * C` (with 3 dense matrices) is compared to that for `A * (B * C)`
1095+
to choose which of these to execute.
1096+
1097+
If the last factor is a vector, or the first a transposed vector, then it is efficient
1098+
to deal with these first. In particular `x' * B * y` means `(x' * B) * y`
1099+
for an ordinary column-major `B::Matrix`. Unlike `dot(x, B, y)`, this
1100+
allocates an intermediate array.
1101+
1102+
If the first or last factor is a number, this will be fused with the matrix
1103+
multiplication, using 5-arg [`mul!`](@ref).
1104+
1105+
See also [`muladd`](@ref), [`dot`](@ref).
1106+
1107+
!!! compat "Julia 1.7"
1108+
These optimisations require at least Julia 1.7.
1109+
"""
1110+
*(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector) = A * (B*x)
1111+
1112+
*(tu::AdjOrTransAbsVec, B::AbstractMatrix, v::AbstractVector) = (tu*B) * v
1113+
*(tu::AdjOrTransAbsVec, B::AdjOrTransAbsMat, v::AbstractVector) = tu * (B*v)
1114+
1115+
*(A::AbstractMatrix, x::AbstractVector, γ::Number) = mat_vec_scalar(A,x,γ)
1116+
*(A::AbstractMatrix, B::AbstractMatrix, γ::Number) = mat_mat_scalar(A,B,γ)
1117+
*::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractVector{<:RealOrComplex}) =
1118+
mat_vec_scalar(B,C,α)
1119+
*::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}) =
1120+
mat_mat_scalar(B,C,α)
1121+
1122+
*::Number, u::AbstractVector, tv::AdjOrTransAbsVec) = broadcast(*, α, u, tv)
1123+
*(u::AbstractVector, tv::AdjOrTransAbsVec, γ::Number) = broadcast(*, u, tv, γ)
1124+
*(u::AbstractVector, tv::AdjOrTransAbsVec, C::AbstractMatrix) = u * (tv*C)
1125+
1126+
*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix) = _tri_matmul(A,B,C)
1127+
*(tv::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix) = (tv*B) * C
1128+
1129+
function _tri_matmul(A,B,C,δ=nothing)
1130+
n,m = size(A)
1131+
# m,k == size(B)
1132+
k,l = size(C)
1133+
costAB_C = n*m*k + n*k*l # multiplications, allocations n*k + n*l
1134+
costA_BC = m*k*l + n*m*l # m*l + n*l
1135+
if costA_BC < costAB_C
1136+
isnothing(δ) ? A * (B*C) : A * mat_mat_scalar(B,C,δ)
1137+
else
1138+
isnothing(δ) ? (A*B) * C : mat_mat_scalar(A*B, C, δ)
1139+
end
1140+
end
1141+
1142+
# Fast path for two arrays * one scalar is opt-in, via mat_vec_scalar and mat_mat_scalar.
1143+
1144+
mat_vec_scalar(A, x, γ) = A * (x .* γ) # fallback
1145+
mat_vec_scalar(A::StridedMaybeAdjOrTransMat, x::StridedVector, γ) = _mat_vec_scalar(A, x, γ)
1146+
mat_vec_scalar(A::AdjOrTransAbsVec, x::StridedVector, γ) = (A * x) * γ
1147+
1148+
function _mat_vec_scalar(A, x, γ)
1149+
T = promote_type(eltype(A), eltype(x), typeof(γ))
1150+
C = similar(A, T, axes(A,1))
1151+
mul!(C, A, x, γ, false)
1152+
end
1153+
1154+
mat_mat_scalar(A, B, γ) = (A*B) .* γ # fallback
1155+
mat_mat_scalar(A::StridedMaybeAdjOrTransMat, B::StridedMaybeAdjOrTransMat, γ) =
1156+
_mat_mat_scalar(A, B, γ)
1157+
1158+
function _mat_mat_scalar(A, B, γ)
1159+
T = promote_type(eltype(A), eltype(B), typeof(γ))
1160+
C = similar(A, T, axes(A,1), axes(B,2))
1161+
mul!(C, A, B, γ, false)
1162+
end
1163+
1164+
mat_mat_scalar(A::AdjointAbsVec, B, γ) =' .* (A * B)')' # preserving order, adjoint reverses
1165+
mat_mat_scalar(A::AdjointAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) =
1166+
mat_vec_scalar(B', A', γ')'
1167+
1168+
mat_mat_scalar(A::TransposeAbsVec, B, γ) = transpose.* transpose(A * B))
1169+
mat_mat_scalar(A::TransposeAbsVec{<:RealOrComplex}, B::StridedMaybeAdjOrTransMat{<:RealOrComplex}, γ::RealOrComplex) =
1170+
transpose(mat_vec_scalar(transpose(B), transpose(A), γ))
1171+
1172+
1173+
# Four-argument *, by type
1174+
*::Number, β::Number, C::AbstractMatrix, x::AbstractVector) =*β) * C * x
1175+
*::Number, β::Number, C::AbstractMatrix, D::AbstractMatrix) =*β) * C * D
1176+
*::Number, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = α * B * (C*x)
1177+
*::Number, vt::AdjOrTransAbsVec, C::AbstractMatrix, x::AbstractVector) = α * (vt*C*x)
1178+
*::RealOrComplex, vt::AdjOrTransAbsVec{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) =
1179+
*vt*C) * D # solves an ambiguity
1180+
1181+
*(A::AbstractMatrix, x::AbstractVector, γ::Number, δ::Number) = A * x **δ)
1182+
*(A::AbstractMatrix, B::AbstractMatrix, γ::Number, δ::Number) = A * B **δ)
1183+
*(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector, δ::Number, ) = A * (B*x*δ)
1184+
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, x::AbstractVector, δ::Number) = (vt*B*x) * δ
1185+
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = (vt*B) * C * δ
1186+
1187+
*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = A * B * (C*x)
1188+
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) = (vt*B) * C * D
1189+
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = vt * B * (C*x)
1190+
1191+
# Four-argument *, by size
1192+
*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, δ::Number) = _tri_matmul(A,B,C,δ)
1193+
*::RealOrComplex, B::AbstractMatrix{<:RealOrComplex}, C::AbstractMatrix{<:RealOrComplex}, D::AbstractMatrix{<:RealOrComplex}) =
1194+
_tri_matmul(B,C,D,α)
1195+
*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) =
1196+
_quad_matmul(A,B,C,D)
1197+
1198+
function _quad_matmul(A,B,C,D)
1199+
c1 = _mul_cost((A,B),(C,D))
1200+
c2 = _mul_cost(((A,B),C),D)
1201+
c3 = _mul_cost(A,(B,(C,D)))
1202+
c4 = _mul_cost((A,(B,C)),D)
1203+
c5 = _mul_cost(A,((B,C),D))
1204+
cmin = min(c1,c2,c3,c4,c5)
1205+
if c1 == cmin
1206+
(A*B) * (C*D)
1207+
elseif c2 == cmin
1208+
((A*B) * C) * D
1209+
elseif c3 == cmin
1210+
A * (B * (C*D))
1211+
elseif c4 == cmin
1212+
(A * (B*C)) * D
1213+
else
1214+
A * ((B*C) * D)
1215+
end
1216+
end
1217+
@inline _mul_cost(A::AbstractMatrix) = 0
1218+
@inline _mul_cost((A,B)::Tuple) = _mul_cost(A,B)
1219+
@inline _mul_cost(A,B) = _mul_cost(A) + _mul_cost(B) + *(_mul_sizes(A)..., last(_mul_sizes(B)))
1220+
@inline _mul_sizes(A::AbstractMatrix) = size(A)
1221+
@inline _mul_sizes((A,B)::Tuple) = first(_mul_sizes(A)), last(_mul_sizes(B))

stdlib/LinearAlgebra/test/matmul.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,4 +766,103 @@ end
766766
@test Matrix{Int}(undef, 2, 0) * Matrix{Int}(undef, 0, 3) == zeros(Int, 2, 3)
767767
end
768768

769+
@testset "3-arg *, order by type" begin
770+
x = [1, 2im]
771+
y = [im, 20, 30+40im]
772+
z = [-1, 200+im, -3]
773+
A = [1 2 3im; 4 5 6+im]
774+
B = [-10 -20; -30 -40]
775+
a = 3 + im * round(Int, 10^6*(pi-3))
776+
b = 123
777+
778+
@test x'*A*y == (x'*A)*y == x'*(A*y)
779+
@test y'*A'*x == (y'*A')*x == y'*(A'*x)
780+
@test y'*transpose(A)*x == (y'*transpose(A))*x == y'*(transpose(A)*x)
781+
782+
@test B*A*y == (B*A)*y == B*(A*y)
783+
784+
@test a*A*y == (a*A)*y == a*(A*y)
785+
@test A*y*a == (A*y)*a == A*(y*a)
786+
787+
@test a*B*A == (a*B)*A == a*(B*A)
788+
@test B*A*a == (B*A)*a == B*(A*a)
789+
790+
@test a*y'*z == (a*y')*z == a*(y'*z)
791+
@test y'*z*a == (y'*z)*a == y'*(z*a)
792+
793+
@test a*y*z' == (a*y)*z' == a*(y*z')
794+
@test y*z'*a == (y*z')*a == y*(z'*a)
795+
796+
@test a*x'*A == (a*x')*A == a*(x'*A)
797+
@test x'*A*a == (x'*A)*a == x'*(A*a)
798+
@test a*x'*A isa Adjoint{<:Any, <:Vector}
799+
800+
@test a*transpose(x)*A == (a*transpose(x))*A == a*(transpose(x)*A)
801+
@test transpose(x)*A*a == (transpose(x)*A)*a == transpose(x)*(A*a)
802+
@test a*transpose(x)*A isa Transpose{<:Any, <:Vector}
803+
804+
@test x'*B*A == (x'*B)*A == x'*(B*A)
805+
@test x'*B*A isa Adjoint{<:Any, <:Vector}
806+
807+
@test y*x'*A == (y*x')*A == y*(x'*A)
808+
y31 = reshape(y,3,1)
809+
@test y31*x'*A == (y31*x')*A == y31*(x'*A)
810+
811+
vm = [rand(1:9,2,2) for _ in 1:3]
812+
Mm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3]
813+
814+
@test vm' * Mm * vm == (vm' * Mm) * vm == vm' * (Mm * vm)
815+
@test Mm * Mm' * vm == (Mm * Mm') * vm == Mm * (Mm' * vm)
816+
@test vm' * Mm * Mm == (vm' * Mm) * Mm == vm' * (Mm * Mm)
817+
@test Mm * Mm' * Mm == (Mm * Mm') * Mm == Mm * (Mm' * Mm)
818+
end
819+
820+
@testset "3-arg *, order by size" begin
821+
M44 = randn(4,4)
822+
M24 = randn(2,4)
823+
M42 = randn(4,2)
824+
@test M44*M44*M44 (M44*M44)*M44 M44*(M44*M44)
825+
@test M42*M24*M44 (M42*M24)*M44 M42*(M24*M44)
826+
@test M44*M42*M24 (M44*M42)*M24 M44*(M42*M24)
827+
end
828+
829+
@testset "4-arg *, by type" begin
830+
y = [im, 20, 30+40im]
831+
z = [-1, 200+im, -3]
832+
a = 3 + im * round(Int, 10^6*(pi-3))
833+
b = 123
834+
M = rand(vcat(1:9, im.*[1,2,3]), 3,3)
835+
N = rand(vcat(1:9, im.*[1,2,3]), 3,3)
836+
837+
@test a * b * M * y == (a*b) * (M*y)
838+
@test a * b * M * N == (a*b) * (M*N)
839+
@test a * M * N * y == (a*M) * (N*y)
840+
@test a * y' * M * z == (a*y') * (M*z)
841+
@test a * y' * M * N == (a*y') * (M*N)
842+
843+
@test M * y * a * b == (M*y) * (a*b)
844+
@test M * N * a * b == (M*N) * (a*b)
845+
@test M * N * y * a == (a*M) * (N*y)
846+
@test y' * M * z * a == (a*y') * (M*z)
847+
@test y' * M * N * a == (a*y') * (M*N)
848+
849+
@test M * N * conj(M) * y == (M*N) * (conj(M)*y)
850+
@test y' * M * N * conj(M) == (y'*M) * (N*conj(M))
851+
@test y' * M * N * z == (y'*M) * (N*z)
852+
end
853+
854+
@testset "4-arg *, by size" begin
855+
for shift in 1:5
856+
s1,s2,s3,s4,s5 = circshift(3:7, shift)
857+
a=randn(s1,s2); b=randn(s2,s3); c=randn(s3,s4); d=randn(s4,s5)
858+
859+
# _quad_matmul
860+
@test *(a,b,c,d) (a*b) * (c*d)
861+
862+
# _tri_matmul(A,B,B,δ)
863+
@test *(11.1,b,c,d) (11.1*b) * (c*d)
864+
@test *(a,b,c,99.9) (a*b) * (c*99.9)
865+
end
866+
end
867+
769868
end # module TestMatmul

0 commit comments

Comments
 (0)