Skip to content

Commit 554a9c0

Browse files
authored
Tracker.data overload (#260)
1 parent 5515922 commit 554a9c0

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <[email protected]>"]
4-
version = "0.15.12"
4+
version = "0.15.13"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

ext/ComponentArraysTrackerExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ end
1010

1111
Tracker.extract_grad!(ca::ComponentArray) = Tracker.extract_grad!(getdata(ca))
1212

13+
Tracker.data(ca::ComponentArray) = ComponentArray(Tracker.data(getdata(ca)), getaxes(ca))
14+
1315
function Base.materialize(bc::Base.Broadcast.Broadcasted{Tracker.TrackedStyle, Nothing,
1416
typeof(zero), <:Tuple{<:ComponentVector}})
1517
ca = first(bc.args)

test/autodiff_tests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,13 @@ end
117117

118118
@test Δ isa AbstractVector{Float64}
119119
end
120+
121+
@testset "Tracker untrack" begin
122+
ps = Tracker.param(ComponentArray(; a = rand(2)))
123+
@test eltype(getdata(ps)) <: Tracker.TrackedReal{Float64}
124+
125+
ps_data = Tracker.data(ps)
126+
@test !(eltype(getdata(ps_data)) <: Tracker.TrackedReal{Float64})
127+
@test eltype(getdata(ps_data)) <: Float64
128+
end
129+

0 commit comments

Comments
 (0)