|
28 | 28 | const _getproperty = getproperty
|
29 | 29 | end
|
30 | 30 |
|
31 |
| -function _foreachfield(names, L) |
| 31 | +array_names_types(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = array_names_types(C) |
| 32 | +array_names_types(::Type{NamedTuple{names, types}}) where {names, types} = zip(names, types.parameters) |
| 33 | +array_names_types(::Type{T}) where {T<:Tuple} = enumerate(T.parameters) |
| 34 | + |
| 35 | +function apply_f_to_vars_fields(names_types, vars) |
| 36 | + exprs = Expr[] |
| 37 | + for (name, type) in names_types |
| 38 | + sym = QuoteNode(name) |
| 39 | + args = [Expr(:call, :_getproperty, var, sym) for var in vars] |
| 40 | + expr = if type <: StructArray |
| 41 | + apply_f_to_vars_fields(array_names_types(type), args) |
| 42 | + else |
| 43 | + Expr(:call, :f, args...) |
| 44 | + end |
| 45 | + push!(exprs, expr) |
| 46 | + end |
| 47 | + return Expr(:block, exprs...) |
| 48 | +end |
| 49 | + |
| 50 | +function _foreachfield(names_types, L) |
32 | 51 | vars = ntuple(i -> gensym(), L)
|
33 | 52 | exprs = Expr[]
|
34 | 53 | for (i, v) in enumerate(vars)
|
35 | 54 | push!(exprs, Expr(:(=), v, Expr(:call, :getfield, :xs, i)))
|
36 | 55 | end
|
37 |
| - for field in names |
38 |
| - sym = QuoteNode(field) |
39 |
| - args = [Expr(:call, :_getproperty, var, sym) for var in vars] |
40 |
| - push!(exprs, Expr(:call, :f, args...)) |
41 |
| - end |
| 56 | + push!(exprs, apply_f_to_vars_fields(names_types, vars)) |
42 | 57 | push!(exprs, :(return nothing))
|
43 | 58 | return Expr(:block, exprs...)
|
44 | 59 | end
|
45 | 60 |
|
46 |
| -@generated foreachfield_gen(::NamedTuple{names}, f, xs::Vararg{Any, L}) where {names, L} = |
47 |
| - _foreachfield(names, L) |
48 |
| -@generated foreachfield_gen(::NTuple{N, Any}, f, xs::Vararg{Any, L}) where {N, L} = |
49 |
| - _foreachfield(Base.OneTo(N), L) |
| 61 | +@generated foreachfield_gen(::S, f, xs::Vararg{Any, L}) where {S<:StructArray, L} = |
| 62 | + _foreachfield(array_names_types(S), L) |
50 | 63 |
|
51 |
| -foreachfield(f, x::StructArray, xs...) = foreachfield_gen(fieldarrays(x), f, x, xs...) |
| 64 | +foreachfield(f, x::StructArray, xs...) = foreachfield_gen(x, f, x, xs...) |
52 | 65 |
|
53 | 66 | """
|
54 | 67 | `iscompatible(::Type{S}, ::Type{V}) where {S, V<:AbstractArray}`
|
|
0 commit comments