Refactor to use pydantic models for schema (#512)

* 📦. Move pydantic to main dependencies

* 🐞 Remove pydantic from mypy due to existing bug around inits

*  Reimplement Serializable using BaseModel

* 🚧 Create BaseElement using BaseModel

* 🚧 Slight Edits

* 📦 Update Pydantic to 1.9.0

* Handle Serialization

*  Fix mypy issues with linting

*  Add schema functionality, fix factories, fix tests

*  Add add and sub to BaseElement

* 🧹Boyscout test files for mypy

*  Test Chaum Pedersen Adjustment

* ♻️ Refactor from_file_to_dataclass to to_file

* fix path change

Co-authored-by: Matt Wilhelm <github@addressxception.com>
This commit is contained in:
Keith Fung 2022-01-25 11:03:50 -05:00 коммит произвёл GitHub
Родитель a4d6cb43cb
Коммит 922fd68a75
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
19 изменённых файлов: 1073 добавлений и 810 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -145,6 +145,7 @@ election_record
election_private_data
election_record.zip
election_private_data.zip
schemas
# VS Code
.vscode/settings.json

1146
poetry.lock сгенерированный

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -31,12 +31,13 @@ packages = [{ include = "electionguard", from = "src" }]
[tool.poetry.dependencies]
python = "^3.9.5"
gmpy2 = ">=2.0.8"
gmpy2 = "2.0.8"
# For Windows Builds
# gmpy2 = { path = "./packages/gmpy2-2.0.8-cp39-cp39-win_amd64.whl" } # 64 bit
# gmpy2 = { path = "./packages/gmpy2-2.0.8-cp39-cp39-win32.whl" } # 32 bit
cryptography = ">=3.2"
psutil = ">=5.7.2"
pydantic = "1.9.0"
[tool.poetry.dev-dependencies]
atomicwrites = "*"
@ -53,7 +54,6 @@ mkinit = "^0.3.3"
mypy = "^0.910"
pydeps = "*"
pylint = "*"
pydantic = "^1.8.2"
pytest = "*"
secretstorage = "*"
twine = "*"
@ -101,7 +101,6 @@ build-backend = "poetry.core.masonry.api"
[tool.mypy]
python_version = 3.9
plugins = ["pydantic.mypy"]
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true

Просмотреть файл

@ -34,6 +34,7 @@ from electionguard import proof
from electionguard import rsa
from electionguard import scheduler
from electionguard import schnorr
from electionguard import serialize
from electionguard import singleton
from electionguard import tally
from electionguard import type
@ -171,6 +172,7 @@ from electionguard.discrete_log import (
DiscreteLog,
compute_discrete_log,
compute_discrete_log_cache,
discrete_log_async,
)
from electionguard.election import (
CiphertextElectionContext,
@ -347,6 +349,17 @@ from electionguard.schnorr import (
SchnorrProof,
make_schnorr_proof,
)
from electionguard.serialize import (
Private,
Serializable,
construct_path,
from_file,
from_list_in_file,
from_raw,
get_schema,
to_file,
to_raw,
)
from electionguard.singleton import (
Singleton,
)
@ -488,6 +501,7 @@ __all__ = [
"PlaintextTallyContest",
"PlaintextTallySelection",
"PrimeOption",
"Private",
"PrivateGuardianRecord",
"Proof",
"ProofOrRecovery",
@ -507,6 +521,7 @@ __all__ = [
"SecretCoefficient",
"SelectionDescription",
"SelectionId",
"Serializable",
"Singleton",
"SubmittedBallot",
"VerifierId",
@ -545,6 +560,7 @@ __all__ = [
"compute_polynomial_coordinate",
"compute_recovery_public_key",
"constants",
"construct_path",
"contest_description_with_placeholders_from",
"contest_from",
"contest_is_valid_for_style",
@ -568,6 +584,7 @@ __all__ = [
"decryption_mediator",
"decryption_share",
"discrete_log",
"discrete_log_async",
"div_p",
"div_q",
"election",
@ -589,6 +606,9 @@ __all__ = [
"expand_compact_submitted_ballot",
"flatmap_optional",
"from_ciphertext_ballot",
"from_file",
"from_list_in_file",
"from_raw",
"g_pow_p",
"generate_device_uuid",
"generate_election_key_pair",
@ -610,6 +630,7 @@ __all__ = [
"get_optional",
"get_or_else_optional",
"get_or_else_optional_func",
"get_schema",
"get_shares_for_selection",
"get_small_prime",
"get_stream_handler",
@ -672,12 +693,15 @@ __all__ = [
"selection_from",
"selection_is_valid_for_style",
"sequence_order_sort",
"serialize",
"singleton",
"space_between_capitals",
"tally",
"tally_ballot",
"tally_ballots",
"to_file",
"to_iso_date_string",
"to_raw",
"to_ticks",
"type",
"utils",

Просмотреть файл

@ -95,11 +95,11 @@ def compute_polynomial_coordinate(
:return: Polynomial used to share election keys
"""
exponent_modifier = ElementModQ(exponent_modifier)
exponent_modifier_mod_q = ElementModQ(exponent_modifier)
computed_value = ZERO_MOD_Q
for (i, coefficient) in enumerate(polynomial.coefficients):
exponent = pow_q(exponent_modifier, i)
exponent = pow_q(exponent_modifier_mod_q, i)
factor = mult_q(coefficient.value, exponent)
computed_value = add_q(computed_value, factor)
return computed_value
@ -144,11 +144,11 @@ def verify_polynomial_coordinate(
:return: True if verified on polynomial
"""
exponent_modifier = ElementModQ(exponent_modifier)
exponent_modifier_mod_q = ElementModQ(exponent_modifier)
commitment_output = ONE_MOD_P
for (i, commitment) in enumerate(commitments):
exponent = pow_p(exponent_modifier, i)
exponent = pow_p(exponent_modifier_mod_q, i)
factor = pow_p(commitment, exponent)
commitment_output = mult_p(commitment_output, factor)

Просмотреть файл

@ -5,8 +5,8 @@ in the sense that performance may be less than hand-optimized C code, and no gua
made about timing or other side-channels.
"""
from abc import ABC
from typing import Any, Final, Optional, Union
from abc import ABC, abstractmethod
from typing import Any, Final, Optional, Tuple, Union
from base64 import b16decode
from secrets import randbelow
from sys import maxsize
@ -14,53 +14,137 @@ from sys import maxsize
# pylint: disable=no-name-in-module
from gmpy2 import mpz, powmod, invert
from .serialize import Serializable, Private
from .constants import get_large_prime, get_small_prime, get_generator
class BaseElement(ABC, int):
def hex_to_int(input: str) -> int:
"""Given a hex string representing bytes, returns an int."""
return int(input, 16)
def int_to_hex(input: int) -> str:
"""Given an int, returns a hex string representing bytes."""
hex = format(input, "02X")
if len(hex) % 2:
hex = "0" + hex
return hex
_zero = mpz(0)
def _mpz_zero() -> mpz:
return _zero
def _convert_to_element(data: Union[int, str]) -> Tuple[str, int]:
"""Convert element to consistent types"""
if isinstance(data, str):
hex = data
integer = hex_to_int(data)
else:
hex = int_to_hex(data)
integer = data
return (hex, integer)
class BaseElement(Serializable, ABC):
"""An element limited by mod T within [0, T) where T is determined by an upper_bound function."""
def __new__(cls, elem: Union[int, str], check_within_bounds: bool = True): # type: ignore
"""Instantiate ElementModT where elem is an int or its hex representation or mpz."""
if isinstance(elem, str):
elem = hex_to_int(elem)
data: str
_value: mpz = Private(default_factory=_mpz_zero)
"""Internal math representation of element"""
def __init__(self, data: Union[int, str], check_within_bounds: bool = True) -> None:
"""Instantiate element mod T where element is an int or its hex representation."""
(hex, integer) = _convert_to_element(data)
super().__init__(data=hex)
self._value = mpz(integer)
if check_within_bounds:
if not 0 <= elem < cls.get_upper_bound():
if not self.is_in_bounds():
raise OverflowError
return super(BaseElement, cls).__new__(cls, elem)
def __str__(self) -> str:
"""Overload string representation"""
return self.data
def __repr__(self) -> str:
"""Overload object representation"""
return self.data
def __int__(self) -> int:
"""Overload int conversion."""
return int(self.get_value())
def __eq__(self, other: Any) -> bool:
"""Overload == (equal to) operator."""
return (
isinstance(other, BaseElement)
and int(self.get_value()) == int(other.get_value())
) or (isinstance(other, int) and int(self.get_value()) == other)
def __ne__(self, other: Any) -> bool:
"""Overload != (not equal to) operator."""
return not self == other
def __eq__(self, other: Any) -> bool:
"""Overload == (equal to) operator."""
return isinstance(other, (BaseElement, int)) and int(self) == other
def __lt__(self, other: Any) -> bool:
"""Overload <= (less than) operator."""
return (
isinstance(other, BaseElement)
and int(self.get_value()) < int(other.get_value())
) or (isinstance(other, int) and int(self.get_value()) < other)
def __le__(self, other: Any) -> bool:
"""Overload <= (less than or equal) operator."""
return self.__lt__(other) or self.__eq__(other)
def __gt__(self, other: Any) -> bool:
"""Overload > (greater than) operator."""
return (
isinstance(other, BaseElement)
and int(self.get_value()) > int(other.get_value())
) or (isinstance(other, int) and int(self.get_value()) > other)
def __ge__(self, other: Any) -> bool:
"""Overload >= (greater than or equal) operator."""
return self.__gt__(other) or self.__eq__(other)
def __add__(self, other: Any) -> Any:
"""Overload addition operator."""
return self.get_value() + other
def __sub__(self, other: Any) -> Any:
"""Overload subtraction operator."""
return self.get_value() - other
def __hash__(self) -> int:
"""Overload the hashing function."""
return hash(self.__int__())
return hash(self.get_value())
@classmethod
def get_upper_bound(cls) -> int: # pylint: disable=no-self-use
@abstractmethod
def get_upper_bound(self) -> int:
"""Get the upper bound for the element."""
return maxsize
def get_value(self) -> mpz:
"""Get internal value for math calculations"""
return self._value
def to_hex(self) -> str:
"""
Convert from the element to the hex representation of bytes.
This is preferable to directly accessing `elem`, whose representation might change.
"""
return int_to_hex(self.__int__())
return self.data
def to_hex_bytes(self) -> bytes:
"""
Convert from the element to the representation of bytes by first going through hex.
This is preferable to directly accessing `elem`, whose representation might change.
"""
return b16decode(self.to_hex())
return b16decode(self.data)
def is_in_bounds(self) -> bool:
"""
@ -68,7 +152,7 @@ class BaseElement(ABC, int):
Returns true if all is good, false if something's wrong.
"""
return 0 <= self.__int__() < self.get_upper_bound()
return 0 <= self.get_value() < self.get_upper_bound()
def is_in_bounds_no_zero(self) -> bool:
"""
@ -76,14 +160,13 @@ class BaseElement(ABC, int):
Returns true if all is good, false if something's wrong.
"""
return 1 <= self.__int__() < self.get_upper_bound()
return 1 <= self.get_value() < self.get_upper_bound()
class ElementModQ(BaseElement):
"""An element of the smaller `mod q` space, i.e., in [0, Q), where Q is a 256-bit prime."""
@classmethod
def get_upper_bound(cls) -> int:
def get_upper_bound(self) -> int:
"""Get the upper bound for the element."""
return get_small_prime()
@ -91,8 +174,7 @@ class ElementModQ(BaseElement):
class ElementModP(BaseElement):
"""An element of the larger `mod p` space, i.e., in [0, P), where P is a 4096-bit prime."""
@classmethod
def get_upper_bound(cls) -> int:
def get_upper_bound(self) -> int:
"""Get the upper bound for the element."""
return get_large_prime()
@ -119,22 +201,11 @@ ElementModPorInt = Union[ElementModP, int]
def _get_mpz(input: Union[BaseElement, int]) -> mpz:
"""Get BaseElement or integer as mpz."""
if isinstance(input, BaseElement):
return input.get_value()
return mpz(input)
def hex_to_int(input: str) -> int:
"""Given a hex string representing bytes, returns an int."""
return int(input, 16)
def int_to_hex(input: int) -> str:
"""Given an int, returns a hex string representing bytes."""
hex = format(input, "02X")
if len(hex) % 2:
hex = "0" + hex
return hex
def hex_to_q(input: str) -> Optional[ElementModQ]:
"""
Given a hex string representing bytes, returns an ElementModQ.

Просмотреть файл

@ -0,0 +1,92 @@
import json
import os
from pathlib import Path
from typing import Any, List, Type, TypeVar, Union
from pydantic import BaseModel, PrivateAttr
from pydantic.json import pydantic_encoder
from pydantic.tools import parse_raw_as, parse_obj_as, schema_json_of
Private = PrivateAttr
class Serializable(BaseModel):
"""Serializable data object intended for exporting and importing"""
class Config:
"""Model config to handle private properties"""
underscore_attrs_are_private = True
_T = TypeVar("_T")
_indent = 2
_encoding = "utf-8"
_file_extension = "json"
def construct_path(
target_file_name: str,
target_path: str = "",
target_file_extension: str = _file_extension,
) -> str:
"""Construct path from file name, path, and extension."""
target_file = f"{target_file_name}.{target_file_extension}"
return os.path.join(target_path, target_file)
def from_raw(type_: Type[_T], obj: Any) -> _T:
"""Deserialize raw as type."""
return parse_raw_as(type_, obj)
def to_raw(data: Any) -> Any:
"""Serialize data to raw json format."""
return json.dumps(data, indent=_indent, default=pydantic_encoder)
def from_file(type_: Type[_T], path: Union[str, Path]) -> _T:
"""Deserialize json file as type."""
with open(path, "r", encoding=_encoding) as json_file:
data = json.load(json_file)
return parse_obj_as(type_, data)
def from_list_in_file(type_: Type[_T], path: Union[str, Path]) -> List[_T]:
"""Deserialize json file that has an array of certain type."""
with open(path, "r", encoding=_encoding) as json_file:
data = json.load(json_file)
ls: List[_T] = []
for item in data:
ls.append(parse_obj_as(type_, item))
return ls
def to_file(
data: Any,
target_file_name: str,
target_path: str = "",
) -> None:
"""Serialize object to JSON"""
if not os.path.exists(target_path):
os.makedirs(target_path)
with open(
construct_path(target_file_name, target_path),
"w",
encoding=_encoding,
) as outfile:
json.dump(data, outfile, indent=_indent, default=pydantic_encoder)
def get_schema(_type: Any) -> str:
"""Get JSON Schema for type"""
return schema_json_of(_type)

Просмотреть файл

@ -34,8 +34,6 @@ from electionguard_tools.helpers import (
GUARDIAN_PREFIX,
KeyCeremonyOrchestrator,
MANIFEST_FILE_NAME,
NumberEncodeOption,
OPTION,
PLAINTEXT_BALLOT_PREFIX,
PRIVATE_DATA_DIR,
PRIVATE_GUARDIAN_PREFIX,
@ -43,28 +41,17 @@ from electionguard_tools.helpers import (
SPOILED_BALLOT_PREFIX,
SUBMITTED_BALLOTS_DIR,
SUBMITTED_BALLOT_PREFIX,
T,
TALLY_FILE_NAME,
TallyCeremonyOrchestrator,
accumulate_plaintext_ballots,
banlist,
construct_path,
custom_decoder,
custom_encoder,
export,
export_private_data,
from_file_to_dataclass,
from_list_in_file_to_dataclass,
from_raw,
identity_auxiliary_decrypt,
identity_auxiliary_encrypt,
identity_encrypt,
key_ceremony_orchestrator,
serialize,
tally_accumulate,
tally_ceremony_orchestrator,
to_file,
to_raw,
)
from electionguard_tools.scripts import (
DEFAULT_NUMBER_OF_BALLOTS,
@ -140,8 +127,6 @@ __all__ = [
"KeyCeremonyOrchestrator",
"MANIFEST_FILE_NAME",
"NUMBER_OF_GUARDIANS",
"NumberEncodeOption",
"OPTION",
"PLAINTEXT_BALLOT_PREFIX",
"PRIVATE_DATA_DIR",
"PRIVATE_GUARDIAN_PREFIX",
@ -150,7 +135,6 @@ __all__ = [
"SPOILED_BALLOT_PREFIX",
"SUBMITTED_BALLOTS_DIR",
"SUBMITTED_BALLOT_PREFIX",
"T",
"TALLY_FILE_NAME",
"TallyCeremonyOrchestrator",
"accumulate_plaintext_ballots",
@ -158,16 +142,12 @@ __all__ = [
"annotated_strings",
"ballot_factory",
"ballot_styles",
"banlist",
"candidate_contest_descriptions",
"candidates",
"ciphertext_elections",
"construct_path",
"contact_infos",
"contest_descriptions",
"contest_descriptions_room_for_overvoting",
"custom_decoder",
"custom_encoder",
"data",
"election",
"election_descriptions",
@ -183,9 +163,6 @@ __all__ = [
"export",
"export_private_data",
"factories",
"from_file_to_dataclass",
"from_list_in_file_to_dataclass",
"from_raw",
"geopolitical_units",
"get_contest_description_well_formed",
"get_selection_description_well_formed",
@ -209,12 +186,9 @@ __all__ = [
"reporting_unit_types",
"sample_generator",
"scripts",
"serialize",
"strategies",
"tally_accumulate",
"tally_ceremony_orchestrator",
"to_file",
"to_raw",
"two_letter_codes",
]

Просмотреть файл

@ -23,10 +23,7 @@ from electionguard.manifest import (
SelectionDescription,
InternalManifest,
)
from electionguard_tools.helpers.serialize import (
from_file_to_dataclass,
from_list_in_file_to_dataclass,
)
from electionguard.serialize import from_file, from_list_in_file
_T = TypeVar("_T")
@ -146,13 +143,11 @@ class BallotFactory:
@staticmethod
def _get_ballot_from_file(filename: str) -> PlaintextBallot:
return from_file_to_dataclass(PlaintextBallot, os.path.join(data, filename))
return from_file(PlaintextBallot, os.path.join(data, filename))
@staticmethod
def _get_ballots_from_file(filename: str) -> List[PlaintextBallot]:
return from_list_in_file_to_dataclass(
PlaintextBallot, os.path.join(data, filename)
)
return from_list_in_file(PlaintextBallot, os.path.join(data, filename))
@composite

Просмотреть файл

@ -40,12 +40,13 @@ from electionguard.manifest import (
CandidateContestDescription,
ReferendumContestDescription,
)
from electionguard.serialize import from_file
from electionguard.utils import get_optional
from electionguard_tools.helpers.key_ceremony_orchestrator import (
KeyCeremonyOrchestrator,
)
from electionguard_tools.helpers.serialize import from_file_to_dataclass
_T = TypeVar("_T")
_DrawType = Callable[[SearchStrategy[_T]], _T]
@ -86,17 +87,22 @@ class ElectionFactory:
@staticmethod
def get_manifest_from_file(spec_version: str, sample_manifest: str) -> Manifest:
"""Get simple manifest from json file."""
return from_file_to_dataclass(
return from_file(
Manifest,
os.path.join(
data, spec_version, "sample", sample_manifest, "manifest.json"
data,
spec_version,
"sample",
sample_manifest,
"election_record",
"manifest.json",
),
)
@staticmethod
def get_hamilton_manifest_from_file() -> Manifest:
"""Get Hamilton County manifest from json file."""
return from_file_to_dataclass(
return from_file(
Manifest,
os.path.join(
data, os.path.join(data, "hamilton-county", "election_manifest.json")
@ -253,7 +259,7 @@ class ElectionFactory:
@staticmethod
def _get_manifest_from_file(filename: str) -> Manifest:
return from_file_to_dataclass(Manifest, os.path.join(data, filename))
return from_file(Manifest, os.path.join(data, filename))
@staticmethod
def get_encryption_device() -> EncryptionDevice:

Просмотреть файл

@ -1,7 +1,6 @@
from electionguard_tools.helpers import export
from electionguard_tools.helpers import identity_encrypt
from electionguard_tools.helpers import key_ceremony_orchestrator
from electionguard_tools.helpers import serialize
from electionguard_tools.helpers import tally_accumulate
from electionguard_tools.helpers import tally_ceremony_orchestrator
@ -35,20 +34,6 @@ from electionguard_tools.helpers.identity_encrypt import (
from electionguard_tools.helpers.key_ceremony_orchestrator import (
KeyCeremonyOrchestrator,
)
from electionguard_tools.helpers.serialize import (
NumberEncodeOption,
OPTION,
T,
banlist,
construct_path,
custom_decoder,
custom_encoder,
from_file_to_dataclass,
from_list_in_file_to_dataclass,
from_raw,
to_file,
to_raw,
)
from electionguard_tools.helpers.tally_accumulate import (
accumulate_plaintext_ballots,
)
@ -69,8 +54,6 @@ __all__ = [
"GUARDIAN_PREFIX",
"KeyCeremonyOrchestrator",
"MANIFEST_FILE_NAME",
"NumberEncodeOption",
"OPTION",
"PLAINTEXT_BALLOT_PREFIX",
"PRIVATE_DATA_DIR",
"PRIVATE_GUARDIAN_PREFIX",
@ -78,26 +61,15 @@ __all__ = [
"SPOILED_BALLOT_PREFIX",
"SUBMITTED_BALLOTS_DIR",
"SUBMITTED_BALLOT_PREFIX",
"T",
"TALLY_FILE_NAME",
"TallyCeremonyOrchestrator",
"accumulate_plaintext_ballots",
"banlist",
"construct_path",
"custom_decoder",
"custom_encoder",
"export",
"export_private_data",
"from_file_to_dataclass",
"from_list_in_file_to_dataclass",
"from_raw",
"identity_auxiliary_decrypt",
"identity_auxiliary_encrypt",
"identity_encrypt",
"key_ceremony_orchestrator",
"serialize",
"tally_accumulate",
"tally_ceremony_orchestrator",
"to_file",
"to_raw",
]

Просмотреть файл

@ -18,9 +18,9 @@ from electionguard.guardian import GuardianRecord, PrivateGuardianRecord
from electionguard.election import CiphertextElectionContext
from electionguard.encrypt import EncryptionDevice
from electionguard.manifest import Manifest
from electionguard.serialize import to_file
from electionguard.tally import PlaintextTally, PublishedCiphertextTally
from .serialize import to_file
# Public
ELECTION_RECORD_DIR = "election_record"

Просмотреть файл

@ -1,142 +0,0 @@
"""
Testing tool to validate serialization and deserialization.
WARNING: Not for production use.
Specifically constructed to assist with json loading and dumping within the library.
As a secondary case, this displays how many python serializers/deserializers should be able to take
advantage of the dataclass usage.
"""
from dataclasses import asdict, is_dataclass
from enum import Enum
import json
import os
from pathlib import Path
from typing import Any, Callable, List, Optional, Type, TypeVar, Union, cast
from pydantic.json import pydantic_encoder
from pydantic.tools import parse_raw_as, parse_obj_as
from electionguard.group import hex_to_int, int_to_hex
T = TypeVar("T")
def construct_path(
target_file_name: str,
target_path: Optional[Path] = None,
target_file_extension="json",
) -> Path:
"""Construct path from file name, path, and extension."""
target_file = f"{target_file_name}.{target_file_extension}"
return os.path.join(target_path, target_file)
def from_raw(type_: Type[T], obj: Any) -> T:
"""Deserialize raw as type."""
obj = custom_decoder(obj)
return cast(type_, parse_raw_as(type_, obj))
def to_raw(data: Any) -> Any:
"""Serialize data to raw json format."""
return json.dumps(data, indent=4, default=custom_encoder)
def from_file_to_dataclass(dataclass_type_: Type[T], path: Union[str, Path]) -> T:
"""Deserialize file as dataclass type."""
with open(path, "r") as json_file:
data = json.load(json_file)
data = custom_decoder(data)
return parse_obj_as(dataclass_type_, data)
def from_list_in_file_to_dataclass(
dataclass_type_: Type[T], path: Union[str, Path]
) -> T:
"""Deserialize list of objects in file as dataclass type."""
with open(path, "r") as json_file:
data = json.load(json_file)
data = custom_decoder(data)
return cast(dataclass_type_, parse_obj_as(List[dataclass_type_], data))
def to_file(
data: Any,
target_file_name: str,
target_path: Optional[Path] = None,
target_file_extension="json",
) -> None:
"""Serialize object to file (defaultly json)."""
if not os.path.exists(target_path):
os.makedirs(target_path)
with open(
construct_path(target_file_name, target_path, target_file_extension), "w"
) as outfile:
json.dump(data, outfile, indent=4, default=custom_encoder)
# Color and abbreviation can both be of type hex but should not be converted
banlist = ["color", "abbreviation", "is_write_in"]
def _recursive_replace(object, type_: Type, replace: Callable[[Any], Any]):
"""Iterate through object to replace."""
if isinstance(object, dict):
for key, item in object.items():
if isinstance(item, (dict, list)):
object[key] = _recursive_replace(item, type_, replace)
if isinstance(item, type_) and key not in banlist:
object[key] = replace(item)
if isinstance(object, list):
for index, item in enumerate(object):
if isinstance(item, (dict, list)):
object[index] = _recursive_replace(item, type_, replace)
if isinstance(item, type_):
object[index] = replace(item)
return object
class NumberEncodeOption(Enum):
"""Option for encoding numbers."""
Int = "int"
Hex = "hex"
# Base64 = "base64"
OPTION = NumberEncodeOption.Hex
def _get_int_encoder() -> Callable[[Any], Any]:
if OPTION is NumberEncodeOption.Hex:
return int_to_hex
return lambda x: x
def custom_encoder(obj: Any) -> Any:
"""Integer encoder to convert int representations to type for json."""
if is_dataclass(obj):
new_dict = asdict(obj)
obj = _recursive_replace(new_dict, int, _get_int_encoder())
return obj
return pydantic_encoder(obj)
def _get_int_decoder() -> Callable[[Any], Any]:
def safe_hex_to_int(input: str) -> Union[int, str]:
try:
return hex_to_int(input)
except ValueError:
return input
if OPTION is NumberEncodeOption.Hex:
return safe_hex_to_int
return lambda x: x
def custom_decoder(obj: Any) -> Any:
"""Integer decoder to convert json stored int back to int representations."""
return _recursive_replace(obj, str, _get_int_decoder())

Просмотреть файл

@ -0,0 +1,107 @@
import json
from unittest import TestCase
from shutil import rmtree
from electionguard.constants import ElectionConstants
from electionguard.election import CiphertextElectionContext
from electionguard.guardian import GuardianRecord
from electionguard.manifest import Manifest
from electionguard.ballot import (
CiphertextBallot,
PlaintextBallot,
SubmittedBallot,
)
from electionguard.ballot_compact import CompactPlaintextBallot, CompactSubmittedBallot
from electionguard.election_polynomial import LagrangeCoefficientsRecord
from electionguard.encrypt import EncryptionDevice
from electionguard.tally import (
PublishedCiphertextTally,
PlaintextTally,
)
from electionguard.serialize import construct_path, get_schema, to_file
class TestCreateSchema(TestCase):
"""Test creating schema."""
schema_dir = "schemas"
remove_schema = False
# TODO Fix Pydantic errors with json schema
resolve_pydantic_errors = False
def test_create_schema(self) -> None:
to_file(
json.loads((get_schema(CiphertextElectionContext))),
construct_path("context_schema"),
self.schema_dir,
)
to_file(
json.loads(get_schema(ElectionConstants)),
construct_path("constants_schema"),
self.schema_dir,
)
to_file(
json.loads(get_schema(EncryptionDevice)),
construct_path("device_schema"),
self.schema_dir,
)
to_file(
json.loads(get_schema(LagrangeCoefficientsRecord)),
construct_path("coefficients_schema"),
self.schema_dir,
)
if self.resolve_pydantic_errors:
to_file(
json.loads(get_schema(Manifest)),
construct_path("manifest_schema"),
self.schema_dir,
)
to_file(
json.loads(get_schema(GuardianRecord)),
construct_path("guardian_schema"),
self.schema_dir,
)
to_file(
json.loads(get_schema(PlaintextBallot)),
construct_path("plaintext_ballot_schema"),
self.schema_dir,
)
to_file(
json.loads(get_schema(CiphertextBallot)),
construct_path("ciphertext_ballot_schema"),
self.schema_dir,
)
to_file(
json.loads(get_schema(SubmittedBallot)),
construct_path("submitted_ballot_schema"),
self.schema_dir,
)
to_file(
json.loads(get_schema(CompactPlaintextBallot)),
construct_path("compact_plaintext_ballot_schema"),
self.schema_dir,
)
to_file(
json.loads(get_schema(CompactSubmittedBallot)),
construct_path("compact_submitted_ballot_schema"),
self.schema_dir,
)
to_file(
json.loads(get_schema(PlaintextTally)),
construct_path("plaintext_tally_schema"),
self.schema_dir,
)
to_file(
json.loads(get_schema(PublishedCiphertextTally)),
construct_path("ciphertext_tally_schema"),
self.schema_dir,
)
if self.remove_schema:
rmtree(self.schema_dir)

Просмотреть файл

@ -46,6 +46,7 @@ from electionguard.decryption_mediator import DecryptionMediator
from electionguard.election_polynomial import LagrangeCoefficientsRecord
# Step 5 - Publish and Verify
from electionguard.serialize import from_file, construct_path
from electionguard_tools.helpers.export import (
COEFFICIENTS_FILE_NAME,
DEVICES_DIR,
@ -66,7 +67,6 @@ from electionguard_tools.helpers.export import (
TALLY_FILE_NAME,
export_private_data,
)
from electionguard_tools.helpers.serialize import from_file_to_dataclass, construct_path
from electionguard_tools.factories.ballot_factory import BallotFactory
from electionguard_tools.factories.election_factory import (
@ -554,30 +554,30 @@ class TestEndToEndElection(BaseTestCase):
"""Ensure published data can be deserialized."""
# Deserialize
manifest_from_file = from_file_to_dataclass(
manifest_from_file = from_file(
Manifest,
construct_path(MANIFEST_FILE_NAME, ELECTION_RECORD_DIR),
)
self.assertEqualAsDicts(self.manifest, manifest_from_file)
context_from_file = from_file_to_dataclass(
context_from_file = from_file(
CiphertextElectionContext,
construct_path(CONTEXT_FILE_NAME, ELECTION_RECORD_DIR),
)
self.assertEqualAsDicts(self.context, context_from_file)
constants_from_file = from_file_to_dataclass(
constants_from_file = from_file(
ElectionConstants, construct_path(CONSTANTS_FILE_NAME, ELECTION_RECORD_DIR)
)
self.assertEqualAsDicts(self.constants, constants_from_file)
coefficients_from_file = from_file_to_dataclass(
coefficients_from_file = from_file(
LagrangeCoefficientsRecord,
construct_path(COEFFICIENTS_FILE_NAME, ELECTION_RECORD_DIR),
)
self.assertEqualAsDicts(self.lagrange_coefficients, coefficients_from_file)
device_from_file = from_file_to_dataclass(
device_from_file = from_file(
EncryptionDevice,
construct_path(
DEVICE_PREFIX + str(self.device.device_id), devices_directory
@ -586,7 +586,7 @@ class TestEndToEndElection(BaseTestCase):
self.assertEqualAsDicts(self.device, device_from_file)
for ballot in self.ballot_store.all():
ballot_from_file = from_file_to_dataclass(
ballot_from_file = from_file(
SubmittedBallot,
construct_path(
SUBMITTED_BALLOT_PREFIX + ballot.object_id,
@ -596,7 +596,7 @@ class TestEndToEndElection(BaseTestCase):
self.assertEqualAsDicts(ballot, ballot_from_file)
for spoiled_ballot in self.plaintext_spoiled_ballots.values():
spoiled_ballot_from_file = from_file_to_dataclass(
spoiled_ballot_from_file = from_file(
PlaintextTally,
construct_path(
SPOILED_BALLOT_PREFIX + spoiled_ballot.object_id,
@ -605,7 +605,7 @@ class TestEndToEndElection(BaseTestCase):
)
self.assertEqualAsDicts(spoiled_ballot, spoiled_ballot_from_file)
published_ciphertext_tally_from_file = from_file_to_dataclass(
published_ciphertext_tally_from_file = from_file(
PublishedCiphertextTally,
construct_path(ENCRYPTED_TALLY_FILE_NAME, ELECTION_RECORD_DIR),
)
@ -614,13 +614,13 @@ class TestEndToEndElection(BaseTestCase):
published_ciphertext_tally_from_file,
)
plainttext_tally_from_file = from_file_to_dataclass(
plainttext_tally_from_file = from_file(
PlaintextTally, construct_path(TALLY_FILE_NAME, ELECTION_RECORD_DIR)
)
self.assertEqualAsDicts(self.plaintext_tally, plainttext_tally_from_file)
for guardian_record in self.guardian_records:
guardian_record_from_file = from_file_to_dataclass(
guardian_record_from_file = from_file(
GuardianRecord,
construct_path(
GUARDIAN_PREFIX + guardian_record.guardian_id, guardians_directory
@ -639,7 +639,7 @@ class TestEndToEndElection(BaseTestCase):
print(f"{name}: {message}: {result}")
self.assertTrue(result)
def assertEqualAsDicts(self, first: object, second: object):
def assertEqualAsDicts(self, first: object, second: object) -> None:
"""
Specialty assertEqual to compare dataclasses as dictionaries.

Просмотреть файл

@ -1,5 +1,5 @@
from datetime import timedelta
from hypothesis import given, settings, HealthCheck
from hypothesis import given, settings, HealthCheck, Phase
from hypothesis.strategies import integers
@ -155,6 +155,7 @@ class TestChaumPedersen(BaseTestCase):
deadline=timedelta(milliseconds=2000),
suppress_health_check=[HealthCheck.too_slow],
max_examples=10,
phases=[Phase.explicit, Phase.reuse, Phase.generate, Phase.target],
)
@given(
elgamal_keypairs(),

Просмотреть файл

@ -37,7 +37,7 @@ from electionguard_tools.strategies.group import elements_mod_q_no_zero
class TestElGamal(BaseTestCase):
"""ElGamal tests"""
def test_simple_elgamal_encryption_decryption(self):
def test_simple_elgamal_encryption_decryption(self) -> None:
nonce = ONE_MOD_Q
secret_key = TWO_MOD_Q
keypair = get_optional(elgamal_keypair_from_secret(secret_key))
@ -50,12 +50,12 @@ class TestElGamal(BaseTestCase):
ciphertext = get_optional(elgamal_encrypt(0, nonce, keypair.public_key))
self.assertEqual(get_generator(), ciphertext.pad)
self.assertEqual(
pow(ciphertext.pad, secret_key, get_large_prime()),
pow(public_key, nonce, get_large_prime()),
pow(ciphertext.pad.get_value(), secret_key.get_value(), get_large_prime()),
pow(public_key.get_value(), nonce.get_value(), get_large_prime()),
)
self.assertEqual(
ciphertext.data,
pow(public_key, nonce, get_large_prime()),
ciphertext.data.get_value(),
pow(public_key.get_value(), nonce.get_value(), get_large_prime()),
)
plaintext = ciphertext.decrypt(keypair.secret_key)
@ -65,17 +65,17 @@ class TestElGamal(BaseTestCase):
@given(integers(0, 100), elgamal_keypairs())
def test_elgamal_encrypt_requires_nonzero_nonce(
self, message: int, keypair: ElGamalKeyPair
):
) -> None:
self.assertEqual(None, elgamal_encrypt(message, ZERO_MOD_Q, keypair.public_key))
def test_elgamal_keypair_from_secret_requires_key_greater_than_one(self):
def test_elgamal_keypair_from_secret_requires_key_greater_than_one(self) -> None:
self.assertEqual(None, elgamal_keypair_from_secret(ZERO_MOD_Q))
self.assertEqual(None, elgamal_keypair_from_secret(ONE_MOD_Q))
@given(integers(0, 100), elements_mod_q_no_zero(), elgamal_keypairs())
def test_elgamal_encryption_decryption_inverses(
self, message: int, nonce: ElementModQ, keypair: ElGamalKeyPair
):
) -> None:
ciphertext = get_optional(elgamal_encrypt(message, nonce, keypair.public_key))
plaintext = ciphertext.decrypt(keypair.secret_key)
@ -84,14 +84,16 @@ class TestElGamal(BaseTestCase):
@given(integers(0, 100), elements_mod_q_no_zero(), elgamal_keypairs())
def test_elgamal_encryption_decryption_with_known_nonce_inverses(
self, message: int, nonce: ElementModQ, keypair: ElGamalKeyPair
):
) -> None:
ciphertext = get_optional(elgamal_encrypt(message, nonce, keypair.public_key))
plaintext = ciphertext.decrypt_known_nonce(keypair.public_key, nonce)
self.assertEqual(message, plaintext)
@given(elgamal_keypairs())
def test_elgamal_generated_keypairs_are_within_range(self, keypair: ElGamalKeyPair):
def test_elgamal_generated_keypairs_are_within_range(
self, keypair: ElGamalKeyPair
) -> None:
self.assertLess(keypair.public_key, get_large_prime())
self.assertLess(keypair.secret_key, get_small_prime())
self.assertEqual(g_pow_p(keypair.secret_key), keypair.public_key)
@ -110,7 +112,7 @@ class TestElGamal(BaseTestCase):
r1: ElementModQ,
m2: int,
r2: ElementModQ,
):
) -> None:
c1 = get_optional(elgamal_encrypt(m1, r1, keypair.public_key))
c2 = get_optional(elgamal_encrypt(m2, r2, keypair.public_key))
c_sum = elgamal_add(c1, c2)
@ -118,14 +120,14 @@ class TestElGamal(BaseTestCase):
self.assertEqual(total, m1 + m2)
def test_elgamal_add_requires_args(self):
def test_elgamal_add_requires_args(self) -> None:
self.assertRaises(Exception, elgamal_add)
@given(elgamal_keypairs())
def test_elgamal_keypair_produces_valid_residue(self, keypair):
def test_elgamal_keypair_produces_valid_residue(self, keypair) -> None:
self.assertTrue(keypair.public_key.is_valid_residue())
def test_elgamal_keypair_random(self):
def test_elgamal_keypair_random(self) -> None:
# Act
random_keypair = elgamal_keypair_random()
random_keypair_two = elgamal_keypair_random()
@ -136,7 +138,7 @@ class TestElGamal(BaseTestCase):
self.assertIsNotNone(random_keypair.secret_key)
self.assertNotEqual(random_keypair, random_keypair_two)
def test_elgamal_combine_public_keys(self):
def test_elgamal_combine_public_keys(self) -> None:
# Arrange
random_keypair = elgamal_keypair_random()
random_keypair_two = elgamal_keypair_random()
@ -150,7 +152,7 @@ class TestElGamal(BaseTestCase):
self.assertNotEqual(joint_key, random_keypair.public_key)
self.assertNotEqual(joint_key, random_keypair_two.public_key)
def test_gmpy2_parallelism_is_safe(self):
def test_gmpy2_parallelism_is_safe(self) -> None:
"""
Ensures running lots of parallel exponentiations still yields the correct answer.
This verifies that nothing incorrect is happening in the GMPY2 library

Просмотреть файл

@ -46,11 +46,11 @@ class TestEquality(BaseTestCase):
"""Math equality tests"""
@given(elements_mod_q(), elements_mod_q())
def test_p_not_equal_to_q(self, q: ElementModQ, q2: ElementModQ):
def test_p_not_equal_to_q(self, q: ElementModQ, q2: ElementModQ) -> None:
i = int(q)
i2 = int(q2)
p = ElementModP(q)
p2 = ElementModP(q2)
p = ElementModP(q.get_value())
p2 = ElementModP(q2.get_value())
# same value should imply they're equal
self.assertEqual(p, q)
@ -78,54 +78,54 @@ class TestModularArithmetic(BaseTestCase):
"""Math Modular Arithmetic tests"""
@given(elements_mod_q())
def test_add_q(self, q: ElementModQ):
def test_add_q(self, q: ElementModQ) -> None:
as_int = add_q(q, 1)
as_elem = add_q(q, ElementModQ(1))
self.assertEqual(as_int, as_elem)
@given(elements_mod_q())
def test_a_plus_bc_q(self, q: ElementModQ):
def test_a_plus_bc_q(self, q: ElementModQ) -> None:
as_int = a_plus_bc_q(q, 1, 1)
as_elem = a_plus_bc_q(q, ElementModQ(1), ElementModQ(1))
self.assertEqual(as_int, as_elem)
@given(elements_mod_q())
def test_a_minus_b_q(self, q: ElementModQ):
def test_a_minus_b_q(self, q: ElementModQ) -> None:
as_int = a_minus_b_q(q, 1)
as_elem = a_minus_b_q(q, ElementModQ(1))
self.assertEqual(as_int, as_elem)
@given(elements_mod_q())
def test_div_q(self, q: ElementModQ):
def test_div_q(self, q: ElementModQ) -> None:
as_int = div_q(q, 1)
as_elem = div_q(q, ElementModQ(1))
self.assertEqual(as_int, as_elem)
@given(elements_mod_p())
def test_div_p(self, p: ElementModQ):
def test_div_p(self, p: ElementModQ) -> None:
as_int = div_p(p, 1)
as_elem = div_p(p, ElementModP(1))
self.assertEqual(as_int, as_elem)
def test_no_mult_inv_of_zero(self):
def test_no_mult_inv_of_zero(self) -> None:
self.assertRaises(Exception, mult_inv_p, ZERO_MOD_P)
@given(elements_mod_p_no_zero())
def test_mult_inverses(self, elem: ElementModP):
def test_mult_inverses(self, elem: ElementModP) -> None:
inv = mult_inv_p(elem)
self.assertEqual(mult_p(elem, inv), ONE_MOD_P)
@given(elements_mod_p())
def test_mult_identity(self, elem: ElementModP):
def test_mult_identity(self, elem: ElementModP) -> None:
self.assertEqual(elem, mult_p(elem))
def test_mult_noargs(self):
def test_mult_noargs(self) -> None:
self.assertEqual(ONE_MOD_P, mult_p())
def test_add_noargs(self):
def test_add_noargs(self) -> None:
self.assertEqual(ZERO_MOD_Q, add_q())
def test_properties_for_constants(self):
def test_properties_for_constants(self) -> None:
self.assertNotEqual(get_generator(), 1)
self.assertEqual(
(get_cofactor() * get_small_prime()) % get_large_prime(),
@ -135,13 +135,13 @@ class TestModularArithmetic(BaseTestCase):
self.assertLess(get_generator(), get_large_prime())
self.assertLess(get_cofactor(), get_large_prime())
def test_simple_powers(self):
def test_simple_powers(self) -> None:
gp = int_to_p(get_generator())
self.assertEqual(gp, g_pow_p(ONE_MOD_Q))
self.assertEqual(ONE_MOD_P, g_pow_p(ZERO_MOD_Q))
@given(elements_mod_q())
def test_in_bounds_q(self, q: ElementModQ):
def test_in_bounds_q(self, q: ElementModQ) -> None:
self.assertTrue(q.is_in_bounds())
too_big = q + get_small_prime()
too_small = q - get_small_prime()
@ -155,7 +155,7 @@ class TestModularArithmetic(BaseTestCase):
ElementModQ(too_small)
@given(elements_mod_p())
def test_in_bounds_p(self, p: ElementModP):
def test_in_bounds_p(self, p: ElementModP) -> None:
self.assertTrue(p.is_in_bounds())
too_big = p + get_large_prime()
too_small = p - get_large_prime()
@ -180,7 +180,7 @@ class TestModularArithmetic(BaseTestCase):
)
@given(elements_mod_p_no_zero())
def test_in_bounds_p_no_zero(self, p: ElementModP):
def test_in_bounds_p_no_zero(self, p: ElementModP) -> None:
self.assertTrue(p.is_in_bounds_no_zero())
self.assertFalse(ZERO_MOD_P.is_in_bounds_no_zero())
self.assertFalse(
@ -191,7 +191,7 @@ class TestModularArithmetic(BaseTestCase):
)
@given(elements_mod_q())
def test_large_values_rejected_by_int_to_q(self, q: ElementModQ):
def test_large_values_rejected_by_int_to_q(self, q: ElementModQ) -> None:
oversize = q + get_small_prime()
self.assertEqual(None, int_to_q(oversize))
@ -199,28 +199,28 @@ class TestModularArithmetic(BaseTestCase):
class TestOptionalFunctions(BaseTestCase):
"""Math Optional Functions tests"""
def test_unwrap(self):
def test_unwrap(self) -> None:
good: Optional[int] = 3
bad: Optional[int] = None
self.assertEqual(get_optional(good), 3)
self.assertRaises(Exception, get_optional, bad)
def test_match(self):
def test_match(self) -> None:
good: Optional[int] = 3
bad: Optional[int] = None
self.assertEqual(5, match_optional(good, lambda: 1, lambda x: x + 2))
self.assertEqual(1, match_optional(bad, lambda: 1, lambda x: x + 2))
def test_get_or_else(self):
def test_get_or_else(self) -> None:
good: Optional[int] = 3
bad: Optional[int] = None
self.assertEqual(3, get_or_else_optional(good, 5))
self.assertEqual(5, get_or_else_optional(bad, 5))
def test_flatmap(self):
def test_flatmap(self) -> None:
good: Optional[int] = 3
bad: Optional[int] = None

Просмотреть файл

@ -9,9 +9,10 @@ from electionguard.manifest import (
SelectionDescription,
VoteVariationType,
)
from electionguard.serialize import from_raw, to_raw
import electionguard_tools.factories.election_factory as ElectionFactory
import electionguard_tools.factories.ballot_factory as BallotFactory
from electionguard_tools.helpers.serialize import from_raw, to_raw
election_factory = ElectionFactory.ElectionFactory()
ballot_factory = BallotFactory.BallotFactory()
@ -20,7 +21,7 @@ ballot_factory = BallotFactory.BallotFactory()
class TestManifest(BaseTestCase):
"""Manifest tests"""
def test_simple_manifest_is_valid(self):
def test_simple_manifest_is_valid(self) -> None:
# Act
subject = election_factory.get_simple_manifest_from_file()
@ -30,7 +31,7 @@ class TestManifest(BaseTestCase):
self.assertEqual(subject.election_scope_id, "jefferson-county-primary")
self.assertTrue(subject.is_valid())
def test_simple_manifest_can_serialize(self):
def test_simple_manifest_can_serialize(self) -> None:
# Arrange
subject = election_factory.get_simple_manifest_from_file()
intermediate = to_raw(subject)
@ -42,7 +43,7 @@ class TestManifest(BaseTestCase):
self.assertIsNotNone(result.election_scope_id)
self.assertEqual(result.election_scope_id, "jefferson-county-primary")
def test_manifest_has_deterministic_hash(self):
def test_manifest_has_deterministic_hash(self) -> None:
# Act
subject1 = election_factory.get_simple_manifest_from_file()
@ -51,7 +52,7 @@ class TestManifest(BaseTestCase):
# Assert
self.assertEqual(subject1.crypto_hash(), subject2.crypto_hash())
def test_manifest_hash_is_consistent_regardless_of_format(self):
def test_manifest_hash_is_consistent_regardless_of_format(self) -> None:
# Act
subject1 = election_factory.get_simple_manifest_from_file()
@ -72,7 +73,7 @@ class TestManifest(BaseTestCase):
def test_manifest_from_file_generates_consistent_internal_description_contest_hashes(
self,
):
) -> None:
# Arrange
comparator = election_factory.get_simple_manifest_from_file()
subject = InternalManifest(comparator)
@ -84,7 +85,7 @@ class TestManifest(BaseTestCase):
if expected.object_id == actual.object_id:
self.assertEqual(expected.crypto_hash(), actual.crypto_hash())
def test_contest_description_valid_input_succeeds(self):
def test_contest_description_valid_input_succeeds(self) -> None:
description = ContestDescriptionWithPlaceholders(
object_id="0@A.com-contest",
electoral_district_id="0@A.com-gp-unit",
@ -118,7 +119,7 @@ class TestManifest(BaseTestCase):
self.assertTrue(description.is_valid())
def test_contest_description_invalid_input_fails(self):
def test_contest_description_invalid_input_fails(self) -> None:
description = ContestDescriptionWithPlaceholders(
object_id="0@A.com-contest",