Skip to content

Commit 567f04c

Browse files
committed
Integrate ArrayInterface
1 parent bbf1f27 commit 567f04c

File tree

5 files changed

+23
-1
lines changed

5 files changed

+23
-1
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
11
name = "AxisArrays"
22
uuid = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
3-
version = "0.4.6"
3+
version = "0.4.7"
44

55
[deps]
6+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
67
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
78
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
89
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
910
RangeArrays = "b3c3ace0-ae52-54e7-9d0b-2c1406fd6b9d"
11+
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1012

1113
[compat]
14+
ArrayInterface = "6"
1215
IntervalSets = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7"
1316
IterTools = "1"
1417
RangeArrays = "0.3"
18+
Static = "0.7"
1519
julia = "1"
1620

1721
[extras]
22+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1823
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1924
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2025
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
26+
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
2127
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2228
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
2329

src/AxisArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ VERSION < v"0.7.0-beta2.199" && __precompile__()
22

33
module AxisArrays
44

5+
using ArrayInterface
56
using Base: tail
67
import Base.Iterators: repeated
78
using RangeArrays, IntervalSets
89
using IterTools
910
using Dates
11+
using Static
1012

1113
function axes end
1214

src/core.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ struct AxisArray{T,N,D,Ax} <: AbstractArray{T,N}
178178
AxisArray{T,N,D,Ax}(data::AbstractArray{T,N}, axs::Tuple{Vararg{Axis,N}}) where {T,N,D,Ax} = new{T,N,D,Ax}(data, axs)
179179
end
180180

181+
ArrayInterface.is_forwarding_wrapper(@nospecialize T::Type{<:AxisArray}) = true
182+
181183
"""
182184
AxisMatrix{T}
183185
Alias for [`AxisArray{T,2,D,Ax}`](@ref AxisArray).
@@ -281,6 +283,13 @@ axisnames() = ()
281283

282284
axisname(::Union{Type{<:Axis{name}},Axis{name}}) where {name} = name
283285

286+
ArrayInterface.known_dimnames(@nospecialize T::Type{<:Axis}) = (axisname(T),)
287+
function ArrayInterface.known_dimnames(@nospecialize T::Type{<:AxisArray})
288+
ArrayInterface.map_tuple_type(axisname, fieldtype(T, :axes))
289+
end
290+
ArrayInterface.dimnames(::Axis{name}) where {name} = (StaticSymbol(name),)
291+
ArrayInterface.dimnames(x::AxisArray) = static(ArrayInterface.known_dimnames(x))
292+
284293
# Axis definitions
285294
"""
286295
axisdim(::AxisArray, ::Axis) -> Int

test/core.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,6 @@ end
347347

348348
C = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}([:a,:b,:c]), Axis{:x}(["a","b","c","d","e"]))
349349
@test occursin(r"axes:\n\s+:y,", summary(C))
350+
351+
# ensure that ArrayInterface.is_forwarding_wrapper is properly propagating across wrapper
352+
@test @inferred(ArrayInterface.strides(AxisArray([1 2 3]'))) === (1, static(1))

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using AxisArrays
2+
using ArrayInterface
23
using Dates
34
using Test
45
using Random
6+
using Static
57
import IterTools
68

79
@testset "AxisArrays" begin

0 commit comments

Comments
 (0)