Coverage for .nox/test-3-9/lib/python3.9/site-packages/nskit/mixer/components/recipe.py: 93%
113 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-25 17:38 +0000
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-25 17:38 +0000
1"""The base recipe object."""
2import datetime as dt
3import inspect
4from pathlib import Path
5import sys
6from typing import List, Optional
8from pydantic import BaseModel, Field
10from nskit import __version__
11from nskit.common.extensions import get_extension_names, load_extension
12from nskit.common.io import yaml
13from nskit.mixer.components.folder import Folder
14from nskit.mixer.components.hook import Hook
16RECIPE_ENTRYPOINT = 'nskit.recipes'
19class Recipe(Folder):
20 """The base Recipe object.
22 A Recipe is a folder, with additional methods for handling context for the Jinja templating.
23 It also includes hooks that can be run before rendering (``pre-hooks``) e.g. checking or changing values,
24 and after (``post-hooks``) e.g. running post-generation steps.
25 """
26 name: str = Field(None, validate_default=True, description='The repository name')
27 version: Optional[str] = Field(None, description='The recipe version') # type: ignore
28 pre_hooks: List[Hook] = Field(
29 default_factory=list,
30 validate_default=True,
31 description='Hooks that can be used to modify a recipe path and context before writing'
32 )
33 post_hooks: List[Hook] = Field(
34 default_factory=list,
35 validate_default=True,
36 description='Hooks that can be used to modify a recipe path and context after writing'
37 )
38 extension_name: Optional[str] = Field(None, description="The name of the recipe as an extension to load.")
40 @property
41 def recipe(self):
42 """Recipe context."""
43 extension_name = self.extension_name
44 if extension_name is None:
45 extension_name = self.__class__.__name__
46 return {'name': f'{self.__class__.__module__}:{self.__class__.__name__}', 'version': self.version, 'extension_name': extension_name}
48 def create(self, base_path: Optional[Path] = None, override_path: Optional[Path] = None, **additional_context):
49 """Create the recipe.
51 Use the configured parameters and any additional context as kwargs to create the recipe at the
52 base path (or current directory if not provided).
53 """
54 if base_path is None:
55 base_path = Path.cwd()
56 else:
57 base_path = Path(base_path)
58 context = self.context
59 context.update(additional_context)
60 recipe_path = self.get_path(base_path, context, override_path=override_path)
61 for hook in self.pre_hooks:
62 recipe_path, context = hook(recipe_path, context)
63 content = self.write(recipe_path.parent, context, override_path=recipe_path.name)
64 recipe_path = list(content.keys())[0]
65 for hook in self.post_hooks:
66 recipe_path, context = hook(recipe_path, context)
67 self._write_batch(Path(recipe_path))
68 return {Path(recipe_path): list(content.values())[0]}
70 def _write_batch(self, folder_path: Path):
71 """Write out the parameters used.
73 When we use this we want to keep track of what parameters were used to enable rerunning.
74 This methods writes this into the generated folder as a YAML file.
75 """
76 batch_path = Path(folder_path)/'.recipe-batch.yaml'
77 if batch_path.exists():
78 with batch_path.open() as f:
79 batch = yaml.loads(f.read())
80 else:
81 batch = []
82 batch.append(self.recipe_batch)
83 with batch_path.open('w') as f:
84 f.write(yaml.dumps(batch))
86 @property
87 def recipe_batch(self):
88 """Get information about the specific info of this recipe."""
89 if sys.version_info.major <= 3 and sys.version_info.minor < 11:
90 creation_time = dt.datetime.now().astimezone()
91 else:
92 creation_time = dt.datetime.now(dt.UTC).isoformat()
93 return {'context': self.__dump_context(ser=True),
94 'nskit_version': __version__,
95 'creation_time': creation_time,
96 'recipe': self.recipe}
98 @property
99 def context(self):
100 """Get the context on the initialised recipe."""
101 # This inherits (via FileSystemObject) from nskit.common.configuration:BaseConfiguration, which includes properties in model dumps
102 return self.__dump_context()
104 def __dump_context(self, ser=False):
105 # Make sure it is serialisable if required
106 if ser:
107 mode = 'json'
108 else:
109 mode = 'python'
110 context = self.model_dump(
111 mode=mode,
112 exclude={
113 'context',
114 'contents',
115 'name',
116 'id_',
117 'post_hooks',
118 'pre_hooks',
119 'version',
120 'recipe_batch',
121 'recipe',
122 'extension_name',
123 }
124 )
125 context.update({'recipe': self.recipe})
126 return context
128 def __repr__(self):
129 """Repr(x) == x.__repr__."""
130 context = self.context
131 return f'{self._repr(context=context)}\n\nContext: {context}'
133 def dryrun(
134 self,
135 base_path: Optional[Path] = None,
136 override_path: Optional[Path] = None,
137 **additional_context):
138 """See the recipe as a dry run."""
139 combined_context = self.context
140 combined_context.update(additional_context)
141 if base_path is None:
142 base_path = Path.cwd()
143 return super().dryrun(base_path=base_path, context=combined_context, override_path=override_path)
145 def validate(
146 self,
147 base_path: Optional[Path] = None,
148 override_path: Optional[Path] = None,
149 **additional_context):
150 """Validate the created repo."""
151 combined_context = self.context
152 combined_context.update(additional_context)
153 if base_path is None:
154 base_path = Path.cwd()
155 return super().validate(base_path=base_path, context=combined_context, override_path=override_path)
157 @staticmethod
158 def load(recipe_name: str, **kwargs):
159 """Load a recipe as an extension."""
160 recipe_klass = load_extension(RECIPE_ENTRYPOINT, recipe_name)
161 if recipe_klass is None:
162 raise ValueError(f'Recipe {recipe_name} not found, it may be mis-spelt or not installed. Available recipes: {get_extension_names(RECIPE_ENTRYPOINT)}')
163 recipe = recipe_klass(**kwargs)
164 recipe.extension_name = recipe_name
165 return recipe
167 @staticmethod
168 def inspect(recipe_name: str, include_private: bool = False, include_folder: bool = False, include_base: bool = False):
169 """Get the fields on a recipe as an extension."""
170 recipe_klass = load_extension(RECIPE_ENTRYPOINT, recipe_name)
171 if recipe_klass is None:
172 raise ValueError(f'Recipe {recipe_name} not found, it may be mis-spelt or not installed. Available recipes: {get_extension_names(RECIPE_ENTRYPOINT)}')
173 sig = Recipe._inspect_basemodel(recipe_klass, include_private=include_private)
174 if not include_folder:
175 folder_sig = inspect.signature(Folder)
176 params = [v for u, v in sig.parameters.items() if u not in folder_sig.parameters.keys() or u == 'name']
177 sig = sig.replace(parameters=params)
178 if not include_base:
179 recipe_sig = inspect.signature(Recipe)
180 params = [v for u, v in sig.parameters.items() if u not in recipe_sig.parameters.keys() or u == 'name']
181 sig = sig.replace(parameters=params)
182 return sig
184 @staticmethod
185 def _inspect_basemodel(kls, include_private: bool = False):
186 sig = inspect.signature(kls)
187 # we need to drop the private params
188 params = []
189 for u, v in sig.parameters.items():
190 if not include_private and u.startswith('_'):
191 continue
192 if isinstance(v.annotation, type) and issubclass(v.annotation, BaseModel):
193 params.append(v.replace(default=Recipe._inspect_basemodel(v.annotation, include_private=include_private)))
194 else:
195 params.append(v)
196 return sig.replace(parameters=params, return_annotation=kls)