Load config (#276)
This commit is contained in:
Родитель
ba8d4af83f
Коммит
96e841b2a6
|
@ -1,5 +1,5 @@
|
|||
"""Main configuration API."""
|
||||
|
||||
from .config import Prefixed, config
|
||||
from .config import Prefixed, load_config
|
||||
|
||||
__all__ = ["Prefixed", "config"]
|
||||
__all__ = ["Prefixed", "load_config"]
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
"""Main configuration module for the essex_config package."""
|
||||
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from functools import cache
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
Protocol,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -24,27 +21,43 @@ DEFAULT_SOURCE_LIST: list[Source] = [EnvSource()]
|
|||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
V = TypeVar("V", bound=BaseModel, covariant=True)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Config(Protocol[V]):
|
||||
"""Protocol to define the configuration class."""
|
||||
def load_config(
|
||||
cls: type[T],
|
||||
*,
|
||||
sources: list[Source] = DEFAULT_SOURCE_LIST,
|
||||
prefix: str = "",
|
||||
inner: bool = False,
|
||||
refresh_config: bool = False,
|
||||
) -> T:
|
||||
"""Instantiate the configuration and all values.
|
||||
|
||||
@classmethod
|
||||
def load_config(
|
||||
cls: type,
|
||||
sources: list[Source] | None = None,
|
||||
*,
|
||||
prefix: str | None = None,
|
||||
refresh_config: bool = False,
|
||||
) -> V: # pragma: no cover
|
||||
"""Load the configuration values."""
|
||||
...
|
||||
Parameters
|
||||
----------
|
||||
sources: tuple[Source], optional
|
||||
A tuple of sources to use to get the values.
|
||||
prefix: str, optional
|
||||
The prefix name to use to look for the values in the sources, by default ""
|
||||
refresh_config: bool, optional
|
||||
If True, the cache is cleared before loading the configuration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
T: Instance of the configuration class.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If any of the values is not found.
|
||||
"""
|
||||
if refresh_config:
|
||||
_load_config.cache_clear()
|
||||
return _load_config(cls, tuple(sources), prefix, inner)
|
||||
|
||||
|
||||
@cache
|
||||
def load_config(
|
||||
def _load_config(
|
||||
cls: type[T],
|
||||
sources: tuple[Source, ...],
|
||||
prefix: str = "",
|
||||
|
@ -105,7 +118,7 @@ def load_config(
|
|||
if prefix_annotation is None:
|
||||
field_prefix += f".{name}" if field_prefix != "" else name
|
||||
try:
|
||||
values[name] = load_config(
|
||||
values[name] = _load_config(
|
||||
type_, sources, prefix=field_prefix, inner=True
|
||||
)
|
||||
except Exception: # noqa: S112, BLE001
|
||||
|
@ -117,7 +130,7 @@ def load_config(
|
|||
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
|
||||
if prefix_annotation is None:
|
||||
field_prefix += f".{name}" if field_prefix != "" else name
|
||||
values[name] = load_config(
|
||||
values[name] = _load_config(
|
||||
field_type, sources, prefix=field_prefix, inner=True
|
||||
)
|
||||
continue
|
||||
|
@ -166,55 +179,3 @@ def load_config(
|
|||
values[name] = value
|
||||
|
||||
return cls.model_validate(values)
|
||||
|
||||
|
||||
def _add_docs(cls: type[Config], sources: list[Source], prefix: str):
|
||||
if cls.__doc__ is None:
|
||||
cls.__doc__ = ""
|
||||
|
||||
new_docs = f"*Configuration Prefix: {prefix}*\n\n" if prefix != "" else ""
|
||||
new_docs += "\n\n".join([
|
||||
line.strip() for line in cls.__doc__.split("\n") if line.strip() != ""
|
||||
])
|
||||
|
||||
doc_sources = "\n".join([f"* {source!s}" for source in sources])
|
||||
new_docs += f"\n\n# Sources:\n\n{doc_sources}\n"
|
||||
|
||||
cls.__doc__ = new_docs
|
||||
|
||||
|
||||
def config(
|
||||
*, sources: list[Source] = DEFAULT_SOURCE_LIST, prefix: str = ""
|
||||
) -> Callable[[type[T]], type[Config[T]]]:
|
||||
"""Add configuration loading capabilities to BaseModel pydantic class."""
|
||||
|
||||
def wrapper(cls: type[T]) -> type[Config[T]]:
|
||||
_sources = sources
|
||||
_prefix = prefix
|
||||
|
||||
def load(
|
||||
cls: type[T],
|
||||
sources: list[Source] | None = None,
|
||||
*,
|
||||
prefix: str | None = None,
|
||||
refresh_config: bool = False,
|
||||
) -> T:
|
||||
if refresh_config:
|
||||
load_config.cache_clear()
|
||||
if sources is None:
|
||||
sources = _sources
|
||||
else:
|
||||
sources.extend(_sources)
|
||||
if prefix is None:
|
||||
prefix = _prefix
|
||||
|
||||
return load_config(cls, tuple(sources), prefix) # type: ignore
|
||||
|
||||
protocol_cls = cast(type[Config[T]], cls)
|
||||
protocol_cls.load_config = classmethod(load) # type: ignore
|
||||
|
||||
_add_docs(protocol_cls, _sources, _prefix)
|
||||
|
||||
return protocol_cls
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -18,7 +18,7 @@ if __name__ == "__main__":
|
|||
"package", help="Package name to search for configuration classes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"classes", help="Classes to generate documentation for.", nargs="*"
|
||||
"classes", help="Classes to generate documentation for.", nargs="+"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_nested",
|
||||
|
|
|
@ -3,11 +3,9 @@
|
|||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from essex_config.config import Config
|
||||
from essex_config.doc_gen.printer import ConfigurationPrinter
|
||||
|
||||
|
||||
|
@ -23,8 +21,8 @@ def __find_subclasses(package_name: str) -> list[type]:
|
|||
|
||||
# Iterate through all members of the module
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, Config) and obj is not Config:
|
||||
subclasses.append(cast(type[Config[BaseModel]], obj))
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseModel):
|
||||
subclasses.append(obj)
|
||||
|
||||
return subclasses
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[tool.poetry]
|
||||
name = "essex-config"
|
||||
version = "0.0.2"
|
||||
description = ""
|
||||
version = "0.0.3"
|
||||
description = "Python library for creating configuration objects that read from various sources, including files, environment variables, and Azure Key Vault."
|
||||
authors = ["Andres Morales <andresmor@microsoft.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import Annotated
|
|||
from pydantic import Field
|
||||
|
||||
import tests.integration.graphrag_config.defaults as defs
|
||||
from essex_config import config
|
||||
from essex_config.field_decorators import Parser
|
||||
from essex_config.sources.utils import plain_text_list_parser
|
||||
|
||||
|
@ -26,7 +25,6 @@ from .summarize_descriptions_config import SummarizeDescriptionsConfig
|
|||
from .text_embedding_config import TextEmbeddingConfig
|
||||
|
||||
|
||||
@config()
|
||||
class GraphRagConfig(LLMConfig):
|
||||
"""Base class for the Default-Configuration parameterization settings."""
|
||||
|
||||
|
|
|
@ -2,13 +2,14 @@ import os
|
|||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from essex_config import load_config
|
||||
from essex_config.sources import EnvSource, FileSource
|
||||
|
||||
from .graphrag_config import GraphRagConfig, LLMType, ReportingType, TextEmbeddingTarget
|
||||
|
||||
|
||||
def test_graphrag_config_defaults():
|
||||
config = GraphRagConfig.load_config()
|
||||
config = load_config(GraphRagConfig)
|
||||
assert config.root_dir == "."
|
||||
|
||||
|
||||
|
@ -20,7 +21,7 @@ def test_graphrag_config_defaults():
|
|||
clear=True,
|
||||
)
|
||||
def test_graphrag_api_key_override():
|
||||
config = GraphRagConfig.load_config(sources=[EnvSource(prefix="graphrag")])
|
||||
config = load_config(GraphRagConfig, sources=[EnvSource(prefix="graphrag")])
|
||||
assert config.llm.api_key == "abc123"
|
||||
|
||||
|
||||
|
@ -32,12 +33,13 @@ def test_graphrag_api_key_override():
|
|||
clear=True,
|
||||
)
|
||||
def test_graphrag_api_key_override_2():
|
||||
config = GraphRagConfig.load_config(sources=[EnvSource(prefix="graphrag")])
|
||||
config = load_config(GraphRagConfig, sources=[EnvSource(prefix="graphrag")])
|
||||
assert config.llm.api_key == "abc123"
|
||||
|
||||
|
||||
def test_graphrag_yml():
|
||||
config = GraphRagConfig.load_config(
|
||||
config = load_config(
|
||||
GraphRagConfig,
|
||||
sources=[
|
||||
EnvSource(prefix="graphrag"),
|
||||
FileSource(
|
||||
|
@ -45,7 +47,7 @@ def test_graphrag_yml():
|
|||
required=True,
|
||||
prefix="",
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
assert config.root_dir == "/some/path"
|
||||
assert config.encoding_model == "utf-812"
|
||||
|
@ -135,7 +137,7 @@ def test_graphrag_yml():
|
|||
clear=True,
|
||||
)
|
||||
def test_graphrag_config_env_vars():
|
||||
config = GraphRagConfig.load_config(sources=[EnvSource(prefix="graphrag")])
|
||||
config = load_config(GraphRagConfig, sources=[EnvSource(prefix="graphrag")])
|
||||
assert config.root_dir == "root"
|
||||
assert config.encoding_model == "model_xyz"
|
||||
assert config.reporting.type == ReportingType.console
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Annotated
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from essex_config.config import Prefixed, config
|
||||
from essex_config.config import Prefixed
|
||||
from essex_config.sources.env_source import EnvSource
|
||||
from essex_config.sources.file_source import FileSource
|
||||
from essex_config.sources.source import Alias
|
||||
|
@ -22,7 +22,6 @@ class NestedConfiguration(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
@config(prefix="test", sources=[EnvSource(), FileSource("config.yml")])
|
||||
class MainConfiguration(BaseModel):
|
||||
"""Main configuration class for the project.
|
||||
|
||||
|
|
|
@ -6,12 +6,6 @@ from . import sample_module
|
|||
|
||||
|
||||
def test_generate_docs():
|
||||
mock_printer = MagicMock()
|
||||
generate_docs(sample_module.__name__, mock_printer, [], False)
|
||||
mock_printer.print.assert_called_once()
|
||||
|
||||
|
||||
def test_generate_docs_filter():
|
||||
mock_printer = MagicMock()
|
||||
generate_docs(sample_module.__name__, mock_printer, ["MainConfiguration"], False)
|
||||
mock_printer.print.assert_called_once()
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
@ -51,8 +53,11 @@ def test_env_source_file_from_env():
|
|||
|
||||
|
||||
def test_env_source_file_invalid_file():
|
||||
source = EnvSource(file_path="wrong/.env.test", required=True)
|
||||
with pytest.raises(FileNotFoundError, match="File wrong/.env.test not found."):
|
||||
wrong_file_path = Path("wrong/.env.test")
|
||||
source = EnvSource(file_path=wrong_file_path, required=True)
|
||||
with pytest.raises(
|
||||
FileNotFoundError, match=re.escape(f"File {wrong_file_path!s} not found.")
|
||||
):
|
||||
source.get_value("TEST_VALUE", str)
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Annotated, Any, TypeVar
|
|||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from essex_config.config import config
|
||||
from essex_config import load_config
|
||||
from essex_config.field_decorators import Alias, Parser, Prefixed
|
||||
from essex_config.sources import Source
|
||||
from essex_config.sources.args_source import ArgSource
|
||||
|
@ -50,69 +50,72 @@ class MockSource(Source):
|
|||
|
||||
|
||||
def test_basic_config():
|
||||
@config(sources=[MockSource()])
|
||||
class BasicConfiguration(BaseModel):
|
||||
hello: str
|
||||
|
||||
basic_config = BasicConfiguration.load_config()
|
||||
basic_config = load_config(BasicConfiguration, sources=[MockSource()])
|
||||
assert basic_config.hello == "world"
|
||||
|
||||
assert type(basic_config) == BasicConfiguration
|
||||
|
||||
|
||||
def test_prefixed_config():
|
||||
@config(prefix="prefix", sources=[MockSource()])
|
||||
class PrefixedConfiguration(BaseModel):
|
||||
hello: str
|
||||
|
||||
basic_config = PrefixedConfiguration.load_config()
|
||||
basic_config = load_config(
|
||||
PrefixedConfiguration, sources=[MockSource()], prefix="prefix"
|
||||
)
|
||||
assert basic_config.hello == "world prefixed"
|
||||
|
||||
|
||||
def test_prefixed_source():
|
||||
@config(sources=[MockSource(prefix="prefix")])
|
||||
class PrefixedConfiguration(BaseModel):
|
||||
hello: str
|
||||
|
||||
basic_config = PrefixedConfiguration.load_config()
|
||||
basic_config = load_config(
|
||||
PrefixedConfiguration, sources=[MockSource(prefix="prefix")]
|
||||
)
|
||||
assert basic_config.hello == "world prefixed"
|
||||
|
||||
|
||||
def test_prefixed_source_overrides_config():
|
||||
@config(prefix="override_this", sources=[MockSource(prefix="prefix")])
|
||||
class PrefixedConfiguration(BaseModel):
|
||||
hello: str
|
||||
|
||||
basic_config = PrefixedConfiguration.load_config()
|
||||
basic_config = load_config(
|
||||
PrefixedConfiguration,
|
||||
sources=[MockSource(prefix="prefix")],
|
||||
prefix="override this",
|
||||
)
|
||||
assert basic_config.hello == "world prefixed"
|
||||
|
||||
|
||||
def test_alias_config():
|
||||
@config(sources=[MockSource()])
|
||||
class AliasConfiguration(BaseModel):
|
||||
hello: Annotated[str, Alias(MockSource, ["not_hello"])]
|
||||
|
||||
basic_config = AliasConfiguration.load_config()
|
||||
basic_config = load_config(AliasConfiguration, sources=[MockSource()])
|
||||
assert basic_config.hello == "not world"
|
||||
|
||||
|
||||
def test_prefixed_field_config():
|
||||
@config(sources=[MockSource()])
|
||||
class PrefixedFieldConfiguration(BaseModel):
|
||||
hello: Annotated[str, Prefixed("field_prefix")]
|
||||
no_prefix_field: int
|
||||
|
||||
basic_config = PrefixedFieldConfiguration.load_config()
|
||||
basic_config = load_config(PrefixedFieldConfiguration, sources=[MockSource()])
|
||||
assert basic_config.hello == "world prefixed field"
|
||||
assert basic_config.no_prefix_field == 1
|
||||
|
||||
|
||||
def test_prefixed_field_config_with_source_prefix():
|
||||
@config(sources=[MockSource(prefix="source_prefix")])
|
||||
class PrefixedFieldConfiguration(BaseModel):
|
||||
hello: Annotated[str, Prefixed("field_prefix")]
|
||||
|
||||
basic_config = PrefixedFieldConfiguration.load_config()
|
||||
basic_config = load_config(
|
||||
PrefixedFieldConfiguration, sources=[MockSource(prefix="source_prefix")]
|
||||
)
|
||||
assert basic_config.hello == "world prefixed field"
|
||||
|
||||
|
||||
|
@ -120,12 +123,11 @@ def test_nested_config():
|
|||
class Inner(BaseModel):
|
||||
hello: str
|
||||
|
||||
@config(sources=[MockSource()])
|
||||
class NestedConfiguration(BaseModel):
|
||||
hello: str
|
||||
nested: Inner
|
||||
|
||||
basic_config = NestedConfiguration.load_config()
|
||||
basic_config = load_config(NestedConfiguration, sources=[MockSource()])
|
||||
assert basic_config.hello == "world"
|
||||
assert basic_config.nested.hello == "nested world"
|
||||
|
||||
|
@ -134,12 +136,11 @@ def test_nested_prefixed_field_config():
|
|||
class Inner(BaseModel):
|
||||
hello: str
|
||||
|
||||
@config(sources=[MockSource()])
|
||||
class NestedConfiguration(BaseModel):
|
||||
hello: str
|
||||
nested: Annotated[Inner, Prefixed("nested2")]
|
||||
|
||||
basic_config = NestedConfiguration.load_config()
|
||||
basic_config = load_config(NestedConfiguration, sources=[MockSource()])
|
||||
assert basic_config.hello == "world"
|
||||
assert basic_config.nested.hello == "nested2 world"
|
||||
|
||||
|
@ -150,12 +151,11 @@ def test_nested_alias():
|
|||
str, Alias(MockSource, ["not_hello_nested"], include_prefix=True)
|
||||
]
|
||||
|
||||
@config(sources=[MockSource()])
|
||||
class NestedConfiguration(BaseModel):
|
||||
hello: str
|
||||
nested: Inner
|
||||
|
||||
basic_config = NestedConfiguration.load_config()
|
||||
basic_config = load_config(NestedConfiguration, sources=[MockSource()])
|
||||
assert basic_config.hello == "world"
|
||||
assert basic_config.nested.hello == "nested alias world"
|
||||
|
||||
|
@ -164,12 +164,11 @@ def test_nested_alias_for_different_source():
|
|||
class Inner(BaseModel):
|
||||
hello: Annotated[str, Alias(EnvSource, ["this_should_be_ignored"])]
|
||||
|
||||
@config(sources=[MockSource()])
|
||||
class NestedConfiguration(BaseModel):
|
||||
hello: str
|
||||
nested: Inner
|
||||
|
||||
basic_config = NestedConfiguration.load_config()
|
||||
basic_config = load_config(NestedConfiguration, sources=[MockSource()])
|
||||
assert basic_config.hello == "world"
|
||||
assert basic_config.nested.hello == "nested world"
|
||||
|
||||
|
@ -177,51 +176,48 @@ def test_nested_alias_for_different_source():
|
|||
def test_cache_and_refresh():
|
||||
source = MockSource()
|
||||
|
||||
@config(sources=[source])
|
||||
class BasicConfiguration(BaseModel):
|
||||
hello: str
|
||||
|
||||
basic_config = BasicConfiguration.load_config()
|
||||
basic_config = load_config(BasicConfiguration, sources=[source])
|
||||
assert basic_config.hello == "world"
|
||||
|
||||
source.data["hello"] = "new world"
|
||||
|
||||
basic_config = BasicConfiguration.load_config()
|
||||
basic_config = load_config(BasicConfiguration, sources=[source])
|
||||
assert basic_config.hello == "world"
|
||||
|
||||
basic_config = BasicConfiguration.load_config(refresh_config=True)
|
||||
basic_config = load_config(
|
||||
BasicConfiguration, sources=[source], refresh_config=True
|
||||
)
|
||||
assert basic_config.hello == "new world"
|
||||
|
||||
|
||||
def test_basic_default_no_value_config():
|
||||
@config(sources=[MockSource()])
|
||||
class BasicConfiguration(BaseModel):
|
||||
not_a_value_in_config: str = "hello"
|
||||
|
||||
basic_config = BasicConfiguration.load_config()
|
||||
basic_config = load_config(BasicConfiguration, sources=[MockSource()])
|
||||
assert basic_config.not_a_value_in_config == "hello"
|
||||
|
||||
|
||||
def test_basic_default_no_value_field_config():
|
||||
@config(sources=[MockSource()])
|
||||
class BasicConfiguration(BaseModel):
|
||||
not_a_value_in_config: str = Field(default="hello")
|
||||
|
||||
basic_config = BasicConfiguration.load_config()
|
||||
basic_config = load_config(BasicConfiguration, sources=[MockSource()])
|
||||
assert basic_config.not_a_value_in_config == "hello"
|
||||
|
||||
|
||||
def test_basic_default_no_value_field_factory_config():
|
||||
@config(sources=[MockSource()])
|
||||
class BasicConfiguration(BaseModel):
|
||||
not_a_value_in_config: str = Field(default_factory=lambda: "hello")
|
||||
|
||||
basic_config = BasicConfiguration.load_config()
|
||||
basic_config = load_config(BasicConfiguration, sources=[MockSource()])
|
||||
assert basic_config.not_a_value_in_config == "hello"
|
||||
|
||||
|
||||
def test_missing_key():
|
||||
@config(sources=[MockSource()])
|
||||
class KeyErrorConfig(BaseModel):
|
||||
not_valid_key: str
|
||||
|
||||
|
@ -229,38 +225,38 @@ def test_missing_key():
|
|||
ValueError,
|
||||
match="Value for not_valid_key is required and not found in any source.",
|
||||
):
|
||||
KeyErrorConfig.load_config()
|
||||
load_config(KeyErrorConfig, sources=[MockSource()])
|
||||
|
||||
|
||||
def test_not_required_field():
|
||||
@config(sources=[MockSource()])
|
||||
class NotRequiredConfig(BaseModel):
|
||||
not_required_key: str | None = None
|
||||
|
||||
assert NotRequiredConfig.load_config().not_required_key is None
|
||||
assert (
|
||||
load_config(NotRequiredConfig, sources=[MockSource()]).not_required_key is None
|
||||
)
|
||||
|
||||
|
||||
def test_union_type():
|
||||
@config(sources=[MockSource()])
|
||||
class UnionTypeConfig(BaseModel):
|
||||
hello: int | str | None = None
|
||||
no_prefix_field: int | str | None = None
|
||||
|
||||
assert UnionTypeConfig.load_config().hello == "world"
|
||||
assert UnionTypeConfig.load_config().no_prefix_field == 1
|
||||
config = load_config(UnionTypeConfig, sources=[MockSource()])
|
||||
assert config.hello == "world"
|
||||
assert config.no_prefix_field == 1
|
||||
|
||||
|
||||
def test_nested_optional():
|
||||
class Inner(BaseModel):
|
||||
hello: str
|
||||
|
||||
@config(sources=[MockSource()])
|
||||
class NestedConfiguration(BaseModel):
|
||||
hello: str
|
||||
not_valid_key: Inner | None = None
|
||||
nested: Inner | None
|
||||
|
||||
basic_config = NestedConfiguration.load_config()
|
||||
basic_config = load_config(NestedConfiguration, sources=[MockSource()])
|
||||
assert basic_config.hello == "world"
|
||||
assert basic_config.not_valid_key is None
|
||||
assert isinstance(basic_config.nested, Inner)
|
||||
|
@ -268,7 +264,6 @@ def test_nested_optional():
|
|||
|
||||
|
||||
def test_wrong_type():
|
||||
@config(sources=[MockSource()])
|
||||
class WrongTypeConfig(BaseModel):
|
||||
not_int: int
|
||||
|
||||
|
@ -278,11 +273,10 @@ def test_wrong_type():
|
|||
"Cannot convert [this is not a int value] to type [<class 'int'>]."
|
||||
),
|
||||
):
|
||||
WrongTypeConfig.load_config()
|
||||
load_config(WrongTypeConfig, sources=[MockSource()])
|
||||
|
||||
|
||||
def test_wrong_union_type():
|
||||
@config(sources=[MockSource()])
|
||||
class WrongTypeConfig(BaseModel):
|
||||
hello: int | float | None = None
|
||||
|
||||
|
@ -292,20 +286,18 @@ def test_wrong_union_type():
|
|||
"Cannot convert [world] to any of the types [(<class 'int'>, <class 'float'>, <class 'NoneType'>)]."
|
||||
),
|
||||
):
|
||||
WrongTypeConfig.load_config()
|
||||
load_config(WrongTypeConfig, sources=[MockSource()])
|
||||
|
||||
|
||||
def test_custom_parser():
|
||||
@config(sources=[MockSource()])
|
||||
class CustomParserConfig(BaseModel):
|
||||
custom_parser: Annotated[list[int], Parser(json_list_parser)]
|
||||
|
||||
basic_config = CustomParserConfig.load_config()
|
||||
basic_config = load_config(CustomParserConfig, sources=[MockSource()])
|
||||
assert basic_config.custom_parser == [1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_custom_parser_malformed_json():
|
||||
@config(sources=[MockSource()])
|
||||
class CustomParserConfig(BaseModel):
|
||||
malformed_json: Annotated[list[int], Parser(json_list_parser)]
|
||||
|
||||
|
@ -313,26 +305,24 @@ def test_custom_parser_malformed_json():
|
|||
ValueError,
|
||||
match=re.escape("Error parsing the value [1,2,3,4 for key malformed_json."),
|
||||
):
|
||||
CustomParserConfig.load_config()
|
||||
load_config(CustomParserConfig, sources=[MockSource()])
|
||||
|
||||
|
||||
def test_custom_parser_str():
|
||||
@config(sources=[MockSource()])
|
||||
class CustomParserConfig(BaseModel):
|
||||
custom_parser2: Annotated[
|
||||
list[str], Parser(lambda x, _: [str(i) for i in x.split(",")])
|
||||
]
|
||||
|
||||
basic_config = CustomParserConfig.load_config()
|
||||
basic_config = load_config(CustomParserConfig, sources=[MockSource()])
|
||||
assert basic_config.custom_parser2 == ["1", "2", "3", "4"]
|
||||
|
||||
|
||||
def test_custom_parser_hex_values():
|
||||
@config(sources=[MockSource()])
|
||||
class CustomParserConfig(BaseModel):
|
||||
hex_string: Annotated[int, Parser(lambda x, _: int(x, 0))]
|
||||
|
||||
basic_config = CustomParserConfig.load_config()
|
||||
basic_config = load_config(CustomParserConfig, sources=[MockSource()])
|
||||
assert basic_config.hex_string == 0xDEADBEEF
|
||||
|
||||
|
||||
|
@ -340,7 +330,6 @@ def test_add_runtime_source():
|
|||
class Inner(BaseModel):
|
||||
value: str
|
||||
|
||||
@config(sources=[MockSource()])
|
||||
class RuntimeSourceConfig(BaseModel):
|
||||
hello: str
|
||||
runtime_source_var: str
|
||||
|
@ -348,14 +337,16 @@ def test_add_runtime_source():
|
|||
lower_false: bool
|
||||
random_bool: bool
|
||||
|
||||
basic_config = RuntimeSourceConfig.load_config(
|
||||
basic_config = load_config(
|
||||
RuntimeSourceConfig,
|
||||
sources=[
|
||||
MockSource(),
|
||||
ArgSource(
|
||||
runtime_source_var="runtime",
|
||||
random_bool="this is not false",
|
||||
**{"nested.value": "world"},
|
||||
)
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
assert basic_config.hello == "world"
|
||||
assert basic_config.runtime_source_var == "runtime"
|
||||
|
@ -365,20 +356,18 @@ def test_add_runtime_source():
|
|||
|
||||
|
||||
def test_plain_text_parser():
|
||||
@config(sources=[MockSource()])
|
||||
class CustomParserConfig(BaseModel):
|
||||
plain_text_list: Annotated[list[int], Parser(plain_text_list_parser())]
|
||||
|
||||
basic_config = CustomParserConfig.load_config()
|
||||
basic_config = load_config(CustomParserConfig, sources=[MockSource()])
|
||||
assert basic_config.plain_text_list == [1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_malformed_plaintext_list():
|
||||
@config()
|
||||
class RuntimeSourceConfig(BaseModel):
|
||||
malformed_list: Annotated[list[int], Parser(plain_text_list_parser())]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Error parsing the value 1,2,3,a for key malformed_list."
|
||||
):
|
||||
RuntimeSourceConfig.load_config(sources=[ArgSource(malformed_list="1,2,3,a")])
|
||||
load_config(RuntimeSourceConfig, sources=[ArgSource(malformed_list="1,2,3,a")])
|
||||
|
|
Загрузка…
Ссылка в новой задаче