Coverage for .nox/test-3-9/lib/python3.9/site-packages/nskit/common/extensions.py: 97%
39 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-25 17:38 +0000
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-25 17:38 +0000
1"""Common extension helpers."""
2from enum import Enum
3import sys
5if sys.version_info.major >= 3 and sys.version_info.minor >= 10:
6 from importlib.metadata import entry_points
7else:
8 from backports.entry_points_selectable import entry_points
10from aenum import extend_enum
12from nskit._logging import logger_factory
15def get_extension_names(entrypoint: str):
16 """Get all installed extension names for a given entrypoint."""
17 extensions = []
18 for ep in entry_points().select(group=entrypoint):
19 extensions.append(ep.name)
20 logger_factory.get_logger(__name__).debug(f'Identified extensions {extensions} for entrypoint {entrypoint}')
21 return extensions
24def load_extension(entrypoint: str, extension: str):
25 """Load a given extension for a given entrypoint."""
26 for ep in entry_points().select(group=entrypoint, name=extension):
27 return ep.load()
28 logger_factory.get_logger(__name__).warn(f'Entrypoint {extension} not found for {entrypoint}')
31def get_extensions(entrypoint: str):
32 """Load all extensions for a given entrypoint."""
33 extensions = {}
34 for ep in entry_points().select(group=entrypoint):
35 extensions[ep.name] = ep
36 logger_factory.get_logger(__name__).debug(f'Identified extensions {extensions} for entrypoint {entrypoint}')
37 return extensions
40class ExtensionsEnum(Enum):
41 """Enum created from available extensions on an entrypoint."""
43 @classmethod
44 def from_entrypoint(cls, name: str, entrypoint: str):
45 """Create the enum with name, from entrypoint options."""
46 options = {u: u for u in get_extension_names(entrypoint)}
47 kls = cls(name, options)
48 kls.__entrypoint__ = entrypoint
49 return kls
51 @property
52 def extension(self):
53 """Load the extension."""
54 return load_extension(self.__entrypoint__, self.value)
56 @classmethod
57 def _patch(cls):
58 """Used for testing and patching objects."""
59 options = {u: u for u in get_extension_names(cls.__entrypoint__)}
60 # Loop over options not in members
61 for key in options:
62 if key not in cls._member_names_:
63 extend_enum(cls, key, key)