@@ -23,14 +23,15 @@ def run_one_process(
23
23
results : list [pd .DataFrame ],
24
24
filter_class : Optional [type [DataFilter ]],
25
25
filter_kwargs : Optional [dict [str , Any ]],
26
- datafilter_init_fn : Optional [Callable [[int , Union [str , torch .device ]], DataFilter ]],
26
+ datafilter_init_fn : Optional [Callable [[int , Union [str , torch .device ], dict [str , Any ]], DataFilter ]],
27
+ datafilter_init_fn_kwargs : dict [str , Any ],
27
28
device : Union [str , torch .device ],
28
29
filter_run_kwargs : dict [str , Any ]
29
30
) -> None :
30
31
reader = DatasetReader (connector = connector )
31
32
processor = reader .from_df (config , df )
32
33
if datafilter_init_fn :
33
- datafilter = datafilter_init_fn (i , device )
34
+ datafilter = datafilter_init_fn (i , device , datafilter_init_fn_kwargs )
34
35
else :
35
36
datafilter = filter_class (** filter_kwargs , _pbar_position = i , device = device ) # type: ignore
36
37
@@ -51,7 +52,8 @@ def __init__(
51
52
devices : list [Union [torch .device , str ]],
52
53
datafilter_class : Optional [type [DataFilter ]] = None ,
53
54
datafilter_params : Optional [dict [str , Any ]] = None ,
54
- datafilter_init_fn : Optional [Callable [[int , Union [str , torch .device ]], DataFilter ]] = None
55
+ datafilter_init_fn : Optional [Callable [[int , Union [str , torch .device ], dict [str , Any ]], DataFilter ]] = None ,
56
+ datafilter_init_fn_kwargs : Optional [dict [str , Any ]] = None ,
55
57
):
56
58
"""
57
59
Parameters
@@ -62,19 +64,22 @@ def __init__(
62
64
Class of datafilter to use
63
65
datafilter_params: Optional[dict[str, Any]] = None
64
66
Parameters for datafilter_class initialization
65
- datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]] = None
67
+ datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device], dict[str, Any] ], DataFilter]] = None
66
68
Initialization function for a datafilter. Takes _pbar_position as first arg and device as a second arg
69
+ datafilter_init_fn_kwargs: Optional[dict[str, Any]] = None
70
+ Additional parameters for datafilter_init_fn
67
71
"""
68
72
self .filter_class = datafilter_class
69
73
self .filter_params = datafilter_params
70
74
self .datafilter_init_fn = datafilter_init_fn
75
+ self .datafilter_init_fn_kwargs = datafilter_init_fn_kwargs if datafilter_init_fn_kwargs is not None else {}
71
76
assert self .datafilter_init_fn or self .filter_class , "One method of filter initialization should be specified"
72
77
self .devices = devices
73
78
self .num_parts = len (devices )
74
79
75
80
# getting result columns names
76
81
if self .datafilter_init_fn :
77
- datafilter = self .datafilter_init_fn (0 , devices [0 ])
82
+ datafilter = self .datafilter_init_fn (0 , devices [0 ], self . datafilter_init_fn_kwargs )
78
83
else :
79
84
datafilter = self .filter_class (** self .filter_params , device = devices [0 ]) # type: ignore
80
85
self ._result_columns = datafilter .result_columns
@@ -127,6 +132,7 @@ def run(
127
132
self .filter_class ,
128
133
self .filter_params ,
129
134
self .datafilter_init_fn ,
135
+ self .datafilter_init_fn_kwargs ,
130
136
self .devices [i ],
131
137
filter_run_kwargs
132
138
)
0 commit comments