2727import os
2828import tarfile
2929import zipfile
30- from typing import TYPE_CHECKING
30+ from pathlib import Path
3131
3232import py7zr
3333import py7zr .callbacks
3434import py7zr .exceptions
35+ import pytest
3536
3637from rubisco .config import COPY_BUFSIZE
3738from rubisco .lib .archive .sevenzip import compress_7z , extract_7z
3839from rubisco .lib .archive .tar import compress_tarball , extract_tarball
3940from rubisco .lib .archive .zip import compress_zip , extract_zip
4041from rubisco .lib .exceptions import RUValueError
4142from rubisco .lib .fileutil import (
43+ TemporaryObject ,
4244 assert_rel_path ,
4345 check_file_exists ,
4446 rm_recursive ,
4850from rubisco .lib .variable import format_str
4951from rubisco .shared .ktrigger import IKernelTrigger , call_ktrigger
5052
51- if TYPE_CHECKING :
52- from pathlib import Path
53-
5453__all__ = ["compress" , "extract" ]
5554
5655
@@ -184,13 +183,14 @@ def _extract(
184183 raise UnsupportedArchiveTypeError
185184
186185
187- def extract ( # pylint: disable=too-many-branches
186+ def extract ( # pylint: disable=R0913 # noqa: PLR0913
188187 file : Path ,
189188 dest : Path ,
190189 compress_type : str | None = None ,
191190 password : str | None = None ,
192191 * ,
193192 overwrite : bool = False ,
193+ allow_absolute_dest : bool = False ,
194194) -> None :
195195 """Extract compressed file to destination.
196196
@@ -204,10 +204,13 @@ def extract( # pylint: disable=too-many-branches
204204 Defaults to None. Tarball is not supported.
205205 overwrite (bool, optional): Overwrite destination if it exists.
206206 Defaults to False.
207+ allow_absolute_dest (bool, optional): Allow absolute destination path.
208+ Defaults to False.
207209
208210 """
209211 compress_type = compress_type .lower ().strip () if compress_type else None
210- assert_rel_path (dest )
212+ if not allow_absolute_dest :
213+ assert_rel_path (dest )
211214 try :
212215 if compress_type is None :
213216 suffix1 = file .suffix
@@ -229,13 +232,13 @@ def extract( # pylint: disable=too-many-branches
229232 ),
230233 hint = _ ("Please specify the compression type explicitly." ),
231234 )
232- _extract (
233- compress_type ,
234- file ,
235- dest ,
236- password ,
237- overwrite = overwrite ,
238- )
235+ _extract (
236+ compress_type ,
237+ file ,
238+ dest ,
239+ password ,
240+ overwrite = overwrite ,
241+ )
239242 except UnsupportedArchiveTypeError :
240243 logger .error (
241244 "Unsupported compression type: '%s'" ,
@@ -440,7 +443,6 @@ def _compress( # pylint: disable=R0913, R0917 # noqa: PLR0913
440443 raise UnsupportedArchiveTypeError
441444
442445
443- # We should rewrite this ugly function later.
444446def compress ( # pylint: disable=R0913, R0917 # noqa: PLR0913
445447 src : Path ,
446448 dest : Path ,
@@ -450,6 +452,7 @@ def compress( # pylint: disable=R0913, R0917 # noqa: PLR0913
450452 compress_level : int | None = None ,
451453 * ,
452454 overwrite : bool = False ,
455+ allow_absolute_dest : bool = False ,
453456) -> None :
454457 """Compress a file or directory to destination.
455458
@@ -467,10 +470,13 @@ def compress( # pylint: disable=R0913, R0917 # noqa: PLR0913
467470 others.
468471 overwrite (bool, optional): Overwrite destination if it exists.
469472 Defaults to False.
473+ allow_absolute_dest (bool, optional): Allow absolute destination path.
474+ Defaults to False.
470475
471476 """
472477 compress_type = compress_type .lower ().strip () if compress_type else None
473- assert_rel_path (dest )
478+ if not allow_absolute_dest :
479+ assert_rel_path (dest )
474480 try :
475481 if compress_type is None :
476482 suffix1 = dest .suffix
@@ -534,3 +540,147 @@ def compress( # pylint: disable=R0913, R0917 # noqa: PLR0913
534540 fmt = {"src" : str (src ), "dest" : str (dest ), "exc" : str (exc )},
535541 ),
536542 ) from exc
543+
544+
545+ class TestArchive :
546+ """Test archive module."""
547+
548+ def _check_extract_archive (self , path : Path ) -> None :
549+ """Check extract archive.
550+
551+ Args:
552+ path (Path): Extract destination.
553+
554+ """
555+ if not path .is_dir ():
556+ raise AssertionError
557+ if not (path / "dir1" ).is_dir ():
558+ raise AssertionError
559+ if not (path / "dir1" / "file1" ).is_file ():
560+ raise AssertionError
561+ if (path / "dir1" / "file1" ).stat ().st_size != 0 :
562+ raise AssertionError
563+ with (path / "file2" ).open (encoding = "utf-8" ) as f :
564+ if f .read () != "Test\n " :
565+ raise AssertionError
566+ if not (path / "file3" ).is_file ():
567+ raise AssertionError
568+
569+ def _extract (self , compress_type : str , password : str = "" ) -> None :
570+ with TemporaryObject .new_directory (suffix = "test" ) as path :
571+ archive_file = Path (f"tests/test.{ compress_type } " )
572+ if password :
573+ archive_file = Path (f"tests/test-pwd.{ compress_type } " )
574+ extract (
575+ archive_file ,
576+ path .path / "test" ,
577+ password = password ,
578+ allow_absolute_dest = True ,
579+ )
580+ self ._check_extract_archive (path .path / "test" )
581+
582+ def test_extract_7z (self ) -> None :
583+ """Test extract 7z."""
584+ self ._extract ("7z" )
585+
586+ def test_extract_zip (self ) -> None :
587+ """Test extract zip."""
588+ self ._extract ("zip" )
589+
590+ def test_extract_tar_gz (self ) -> None :
591+ """Test extract tar.gz."""
592+ self ._extract ("tar.gz" )
593+
594+ def test_extract_tar_xz (self ) -> None :
595+ """Test extract tar.xz."""
596+ self ._extract ("tar.xz" )
597+
598+ def test_extract_tgz (self ) -> None :
599+ """Test extract tgz (alias of tar.gz)."""
600+ self ._extract ("tgz" )
601+
602+ def test_extract_password_zip (self ) -> None :
603+ """Test extract zip with password."""
604+ self ._extract ("zip" , password = "1234" ) # noqa: S106
605+
606+ def test_extract_invalid_password (self ) -> None :
607+ """Test extract zip with invalid password."""
608+ pytest .raises (
609+ RuntimeError ,
610+ self ._extract ,
611+ "zip" ,
612+ password = "0" , # noqa: S106
613+ )
614+
615+ def test_extract_to_absolute (self ) -> None :
616+ """Test extract tgz (alias of tar.gz)."""
617+ self ._extract ("tgz" )
618+
619+ def _compress (self , compress_type : str ) -> None :
620+ with TemporaryObject .new_directory (suffix = "test" ) as path :
621+ compress (
622+ Path ("tests/data" ).absolute (),
623+ path .path / f"test.{ compress_type } " ,
624+ start = Path ("tests/data" ).absolute (),
625+ compress_type = compress_type ,
626+ compress_level = 9 ,
627+ allow_absolute_dest = True ,
628+ )
629+ extract (
630+ path .path / f"test.{ compress_type } " ,
631+ path .path / "extract" ,
632+ compress_type = compress_type ,
633+ allow_absolute_dest = True ,
634+ )
635+ self ._check_extract_archive (path .path / "extract" )
636+
637+ def test_compress_7z (self ) -> None :
638+ """Test compress 7-Zip."""
639+ self ._compress ("7z" )
640+
641+ def test_compress_zip (self ) -> None :
642+ """Test compress zip."""
643+ self ._compress ("zip" )
644+
645+ def test_compress_tar_gz (self ) -> None :
646+ """Test compress tar.gz."""
647+ self ._compress ("tar.gz" )
648+
649+ def test_compress_tar_xz (self ) -> None :
650+ """Test compress tar.xz."""
651+ self ._compress ("tar.xz" )
652+
653+ def test_compress_tgz (self ) -> None :
654+ """Test compress tgz (alias of tar.gz)."""
655+ self ._compress ("tgz" )
656+
657+ def test_compress_to_absolute (self ) -> None :
658+ """Test compress to absolute path."""
659+ with TemporaryObject .new_directory (suffix = "test" ) as path :
660+ pytest .raises (
661+ RUValueError ,
662+ compress ,
663+ Path ("tests/data" ).absolute (),
664+ (path .path / "test.zip" ).absolute (),
665+ start = Path ("tests/data" ),
666+ compress_level = 9 ,
667+ allow_absolute_dest = False ,
668+ )
669+
670+ def test_compress_invalid_type (self ) -> None :
671+ """Test compress invalid type."""
672+ with TemporaryObject .new_directory (suffix = "test" ) as path :
673+ pytest .raises (
674+ RUValueError ,
675+ compress ,
676+ Path ("tests/data" ).absolute (),
677+ path .path / "test.zip" ,
678+ start = Path ("tests/data" ),
679+ compress_type = "invalid" ,
680+ compress_level = 9 ,
681+ allow_absolute_dest = True ,
682+ )
683+
684+
685+ if __name__ == "__main__" :
686+ pytest .main ([__file__ ])
0 commit comments