diff --git a/Project.toml b/Project.toml index 873eca7..94a67ba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,23 +1,29 @@ name = "AxisArrays" uuid = "39de3d68-74b9-583c-8d2d-e117c070f3a9" -version = "0.4.6" +version = "0.4.7" [deps] +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" RangeArrays = "b3c3ace0-ae52-54e7-9d0b-2c1406fd6b9d" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" [compat] +ArrayInterface = "6" IntervalSets = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7" IterTools = "1" RangeArrays = "0.3" +Static = "0.7" julia = "1" [extras] +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" diff --git a/src/AxisArrays.jl b/src/AxisArrays.jl index 2a04b99..c2462de 100644 --- a/src/AxisArrays.jl +++ b/src/AxisArrays.jl @@ -2,11 +2,13 @@ VERSION < v"0.7.0-beta2.199" && __precompile__() module AxisArrays +using ArrayInterface using Base: tail import Base.Iterators: repeated using RangeArrays, IntervalSets using IterTools using Dates +using Static function axes end diff --git a/src/core.jl b/src/core.jl index 3871157..16cd37a 100644 --- a/src/core.jl +++ b/src/core.jl @@ -178,6 +178,8 @@ struct AxisArray{T,N,D,Ax} <: AbstractArray{T,N} AxisArray{T,N,D,Ax}(data::AbstractArray{T,N}, axs::Tuple{Vararg{Axis,N}}) where {T,N,D,Ax} = new{T,N,D,Ax}(data, axs) end +ArrayInterface.is_forwarding_wrapper(@nospecialize T::Type{<:AxisArray}) = true + """ AxisMatrix{T} Alias for [`AxisArray{T,2,D,Ax}`](@ref AxisArray). @@ -281,6 +283,13 @@ axisnames() = () axisname(::Union{Type{<:Axis{name}},Axis{name}}) where {name} = name +ArrayInterface.known_dimnames(@nospecialize T::Type{<:Axis}) = (axisname(T),) +function ArrayInterface.known_dimnames(@nospecialize T::Type{<:AxisArray}) + ArrayInterface.map_tuple_type(axisname, fieldtype(T, :axes)) +end +ArrayInterface.dimnames(::Axis{name}) where {name} = (StaticSymbol(name),) +ArrayInterface.dimnames(x::AxisArray) = static(ArrayInterface.known_dimnames(x)) + # Axis definitions """ axisdim(::AxisArray, ::Axis) -> Int diff --git a/test/core.jl b/test/core.jl index 93da117..f8b9f17 100644 --- a/test/core.jl +++ b/test/core.jl @@ -347,3 +347,6 @@ end C = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}([:a,:b,:c]), Axis{:x}(["a","b","c","d","e"])) @test occursin(r"axes:\n\s+:y,", summary(C)) + +# ensure that ArrayInterface.is_forwarding_wrapper is properly propagating across wrapper +@test @inferred(ArrayInterface.strides(AxisArray([1 2 3]'))) === (1, static(1)) diff --git a/test/runtests.jl b/test/runtests.jl index a6f1a0e..83071b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,9 @@ using AxisArrays +using ArrayInterface using Dates using Test using Random +using Static import IterTools @testset "AxisArrays" begin