Skip to content

Commit ee5344e

Browse files
committed
feat: add params in init datafilter fn
1 parent 5d14d57 commit ee5344e

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

DPF/filters/multigpu_filter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@ def run_one_process(
2323
results: list[pd.DataFrame],
2424
filter_class: Optional[type[DataFilter]],
2525
filter_kwargs: Optional[dict[str, Any]],
26-
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]],
26+
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device], dict[str, Any]], DataFilter]],
27+
datafilter_init_fn_kwargs: dict[str, Any],
2728
device: Union[str, torch.device],
2829
filter_run_kwargs: dict[str, Any]
2930
) -> None:
3031
reader = DatasetReader(connector=connector)
3132
processor = reader.from_df(config, df)
3233
if datafilter_init_fn:
33-
datafilter = datafilter_init_fn(i, device)
34+
datafilter = datafilter_init_fn(i, device, datafilter_init_fn_kwargs)
3435
else:
3536
datafilter = filter_class(**filter_kwargs, _pbar_position=i, device=device) # type: ignore
3637

@@ -51,7 +52,8 @@ def __init__(
5152
devices: list[Union[torch.device, str]],
5253
datafilter_class: Optional[type[DataFilter]] = None,
5354
datafilter_params: Optional[dict[str, Any]] = None,
54-
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]] = None
55+
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device], dict[str, Any]], DataFilter]] = None,
56+
datafilter_init_fn_kwargs: Optional[dict[str, Any]] = None,
5557
):
5658
"""
5759
Parameters
@@ -62,19 +64,22 @@ def __init__(
6264
Class of datafilter to use
6365
datafilter_params: Optional[dict[str, Any]] = None
6466
Parameters for datafilter_class initialization
65-
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]] = None
67+
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device], dict[str, Any]], DataFilter]] = None
6668
Initialization function for a datafilter. Takes _pbar_position as first arg and device as a second arg
69+
datafilter_init_fn_kwargs: Optional[dict[str, Any]] = None
70+
Additional parameters for datafilter_init_fn
6771
"""
6872
self.filter_class = datafilter_class
6973
self.filter_params = datafilter_params
7074
self.datafilter_init_fn = datafilter_init_fn
75+
self.datafilter_init_fn_kwargs = datafilter_init_fn_kwargs if datafilter_init_fn_kwargs is not None else {}
7176
assert self.datafilter_init_fn or self.filter_class, "One method of filter initialization should be specified"
7277
self.devices = devices
7378
self.num_parts = len(devices)
7479

7580
# getting result columns names
7681
if self.datafilter_init_fn:
77-
datafilter = self.datafilter_init_fn(0, devices[0])
82+
datafilter = self.datafilter_init_fn(0, devices[0], self.datafilter_init_fn_kwargs)
7883
else:
7984
datafilter = self.filter_class(**self.filter_params, device=devices[0]) # type: ignore
8085
self._result_columns = datafilter.result_columns
@@ -127,6 +132,7 @@ def run(
127132
self.filter_class,
128133
self.filter_params,
129134
self.datafilter_init_fn,
135+
self.datafilter_init_fn_kwargs,
130136
self.devices[i],
131137
filter_run_kwargs
132138
)

docs/multi_gpu_filter.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,18 @@ To run a complex datafilter or if you want to manually create a datafilter class
2626
from DPF.filters.images.llava_captioning_filter import LLaVaCaptioningFilter
2727
from DPF.filters.multigpu_filter import MultiGPUDataFilter
2828

29-
def init_fn(pbar_pos: int, device: str):
30-
print('INIT FN', pbar_pos, device)
29+
def init_fn(pbar_pos: int, device: str, params: dict):
30+
print('INIT FN', pbar_pos, device, params)
3131

3232
return LLaVaCaptioningFilter(
33-
workers=8, prompt='short', batch_size=16,
33+
workers=8, prompt=params['prompt'], batch_size=16,
3434
device=device, _pbar_position=pbar_pos
3535
)
3636

3737
multigpufilter = MultiGPUDataFilter(
3838
['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'],
39-
datafilter_init_fn=init_fn
39+
datafilter_init_fn=init_fn,
40+
datafilter_init_fn_kwargs={'prompt': 'short'}
4041
)
4142
processor.apply_multi_gpu_data_filter(multigpufilter)
4243
```

0 commit comments

Comments
 (0)