Coverage for .nox/test-3-9/lib/python3.9/site-packages/nskit/mixer/components/recipe.py: 93%

113 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-02-25 17:38 +0000

1"""The base recipe object.""" 

2import datetime as dt 

3import inspect 

4from pathlib import Path 

5import sys 

6from typing import List, Optional 

7 

8from pydantic import BaseModel, Field 

9 

10from nskit import __version__ 

11from nskit.common.extensions import get_extension_names, load_extension 

12from nskit.common.io import yaml 

13from nskit.mixer.components.folder import Folder 

14from nskit.mixer.components.hook import Hook 

15 

16RECIPE_ENTRYPOINT = 'nskit.recipes' 

17 

18 

19class Recipe(Folder): 

20 """The base Recipe object. 

21 

22 A Recipe is a folder, with additional methods for handling context for the Jinja templating. 

23 It also includes hooks that can be run before rendering (``pre-hooks``) e.g. checking or changing values, 

24 and after (``post-hooks``) e.g. running post-generation steps. 

25 """ 

26 name: str = Field(None, validate_default=True, description='The repository name') 

27 version: Optional[str] = Field(None, description='The recipe version') # type: ignore 

28 pre_hooks: List[Hook] = Field( 

29 default_factory=list, 

30 validate_default=True, 

31 description='Hooks that can be used to modify a recipe path and context before writing' 

32 ) 

33 post_hooks: List[Hook] = Field( 

34 default_factory=list, 

35 validate_default=True, 

36 description='Hooks that can be used to modify a recipe path and context after writing' 

37 ) 

38 extension_name: Optional[str] = Field(None, description="The name of the recipe as an extension to load.") 

39 

40 @property 

41 def recipe(self): 

42 """Recipe context.""" 

43 extension_name = self.extension_name 

44 if extension_name is None: 

45 extension_name = self.__class__.__name__ 

46 return {'name': f'{self.__class__.__module__}:{self.__class__.__name__}', 'version': self.version, 'extension_name': extension_name} 

47 

48 def create(self, base_path: Optional[Path] = None, override_path: Optional[Path] = None, **additional_context): 

49 """Create the recipe. 

50 

51 Use the configured parameters and any additional context as kwargs to create the recipe at the 

52 base path (or current directory if not provided). 

53 """ 

54 if base_path is None: 

55 base_path = Path.cwd() 

56 else: 

57 base_path = Path(base_path) 

58 context = self.context 

59 context.update(additional_context) 

60 recipe_path = self.get_path(base_path, context, override_path=override_path) 

61 for hook in self.pre_hooks: 

62 recipe_path, context = hook(recipe_path, context) 

63 content = self.write(recipe_path.parent, context, override_path=recipe_path.name) 

64 recipe_path = list(content.keys())[0] 

65 for hook in self.post_hooks: 

66 recipe_path, context = hook(recipe_path, context) 

67 self._write_batch(Path(recipe_path)) 

68 return {Path(recipe_path): list(content.values())[0]} 

69 

70 def _write_batch(self, folder_path: Path): 

71 """Write out the parameters used. 

72 

73 When we use this we want to keep track of what parameters were used to enable rerunning. 

74 This methods writes this into the generated folder as a YAML file. 

75 """ 

76 batch_path = Path(folder_path)/'.recipe-batch.yaml' 

77 if batch_path.exists(): 

78 with batch_path.open() as f: 

79 batch = yaml.loads(f.read()) 

80 else: 

81 batch = [] 

82 batch.append(self.recipe_batch) 

83 with batch_path.open('w') as f: 

84 f.write(yaml.dumps(batch)) 

85 

86 @property 

87 def recipe_batch(self): 

88 """Get information about the specific info of this recipe.""" 

89 if sys.version_info.major <= 3 and sys.version_info.minor < 11: 

90 creation_time = dt.datetime.now().astimezone() 

91 else: 

92 creation_time = dt.datetime.now(dt.UTC).isoformat() 

93 return {'context': self.__dump_context(ser=True), 

94 'nskit_version': __version__, 

95 'creation_time': creation_time, 

96 'recipe': self.recipe} 

97 

98 @property 

