-
Notifications
You must be signed in to change notification settings - Fork 43
Move StaticArrays
support to extension
#265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c81c308
7171bab
b0d509b
58a54f9
bc629da
0542152
02257b9
b4de96b
60a8c8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
module StructArraysAdaptExt | ||
# Use Adapt allows for automatic conversion of CPU to GPU StructArrays | ||
using Adapt, StructArrays | ||
Adapt.adapt_structure(to, s::StructArray) = replace_storage(adapt(to), s) | ||
N5N3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
module StructArraysGPUArraysCoreExt | ||
|
||
using StructArrays | ||
using StructArrays: map_params, array_types | ||
|
||
using Base: tail | ||
|
||
import GPUArraysCore | ||
|
||
# for GPU broadcast | ||
import GPUArraysCore | ||
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray} | ||
backends = map_params(GPUArraysCore.backend, array_types(T)) | ||
backend, others = backends[1], tail(backends) | ||
isconsistent = mapfoldl(isequal(backend), &, others; init=true) | ||
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend")) | ||
return backend | ||
end | ||
StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true | ||
|
||
end # module |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
module StructArraysStaticArraysExt | ||
|
||
using StructArrays | ||
using StaticArrays: StaticArray, FieldArray, tuple_prod | ||
|
||
""" | ||
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} | ||
|
||
The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`. | ||
```julia | ||
julia> StructArrays.staticschema(SVector{2, Float64}) | ||
Tuple{Float64, Float64} | ||
``` | ||
The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a | ||
struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct | ||
which subtypes `FieldArray`. | ||
""" | ||
@generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} | ||
return quote | ||
Base.@_inline_meta | ||
return NTuple{$(tuple_prod(S)), T} | ||
end | ||
end | ||
StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args) | ||
StructArrays.component(s::StaticArray, i) = getindex(s, i) | ||
|
||
# invoke general fallbacks for a `FieldArray` type. | ||
@inline function StructArrays.staticschema(T::Type{<:FieldArray}) | ||
invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T) | ||
end | ||
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i) | ||
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) | ||
|
||
# Broadcast overload | ||
using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo | ||
using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype | ||
using StructArrays: isnonemptystructtype | ||
using Base.Broadcast: Broadcasted, _broadcast_getindex | ||
|
||
# StaticArrayStyle has no similar defined. | ||
# Overload `try_struct_copy` instead. | ||
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M} | ||
flat = broadcast_flatten(bc); as = flat.args; f = flat.f | ||
argsizes = broadcast_sizes(as...) | ||
ax = axes(bc) | ||
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug at `StaticArrays.jl`.") | ||
return _broadcast(f, Size(map(length, ax)), argsizes, as...) | ||
end | ||
|
||
# A functor generates the ith component of StructStaticBroadcast. | ||
struct Similar_ith{SA, E<:Tuple} | ||
elements::E | ||
Similar_ith{SA}(elements::Tuple) where {SA} = new{SA, typeof(elements)}(elements) | ||
end | ||
function (s::Similar_ith{SA})(i::Int) where {SA} | ||
ith_elements = ntuple(Val(length(s.elements))) do j | ||
getfield(s.elements[j], i) | ||
end | ||
ith_SA = similar_type(SA, fieldtype(eltype(SA), i)) | ||
return @inbounds ith_SA(ith_elements) | ||
end | ||
|
||
@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize} | ||
N5N3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
first_staticarray = first_statictype(a...) | ||
elements, ET = if prod(newsize) == 0 | ||
N5N3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Use inference to get eltype in empty case (following StaticBroadcast defined in StaticArrays.jl) | ||
eltys = Tuple{map(eltype, a)...} | ||
(), Core.Compiler.return_type(f, eltys) | ||
else | ||
temp = __broadcast(f, sz, s, a...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part worries me a little bit, we are using something explicitly marked as internal in StaticArrays. Is there no way to achieve this using only public methods? Or maybe we could check over at StaticArrays if they can offer some solution (maybe add a public method that does what we need). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see, thanks for pointing out that discussion. In that case, maybe one could just add a small docstring in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense too. |
||
temp, eltype(temp) | ||
end | ||
if isnonemptystructtype(ET) | ||
SA = similar_type(first_staticarray, ET, sz) | ||
arrs = ntuple(Similar_ith{SA}(elements), Val(fieldcount(ET))) | ||
return StructArray{ET}(arrs) | ||
else | ||
@inbounds return similar_type(first_staticarray, ET, sz)(elements) | ||
end | ||
end | ||
|
||
# The `__broadcast` kernal is copied from `StaticArrays.jl`. | ||
# see https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/broadcast.jl | ||
@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize | ||
sizes = [sz.parameters[1] for sz ∈ s.parameters] | ||
|
||
indices = CartesianIndices(newsize) | ||
exprs = similar(indices, Expr) | ||
for (j, current_ind) ∈ enumerate(indices) | ||
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes)) | ||
exprs[j] = :(f($(exprs_vals...))) | ||
end | ||
|
||
return quote | ||
Base.@_inline_meta | ||
return tuple($(exprs...)) | ||
end | ||
end | ||
|
||
broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I)) | ||
function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex) | ||
li = LinearIndices(oldsize) | ||
ind = _broadcast_getindex(li, newindex) | ||
return :(a[$i][$ind]) | ||
end | ||
|
||
end |
Uh oh!
There was an error while loading. Please reload this page.