Skip to content

Commit 344100b

Browse files
authored
Merge pull request #66 from aidy1991/fix-suffix-check
Fix suffix check for archive extraction
2 parents 3a66ae5 + 5906228 commit 344100b

File tree

2 files changed

+64
-31
lines changed

2 files changed

+64
-31
lines changed

jupyter_archive/handlers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,15 @@ def make_writer(handler, archive_format="zip"):
7272

7373
def make_reader(archive_path):
7474

75-
archive_format = "".join(archive_path.suffixes)[1:]
75+
archive_format = "".join(archive_path.suffixes)
7676

77-
if archive_format.endswith("zip"):
77+
if archive_format.endswith(".zip"):
7878
archive_file = zipfile.ZipFile(archive_path, mode="r")
79-
elif any([archive_format.endswith(ext) for ext in ["tgz", "tar.gz"]]):
79+
elif any([archive_format.endswith(ext) for ext in [".tgz", ".tar.gz"]]):
8080
archive_file = tarfile.open(archive_path, mode="r|gz")
81-
elif any([archive_format.endswith(ext) for ext in ["tbz", "tbz2", "tar.bz", "tar.bz2"]]):
81+
elif any([archive_format.endswith(ext) for ext in [".tbz", ".tbz2", ".tar.bz", ".tar.bz2"]]):
8282
archive_file = tarfile.open(archive_path, mode="r|bz2")
83-
elif any([archive_format.endswith(ext) for ext in ["txz", "tar.xz"]]):
83+
elif any([archive_format.endswith(ext) for ext in [".txz", ".tar.xz"]]):
8484
archive_file = tarfile.open(archive_path, mode="r|xz")
8585
else:
8686
raise ValueError("'{}' is not a valid archive format.".format(archive_format))
@@ -141,8 +141,7 @@ async def get(self, archive_path, include_body=False):
141141
raise web.HTTPError(400)
142142

143143
archive_path = pathlib.Path(cm.root_dir) / url2path(archive_path)
144-
archive_name = archive_path.name
145-
archive_filename = archive_path.with_suffix(".{}".format(archive_format)).name
144+
archive_filename = f"{archive_path.name}.{archive_format}"
146145

147146
self.log.info("Prepare {} for archiving and downloading.".format(archive_filename))
148147
self.set_header("content-type", "application/octet-stream")

jupyter_archive/tests/test_archive_handler.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import pytest
77

8+
from tornado.httpclient import HTTPClientError
9+
810

911
@pytest.mark.parametrize(
1012
"followSymlinks, download_hidden, file_list",
@@ -111,6 +113,36 @@ async def test_download(jp_fetch, jp_root_dir, followSymlinks, download_hidden,
111113
assert set(map(lambda m: m.name, tf.getmembers())) == file_list
112114

113115

116+
def _create_archive_file(root_dir, file_name, format, mode):
117+
# Create a dummy directory.
118+
archive_dir_path = root_dir / file_name
119+
archive_dir_path.mkdir(parents=True)
120+
121+
(archive_dir_path / "extract-test1.txt").write_text("hello1")
122+
(archive_dir_path / "extract-test2.txt").write_text("hello2")
123+
(archive_dir_path / "extract-test3.md").write_text("hello3")
124+
125+
# Make an archive
126+
archive_dir_path = root_dir / file_name
127+
# The request should fail when the extension has an unnecessary prefix.
128+
archive_path = archive_dir_path.parent / f"{archive_dir_path.name}.{format}"
129+
if format == "zip":
130+
with zipfile.ZipFile(archive_path, mode=mode) as writer:
131+
for file_path in archive_dir_path.rglob("*"):
132+
if file_path.is_file():
133+
writer.write(file_path, file_path.relative_to(root_dir))
134+
else:
135+
with tarfile.open(str(archive_path), mode=mode) as writer:
136+
for file_path in archive_dir_path.rglob("*"):
137+
if file_path.is_file():
138+
writer.add(file_path, file_path.relative_to(root_dir))
139+
140+
# Remove the directory
141+
shutil.rmtree(archive_dir_path)
142+
143+
return archive_dir_path, archive_path
144+
145+
114146
@pytest.mark.parametrize(
115147
"file_name",
116148
[
@@ -134,34 +166,36 @@ async def test_download(jp_fetch, jp_root_dir, followSymlinks, download_hidden,
134166
],
135167
)
136168
async def test_extract(jp_fetch, jp_root_dir, file_name, format, mode):
137-
# Create a dummy directory.
138-
archive_dir_path = jp_root_dir / file_name
139-
archive_dir_path.mkdir(parents=True)
140-
141-
(archive_dir_path / "extract-test1.txt").write_text("hello1")
142-
(archive_dir_path / "extract-test2.txt").write_text("hello2")
143-
(archive_dir_path / "extract-test3.md").write_text("hello3")
144-
145-
# Make an archive
146-
archive_dir_path = jp_root_dir / file_name
147-
archive_path = archive_dir_path.with_suffix("." + format)
148-
if format == "zip":
149-
with zipfile.ZipFile(archive_path, mode=mode) as writer:
150-
for file_path in archive_dir_path.rglob("*"):
151-
if file_path.is_file():
152-
writer.write(file_path, file_path.relative_to(jp_root_dir))
153-
else:
154-
with tarfile.open(str(archive_path), mode=mode) as writer:
155-
for file_path in archive_dir_path.rglob("*"):
156-
if file_path.is_file():
157-
writer.add(file_path, file_path.relative_to(jp_root_dir))
158-
159-
# Remove the directory
160-
shutil.rmtree(archive_dir_path)
169+
archive_dir_path, archive_path = _create_archive_file(jp_root_dir, file_name, format, mode)
161170

162171
r = await jp_fetch("extract-archive", archive_path.relative_to(jp_root_dir).as_posix(), method="GET")
163172
assert r.code == 200
164173
assert archive_dir_path.is_dir()
165174

166175
n_files = len(list(archive_dir_path.glob("*")))
167176
assert n_files == 3
177+
178+
179+
@pytest.mark.parametrize(
180+
"format, mode",
181+
[
182+
("zip", "w"),
183+
("tgz", "w|gz"),
184+
("tar.gz", "w|gz"),
185+
("tbz", "w|bz2"),
186+
("tbz2", "w|bz2"),
187+
("tar.bz", "w|bz2"),
188+
("tar.bz2", "w|bz2"),
189+
("txz", "w|xz"),
190+
("tar.xz", "w|xz"),
191+
],
192+
)
193+
async def test_extract_failure(jp_fetch, jp_root_dir, format, mode):
194+
# The request should fail when the extension has an unnecessary prefix.
195+
prefixed_format = f"prefix{format}"
196+
archive_dir_path, archive_path = _create_archive_file(jp_root_dir, "extract-archive-dir", prefixed_format, mode)
197+
198+
with pytest.raises(Exception) as e:
199+
await jp_fetch("extract-archive", archive_path.relative_to(jp_root_dir).as_posix(), method="GET")
200+
assert e.type == HTTPClientError
201+
assert not archive_dir_path.exists()

0 commit comments

Comments
 (0)