Skip to content

Commit c6281f4

Browse files
committed
fix: fix code style
1 parent 5d778f9 commit c6281f4

13 files changed

+84
-149
lines changed

DPF/connectors/s3_connector.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,12 @@ def _preprocess_filepath(path: str) -> str:
3737

3838
def read_file(self, filepath: str, binary: bool) -> io.BytesIO:
3939
mode = "rb" if binary else "rt"
40-
if '.tar' in filepath and '?tar_offset=' in filepath and '?size=' in filepath:
41-
filepath = self._preprocess_filepath(filepath)
42-
offset = filepath.split('?tar_offset=')[1].split('?size=')[0]
43-
size = filepath.split('?size=')[1]
44-
filepath = filepath.split('?')[0]
45-
offset = int(offset)
46-
size = int(size)
47-
s3 = self.s3client._get_client()
48-
range_header = "bytes=%d-%d" % (offset, offset + size - 1)
49-
bucket_name = filepath.split('/')[0]
50-
tar_key = filepath.replace(bucket_name, '')[1:]
51-
video_obj = s3.get_object(Bucket=bucket_name, Key=tar_key, Range=range_header)
52-
res = video_obj["Body"].read()
40+
with self.s3client.open(self._preprocess_filepath(filepath), mode=mode) as f:
5341
if mode == "rb":
54-
res = io.BytesIO(res)
42+
res = io.BytesIO(f.read())
5543
res.seek(0)
56-
else:
57-
with self.s3client.open(self._preprocess_filepath(filepath), mode=mode) as f:
58-
if mode == "rb":
59-
res = io.BytesIO(f.read())
60-
res.seek(0)
61-
else:
62-
res = f.read()
44+
else:
45+
res = f.read()
6346
return res
6447

