This commit is contained in:
Andres Morales 2024-08-28 12:53:22 -06:00 коммит произвёл GitHub
Родитель ba8d4af83f
Коммит 96e841b2a6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
11 изменённых файлов: 107 добавлений и 161 удалений

Просмотреть файл

@ -1,5 +1,5 @@
"""Main configuration API."""
from .config import Prefixed, config
from .config import Prefixed, load_config
__all__ = ["Prefixed", "config"]
__all__ = ["Prefixed", "load_config"]

Просмотреть файл

@ -1,17 +1,14 @@
"""Main configuration module for the essex_config package."""
import inspect
from collections.abc import Callable
from functools import cache
from types import UnionType
from typing import (
Protocol,
TypeVar,
Union,
cast,
get_args,
get_origin,
runtime_checkable,
)
from pydantic import BaseModel
@ -24,27 +21,43 @@ DEFAULT_SOURCE_LIST: list[Source] = [EnvSource()]
T = TypeVar("T", bound=BaseModel)
V = TypeVar("V", bound=BaseModel, covariant=True)
@runtime_checkable
class Config(Protocol[V]):
"""Protocol to define the configuration class."""
def load_config(
cls: type[T],
*,
sources: list[Source] = DEFAULT_SOURCE_LIST,
prefix: str = "",
inner: bool = False,
refresh_config: bool = False,
) -> T:
"""Instantiate the configuration and all values.
@classmethod
def load_config(
cls: type,
sources: list[Source] | None = None,
*,
prefix: str | None = None,
refresh_config: bool = False,
) -> V: # pragma: no cover
"""Load the configuration values."""
...
Parameters
----------
sources: tuple[Source], optional
A tuple of sources to use to get the values.
prefix: str, optional
The prefix name to use to look for the values in the sources, by default ""
refresh_config: bool, optional
If True, the cache is cleared before loading the configuration.
Returns
-------
T: Instance of the configuration class.
Raises
------
ValueError
If any of the values is not found.
"""
if refresh_config:
_load_config.cache_clear()
return _load_config(cls, tuple(sources), prefix, inner)
@cache
def load_config(
def _load_config(
cls: type[T],
sources: tuple[Source, ...],
prefix: str = "",
@ -105,7 +118,7 @@ def load_config(
if prefix_annotation is None:
field_prefix += f".{name}" if field_prefix != "" else name
try:
values[name] = load_config(
values[name] = _load_config(
type_, sources, prefix=field_prefix, inner=True
)
except Exception: # noqa: S112, BLE001
@ -117,7 +130,7 @@ def load_config(
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
if prefix_annotation is None:
field_prefix += f".{name}" if field_prefix != "" else name
values[name] = load_config(
values[name] = _load_config(
field_type, sources, prefix=field_prefix, inner=True
)
continue
@ -166,55 +179,3 @@ def load_config(
values[name] = value
return cls.model_validate(values)
def _add_docs(cls: type[Config], sources: list[Source], prefix: str):
if cls.__doc__ is None:
cls.__doc__ = ""
new_docs = f"*Configuration Prefix: {prefix}*\n\n" if prefix != "" else ""
new_docs += "\n\n".join([
line.strip() for line in cls.__doc__.split("\n") if line.strip() != ""
])
doc_sources = "\n".join([f"* {source!s}" for source in sources])
new_docs += f"\n\n# Sources:\n\n{doc_sources}\n"
cls.__doc__ = new_docs
def config(
*, sources: list[Source] = DEFAULT_SOURCE_LIST, prefix: str = ""
) -> Callable[[type[T]], type[Config[T]]]:
"""Add configuration loading capabilities to BaseModel pydantic class."""
def wrapper(cls: type[T]) -> type[Config[T]]:
_sources = sources
_prefix = prefix
def load(
cls: type[T],
sources: list[Source] | None = None,
*,
prefix: str | None = None,
refresh_config: bool = False,
) -> T:
if refresh_config:
load_config.cache_clear()
if sources is None:
sources = _sources
else:
sources.extend(_sources)
if prefix is None:
prefix = _prefix
return load_config(cls, tuple(sources), prefix) # type: ignore
protocol_cls = cast(type[Config[T]], cls)
protocol_cls.load_config = classmethod(load) # type: ignore
_add_docs(protocol_cls, _sources, _prefix)
return protocol_cls
return wrapper

Просмотреть файл

@ -18,7 +18,7 @@ if __name__ == "__main__":
"package", help="Package name to search for configuration classes."
)
parser.add_argument(
"classes", help="Classes to generate documentation for.", nargs="*"
"classes", help="Classes to generate documentation for.", nargs="+"
)
parser.add_argument(
"--disable_nested",

Просмотреть файл

@ -3,11 +3,9 @@
import importlib
import inspect
import pkgutil
from typing import cast
from pydantic import BaseModel
from essex_config.config import Config
from essex_config.doc_gen.printer import ConfigurationPrinter
@ -23,8 +21,8 @@ def __find_subclasses(package_name: str) -> list[type]:
# Iterate through all members of the module
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, Config) and obj is not Config:
subclasses.append(cast(type[Config[BaseModel]], obj))
if inspect.isclass(obj) and issubclass(obj, BaseModel):
subclasses.append(obj)
return subclasses

Просмотреть файл

@ -1,7 +1,7 @@
[tool.poetry]
name = "essex-config"
version = "0.0.2"
description = ""
version = "0.0.3"
description = "Python library for creating configuration objects that read from various sources, including files, environment variables, and Azure Key Vault."
authors = ["Andres Morales <andresmor@microsoft.com>"]
readme = "README.md"

Просмотреть файл

@ -8,7 +8,6 @@ from typing import Annotated
from pydantic import Field
import tests.integration.graphrag_config.defaults as defs
from essex_config import config
from essex_config.field_decorators import Parser
from essex_config.sources.utils import plain_text_list_parser
@ -26,7 +25,6 @@ from .summarize_descriptions_config import SummarizeDescriptionsConfig
from .text_embedding_config import TextEmbeddingConfig
@config()
class GraphRagConfig(LLMConfig):
"""Base class for the Default-Configuration parameterization settings."""

Просмотреть файл

@ -2,13 +2,14 @@ import os
from pathlib import Path
from unittest import mock
from essex_config import load_config
from essex_config.sources import EnvSource, FileSource
from .graphrag_config import GraphRagConfig, LLMType, ReportingType, TextEmbeddingTarget
def test_graphrag_config_defaults():
config = GraphRagConfig.load_config()
config = load_config(GraphRagConfig)
assert config.root_dir == "."
@ -20,7 +21,7 @@ def test_graphrag_config_defaults():
clear=True,
)
def test_graphrag_api_key_override():
config = GraphRagConfig.load_config(sources=[EnvSource(prefix="graphrag")])
config = load_config(GraphRagConfig, sources=[EnvSource(prefix="graphrag")])
assert config.llm.api_key == "abc123"
@ -32,12 +33,13 @@ def test_graphrag_api_key_override():
clear=True,
)
def test_graphrag_api_key_override_2():
config = GraphRagConfig.load_config(sources=[EnvSource(prefix="graphrag")])
config = load_config(GraphRagConfig, sources=[EnvSource(prefix="graphrag")])
assert config.llm.api_key == "abc123"
def test_graphrag_yml():
config = GraphRagConfig.load_config(
config = load_config(
GraphRagConfig,
sources=[
EnvSource(prefix="graphrag"),
FileSource(
@ -45,7 +47,7 @@ def test_graphrag_yml():
required=True,
prefix="",
),
]
],
)
assert config.root_dir == "/some/path"
assert config.encoding_model == "utf-812"
@ -135,7 +137,7 @@ def test_graphrag_yml():
clear=True,
)
def test_graphrag_config_env_vars():
config = GraphRagConfig.load_config(sources=[EnvSource(prefix="graphrag")])
config = load_config(GraphRagConfig, sources=[EnvSource(prefix="graphrag")])
assert config.root_dir == "root"
assert config.encoding_model == "model_xyz"
assert config.reporting.type == ReportingType.console

Просмотреть файл

@ -4,7 +4,7 @@ from typing import Annotated
from pydantic import BaseModel, Field
from essex_config.config import Prefixed, config
from essex_config.config import Prefixed
from essex_config.sources.env_source import EnvSource
from essex_config.sources.file_source import FileSource
from essex_config.sources.source import Alias
@ -22,7 +22,6 @@ class NestedConfiguration(BaseModel):
)
@config(prefix="test", sources=[EnvSource(), FileSource("config.yml")])
class MainConfiguration(BaseModel):
"""Main configuration class for the project.

Просмотреть файл

@ -6,12 +6,6 @@ from . import sample_module
def test_generate_docs():
mock_printer = MagicMock()
generate_docs(sample_module.__name__, mock_printer, [], False)
mock_printer.print.assert_called_once()
def test_generate_docs_filter():
mock_printer = MagicMock()
generate_docs(sample_module.__name__, mock_printer, ["MainConfiguration"], False)
mock_printer.print.assert_called_once()

Просмотреть файл

@ -1,4 +1,6 @@
import os
import re
from pathlib import Path
from unittest import mock
import pytest
@ -51,8 +53,11 @@ def test_env_source_file_from_env():
def test_env_source_file_invalid_file():
source = EnvSource(file_path="wrong/.env.test", required=True)
with pytest.raises(FileNotFoundError, match="File wrong/.env.test not found."):
wrong_file_path = Path("wrong/.env.test")
source = EnvSource(file_path=wrong_file_path, required=True)
with pytest.raises(
FileNotFoundError, match=re.escape(f"File {wrong_file_path!s} not found.")
):
source.get_value("TEST_VALUE", str)

Просмотреть файл

@ -6,7 +6,7 @@ from typing import Annotated, Any, TypeVar
import pytest
from pydantic import BaseModel, Field
from essex_config.config import config
from essex_config import load_config
from essex_config.field_decorators import Alias, Parser, Prefixed
from essex_config.sources import Source
from essex_config.sources.args_source import ArgSource
@ -50,69 +50,72 @@ class MockSource(Source):
def test_basic_config():
@config(sources=[MockSource()])
class BasicConfiguration(BaseModel):
hello: str
basic_config = BasicConfiguration.load_config()
basic_config = load_config(BasicConfiguration, sources=[MockSource()])
assert basic_config.hello == "world"
assert type(basic_config) == BasicConfiguration
def test_prefixed_config():
@config(prefix="prefix", sources=[MockSource()])
class PrefixedConfiguration(BaseModel):
hello: str
basic_config = PrefixedConfiguration.load_config()
basic_config = load_config(
PrefixedConfiguration, sources=[MockSource()], prefix="prefix"
)
assert basic_config.hello == "world prefixed"
def test_prefixed_source():
@config(sources=[MockSource(prefix="prefix")])
class PrefixedConfiguration(BaseModel):
hello: str
basic_config = PrefixedConfiguration.load_config()
basic_config = load_config(
PrefixedConfiguration, sources=[MockSource(prefix="prefix")]
)
assert basic_config.hello == "world prefixed"
def test_prefixed_source_overrides_config():
@config(prefix="override_this", sources=[MockSource(prefix="prefix")])
class PrefixedConfiguration(BaseModel):
hello: str
basic_config = PrefixedConfiguration.load_config()
basic_config = load_config(
PrefixedConfiguration,
sources=[MockSource(prefix="prefix")],
prefix="override this",
)
assert basic_config.hello == "world prefixed"
def test_alias_config():
@config(sources=[MockSource()])
class AliasConfiguration(BaseModel):
hello: Annotated[str, Alias(MockSource, ["not_hello"])]
basic_config = AliasConfiguration.load_config()
basic_config = load_config(AliasConfiguration, sources=[MockSource()])
assert basic_config.hello == "not world"
def test_prefixed_field_config():
@config(sources=[MockSource()])
class PrefixedFieldConfiguration(BaseModel):
hello: Annotated[str, Prefixed("field_prefix")]
no_prefix_field: int
basic_config = PrefixedFieldConfiguration.load_config()
basic_config = load_config(PrefixedFieldConfiguration, sources=[MockSource()])
assert basic_config.hello == "world prefixed field"
assert basic_config.no_prefix_field == 1
def test_prefixed_field_config_with_source_prefix():
@config(sources=[MockSource(prefix="source_prefix")])
class PrefixedFieldConfiguration(BaseModel):
hello: Annotated[str, Prefixed("field_prefix")]
basic_config = PrefixedFieldConfiguration.load_config()
basic_config = load_config(
PrefixedFieldConfiguration, sources=[MockSource(prefix="source_prefix")]
)
assert basic_config.hello == "world prefixed field"
@ -120,12 +123,11 @@ def test_nested_config():
class Inner(BaseModel):
hello: str
@config(sources=[MockSource()])
class NestedConfiguration(BaseModel):
hello: str
nested: Inner
basic_config = NestedConfiguration.load_config()
basic_config = load_config(NestedConfiguration, sources=[MockSource()])
assert basic_config.hello == "world"
assert basic_config.nested.hello == "nested world"
@ -134,12 +136,11 @@ def test_nested_prefixed_field_config():
class Inner(BaseModel):
hello: str
@config(sources=[MockSource()])
class NestedConfiguration(BaseModel):
hello: str
nested: Annotated[Inner, Prefixed("nested2")]
basic_config = NestedConfiguration.load_config()
basic_config = load_config(NestedConfiguration, sources=[MockSource()])
assert basic_config.hello == "world"
assert basic_config.nested.hello == "nested2 world"
@ -150,12 +151,11 @@ def test_nested_alias():
str, Alias(MockSource, ["not_hello_nested"], include_prefix=True)
]
@config(sources=[MockSource()])
class NestedConfiguration(BaseModel):
hello: str
nested: Inner
basic_config = NestedConfiguration.load_config()
basic_config = load_config(NestedConfiguration, sources=[MockSource()])
assert basic_config.hello == "world"
assert basic_config.nested.hello == "nested alias world"
@ -164,12 +164,11 @@ def test_nested_alias_for_different_source():
class Inner(BaseModel):
hello: Annotated[str, Alias(EnvSource, ["this_should_be_ignored"])]
@config(sources=[MockSource()])
class NestedConfiguration(BaseModel):
hello: str
nested: Inner
basic_config = NestedConfiguration.load_config()
basic_config = load_config(NestedConfiguration, sources=[MockSource()])
assert basic_config.hello == "world"
assert basic_config.nested.hello == "nested world"
@ -177,51 +176,48 @@ def test_nested_alias_for_different_source():
def test_cache_and_refresh():
source = MockSource()
@config(sources=[source])
class BasicConfiguration(BaseModel):
hello: str
basic_config = BasicConfiguration.load_config()
basic_config = load_config(BasicConfiguration, sources=[source])
assert basic_config.hello == "world"
source.data["hello"] = "new world"
basic_config = BasicConfiguration.load_config()
basic_config = load_config(BasicConfiguration, sources=[source])
assert basic_config.hello == "world"
basic_config = BasicConfiguration.load_config(refresh_config=True)
basic_config = load_config(
BasicConfiguration, sources=[source], refresh_config=True
)
assert basic_config.hello == "new world"
def test_basic_default_no_value_config():
@config(sources=[MockSource()])
class BasicConfiguration(BaseModel):
not_a_value_in_config: str = "hello"
basic_config = BasicConfiguration.load_config()
basic_config = load_config(BasicConfiguration, sources=[MockSource()])
assert basic_config.not_a_value_in_config == "hello"
def test_basic_default_no_value_field_config():
@config(sources=[MockSource()])
class BasicConfiguration(BaseModel):
not_a_value_in_config: str = Field(default="hello")
basic_config = BasicConfiguration.load_config()
basic_config = load_config(BasicConfiguration, sources=[MockSource()])
assert basic_config.not_a_value_in_config == "hello"
def test_basic_default_no_value_field_factory_config():
@config(sources=[MockSource()])
class BasicConfiguration(BaseModel):
not_a_value_in_config: str = Field(default_factory=lambda: "hello")
basic_config = BasicConfiguration.load_config()
basic_config = load_config(BasicConfiguration, sources=[MockSource()])
assert basic_config.not_a_value_in_config == "hello"
def test_missing_key():
@config(sources=[MockSource()])
class KeyErrorConfig(BaseModel):
not_valid_key: str
@ -229,38 +225,38 @@ def test_missing_key():
ValueError,
match="Value for not_valid_key is required and not found in any source.",
):
KeyErrorConfig.load_config()
load_config(KeyErrorConfig, sources=[MockSource()])
def test_not_required_field():
@config(sources=[MockSource()])
class NotRequiredConfig(BaseModel):
not_required_key: str | None = None
assert NotRequiredConfig.load_config().not_required_key is None
assert (
load_config(NotRequiredConfig, sources=[MockSource()]).not_required_key is None
)
def test_union_type():
@config(sources=[MockSource()])
class UnionTypeConfig(BaseModel):
hello: int | str | None = None
no_prefix_field: int | str | None = None
assert UnionTypeConfig.load_config().hello == "world"
assert UnionTypeConfig.load_config().no_prefix_field == 1
config = load_config(UnionTypeConfig, sources=[MockSource()])
assert config.hello == "world"
assert config.no_prefix_field == 1
def test_nested_optional():
class Inner(BaseModel):
hello: str
@config(sources=[MockSource()])
class NestedConfiguration(BaseModel):
hello: str
not_valid_key: Inner | None = None
nested: Inner | None
basic_config = NestedConfiguration.load_config()
basic_config = load_config(NestedConfiguration, sources=[MockSource()])
assert basic_config.hello == "world"
assert basic_config.not_valid_key is None
assert isinstance(basic_config.nested, Inner)
@ -268,7 +264,6 @@ def test_nested_optional():
def test_wrong_type():
@config(sources=[MockSource()])
class WrongTypeConfig(BaseModel):
not_int: int
@ -278,11 +273,10 @@ def test_wrong_type():
"Cannot convert [this is not a int value] to type [<class 'int'>]."
),
):
WrongTypeConfig.load_config()
load_config(WrongTypeConfig, sources=[MockSource()])
def test_wrong_union_type():
@config(sources=[MockSource()])
class WrongTypeConfig(BaseModel):
hello: int | float | None = None
@ -292,20 +286,18 @@ def test_wrong_union_type():
"Cannot convert [world] to any of the types [(<class 'int'>, <class 'float'>, <class 'NoneType'>)]."
),
):
WrongTypeConfig.load_config()
load_config(WrongTypeConfig, sources=[MockSource()])
def test_custom_parser():
@config(sources=[MockSource()])
class CustomParserConfig(BaseModel):
custom_parser: Annotated[list[int], Parser(json_list_parser)]
basic_config = CustomParserConfig.load_config()
basic_config = load_config(CustomParserConfig, sources=[MockSource()])
assert basic_config.custom_parser == [1, 2, 3, 4]
def test_custom_parser_malformed_json():
@config(sources=[MockSource()])
class CustomParserConfig(BaseModel):
malformed_json: Annotated[list[int], Parser(json_list_parser)]
@ -313,26 +305,24 @@ def test_custom_parser_malformed_json():
ValueError,
match=re.escape("Error parsing the value [1,2,3,4 for key malformed_json."),
):
CustomParserConfig.load_config()
load_config(CustomParserConfig, sources=[MockSource()])
def test_custom_parser_str():
@config(sources=[MockSource()])
class CustomParserConfig(BaseModel):
custom_parser2: Annotated[
list[str], Parser(lambda x, _: [str(i) for i in x.split(",")])
]
basic_config = CustomParserConfig.load_config()
basic_config = load_config(CustomParserConfig, sources=[MockSource()])
assert basic_config.custom_parser2 == ["1", "2", "3", "4"]
def test_custom_parser_hex_values():
@config(sources=[MockSource()])
class CustomParserConfig(BaseModel):
hex_string: Annotated[int, Parser(lambda x, _: int(x, 0))]
basic_config = CustomParserConfig.load_config()
basic_config = load_config(CustomParserConfig, sources=[MockSource()])
assert basic_config.hex_string == 0xDEADBEEF
@ -340,7 +330,6 @@ def test_add_runtime_source():
class Inner(BaseModel):
value: str
@config(sources=[MockSource()])
class RuntimeSourceConfig(BaseModel):
hello: str
runtime_source_var: str
@ -348,14 +337,16 @@ def test_add_runtime_source():
lower_false: bool
random_bool: bool
basic_config = RuntimeSourceConfig.load_config(
basic_config = load_config(
RuntimeSourceConfig,
sources=[
MockSource(),
ArgSource(
runtime_source_var="runtime",
random_bool="this is not false",
**{"nested.value": "world"},
)
]
),
],
)
assert basic_config.hello == "world"
assert basic_config.runtime_source_var == "runtime"
@ -365,20 +356,18 @@ def test_add_runtime_source():
def test_plain_text_parser():
@config(sources=[MockSource()])
class CustomParserConfig(BaseModel):
plain_text_list: Annotated[list[int], Parser(plain_text_list_parser())]
basic_config = CustomParserConfig.load_config()
basic_config = load_config(CustomParserConfig, sources=[MockSource()])
assert basic_config.plain_text_list == [1, 2, 3, 4]
def test_malformed_plaintext_list():
@config()
class RuntimeSourceConfig(BaseModel):
malformed_list: Annotated[list[int], Parser(plain_text_list_parser())]
with pytest.raises(
ValueError, match="Error parsing the value 1,2,3,a for key malformed_list."
):
RuntimeSourceConfig.load_config(sources=[ArgSource(malformed_list="1,2,3,a")])
load_config(RuntimeSourceConfig, sources=[ArgSource(malformed_list="1,2,3,a")])