Skip to content

Commit ce15a14

Browse files
author
Pietro Vertechi
committed
test adapt
1 parent f9c4c96 commit ce15a14

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ Tables = "1"
1212
julia = "1"
1313

1414
[extras]
15+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1516
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1617
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
1718
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1819
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
1920

2021
[targets]
21-
test = ["Test", "OffsetArrays", "PooledArrays", "WeakRefStrings"]
22+
test = ["Test", "GPUArrays", "OffsetArrays", "PooledArrays", "WeakRefStrings"]

src/compatibility.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# GPU storage
2+
import Adapt
3+
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)
4+
5+
# Table interface
6+
import Tables
7+
8+
Tables.istable(::Type{<:StructVector}) = true
9+
Tables.rowaccess(::Type{<:StructVector}) = true
10+
Tables.columnaccess(::Type{<:StructVector}) = true
11+
12+
Tables.rows(s::StructVector) = s
13+
Tables.columns(s::StructVector) = fieldarrays(s)
14+
15+
Tables.schema(s::StructVector) = Tables.Schema(staticschema(eltype(s)))
16+
17+
# refarray interface
18+
import DataAPI: refarray, refvalue
19+
20+
refarray(s::StructArray) = StructArray(map(refarray, fieldarrays(s)))
21+
22+
function refvalue(s::StructArray{T}, v::Tup) where {T}
23+
createinstance(T, map(refvalue, fieldarrays(s), v)...)
24+
end

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using StructArrays: staticschema, iscompatible, _promote_typejoin, append!!
33
using OffsetArrays: OffsetArray
44
import Tables, PooledArrays, WeakRefStrings
55
using DataAPI: refarray, refvalue
6+
using Adapt: adapt
7+
import GPUArrays
68
using Test
79

810
@testset "index" begin
@@ -700,3 +702,13 @@ end
700702
@test vcat(dest, StructVector(makeitr())) == append!!(copy(dest), makeitr())
701703
end
702704
end
705+
706+
@testset "adapt" begin
707+
s = StructArray(a = 1:10, b = StructArray(c = 1:10, d = 1:10))
708+
t = adapt(Array, s)
709+
@test propertynames(t) == (:a, :b)
710+
@test s == t
711+
@test t.a isa Array
712+
@test t.b.c isa Array
713+
@test t.b.d isa Array
714+
end

0 commit comments

Comments
 (0)