6548
def save_file(
@@ -95,4 +78,4 @@ def join(self, *args: str) -> str:
9578
path += arg
9679
else:
9780
path += arg+'/'
98-
return path[:-1]
81+
return path[:-1]

DPF/dataset_reader.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Optional, Union
33

44
import pandas as pd
5-
import numpy as np
65
from tqdm.contrib.concurrent import process_map
76

87
from DPF.configs import (
@@ -43,7 +42,6 @@ def __init__(self, connector: Optional[Connector] = None):
4342
if connector is None:
4443
connector = LocalConnector()
4544
self.connector = connector
46-
self.local_connector = LocalConnector()
4745

4846
def _read_and_validate_dataframes(
4947
self,
@@ -272,10 +270,7 @@ def read_files(
272270
Instance of FilesDatasetProcessor dataset
273271
"""
274272
table_path = config.table_path.rstrip("/")
275-
try:
276-
df = self.connector.read_dataframe(table_path)
277-
except:
278-
df = self.local_connector.read_dataframe(table_path)
273+
df = self.connector.read_dataframe(table_path)
279274

280275
required_columns = list(config.user_column2default_column.keys())
281276
column_set = set(df.columns.tolist())
@@ -293,12 +288,6 @@ def read_files(
293288
path_col = datatype.modality.path_column
294289
df[path_col] = df[path_col].apply(lambda x: self.connector.join(config.base_path, x))
295290

296-
# process .tar files with offsets
297-
for i, row in df.iterrows():
298-
if isinstance(df.at[i,'tar_offset'], np.int64) and isinstance(df.at[i,'size'], np.int64):
299-
df.at[i, path_col] += f'?tar_offset={df.at[i,"tar_offset"]}?size={df.at[i,"size"]}'
300-
301-
302291
return FilesDatasetProcessor(
303292
connector=self.connector,
304293
df=df,

DPF/filters/complex_filter.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,14 @@ class ComplexDataFilter(DataFilter):
1616

1717
def __init__(
1818
self,
19-
datafilters,
20-
kwargs,
19+
datafilters: list[DataFilter],
2120
workers: int,
2221
pbar: bool = True,
23-
_pbar_position: int = 0,
24-
device = 'cuda:0'
22+
_pbar_position: int = 0
2523
):
2624
super().__init__(pbar, _pbar_position)
27-
self.datafilters = []
25+
self.datafilters = datafilters
2826
self.workers = workers
29-
self.device = device
30-
31-
for filter, kwarg in zip(datafilters, kwargs):
32-
kwarg['device'] = self.device
33-
self.datafilters.append(filter(**kwarg))
3427

3528
assert len(self.datafilters) > 0
3629
assert all(

DPF/filters/images/complexity_filter.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import os
22
from typing import Any
33
from urllib.request import urlretrieve
4+
45
import numpy as np
56
import torch
7+
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
68

7-
from ...types import ModalityToDataMapping
89
from DPF.utils import read_image_rgb_from_bytes
9-
from .img_filter import ImageFilter
10-
11-
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
1210

11+
from ...types import ModalityToDataMapping
12+
from .img_filter import ImageFilter
1313

1414
WEIGHTS_URL = {'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
1515
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
@@ -57,7 +57,7 @@ def __init__(
5757
self.model_name = model_name
5858
self.weights_folder = weights_folder
5959
self.points_per_side = points_per_side
60-
60+
6161
# Download checkpoints
6262
path_to_model = os.path.join(self.weights_folder, self.model_name + '.pth')
6363
if not os.path.exists(path_to_model):
@@ -67,7 +67,7 @@ def __init__(
6767
sam = sam_model_registry[self.model_name](checkpoint=path_to_model)
6868
sam = sam.to(torch.device(self.device))
6969
self.mask_generator = SamAutomaticMaskGenerator(
70-
sam, points_per_batch=batch_size,
70+
sam, points_per_batch=batch_size,
7171
points_per_side=points_per_side
7272
)
7373

@@ -111,7 +111,7 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
111111
mean_area = np.mean(areas) / hw
112112
else:
113113
max_area = mean_area = 0
114-
114+
115115
df_batch_labels["complexity_num_segments"].extend([num_segments])
116116
df_batch_labels["complexity_max_segment_area"].extend([max_area])
117117
df_batch_labels["complexity_mean_segment_area"].extend([mean_area])

DPF/filters/multigpu_filter.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import multiprocessing
22
from multiprocessing import Manager
3-
from typing import Any, Union, Optional, Callable
3+
from typing import Any, Callable, Optional, Union
44

55
import numpy as np
66
import pandas as pd
@@ -60,9 +60,9 @@ def __init__(
6060
----------
6161
devices: list[Union[torch.device, str]]
6262
List of devices to run datafilter on
63-
datafilter_class: type[DataFilter]
63+
datafilter_class: Optional[type[DataFilter]] = None
6464
Class of datafilter to use
65-
datafilter_params: dict[str, Any]
65+
datafilter_params: Optional[dict[str, Any]] = None
6666
Parameters for datafilter_class initialization
6767
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device], dict[str, Any]], DataFilter]] = None
6868
Initialization function for a datafilter. Takes _pbar_position as first arg and device as a second arg
@@ -77,11 +77,6 @@ def __init__(
7777
self.devices = devices
7878
self.num_parts = len(devices)
7979

80-
self.filters = []
81-
for i in range(self.num_parts):
82-
self.filters.append(datafilter_class(**datafilter_params, _pbar_position=i, device=devices[i]))
83-
self.filters[i]._created_by_multigpu_data_filter = True
84-
8580
# getting result columns names
8681
if self.datafilter_init_fn:
8782
datafilter = self.datafilter_init_fn(0, devices[0], self.datafilter_init_fn_kwargs)
@@ -146,7 +141,7 @@ def run(
146141
processes = []
147142
context = multiprocessing.get_context('spawn')
148143
for param in params:
149-
p = context.Process(target=self.run_one_process, args=param)
144+
p = context.Process(target=run_one_process, args=param)
150145
p.start()
151146
processes.append(p)
152147

@@ -156,21 +151,3 @@ def run(
156151
res_df = pd.concat(shared_results)
157152
res_df.sort_index(inplace=True)
158153
return res_df
159-
160-
161-
def run_one_process(
162-
self,
163-
config: DatasetConfig,
164-
connector: Connector,
165-
df: pd.DataFrame,
166-
i: int,
167-
index: pd.Series,
168-
results: list[pd.DataFrame],
169-
filter_run_kwargs: dict[str, Any]
170-
) -> None:
171-
reader = DatasetReader(connector=connector)
172-
processor = reader.from_df(config, df)
173-
processor.apply_data_filter(self.filters[i], **filter_run_kwargs)
174-
res = processor.df
175-
res.set_index(index, inplace=True)
176-
results.append(res)

DPF/filters/videos/cogvlm2_filter.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1+
import re
12
from io import BytesIO
23
from typing import Any
34

4-
from DPF.types import ModalityToDataMapping
5-
6-
from .video_filter import VideoFilter
75
import numpy as np
86
import torch
97
from decord import VideoReader, bridge
108
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
11-
import re
129

10+
from DPF.types import ModalityToDataMapping
11+
12+
from .video_filter import VideoFilter
1313

1414
prompt_templates = {
1515
'detailed_video': 'Describe this video and its style in a very detailed manner',
@@ -33,25 +33,25 @@
3333
]
3434

3535

36-
def clean_with_regex(caption):
37-
lower_caption = str(caption).lower().strip()
38-
for re_compiled, replacement in compiled_regexs:
39-
iterator = reversed(list(re_compiled.finditer(lower_caption)))
40-
for match in iterator:
41-
pos = list(match.span())
36+
def clean_with_regex(caption: str) -> str:
37+
lower_caption = str(caption).lower().strip()
38+
for re_compiled, replacement in compiled_regexs:
39+
iterator = reversed(list(re_compiled.finditer(lower_caption)))
40+
for match in iterator:
41+
pos = list(match.span())
4242
caption = caption[:pos[0]] + replacement + caption[pos[1]:]
4343
lower_caption = str(caption).lower().strip()
44-
44+
4545
if caption.count('-') > 2:
4646
split_captions = []
4747
for split_caption in caption.split():
4848
if split_caption.count('-') > 2:
4949
split_caption = re.sub(r'-', ' ', split_caption)
5050
split_captions.append(split_caption)
5151
caption = ' '.join(split_captions)
52-
52+
5353
caption = caption.strip('—-:/+=|@#&*')
54-
54+
5555
return caption.strip()
5656

5757

@@ -156,8 +156,8 @@ def preprocess_data(
156156
) -> Any:
157157
key = metadata[self.key_column]
158158
video_file = BytesIO(modality2data['video'])
159-
video_file = self.load_video(video_file, strategy=self.strategy)
160-
return key, video_file
159+
loaded_video_file = self.load_video(video_file, strategy=self.strategy)
160+
return key, loaded_video_file
161161

162162
def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
163163
df_batch_labels = self._get_dict_from_schema()
@@ -196,26 +196,27 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
196196
return df_batch_labels
197197

198198

199-
def load_video(self, video_path, strategy='chat'):
199+
def load_video(self, video_path: BytesIO, strategy: str = 'chat') -> torch.Tensor:
200200
bridge.set_bridge('torch')
201201
num_frames = self.num_frames
202202

203203
decord_vr = VideoReader(uri=video_path)
204-
frame_id_list = None
205204
total_frames = len(decord_vr)
206205
if strategy == 'base':
207206
frame_id_list = np.linspace(0, total_frames - 1, num_frames, dtype=int)
208207
elif strategy == 'chat':
209208
timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
210209
timestamps = [i[0] for i in timestamps]
211210
max_second = round(max(timestamps)) + 1
212-
frame_id_list = []
211+
frame_id_list = [] # type: ignore
213212
for second in range(max_second):
214213
closest_num = min(timestamps, key=lambda x: abs(x - second))
215214
index = timestamps.index(closest_num)
216-
frame_id_list.append(index)
215+
frame_id_list.append(index) # type: ignore
217216
if len(frame_id_list) >= num_frames:
218217
break
219-
video_data = decord_vr.get_batch(frame_id_list)
218+
else:
219+
frame_id_list = None
220+
video_data: torch.Tensor = decord_vr.get_batch(frame_id_list)
220221
video_data = video_data.permute(3, 0, 1, 2)
221222
return video_data

0 commit comments

Comments
 (0)