Coverage for .nox/test-3-9/lib/python3.9/site-packages/nskit/mixer/utilities.py: 94%

82 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-02-25 17:38 +0000

1"""Utilities for interacting with systems etc.""" 

2import os 

3from pathlib import Path 

4import sys 

5from typing import Any 

6 

7if sys.version_info.major <= 3 and sys.version_info.minor < 9: 

8 from importlib_resources import files 

9else: 

10 from importlib.resources import files 

11 

12from jinja2 import BaseLoader, ChoiceLoader, Environment, TemplateNotFound 

13from pydantic import GetCoreSchemaHandler, TypeAdapter, ValidationError 

14from pydantic_core import core_schema, CoreSchema 

15 

16from nskit.common.extensions import ExtensionsEnum 

17 

18 

19class Resource(str): 

20 """A type for a package resource uri.""" 

21 

22 @classmethod 

23 def __get_pydantic_core_schema__( 

24 cls, source_type: Any, handler: GetCoreSchemaHandler # noqa: U100 

25 ) -> CoreSchema: 

26 """Get the schema.""" 

27 return core_schema.no_info_after_validator_function(cls._validate_resource, handler(str)) 

28 

29 @classmethod 

30 def _validate_resource(cls, value: str): 

31 resource_string_example = "<package>.<module>:<resource filename>" 

32 if ':' not in value: 

33 raise ValueError(f'Value should be a resource string, looking like {resource_string_example}.') 

34 parts = value.split(':') 

35 if len(parts) != 2: 

36 raise ValueError(f'Value should be a resource string, looking like {resource_string_example}.') 

37 path, filename = parts 

38 # filename should be a valid filename 

39 if len(Path(filename).parts) != 1: 

40 raise ValueError(f'The part after the colon ({filename}) should be a valid filename as part of the resource string ({resource_string_example}).') 

41 # path should be a valid module path 

42 invalid_path = [u in path for u in [' ', '-', '*', '(', ')']] 

43 if any(invalid_path): 

44 raise ValueError(f'The part before the colon ({path}) should be a valid python module path as part of the resource string ({resource_string_example}).') 

45 return cls(value) 

46 

47 def load(self): 

48 """Load the resource using importlib.resources.""" 

49 path, filename = self.split(':') 

50 with files(path).joinpath(filename) as p: 

51 return p.open().read() 

52 

53 @classmethod 

54 def validate(cls, value): 

55 """Validate the input.""" 

56 ta = TypeAdapter(Resource) 

57 return ta.validate_python(value) 

58 

59 

60class _PkgResourcesTemplateLoader(BaseLoader): 

61 """Load jinja templates via imporlib.resources.""" 

62 

63 @staticmethod 

64 def get_source(environment, template): # noqa: U100 

65 """Get the source using imporlib.resources.""" 

66 try: 

67 Resource.validate(template) 

68 except ValidationError as e: 

69 raise TemplateNotFound(template, *e.args) 

70 resource = Resource(template) 

71 try: 

72 source = resource.load() 

73 except FileNotFoundError: 

74 raise TemplateNotFound(template) 

75 return source, None, lambda: True 

76 

77 

78class _EnvironmentFactory(): 

79 """Jinja2 Environment Factory to allow for extension/customisation.""" 

80 

81 def __init__(self): 

82 """Initialise the factory.""" 

83 self._environment = None 

84 

85 @property 

86 def environment(self) -> Environment: 

87 """Handle caching the environment object so it is lazily initialised.""" 

88 if self._environment is None: 

89 self._environment = self.get_environment() 

90 self.add_extensions(self._environment) 

91 return self._environment 

92 

93 def add_extensions(self, environment: Environment): 

94 """Add Extensions to the environment object.""" 

95 NskitMixerExtensionOptions = ExtensionsEnum.from_entrypoint('NskitMixerExtensionOptions', 'nskit.mixer.environment.extensions') 

96 # Assuming no risk of extension clash 

97 extensions = [] 

98 for ext in NskitMixerExtensionOptions: 

99 extensions += ext.extension() 

100 for extension in list(set(extensions)): 

101 environment.add_extension(extension) 

102 

103 def get_environment(self) -> Environment: 

104 """Get the environment object based on the env var.""" 

105 selected_method = os.environ.get('NSKIT_MIXER_ENVIRONMENT_FACTORY', None) 

106 if selected_method is None or selected_method.lower() == 'default': 

107 # This is our simple implementation 

108 selected_method = 'default' 

109 # We need to validate against the options 

110 NskitMixerEnvironmentOptions = ExtensionsEnum.from_entrypoint('NskitMixerEnvironmentOptions', 'nskit.mixer.environment.factory') 

111 if sys.version_info.major <= 3 and sys.version_info.minor < 12: 

112 if selected_method not in NskitMixerEnvironmentOptions.__members__.keys(): 

113 raise ValueError(f'NSKIT_MIXER_ENVIRONMENT_FACTORY value {selected_method} not installed - available options are {list(NskitMixerEnvironmentOptions)}') 

114 else: 

115 if selected_method not in NskitMixerEnvironmentOptions: 

116 raise ValueError(f'NSKIT_MIXER_ENVIRONMENT_FACTORY value {selected_method} not installed - available options are {list(NskitMixerEnvironmentOptions)}') 

117 return NskitMixerEnvironmentOptions(selected_method).extension() 

118 

119 @staticmethod 

120 def default_environment(): 

121 """Get the default environment object.""" 

122 return Environment(loader=ChoiceLoader([_PkgResourcesTemplateLoader()])) # nosec B701 

123 

124 

125JINJA_ENVIRONMENT_FACTORY = _EnvironmentFactory()