99 def context(self): 

100 """Get the context on the initialised recipe.""" 

101 # This inherits (via FileSystemObject) from nskit.common.configuration:BaseConfiguration, which includes properties in model dumps 

102 return self.__dump_context() 

103 

104 def __dump_context(self, ser=False): 

105 # Make sure it is serialisable if required 

106 if ser: 

107 mode = 'json' 

108 else: 

109 mode = 'python' 

110 context = self.model_dump( 

111 mode=mode, 

112 exclude={ 

113 'context', 

114 'contents', 

115 'name', 

116 'id_', 

117 'post_hooks', 

118 'pre_hooks', 

119 'version', 

120 'recipe_batch', 

121 'recipe', 

122 'extension_name', 

123 } 

124 ) 

125 context.update({'recipe': self.recipe}) 

126 return context 

127 

128 def __repr__(self): 

129 """Repr(x) == x.__repr__.""" 

130 context = self.context 

131 return f'{self._repr(context=context)}\n\nContext: {context}' 

132 

133 def dryrun( 

134 self, 

135 base_path: Optional[Path] = None, 

136 override_path: Optional[Path] = None, 

137 **additional_context): 

138 """See the recipe as a dry run.""" 

139 combined_context = self.context 

140 combined_context.update(additional_context) 

141 if base_path is None: 

142 base_path = Path.cwd() 

143 return super().dryrun(base_path=base_path, context=combined_context, override_path=override_path) 

144 

145 def validate( 

146 self, 

147 base_path: Optional[Path] = None, 

148 override_path: Optional[Path] = None, 

149 **additional_context): 

150 """Validate the created repo.""" 

151 combined_context = self.context 

152 combined_context.update(additional_context) 

153 if base_path is None: 

154 base_path = Path.cwd() 

155 return super().validate(base_path=base_path, context=combined_context, override_path=override_path) 

156 

157 @staticmethod 

158 def load(recipe_name: str, **kwargs): 

159 """Load a recipe as an extension.""" 

160 recipe_klass = load_extension(RECIPE_ENTRYPOINT, recipe_name) 

161 if recipe_klass is None: 

162 raise ValueError(f'Recipe {recipe_name} not found, it may be mis-spelt or not installed. Available recipes: {get_extension_names(RECIPE_ENTRYPOINT)}') 

163 recipe = recipe_klass(**kwargs) 

164 recipe.extension_name = recipe_name 

165 return recipe 

166 

167 @staticmethod 

168 def inspect(recipe_name: str, include_private: bool = False, include_folder: bool = False, include_base: bool = False): 

169 """Get the fields on a recipe as an extension.""" 

170 recipe_klass = load_extension(RECIPE_ENTRYPOINT, recipe_name) 

171 if recipe_klass is None: 

172 raise ValueError(f'Recipe {recipe_name} not found, it may be mis-spelt or not installed. Available recipes: {get_extension_names(RECIPE_ENTRYPOINT)}') 

173 sig = Recipe._inspect_basemodel(recipe_klass, include_private=include_private) 

174 if not include_folder: 

175 folder_sig = inspect.signature(Folder) 

176 params = [v for u, v in sig.parameters.items() if u not in folder_sig.parameters.keys() or u == 'name'] 

177 sig = sig.replace(parameters=params) 

178 if not include_base: 

179 recipe_sig = inspect.signature(Recipe) 

180 params = [v for u, v in sig.parameters.items() if u not in recipe_sig.parameters.keys() or u == 'name'] 

181 sig = sig.replace(parameters=params) 

182 return sig 

183 

184 @staticmethod 

185 def _inspect_basemodel(kls, include_private: bool = False): 

186 sig = inspect.signature(kls) 

187 # we need to drop the private params 

188 params = [] 

189 for u, v in sig.parameters.items(): 

190 if not include_private and u.startswith('_'): 

191 continue 

192 if isinstance(v.annotation, type) and issubclass(v.annotation, BaseModel): 

193 params.append(v.replace(default=Recipe._inspect_basemodel(v.annotation, include_private=include_private))) 

194 else: 

195 params.append(v) 

196 return sig.replace(parameters=params, return_annotation=kls)