Coverage for .nox/test-3-9/lib/python3.9/site-packages/nskit/common/contextmanagers/test_extensions.py: 95%

62 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-02-25 17:38 +0000

1"""Context manager for running a test with an extension.""" 

2from __future__ import annotations 

3 

4from contextlib import ContextDecorator 

5from importlib.metadata import Distribution, MetadataPathFinder 

6from pathlib import Path 

7import sys 

8from typing import Any 

9 

10if sys.version_info.major <= 3 and sys.version_info.minor < 11: 

11 from importlib_metadata import EntryPoint 

12else: 

13 from importlib.metadata import EntryPoint 

14 

15from nskit._logging import logger_factory 

16 

17 

18class _TestEntrypoint(EntryPoint): 

19 

20 def __init__(self, name: str, group: str, entrypoint: type, *, solo: bool = False): 

21 super().__init__(name, f'{entrypoint.__module__}:{entrypoint.__name__}', group) 

22 self._entrypoint = entrypoint 

23 self._sys_meta_path = None 

24 self._solo = solo 

25 

26 def __setattr__(self, __name: str, __value: Any) -> None: 

27 return object.__setattr__(self, __name, __value) 

28 

29 def load(self): 

30 return self._entrypoint 

31 

32 def start(self): 

33 self._sys_meta_path = sys.meta_path[:] 

34 nested = [u for u in sys.meta_path if isinstance(u, _TestExtensionFinder)] 

35 if self._solo: 

36 sys.meta_path = [_TestExtensionFinder(self)] 

37 elif nested: 

38 nested[0]._entrypoints.append(self) 

39 else: 

40 sys.meta_path.append(_TestExtensionFinder(self)) 

41 

42 def stop(self): 

43 sys.meta_path = self._sys_meta_path[:] 

44 

45 

46class _DummyDistribution(Distribution): 

47 

48 def __init__(self, i: int, entrypoint: EntryPoint): 

49 self._entrypoint = entrypoint 

50 self._i = i 

51 

52 @property 

53 def metadata(self): 

54 return { 

55 'Name': f'DummyDistribution{self._i}' 

56 } 

57 

58 @property 

59 def entry_points(self): 

60 return [self._entrypoint] 

61 

62 def read_text(self, filename: str) -> str | None: # noqa: U100 

63 return None 

64 

65 def locate_file(self, path: Path) -> None: # noqa: U100 

66 return None 

67 

68 

69class _TestExtensionFinder(MetadataPathFinder): 

70 

71 def __init__(self, entrypoint: EntryPoint): 

72 self._entrypoints = [entrypoint] 

73 

74 @property 

75 def extension_distributions(self): 

76 return [_DummyDistribution(i, u) for i, u in enumerate(self._entrypoints)] 

77 

78 def find_distributions(self, *args, **kwargs): # noqa: U100 

79 # Return the dummy Distribution 

80 return self.extension_distributions 

81 

82 

83class TestExtension(ContextDecorator): 

84 """Context manager for running a test of an entrypoint.""" 

85 

86 def __init__(self, name: str, group: str, entrypoint: type, *, solo: bool = False): 

87 """Initialise the context manager. 

88 

89 Keyword Args: 

90 name (str): the extension name 

91 group (str): the extension group 

92 entrypoint (type): the object/type to load in the entrypoint 

93 solo (bool): set so only that entrypoint will be found (can cause side-effects) 

94 """ 

95 self.ep = _TestEntrypoint(name=name, group=group, entrypoint=entrypoint, solo=solo) 

96 self._clean = False 

97 

98 def __enter__(self): 

99 """Add the extension so it can be loaded.""" 

100 logger_factory.get_logger(__name__).info(f'Starting entrypoint for extension {self.ep.name} in {self.ep.group}') 

101 self.ep.start() 

102 

103 def __exit__(self, *args, **kwargs): # noqa: U100 

104 """Remove the extension and return to the original.""" 

105 logger_factory.get_logger(__name__).info(f'Stoppings entrypoint for extension {self.ep.name} in {self.ep.group}') 

106 self.ep.stop()