Coverage for .nox/test-3-9/lib/python3.9/site-packages/nskit/common/extensions.py: 97%

39 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-02-25 17:38 +0000

1"""Common extension helpers.""" 

2from enum import Enum 

3import sys 

4 

5if sys.version_info.major >= 3 and sys.version_info.minor >= 10: 

6 from importlib.metadata import entry_points 

7else: 

8 from backports.entry_points_selectable import entry_points 

9 

10from aenum import extend_enum 

11 

12from nskit._logging import logger_factory 

13 

14 

15def get_extension_names(entrypoint: str): 

16 """Get all installed extension names for a given entrypoint.""" 

17 extensions = [] 

18 for ep in entry_points().select(group=entrypoint): 

19 extensions.append(ep.name) 

20 logger_factory.get_logger(__name__).debug(f'Identified extensions {extensions} for entrypoint {entrypoint}') 

21 return extensions 

22 

23 

24def load_extension(entrypoint: str, extension: str): 

25 """Load a given extension for a given entrypoint.""" 

26 for ep in entry_points().select(group=entrypoint, name=extension): 

27 return ep.load() 

28 logger_factory.get_logger(__name__).warn(f'Entrypoint {extension} not found for {entrypoint}') 

29 

30 

31def get_extensions(entrypoint: str): 

32 """Load all extensions for a given entrypoint.""" 

33 extensions = {} 

34 for ep in entry_points().select(group=entrypoint): 

35 extensions[ep.name] = ep 

36 logger_factory.get_logger(__name__).debug(f'Identified extensions {extensions} for entrypoint {entrypoint}') 

37 return extensions 

38 

39 

40class ExtensionsEnum(Enum): 

41 """Enum created from available extensions on an entrypoint.""" 

42 

43 @classmethod 

44 def from_entrypoint(cls, name: str, entrypoint: str): 

45 """Create the enum with name, from entrypoint options.""" 

46 options = {u: u for u in get_extension_names(entrypoint)} 

47 kls = cls(name, options) 

48 kls.__entrypoint__ = entrypoint 

49 return kls 

50 

51 @property 

52 def extension(self): 

53 """Load the extension.""" 

54 return load_extension(self.__entrypoint__, self.value) 

55 

56 @classmethod 

57 def _patch(cls): 

58 """Used for testing and patching objects.""" 

59 options = {u: u for u in get_extension_names(cls.__entrypoint__)} 

60 # Loop over options not in members 

61 for key in options: 

62 if key not in cls._member_names_: 

63 extend_enum(cls, key, key)