Coverage for .nox/test-3-12/lib/python3.12/site-packages/nskit/common/configuration/sources.py: 98%

40 statements  

« prev     ^ index     » next       coverage.py v7.3.3, created at 2023-12-19 17:42 +0000

1"""Add settings sources.""" 

2from pathlib import Path 

3from typing import Any, Dict, Tuple 

4 

5from pydantic.fields import FieldInfo 

6from pydantic_settings import PydanticBaseSettingsSource 

7 

8from nskit.common.io import json, toml, yaml 

9 

10 

11class FileConfigSettingsSource(PydanticBaseSettingsSource): 

12 """A simple settings source class that loads variables from a parsed file. 

13 

14 This can parse JSON, TOML, and YAML files based on the extensions. 

15 """ 

16 

17 def __init__(self, *args, **kwargs): 

18 """Initialise the Settings Source.""" 

19 super().__init__(*args, **kwargs) 

20 self.__parsed_contents = None 

21 

22 def get_field_value( 

23 self, field: FieldInfo, field_name: str # noqa: U100 

24 ) -> Tuple[Any, str, bool]: 

25 """Get a field value.""" 

26 if self.__parsed_contents is None: 

27 try: 

28 encoding = self.config.get('env_file_encoding', 'utf-8') 

29 file_path = Path(self.config.get('config_file_path')) 

30 file_type = self.config.get('config_file_type', None) 

31 file_contents = file_path.read_text(encoding) 

32 if file_path.suffix.lower() in ['.jsn', '.json'] or (file_type is not None and file_type.lower() == 'json'): 

33 self.__parsed_contents = json.loads(file_contents) 

34 elif file_path.suffix.lower() in ['.tml', '.toml'] or (file_type is not None and file_type.lower() == 'toml'): 

35 self.__parsed_contents = toml.loads(file_contents) 

36 elif file_path.suffix.lower() in ['.yml', '.yaml'] or (file_type is not None and file_type.lower() == 'yaml'): 

37 self.__parsed_contents = yaml.loads(file_contents) 

38 except Exception: 

39 pass # nosec B110 

40 if self.__parsed_contents is not None: 

41 field_value = self.__parsed_contents.get(field_name) 

42 else: 

43 field_value = None 

44 return field_value, field_name, False 

45 

46 def prepare_field_value( 

47 self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool # noqa: U100 

48 ) -> Any: 

49 """Prepare the field value.""" 

50 return value 

51 

52 def __call__(self) -> Dict[str, Any]: 

53 """Call the source.""" 

54 d: Dict[str, Any] = {} 

55 

56 for field_name, field in self.settings_cls.model_fields.items(): 

57 field_value, field_key, value_is_complex = self.get_field_value( 

58 field, field_name 

59 ) 

60 field_value = self.prepare_field_value( 

61 field_name, field, field_value, value_is_complex 

62 ) 

63 if field_value is not None: 

64 d[field_key] = field_value 

65 

66 return d 

67 

68 def _load_file(self, file_path: Path, encoding: str) -> Dict[str, Any]: # noqa: U100 

69 file_path = Path(file_path)