Coverage for .nox/test-3-9/lib/python3.9/site-packages/nskit/common/contextmanagers/chdir.py: 90%

31 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-02-25 17:38 +0000

1"""Context manager for running in a specific directory.""" 

2from contextlib import ContextDecorator 

3import os 

4from pathlib import Path 

5import tempfile 

6from typing import Optional 

7import warnings 

8 

9from nskit._logging import logger_factory 

10 

11 

12class ChDir(ContextDecorator): 

13 """Context manager for running in a specified (or temporary) directory. 

14 

15 The optional argument is a path to a specified target directory, if this isn't provided, a temporary directory is created 

16 """ 

17 

18 def __init__(self, target_dir: Optional[Path] = None): 

19 """Initialise the context manager. 

20 

21 Keyword Args: 

22 target_dir (Optional[Path]): the target directory 

23 """ 

24 self._temp_dir = None 

25 if not target_dir: 

26 # Handling circular imports with LoggingConfig 

27 logger_factory.get_logger(__name__).debug('No target_dir provided, using a temporary directory') 

28 self._temp_dir = tempfile.TemporaryDirectory() 

29 target_dir = self._temp_dir.name 

30 self.cwd = Path.cwd() 

31 self.target_dir = Path(target_dir) 

32 

33 def __enter__(self): 

34 """Change to the target directory.""" 

35 # Handling circular imports with LoggingConfig 

36 logger_factory.get_logger(__name__).info(f'Changing to {self.target_dir}') 

37 if not self.target_dir.exists(): 

38 self.target_dir.mkdir() 

39 os.chdir(str(self.target_dir)) 

40 if self._temp_dir: 

41 return self._temp_dir.__enter__() 

42 

43 def __exit__(self, exc_type, exc_val, exc_tb): 

44 """Reset to the original directory.""" 

45 os.chdir(str(self.cwd)) 

46 if self._temp_dir: 

47 try: 

48 self.target_dir.__exit__(exc_type, exc_val, exc_tb) 

49 except PermissionError as e: 

50 # Handling circular imports with LoggingConfig 

51 logger_factory.get_logger(__name__).warn('Unable to delete temporary directory.') 

52 warnings.warn(e, stacklevel=2)