Coverage for .nox/test-3-9/lib/python3.9/site-packages/nskit/mixer/utilities.py: 94%
82 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-25 17:38 +0000
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-25 17:38 +0000
1"""Utilities for interacting with systems etc."""
2import os
3from pathlib import Path
4import sys
5from typing import Any
7if sys.version_info.major <= 3 and sys.version_info.minor < 9:
8 from importlib_resources import files
9else:
10 from importlib.resources import files
12from jinja2 import BaseLoader, ChoiceLoader, Environment, TemplateNotFound
13from pydantic import GetCoreSchemaHandler, TypeAdapter, ValidationError
14from pydantic_core import core_schema, CoreSchema
16from nskit.common.extensions import ExtensionsEnum
19class Resource(str):
20 """A type for a package resource uri."""
22 @classmethod
23 def __get_pydantic_core_schema__(
24 cls, source_type: Any, handler: GetCoreSchemaHandler # noqa: U100
25 ) -> CoreSchema:
26 """Get the schema."""
27 return core_schema.no_info_after_validator_function(cls._validate_resource, handler(str))
29 @classmethod
30 def _validate_resource(cls, value: str):
31 resource_string_example = "<package>.<module>:<resource filename>"
32 if ':' not in value:
33 raise ValueError(f'Value should be a resource string, looking like {resource_string_example}.')
34 parts = value.split(':')
35 if len(parts) != 2:
36 raise ValueError(f'Value should be a resource string, looking like {resource_string_example}.')
37 path, filename = parts
38 # filename should be a valid filename
39 if len(Path(filename).parts) != 1:
40 raise ValueError(f'The part after the colon ({filename}) should be a valid filename as part of the resource string ({resource_string_example}).')
41 # path should be a valid module path
42 invalid_path = [u in path for u in [' ', '-', '*', '(', ')']]
43 if any(invalid_path):
44 raise ValueError(f'The part before the colon ({path}) should be a valid python module path as part of the resource string ({resource_string_example}).')
45 return cls(value)
47 def load(self):
48 """Load the resource using importlib.resources."""
49 path, filename = self.split(':')
50 with files(path).joinpath(filename) as p:
51 return p.open().read()
53 @classmethod
54 def validate(cls, value):
55 """Validate the input."""
56 ta = TypeAdapter(Resource)
57 return ta.validate_python(value)
60class _PkgResourcesTemplateLoader(BaseLoader):
61 """Load jinja templates via imporlib.resources."""
63 @staticmethod
64 def get_source(environment, template): # noqa: U100
65 """Get the source using imporlib.resources."""
66 try:
67 Resource.validate(template)
68 except ValidationError as e:
69 raise TemplateNotFound(template, *e.args)
70 resource = Resource(template)
71 try:
72 source = resource.load()
73 except FileNotFoundError:
74 raise TemplateNotFound(template)
75 return source, None, lambda: True
78class _EnvironmentFactory():
79 """Jinja2 Environment Factory to allow for extension/customisation."""
81 def __init__(self):
82 """Initialise the factory."""
83 self._environment = None
85 @property
86 def environment(self) -> Environment:
87 """Handle caching the environment object so it is lazily initialised."""
88 if self._environment is None:
89 self._environment = self.get_environment()
90 self.add_extensions(self._environment)
91 return self._environment
93 def add_extensions(self, environment: Environment):
94 """Add Extensions to the environment object."""
95 NskitMixerExtensionOptions = ExtensionsEnum.from_entrypoint('NskitMixerExtensionOptions', 'nskit.mixer.environment.extensions')
96 # Assuming no risk of extension clash
97 extensions = []
98 for ext in NskitMixerExtensionOptions:
99 extensions += ext.extension()
100 for extension in list(set(extensions)):
101 environment.add_extension(extension)
103 def get_environment(self) -> Environment:
104 """Get the environment object based on the env var."""
105 selected_method = os.environ.get('NSKIT_MIXER_ENVIRONMENT_FACTORY', None)
106 if selected_method is None or selected_method.lower() == 'default':
107 # This is our simple implementation
108 selected_method = 'default'
109 # We need to validate against the options
110 NskitMixerEnvironmentOptions = ExtensionsEnum.from_entrypoint('NskitMixerEnvironmentOptions', 'nskit.mixer.environment.factory')
111 if sys.version_info.major <= 3 and sys.version_info.minor < 12:
112 if selected_method not in NskitMixerEnvironmentOptions.__members__.keys():
113 raise ValueError(f'NSKIT_MIXER_ENVIRONMENT_FACTORY value {selected_method} not installed - available options are {list(NskitMixerEnvironmentOptions)}')
114 else:
115 if selected_method not in NskitMixerEnvironmentOptions:
116 raise ValueError(f'NSKIT_MIXER_ENVIRONMENT_FACTORY value {selected_method} not installed - available options are {list(NskitMixerEnvironmentOptions)}')
117 return NskitMixerEnvironmentOptions(selected_method).extension()
119 @staticmethod
120 def default_environment():
121 """Get the default environment object."""
122 return Environment(loader=ChoiceLoader([_PkgResourcesTemplateLoader()])) # nosec B701
125JINJA_ENVIRONMENT_FACTORY = _EnvironmentFactory()