Skip to content

Commit 71d838f

Browse files
KumoLiujuampatronics
authored andcommitted
Upgrade the version of transformers (Project-MONAI#7343)
Fixes Project-MONAI#7338 ### Description transformers' version is pinned to v4.22 since Project-MONAI#5157. Updated the version refer to huggingface/transformers#21678. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <[email protected]> Signed-off-by: Juan Pablo de la Cruz Gutiérrez <[email protected]>
1 parent 3f3e03c commit 71d838f

File tree

3 files changed

+15
-39
lines changed

3 files changed

+15
-39
lines changed

monai/networks/nets/transchex.py

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,17 @@
1212
from __future__ import annotations
1313

1414
import math
15-
import os
16-
import shutil
17-
import tarfile
18-
import tempfile
1915
from collections.abc import Sequence
2016

2117
import torch
2218
from torch import nn
2319

20+
from monai.config.type_definitions import PathLike
2421
from monai.utils import optional_import
2522

2623
transformers = optional_import("transformers")
2724
load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert")[0]
28-
cached_path = optional_import("transformers.file_utils", name="cached_path")[0]
25+
cached_file = optional_import("transformers.utils", name="cached_file")[0]
2926
BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0]
3027
BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0]
3128

@@ -63,44 +60,16 @@ def from_pretrained(
6360
state_dict=None,
6461
cache_dir=None,
6562
from_tf=False,
63+
path_or_repo_id="bert-base-uncased",
64+
filename="pytorch_model.bin",
6665
*inputs,
6766
**kwargs,
6867
):
69-
archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz"
70-
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
71-
tempdir = None
72-
if os.path.isdir(resolved_archive_file) or from_tf:
73-
serialization_dir = resolved_archive_file
74-
else:
75-
tempdir = tempfile.mkdtemp()
76-
with tarfile.open(resolved_archive_file, "r:gz") as archive:
77-
78-
def is_within_directory(directory, target):
79-
abs_directory = os.path.abspath(directory)
80-
abs_target = os.path.abspath(target)
81-
82-
prefix = os.path.commonprefix([abs_directory, abs_target])
83-
84-
return prefix == abs_directory
85-
86-
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
87-
for member in tar.getmembers():
88-
member_path = os.path.join(path, member.name)
89-
if not is_within_directory(path, member_path):
90-
raise Exception("Attempted Path Traversal in Tar File")
91-
92-
tar.extractall(path, members, numeric_owner=numeric_owner)
93-
94-
safe_extract(archive, tempdir)
95-
serialization_dir = tempdir
68+
weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir)
9669
model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
9770
if state_dict is None and not from_tf:
98-
weights_path = os.path.join(serialization_dir, "pytorch_model.bin")
9971
state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None)
100-
if tempdir:
101-
shutil.rmtree(tempdir)
10272
if from_tf:
103-
weights_path = os.path.join(serialization_dir, "model.ckpt")
10473
return load_tf_weights_in_bert(model, weights_path)
10574
old_keys = []
10675
new_keys = []
@@ -304,6 +273,8 @@ def __init__(
304273
chunk_size_feed_forward: int = 0,
305274
is_decoder: bool = False,
306275
add_cross_attention: bool = False,
276+
path_or_repo_id: str | PathLike = "bert-base-uncased",
277+
filename: str = "pytorch_model.bin",
307278
) -> None:
308279
"""
309280
Args:
@@ -315,6 +286,10 @@ def __init__(
315286
num_vision_layers: number of vision transformer layers.
316287
num_mixed_layers: number of mixed transformer layers.
317288
drop_out: fraction of the input units to drop.
289+
path_or_repo_id: This can be either:
290+
- a string, the *model id* of a model repo on huggingface.co.
291+
- a path to a *directory* potentially containing the file.
292+
filename: The name of the file to locate in `path_or_repo`.
318293
319294
The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`.
320295
@@ -369,6 +344,8 @@ def __init__(
369344
num_vision_layers=num_vision_layers,
370345
num_mixed_layers=num_mixed_layers,
371346
bert_config=bert_config,
347+
path_or_repo_id=path_or_repo_id,
348+
filename=filename,
372349
)
373350

374351
self.patch_size = patch_size

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin"
3333
pandas
3434
requests
3535
einops
36-
transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157
36+
transformers>=4.36.0
3737
mlflow>=1.28.0
3838
clearml>=1.10.0rc0
3939
matplotlib!=3.5.0

tests/test_transchex.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from monai.networks import eval_mode
2020
from monai.networks.nets.transchex import Transchex
21-
from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick
21+
from tests.utils import skip_if_quick
2222

2323
TEST_CASE_TRANSCHEX = []
2424
for drop_out in [0.4]:
@@ -46,7 +46,6 @@
4646

4747

4848
@skip_if_quick
49-
@SkipIfAtLeastPyTorchVersion((1, 10))
5049
class TestTranschex(unittest.TestCase):
5150
@parameterized.expand(TEST_CASE_TRANSCHEX)
5251
def test_shape(self, input_param, expected_shape):

0 commit comments

Comments
 (0)