Skip to content

Commit 98a7724

Browse files
authored
[ITensors] Fix broken broadcast operation on GPU (#1532)
1 parent eba5e17 commit 98a7724

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

NDTensors/test/test_dense.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,20 @@ NDTensors.dim(i::MyInd) = i.dim
8383
@test A[2, 2] == Aview[1, 1]
8484
end
8585

86+
## Testing A .= α .* B .+ β .* A
87+
C = copy(A)
88+
@allowscalar fill!(B, zero(elt))
89+
β = elt(2)
90+
α = elt(1)
91+
permutedims!!(A, B, (1, 2), (a, b) -> +(*(β, a), *(α, b)))
92+
@allowscalar 2 .* C == A
93+
randn!(B)
94+
C = copy(A)
95+
A = permutedims!!(A, B, (1, 2), (a, b) -> +(*(β, a), *(α, b)))
96+
@allowscalar for i in 1:3, j in 1:4
97+
@test A[i, j] == α * B[i, j] + β * C[i, j]
98+
end
99+
86100
## add elt around 2.0 to preserve the eltype of A.
87101
@test data(A * elt(2.0)) == data(elt(2.0) * A)
88102

src/broadcast.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,13 @@ end
395395
# C .= β .* C .+ α .* A .* B
396396
#
397397

398+
struct axpby{Alpha,Beta} <: Function
399+
alpha::Alpha
400+
beta::Beta
401+
end
402+
403+
(f::axpby)(y, x) = x * f.alpha + y * f.beta
404+
398405
## TODO this code doesn't actually get called
399406
function Base.copyto!(
400407
T::ITensor,
@@ -414,7 +421,9 @@ function Base.copyto!(
414421
A, C = C, A
415422
end
416423
if !isnothing(A) && !isnothing(C) && !isnothing(α) && !isnothing(β)
417-
map!((r, t) -> β * r + α * t, T, T, A)
424+
# The following fails to compile on some GPU backends.
425+
# map!((r, t) -> β * r + α * t, T, T, A)
426+
map!(axpby(α, β), T, T, A)
418427
else
419428
bc_bc_α = find_type(Broadcasted, bc_α.args)
420429
if isnothing(α)

0 commit comments

Comments
 (0)