@@ -1081,3 +1081,141 @@ function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMat
1081
1081
end # inbounds
1082
1082
C
1083
1083
end
1084
+
1085
+ const RealOrComplex = Union{Real,Complex}
1086
+
1087
+ # Three-argument *
1088
+ """
1089
+ *(A, B::AbstractMatrix, C)
1090
+ A * B * C * D
1091
+
1092
+ Chained multiplication of 3 or 4 matrices is done in the most efficient sequence,
1093
+ based on the sizes of the arrays. That is, the number of scalar multiplications needed
1094
+ for `(A * B) * C` (with 3 dense matrices) is compared to that for `A * (B * C)`
1095
+ to choose which of these to execute.
1096
+
1097
+ If the last factor is a vector, or the first a transposed vector, then it is efficient
1098
+ to deal with these first. In particular `x' * B * y` means `(x' * B) * y`
1099
+ for an ordinary column-major `B::Matrix`. Unlike `dot(x, B, y)`, this
1100
+ allocates an intermediate array.
1101
+
1102
+ If the first or last factor is a number, this will be fused with the matrix
1103
+ multiplication, using 5-arg [`mul!`](@ref).
1104
+
1105
+ See also [`muladd`](@ref), [`dot`](@ref).
1106
+
1107
+ !!! compat "Julia 1.7"
1108
+ These optimisations require at least Julia 1.7.
1109
+ """
1110
+ * (A:: AbstractMatrix , B:: AbstractMatrix , x:: AbstractVector ) = A * (B* x)
1111
+
1112
+ * (tu:: AdjOrTransAbsVec , B:: AbstractMatrix , v:: AbstractVector ) = (tu* B) * v
1113
+ * (tu:: AdjOrTransAbsVec , B:: AdjOrTransAbsMat , v:: AbstractVector ) = tu * (B* v)
1114
+
1115
+ * (A:: AbstractMatrix , x:: AbstractVector , γ:: Number ) = mat_vec_scalar (A,x,γ)
1116
+ * (A:: AbstractMatrix , B:: AbstractMatrix , γ:: Number ) = mat_mat_scalar (A,B,γ)
1117
+ * (α:: RealOrComplex , B:: AbstractMatrix{<:RealOrComplex} , C:: AbstractVector{<:RealOrComplex} ) =
1118
+ mat_vec_scalar (B,C,α)
1119
+ * (α:: RealOrComplex , B:: AbstractMatrix{<:RealOrComplex} , C:: AbstractMatrix{<:RealOrComplex} ) =
1120
+ mat_mat_scalar (B,C,α)
1121
+
1122
+ * (α:: Number , u:: AbstractVector , tv:: AdjOrTransAbsVec ) = broadcast (* , α, u, tv)
1123
+ * (u:: AbstractVector , tv:: AdjOrTransAbsVec , γ:: Number ) = broadcast (* , u, tv, γ)
1124
+ * (u:: AbstractVector , tv:: AdjOrTransAbsVec , C:: AbstractMatrix ) = u * (tv* C)
1125
+
1126
+ * (A:: AbstractMatrix , B:: AbstractMatrix , C:: AbstractMatrix ) = _tri_matmul (A,B,C)
1127
+ * (tv:: AdjOrTransAbsVec , B:: AbstractMatrix , C:: AbstractMatrix ) = (tv* B) * C
1128
+
1129
+ function _tri_matmul (A,B,C,δ= nothing )
1130
+ n,m = size (A)
1131
+ # m,k == size(B)
1132
+ k,l = size (C)
1133
+ costAB_C = n* m* k + n* k* l # multiplications, allocations n*k + n*l
1134
+ costA_BC = m* k* l + n* m* l # m*l + n*l
1135
+ if costA_BC < costAB_C
1136
+ isnothing (δ) ? A * (B* C) : A * mat_mat_scalar (B,C,δ)
1137
+ else
1138
+ isnothing (δ) ? (A* B) * C : mat_mat_scalar (A* B, C, δ)
1139
+ end
1140
+ end
1141
+
1142
+ # Fast path for two arrays * one scalar is opt-in, via mat_vec_scalar and mat_mat_scalar.
1143
+
1144
+ mat_vec_scalar (A, x, γ) = A * (x .* γ) # fallback
1145
+ mat_vec_scalar (A:: StridedMaybeAdjOrTransMat , x:: StridedVector , γ) = _mat_vec_scalar (A, x, γ)
1146
+ mat_vec_scalar (A:: AdjOrTransAbsVec , x:: StridedVector , γ) = (A * x) * γ
1147
+
1148
+ function _mat_vec_scalar (A, x, γ)
1149
+ T = promote_type (eltype (A), eltype (x), typeof (γ))
1150
+ C = similar (A, T, axes (A,1 ))
1151
+ mul! (C, A, x, γ, false )
1152
+ end
1153
+
1154
+ mat_mat_scalar (A, B, γ) = (A* B) .* γ # fallback
1155
+ mat_mat_scalar (A:: StridedMaybeAdjOrTransMat , B:: StridedMaybeAdjOrTransMat , γ) =
1156
+ _mat_mat_scalar (A, B, γ)
1157
+
1158
+ function _mat_mat_scalar (A, B, γ)
1159
+ T = promote_type (eltype (A), eltype (B), typeof (γ))
1160
+ C = similar (A, T, axes (A,1 ), axes (B,2 ))
1161
+ mul! (C, A, B, γ, false )
1162
+ end
1163
+
1164
+ mat_mat_scalar (A:: AdjointAbsVec , B, γ) = (γ' .* (A * B)' )' # preserving order, adjoint reverses
1165
+ mat_mat_scalar (A:: AdjointAbsVec{<:RealOrComplex} , B:: StridedMaybeAdjOrTransMat{<:RealOrComplex} , γ:: RealOrComplex ) =
1166
+ mat_vec_scalar (B' , A' , γ' )'
1167
+
1168
+ mat_mat_scalar (A:: TransposeAbsVec , B, γ) = transpose (γ .* transpose (A * B))
1169
+ mat_mat_scalar (A:: TransposeAbsVec{<:RealOrComplex} , B:: StridedMaybeAdjOrTransMat{<:RealOrComplex} , γ:: RealOrComplex ) =
1170
+ transpose (mat_vec_scalar (transpose (B), transpose (A), γ))
1171
+
1172
+
1173
+ # Four-argument *, by type
1174
+ * (α:: Number , β:: Number , C:: AbstractMatrix , x:: AbstractVector ) = (α* β) * C * x
1175
+ * (α:: Number , β:: Number , C:: AbstractMatrix , D:: AbstractMatrix ) = (α* β) * C * D
1176
+ * (α:: Number , B:: AbstractMatrix , C:: AbstractMatrix , x:: AbstractVector ) = α * B * (C* x)
1177
+ * (α:: Number , vt:: AdjOrTransAbsVec , C:: AbstractMatrix , x:: AbstractVector ) = α * (vt* C* x)
1178
+ * (α:: RealOrComplex , vt:: AdjOrTransAbsVec{<:RealOrComplex} , C:: AbstractMatrix{<:RealOrComplex} , D:: AbstractMatrix{<:RealOrComplex} ) =
1179
+ (α* vt* C) * D # solves an ambiguity
1180
+
1181
+ * (A:: AbstractMatrix , x:: AbstractVector , γ:: Number , δ:: Number ) = A * x * (γ* δ)
1182
+ * (A:: AbstractMatrix , B:: AbstractMatrix , γ:: Number , δ:: Number ) = A * B * (γ* δ)
1183
+ * (A:: AbstractMatrix , B:: AbstractMatrix , x:: AbstractVector , δ:: Number , ) = A * (B* x* δ)
1184
+ * (vt:: AdjOrTransAbsVec , B:: AbstractMatrix , x:: AbstractVector , δ:: Number ) = (vt* B* x) * δ
1185
+ * (vt:: AdjOrTransAbsVec , B:: AbstractMatrix , C:: AbstractMatrix , δ:: Number ) = (vt* B) * C * δ
1186
+
1187
+ * (A:: AbstractMatrix , B:: AbstractMatrix , C:: AbstractMatrix , x:: AbstractVector ) = A * B * (C* x)
1188
+ * (vt:: AdjOrTransAbsVec , B:: AbstractMatrix , C:: AbstractMatrix , D:: AbstractMatrix ) = (vt* B) * C * D
1189
+ * (vt:: AdjOrTransAbsVec , B:: AbstractMatrix , C:: AbstractMatrix , x:: AbstractVector ) = vt * B * (C* x)
1190
+
1191
+ # Four-argument *, by size
1192
+ * (A:: AbstractMatrix , B:: AbstractMatrix , C:: AbstractMatrix , δ:: Number ) = _tri_matmul (A,B,C,δ)
1193
+ * (α:: RealOrComplex , B:: AbstractMatrix{<:RealOrComplex} , C:: AbstractMatrix{<:RealOrComplex} , D:: AbstractMatrix{<:RealOrComplex} ) =
1194
+ _tri_matmul (B,C,D,α)
1195
+ * (A:: AbstractMatrix , B:: AbstractMatrix , C:: AbstractMatrix , D:: AbstractMatrix ) =
1196
+ _quad_matmul (A,B,C,D)
1197
+
1198
+ function _quad_matmul (A,B,C,D)
1199
+ c1 = _mul_cost ((A,B),(C,D))
1200
+ c2 = _mul_cost (((A,B),C),D)
1201
+ c3 = _mul_cost (A,(B,(C,D)))
1202
+ c4 = _mul_cost ((A,(B,C)),D)
1203
+ c5 = _mul_cost (A,((B,C),D))
1204
+ cmin = min (c1,c2,c3,c4,c5)
1205
+ if c1 == cmin
1206
+ (A* B) * (C* D)
1207
+ elseif c2 == cmin
1208
+ ((A* B) * C) * D
1209
+ elseif c3 == cmin
1210
+ A * (B * (C* D))
1211
+ elseif c4 == cmin
1212
+ (A * (B* C)) * D
1213
+ else
1214
+ A * ((B* C) * D)
1215
+ end
1216
+ end
1217
+ @inline _mul_cost (A:: AbstractMatrix ) = 0
1218
+ @inline _mul_cost ((A,B):: Tuple ) = _mul_cost (A,B)
1219
+ @inline _mul_cost (A,B) = _mul_cost (A) + _mul_cost (B) + * (_mul_sizes (A)... , last (_mul_sizes (B)))
1220
+ @inline _mul_sizes (A:: AbstractMatrix ) = size (A)
1221
+ @inline _mul_sizes ((A,B):: Tuple ) = first (_mul_sizes (A)), last (_mul_sizes (B))
0 commit comments