Skip to content

Commit b1f814e

Browse files
author
Pietro Vertechi
authored
Do recursive bit during code generation in foreachfield (#141)
* do recursive bit inside foreachfield codegen * minor cleanup
1 parent 68eaf79 commit b1f814e

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

src/utils.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,40 @@ else
2828
const _getproperty = getproperty
2929
end
3030

31-
function _foreachfield(names, L)
31+
array_names_types(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = array_names_types(C)
32+
array_names_types(::Type{NamedTuple{names, types}}) where {names, types} = zip(names, types.parameters)
33+
array_names_types(::Type{T}) where {T<:Tuple} = enumerate(T.parameters)
34+
35+
function apply_f_to_vars_fields(names_types, vars)
36+
exprs = Expr[]
37+
for (name, type) in names_types
38+
sym = QuoteNode(name)
39+
args = [Expr(:call, :_getproperty, var, sym) for var in vars]
40+
expr = if type <: StructArray
41+
apply_f_to_vars_fields(array_names_types(type), args)
42+
else
43+
Expr(:call, :f, args...)
44+
end
45+
push!(exprs, expr)
46+
end
47+
return Expr(:block, exprs...)
48+
end
49+
50+
function _foreachfield(names_types, L)
3251
vars = ntuple(i -> gensym(), L)
3352
exprs = Expr[]
3453
for (i, v) in enumerate(vars)
3554
push!(exprs, Expr(:(=), v, Expr(:call, :getfield, :xs, i)))
3655
end
37-
for field in names
38-
sym = QuoteNode(field)
39-
args = [Expr(:call, :_getproperty, var, sym) for var in vars]
40-
push!(exprs, Expr(:call, :f, args...))
41-
end
56+
push!(exprs, apply_f_to_vars_fields(names_types, vars))
4257
push!(exprs, :(return nothing))
4358
return Expr(:block, exprs...)
4459
end
4560

46-
@generated foreachfield_gen(::NamedTuple{names}, f, xs::Vararg{Any, L}) where {names, L} =
47-
_foreachfield(names, L)
48-
@generated foreachfield_gen(::NTuple{N, Any}, f, xs::Vararg{Any, L}) where {N, L} =
49-
_foreachfield(Base.OneTo(N), L)
61+
@generated foreachfield_gen(::S, f, xs::Vararg{Any, L}) where {S<:StructArray, L} =
62+
_foreachfield(array_names_types(S), L)
5063

51-
foreachfield(f, x::StructArray, xs...) = foreachfield_gen(fieldarrays(x), f, x, xs...)
64+
foreachfield(f, x::StructArray, xs...) = foreachfield_gen(x, f, x, xs...)
5265

5366
"""
5467
`iscompatible(::Type{S}, ::Type{V}) where {S, V<:AbstractArray}`

0 commit comments

Comments
 (0)