✨ Refactor to use pydantic models for schema (#512)
* 📦. Move pydantic to main dependencies * 🐞 Remove pydantic from mypy due to existing bug around inits * ✨ Reimplement Serializable using BaseModel * 🚧 Create BaseElement using BaseModel * 🚧 Slight Edits * 📦 Update Pydantic to 1.9.0 * ✨Handle Serialization * ✅ Fix mypy issues with linting * ✨ Add schema functionality, fix factories, fix tests * ✅ Add add and sub to BaseElement * 🧹Boyscout test files for mypy * ✅ Test Chaum Pedersen Adjustment * ♻️ Refactor from_file_to_dataclass to to_file * fix path change Co-authored-by: Matt Wilhelm <github@addressxception.com>
This commit is contained in:
Родитель
a4d6cb43cb
Коммит
922fd68a75
|
@ -145,6 +145,7 @@ election_record
|
|||
election_private_data
|
||||
election_record.zip
|
||||
election_private_data.zip
|
||||
schemas
|
||||
|
||||
# VS Code
|
||||
.vscode/settings.json
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -31,12 +31,13 @@ packages = [{ include = "electionguard", from = "src" }]
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9.5"
|
||||
gmpy2 = ">=2.0.8"
|
||||
gmpy2 = "2.0.8"
|
||||
# For Windows Builds
|
||||
# gmpy2 = { path = "./packages/gmpy2-2.0.8-cp39-cp39-win_amd64.whl" } # 64 bit
|
||||
# gmpy2 = { path = "./packages/gmpy2-2.0.8-cp39-cp39-win32.whl" } # 32 bit
|
||||
cryptography = ">=3.2"
|
||||
psutil = ">=5.7.2"
|
||||
pydantic = "1.9.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
atomicwrites = "*"
|
||||
|
@ -53,7 +54,6 @@ mkinit = "^0.3.3"
|
|||
mypy = "^0.910"
|
||||
pydeps = "*"
|
||||
pylint = "*"
|
||||
pydantic = "^1.8.2"
|
||||
pytest = "*"
|
||||
secretstorage = "*"
|
||||
twine = "*"
|
||||
|
@ -101,7 +101,6 @@ build-backend = "poetry.core.masonry.api"
|
|||
|
||||
[tool.mypy]
|
||||
python_version = 3.9
|
||||
plugins = ["pydantic.mypy"]
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
|
|
|
@ -34,6 +34,7 @@ from electionguard import proof
|
|||
from electionguard import rsa
|
||||
from electionguard import scheduler
|
||||
from electionguard import schnorr
|
||||
from electionguard import serialize
|
||||
from electionguard import singleton
|
||||
from electionguard import tally
|
||||
from electionguard import type
|
||||
|
@ -171,6 +172,7 @@ from electionguard.discrete_log import (
|
|||
DiscreteLog,
|
||||
compute_discrete_log,
|
||||
compute_discrete_log_cache,
|
||||
discrete_log_async,
|
||||
)
|
||||
from electionguard.election import (
|
||||
CiphertextElectionContext,
|
||||
|
@ -347,6 +349,17 @@ from electionguard.schnorr import (
|
|||
SchnorrProof,
|
||||
make_schnorr_proof,
|
||||
)
|
||||
from electionguard.serialize import (
|
||||
Private,
|
||||
Serializable,
|
||||
construct_path,
|
||||
from_file,
|
||||
from_list_in_file,
|
||||
from_raw,
|
||||
get_schema,
|
||||
to_file,
|
||||
to_raw,
|
||||
)
|
||||
from electionguard.singleton import (
|
||||
Singleton,
|
||||
)
|
||||
|
@ -488,6 +501,7 @@ __all__ = [
|
|||
"PlaintextTallyContest",
|
||||
"PlaintextTallySelection",
|
||||
"PrimeOption",
|
||||
"Private",
|
||||
"PrivateGuardianRecord",
|
||||
"Proof",
|
||||
"ProofOrRecovery",
|
||||
|
@ -507,6 +521,7 @@ __all__ = [
|
|||
"SecretCoefficient",
|
||||
"SelectionDescription",
|
||||
"SelectionId",
|
||||
"Serializable",
|
||||
"Singleton",
|
||||
"SubmittedBallot",
|
||||
"VerifierId",
|
||||
|
@ -545,6 +560,7 @@ __all__ = [
|
|||
"compute_polynomial_coordinate",
|
||||
"compute_recovery_public_key",
|
||||
"constants",
|
||||
"construct_path",
|
||||
"contest_description_with_placeholders_from",
|
||||
"contest_from",
|
||||
"contest_is_valid_for_style",
|
||||
|
@ -568,6 +584,7 @@ __all__ = [
|
|||
"decryption_mediator",
|
||||
"decryption_share",
|
||||
"discrete_log",
|
||||
"discrete_log_async",
|
||||
"div_p",
|
||||
"div_q",
|
||||
"election",
|
||||
|
@ -589,6 +606,9 @@ __all__ = [
|
|||
"expand_compact_submitted_ballot",
|
||||
"flatmap_optional",
|
||||
"from_ciphertext_ballot",
|
||||
"from_file",
|
||||
"from_list_in_file",
|
||||
"from_raw",
|
||||
"g_pow_p",
|
||||
"generate_device_uuid",
|
||||
"generate_election_key_pair",
|
||||
|
@ -610,6 +630,7 @@ __all__ = [
|
|||
"get_optional",
|
||||
"get_or_else_optional",
|
||||
"get_or_else_optional_func",
|
||||
"get_schema",
|
||||
"get_shares_for_selection",
|
||||
"get_small_prime",
|
||||
"get_stream_handler",
|
||||
|
@ -672,12 +693,15 @@ __all__ = [
|
|||
"selection_from",
|
||||
"selection_is_valid_for_style",
|
||||
"sequence_order_sort",
|
||||
"serialize",
|
||||
"singleton",
|
||||
"space_between_capitals",
|
||||
"tally",
|
||||
"tally_ballot",
|
||||
"tally_ballots",
|
||||
"to_file",
|
||||
"to_iso_date_string",
|
||||
"to_raw",
|
||||
"to_ticks",
|
||||
"type",
|
||||
"utils",
|
||||
|
|
|
@ -95,11 +95,11 @@ def compute_polynomial_coordinate(
|
|||
:return: Polynomial used to share election keys
|
||||
"""
|
||||
|
||||
exponent_modifier = ElementModQ(exponent_modifier)
|
||||
exponent_modifier_mod_q = ElementModQ(exponent_modifier)
|
||||
|
||||
computed_value = ZERO_MOD_Q
|
||||
for (i, coefficient) in enumerate(polynomial.coefficients):
|
||||
exponent = pow_q(exponent_modifier, i)
|
||||
exponent = pow_q(exponent_modifier_mod_q, i)
|
||||
factor = mult_q(coefficient.value, exponent)
|
||||
computed_value = add_q(computed_value, factor)
|
||||
return computed_value
|
||||
|
@ -144,11 +144,11 @@ def verify_polynomial_coordinate(
|
|||
:return: True if verified on polynomial
|
||||
"""
|
||||
|
||||
exponent_modifier = ElementModQ(exponent_modifier)
|
||||
exponent_modifier_mod_q = ElementModQ(exponent_modifier)
|
||||
|
||||
commitment_output = ONE_MOD_P
|
||||
for (i, commitment) in enumerate(commitments):
|
||||
exponent = pow_p(exponent_modifier, i)
|
||||
exponent = pow_p(exponent_modifier_mod_q, i)
|
||||
factor = pow_p(commitment, exponent)
|
||||
commitment_output = mult_p(commitment_output, factor)
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ in the sense that performance may be less than hand-optimized C code, and no gua
|
|||
made about timing or other side-channels.
|
||||
"""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, Final, Optional, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Final, Optional, Tuple, Union
|
||||
from base64 import b16decode
|
||||
from secrets import randbelow
|
||||
from sys import maxsize
|
||||
|
@ -14,53 +14,137 @@ from sys import maxsize
|
|||
# pylint: disable=no-name-in-module
|
||||
from gmpy2 import mpz, powmod, invert
|
||||
|
||||
from .serialize import Serializable, Private
|
||||
from .constants import get_large_prime, get_small_prime, get_generator
|
||||
|
||||
|
||||
class BaseElement(ABC, int):
|
||||
def hex_to_int(input: str) -> int:
|
||||
"""Given a hex string representing bytes, returns an int."""
|
||||
return int(input, 16)
|
||||
|
||||
|
||||
def int_to_hex(input: int) -> str:
|
||||
"""Given an int, returns a hex string representing bytes."""
|
||||
hex = format(input, "02X")
|
||||
if len(hex) % 2:
|
||||
hex = "0" + hex
|
||||
return hex
|
||||
|
||||
|
||||
_zero = mpz(0)
|
||||
|
||||
|
||||
def _mpz_zero() -> mpz:
|
||||
return _zero
|
||||
|
||||
|
||||
def _convert_to_element(data: Union[int, str]) -> Tuple[str, int]:
|
||||
"""Convert element to consistent types"""
|
||||
if isinstance(data, str):
|
||||
hex = data
|
||||
integer = hex_to_int(data)
|
||||
else:
|
||||
hex = int_to_hex(data)
|
||||
integer = data
|
||||
return (hex, integer)
|
||||
|
||||
|
||||
class BaseElement(Serializable, ABC):
|
||||
"""An element limited by mod T within [0, T) where T is determined by an upper_bound function."""
|
||||
|
||||
def __new__(cls, elem: Union[int, str], check_within_bounds: bool = True): # type: ignore
|
||||
"""Instantiate ElementModT where elem is an int or its hex representation or mpz."""
|
||||
if isinstance(elem, str):
|
||||
elem = hex_to_int(elem)
|
||||
data: str
|
||||
|
||||
_value: mpz = Private(default_factory=_mpz_zero)
|
||||
"""Internal math representation of element"""
|
||||
|
||||
def __init__(self, data: Union[int, str], check_within_bounds: bool = True) -> None:
|
||||
"""Instantiate element mod T where element is an int or its hex representation."""
|
||||
(hex, integer) = _convert_to_element(data)
|
||||
super().__init__(data=hex)
|
||||
self._value = mpz(integer)
|
||||
|
||||
if check_within_bounds:
|
||||
if not 0 <= elem < cls.get_upper_bound():
|
||||
if not self.is_in_bounds():
|
||||
raise OverflowError
|
||||
return super(BaseElement, cls).__new__(cls, elem)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Overload string representation"""
|
||||
return self.data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Overload object representation"""
|
||||
return self.data
|
||||
|
||||
def __int__(self) -> int:
|
||||
"""Overload int conversion."""
|
||||
return int(self.get_value())
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Overload == (equal to) operator."""
|
||||
return (
|
||||
isinstance(other, BaseElement)
|
||||
and int(self.get_value()) == int(other.get_value())
|
||||
) or (isinstance(other, int) and int(self.get_value()) == other)
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
"""Overload != (not equal to) operator."""
|
||||
return not self == other
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Overload == (equal to) operator."""
|
||||
return isinstance(other, (BaseElement, int)) and int(self) == other
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
"""Overload <= (less than) operator."""
|
||||
return (
|
||||
isinstance(other, BaseElement)
|
||||
and int(self.get_value()) < int(other.get_value())
|
||||
) or (isinstance(other, int) and int(self.get_value()) < other)
|
||||
|
||||
def __le__(self, other: Any) -> bool:
|
||||
"""Overload <= (less than or equal) operator."""
|
||||
return self.__lt__(other) or self.__eq__(other)
|
||||
|
||||
def __gt__(self, other: Any) -> bool:
|
||||
"""Overload > (greater than) operator."""
|
||||
return (
|
||||
isinstance(other, BaseElement)
|
||||
and int(self.get_value()) > int(other.get_value())
|
||||
) or (isinstance(other, int) and int(self.get_value()) > other)
|
||||
|
||||
def __ge__(self, other: Any) -> bool:
|
||||
"""Overload >= (greater than or equal) operator."""
|
||||
return self.__gt__(other) or self.__eq__(other)
|
||||
|
||||
def __add__(self, other: Any) -> Any:
|
||||
"""Overload addition operator."""
|
||||
return self.get_value() + other
|
||||
|
||||
def __sub__(self, other: Any) -> Any:
|
||||
"""Overload subtraction operator."""
|
||||
return self.get_value() - other
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Overload the hashing function."""
|
||||
return hash(self.__int__())
|
||||
return hash(self.get_value())
|
||||
|
||||
@classmethod
|
||||
def get_upper_bound(cls) -> int: # pylint: disable=no-self-use
|
||||
@abstractmethod
|
||||
def get_upper_bound(self) -> int:
|
||||
"""Get the upper bound for the element."""
|
||||
return maxsize
|
||||
|
||||
def get_value(self) -> mpz:
|
||||
"""Get internal value for math calculations"""
|
||||
return self._value
|
||||
|
||||
def to_hex(self) -> str:
|
||||
"""
|
||||
Convert from the element to the hex representation of bytes.
|
||||
|
||||
This is preferable to directly accessing `elem`, whose representation might change.
|
||||
"""
|
||||
return int_to_hex(self.__int__())
|
||||
return self.data
|
||||
|
||||
def to_hex_bytes(self) -> bytes:
|
||||
"""
|
||||
Convert from the element to the representation of bytes by first going through hex.
|
||||
|
||||
This is preferable to directly accessing `elem`, whose representation might change.
|
||||
"""
|
||||
return b16decode(self.to_hex())
|
||||
|
||||
return b16decode(self.data)
|
||||
|
||||
def is_in_bounds(self) -> bool:
|
||||
"""
|
||||
|
@ -68,7 +152,7 @@ class BaseElement(ABC, int):
|
|||
|
||||
Returns true if all is good, false if something's wrong.
|
||||
"""
|
||||
return 0 <= self.__int__() < self.get_upper_bound()
|
||||
return 0 <= self.get_value() < self.get_upper_bound()
|
||||
|
||||
def is_in_bounds_no_zero(self) -> bool:
|
||||
"""
|
||||
|
@ -76,14 +160,13 @@ class BaseElement(ABC, int):
|
|||
|
||||
Returns true if all is good, false if something's wrong.
|
||||
"""
|
||||
return 1 <= self.__int__() < self.get_upper_bound()
|
||||
return 1 <= self.get_value() < self.get_upper_bound()
|
||||
|
||||
|
||||
class ElementModQ(BaseElement):
|
||||
"""An element of the smaller `mod q` space, i.e., in [0, Q), where Q is a 256-bit prime."""
|
||||
|
||||
@classmethod
|
||||
def get_upper_bound(cls) -> int:
|
||||
def get_upper_bound(self) -> int:
|
||||
"""Get the upper bound for the element."""
|
||||
return get_small_prime()
|
||||
|
||||
|
@ -91,8 +174,7 @@ class ElementModQ(BaseElement):
|
|||
class ElementModP(BaseElement):
|
||||
"""An element of the larger `mod p` space, i.e., in [0, P), where P is a 4096-bit prime."""
|
||||
|
||||
@classmethod
|
||||
def get_upper_bound(cls) -> int:
|
||||
def get_upper_bound(self) -> int:
|
||||
"""Get the upper bound for the element."""
|
||||
return get_large_prime()
|
||||
|
||||
|
@ -119,22 +201,11 @@ ElementModPorInt = Union[ElementModP, int]
|
|||
|
||||
def _get_mpz(input: Union[BaseElement, int]) -> mpz:
|
||||
"""Get BaseElement or integer as mpz."""
|
||||
if isinstance(input, BaseElement):
|
||||
return input.get_value()
|
||||
return mpz(input)
|
||||
|
||||
|
||||
def hex_to_int(input: str) -> int:
|
||||
"""Given a hex string representing bytes, returns an int."""
|
||||
return int(input, 16)
|
||||
|
||||
|
||||
def int_to_hex(input: int) -> str:
|
||||
"""Given an int, returns a hex string representing bytes."""
|
||||
hex = format(input, "02X")
|
||||
if len(hex) % 2:
|
||||
hex = "0" + hex
|
||||
return hex
|
||||
|
||||
|
||||
def hex_to_q(input: str) -> Optional[ElementModQ]:
|
||||
"""
|
||||
Given a hex string representing bytes, returns an ElementModQ.
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
from pydantic.json import pydantic_encoder
|
||||
from pydantic.tools import parse_raw_as, parse_obj_as, schema_json_of
|
||||
|
||||
Private = PrivateAttr
|
||||
|
||||
|
||||
class Serializable(BaseModel):
|
||||
"""Serializable data object intended for exporting and importing"""
|
||||
|
||||
class Config:
|
||||
"""Model config to handle private properties"""
|
||||
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
_indent = 2
|
||||
_encoding = "utf-8"
|
||||
_file_extension = "json"
|
||||
|
||||
|
||||
def construct_path(
|
||||
target_file_name: str,
|
||||
target_path: str = "",
|
||||
target_file_extension: str = _file_extension,
|
||||
) -> str:
|
||||
"""Construct path from file name, path, and extension."""
|
||||
|
||||
target_file = f"{target_file_name}.{target_file_extension}"
|
||||
return os.path.join(target_path, target_file)
|
||||
|
||||
|
||||
def from_raw(type_: Type[_T], obj: Any) -> _T:
|
||||
"""Deserialize raw as type."""
|
||||
|
||||
return parse_raw_as(type_, obj)
|
||||
|
||||
|
||||
def to_raw(data: Any) -> Any:
|
||||
"""Serialize data to raw json format."""
|
||||
|
||||
return json.dumps(data, indent=_indent, default=pydantic_encoder)
|
||||
|
||||
|
||||
def from_file(type_: Type[_T], path: Union[str, Path]) -> _T:
|
||||
"""Deserialize json file as type."""
|
||||
|
||||
with open(path, "r", encoding=_encoding) as json_file:
|
||||
data = json.load(json_file)
|
||||
return parse_obj_as(type_, data)
|
||||
|
||||
|
||||
def from_list_in_file(type_: Type[_T], path: Union[str, Path]) -> List[_T]:
|
||||
"""Deserialize json file that has an array of certain type."""
|
||||
|
||||
with open(path, "r", encoding=_encoding) as json_file:
|
||||
data = json.load(json_file)
|
||||
ls: List[_T] = []
|
||||
for item in data:
|
||||
ls.append(parse_obj_as(type_, item))
|
||||
return ls
|
||||
|
||||
|
||||
def to_file(
|
||||
data: Any,
|
||||
target_file_name: str,
|
||||
target_path: str = "",
|
||||
) -> None:
|
||||
"""Serialize object to JSON"""
|
||||
|
||||
if not os.path.exists(target_path):
|
||||
os.makedirs(target_path)
|
||||
|
||||
with open(
|
||||
construct_path(target_file_name, target_path),
|
||||
"w",
|
||||
encoding=_encoding,
|
||||
) as outfile:
|
||||
json.dump(data, outfile, indent=_indent, default=pydantic_encoder)
|
||||
|
||||
|
||||
def get_schema(_type: Any) -> str:
|
||||
"""Get JSON Schema for type"""
|
||||
|
||||
return schema_json_of(_type)
|
|
@ -34,8 +34,6 @@ from electionguard_tools.helpers import (
|
|||
GUARDIAN_PREFIX,
|
||||
KeyCeremonyOrchestrator,
|
||||
MANIFEST_FILE_NAME,
|
||||
NumberEncodeOption,
|
||||
OPTION,
|
||||
PLAINTEXT_BALLOT_PREFIX,
|
||||
PRIVATE_DATA_DIR,
|
||||
PRIVATE_GUARDIAN_PREFIX,
|
||||
|
@ -43,28 +41,17 @@ from electionguard_tools.helpers import (
|
|||
SPOILED_BALLOT_PREFIX,
|
||||
SUBMITTED_BALLOTS_DIR,
|
||||
SUBMITTED_BALLOT_PREFIX,
|
||||
T,
|
||||
TALLY_FILE_NAME,
|
||||
TallyCeremonyOrchestrator,
|
||||
accumulate_plaintext_ballots,
|
||||
banlist,
|
||||
construct_path,
|
||||
custom_decoder,
|
||||
custom_encoder,
|
||||
export,
|
||||
export_private_data,
|
||||
from_file_to_dataclass,
|
||||
from_list_in_file_to_dataclass,
|
||||
from_raw,
|
||||
identity_auxiliary_decrypt,
|
||||
identity_auxiliary_encrypt,
|
||||
identity_encrypt,
|
||||
key_ceremony_orchestrator,
|
||||
serialize,
|
||||
tally_accumulate,
|
||||
tally_ceremony_orchestrator,
|
||||
to_file,
|
||||
to_raw,
|
||||
)
|
||||
from electionguard_tools.scripts import (
|
||||
DEFAULT_NUMBER_OF_BALLOTS,
|
||||
|
@ -140,8 +127,6 @@ __all__ = [
|
|||
"KeyCeremonyOrchestrator",
|
||||
"MANIFEST_FILE_NAME",
|
||||
"NUMBER_OF_GUARDIANS",
|
||||
"NumberEncodeOption",
|
||||
"OPTION",
|
||||
"PLAINTEXT_BALLOT_PREFIX",
|
||||
"PRIVATE_DATA_DIR",
|
||||
"PRIVATE_GUARDIAN_PREFIX",
|
||||
|
@ -150,7 +135,6 @@ __all__ = [
|
|||
"SPOILED_BALLOT_PREFIX",
|
||||
"SUBMITTED_BALLOTS_DIR",
|
||||
"SUBMITTED_BALLOT_PREFIX",
|
||||
"T",
|
||||
"TALLY_FILE_NAME",
|
||||
"TallyCeremonyOrchestrator",
|
||||
"accumulate_plaintext_ballots",
|
||||
|
@ -158,16 +142,12 @@ __all__ = [
|
|||
"annotated_strings",
|
||||
"ballot_factory",
|
||||
"ballot_styles",
|
||||
"banlist",
|
||||
"candidate_contest_descriptions",
|
||||
"candidates",
|
||||
"ciphertext_elections",
|
||||
"construct_path",
|
||||
"contact_infos",
|
||||
"contest_descriptions",
|
||||
"contest_descriptions_room_for_overvoting",
|
||||
"custom_decoder",
|
||||
"custom_encoder",
|
||||
"data",
|
||||
"election",
|
||||
"election_descriptions",
|
||||
|
@ -183,9 +163,6 @@ __all__ = [
|
|||
"export",
|
||||
"export_private_data",
|
||||
"factories",
|
||||
"from_file_to_dataclass",
|
||||
"from_list_in_file_to_dataclass",
|
||||
"from_raw",
|
||||
"geopolitical_units",
|
||||
"get_contest_description_well_formed",
|
||||
"get_selection_description_well_formed",
|
||||
|
@ -209,12 +186,9 @@ __all__ = [
|
|||
"reporting_unit_types",
|
||||
"sample_generator",
|
||||
"scripts",
|
||||
"serialize",
|
||||
"strategies",
|
||||
"tally_accumulate",
|
||||
"tally_ceremony_orchestrator",
|
||||
"to_file",
|
||||
"to_raw",
|
||||
"two_letter_codes",
|
||||
]
|
||||
|
||||
|
|
|
@ -23,10 +23,7 @@ from electionguard.manifest import (
|
|||
SelectionDescription,
|
||||
InternalManifest,
|
||||
)
|
||||
from electionguard_tools.helpers.serialize import (
|
||||
from_file_to_dataclass,
|
||||
from_list_in_file_to_dataclass,
|
||||
)
|
||||
from electionguard.serialize import from_file, from_list_in_file
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
@ -146,13 +143,11 @@ class BallotFactory:
|
|||
|
||||
@staticmethod
|
||||
def _get_ballot_from_file(filename: str) -> PlaintextBallot:
|
||||
return from_file_to_dataclass(PlaintextBallot, os.path.join(data, filename))
|
||||
return from_file(PlaintextBallot, os.path.join(data, filename))
|
||||
|
||||
@staticmethod
|
||||
def _get_ballots_from_file(filename: str) -> List[PlaintextBallot]:
|
||||
return from_list_in_file_to_dataclass(
|
||||
PlaintextBallot, os.path.join(data, filename)
|
||||
)
|
||||
return from_list_in_file(PlaintextBallot, os.path.join(data, filename))
|
||||
|
||||
|
||||
@composite
|
||||
|
|
|
@ -40,12 +40,13 @@ from electionguard.manifest import (
|
|||
CandidateContestDescription,
|
||||
ReferendumContestDescription,
|
||||
)
|
||||
from electionguard.serialize import from_file
|
||||
from electionguard.utils import get_optional
|
||||
|
||||
from electionguard_tools.helpers.key_ceremony_orchestrator import (
|
||||
KeyCeremonyOrchestrator,
|
||||
)
|
||||
from electionguard_tools.helpers.serialize import from_file_to_dataclass
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_DrawType = Callable[[SearchStrategy[_T]], _T]
|
||||
|
@ -86,17 +87,22 @@ class ElectionFactory:
|
|||
@staticmethod
|
||||
def get_manifest_from_file(spec_version: str, sample_manifest: str) -> Manifest:
|
||||
"""Get simple manifest from json file."""
|
||||
return from_file_to_dataclass(
|
||||
return from_file(
|
||||
Manifest,
|
||||
os.path.join(
|
||||
data, spec_version, "sample", sample_manifest, "manifest.json"
|
||||
data,
|
||||
spec_version,
|
||||
"sample",
|
||||
sample_manifest,
|
||||
"election_record",
|
||||
"manifest.json",
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_hamilton_manifest_from_file() -> Manifest:
|
||||
"""Get Hamilton County manifest from json file."""
|
||||
return from_file_to_dataclass(
|
||||
return from_file(
|
||||
Manifest,
|
||||
os.path.join(
|
||||
data, os.path.join(data, "hamilton-county", "election_manifest.json")
|
||||
|
@ -253,7 +259,7 @@ class ElectionFactory:
|
|||
|
||||
@staticmethod
|
||||
def _get_manifest_from_file(filename: str) -> Manifest:
|
||||
return from_file_to_dataclass(Manifest, os.path.join(data, filename))
|
||||
return from_file(Manifest, os.path.join(data, filename))
|
||||
|
||||
@staticmethod
|
||||
def get_encryption_device() -> EncryptionDevice:
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from electionguard_tools.helpers import export
|
||||
from electionguard_tools.helpers import identity_encrypt
|
||||
from electionguard_tools.helpers import key_ceremony_orchestrator
|
||||
from electionguard_tools.helpers import serialize
|
||||
from electionguard_tools.helpers import tally_accumulate
|
||||
from electionguard_tools.helpers import tally_ceremony_orchestrator
|
||||
|
||||
|
@ -35,20 +34,6 @@ from electionguard_tools.helpers.identity_encrypt import (
|
|||
from electionguard_tools.helpers.key_ceremony_orchestrator import (
|
||||
KeyCeremonyOrchestrator,
|
||||
)
|
||||
from electionguard_tools.helpers.serialize import (
|
||||
NumberEncodeOption,
|
||||
OPTION,
|
||||
T,
|
||||
banlist,
|
||||
construct_path,
|
||||
custom_decoder,
|
||||
custom_encoder,
|
||||
from_file_to_dataclass,
|
||||
from_list_in_file_to_dataclass,
|
||||
from_raw,
|
||||
to_file,
|
||||
to_raw,
|
||||
)
|
||||
from electionguard_tools.helpers.tally_accumulate import (
|
||||
accumulate_plaintext_ballots,
|
||||
)
|
||||
|
@ -69,8 +54,6 @@ __all__ = [
|
|||
"GUARDIAN_PREFIX",
|
||||
"KeyCeremonyOrchestrator",
|
||||
"MANIFEST_FILE_NAME",
|
||||
"NumberEncodeOption",
|
||||
"OPTION",
|
||||
"PLAINTEXT_BALLOT_PREFIX",
|
||||
"PRIVATE_DATA_DIR",
|
||||
"PRIVATE_GUARDIAN_PREFIX",
|
||||
|
@ -78,26 +61,15 @@ __all__ = [
|
|||
"SPOILED_BALLOT_PREFIX",
|
||||
"SUBMITTED_BALLOTS_DIR",
|
||||
"SUBMITTED_BALLOT_PREFIX",
|
||||
"T",
|
||||
"TALLY_FILE_NAME",
|
||||
"TallyCeremonyOrchestrator",
|
||||
"accumulate_plaintext_ballots",
|
||||
"banlist",
|
||||
"construct_path",
|
||||
"custom_decoder",
|
||||
"custom_encoder",
|
||||
"export",
|
||||
"export_private_data",
|
||||
"from_file_to_dataclass",
|
||||
"from_list_in_file_to_dataclass",
|
||||
"from_raw",
|
||||
"identity_auxiliary_decrypt",
|
||||
"identity_auxiliary_encrypt",
|
||||
"identity_encrypt",
|
||||
"key_ceremony_orchestrator",
|
||||
"serialize",
|
||||
"tally_accumulate",
|
||||
"tally_ceremony_orchestrator",
|
||||
"to_file",
|
||||
"to_raw",
|
||||
]
|
||||
|
|
|
@ -18,9 +18,9 @@ from electionguard.guardian import GuardianRecord, PrivateGuardianRecord
|
|||
from electionguard.election import CiphertextElectionContext
|
||||
from electionguard.encrypt import EncryptionDevice
|
||||
from electionguard.manifest import Manifest
|
||||
from electionguard.serialize import to_file
|
||||
from electionguard.tally import PlaintextTally, PublishedCiphertextTally
|
||||
|
||||
from .serialize import to_file
|
||||
|
||||
# Public
|
||||
ELECTION_RECORD_DIR = "election_record"
|
||||
|
|
|
@ -1,142 +0,0 @@
|
|||
"""
|
||||
Testing tool to validate serialization and deserialization.
|
||||
|
||||
WARNING: Not for production use.
|
||||
|
||||
Specifically constructed to assist with json loading and dumping within the library.
|
||||
As a secondary case, this displays how many python serializers/deserializers should be able to take
|
||||
advantage of the dataclass usage.
|
||||
"""
|
||||
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from enum import Enum
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
from pydantic.json import pydantic_encoder
|
||||
from pydantic.tools import parse_raw_as, parse_obj_as
|
||||
|
||||
from electionguard.group import hex_to_int, int_to_hex
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def construct_path(
|
||||
target_file_name: str,
|
||||
target_path: Optional[Path] = None,
|
||||
target_file_extension="json",
|
||||
) -> Path:
|
||||
"""Construct path from file name, path, and extension."""
|
||||
target_file = f"{target_file_name}.{target_file_extension}"
|
||||
return os.path.join(target_path, target_file)
|
||||
|
||||
|
||||
def from_raw(type_: Type[T], obj: Any) -> T:
|
||||
"""Deserialize raw as type."""
|
||||
obj = custom_decoder(obj)
|
||||
return cast(type_, parse_raw_as(type_, obj))
|
||||
|
||||
|
||||
def to_raw(data: Any) -> Any:
|
||||
"""Serialize data to raw json format."""
|
||||
return json.dumps(data, indent=4, default=custom_encoder)
|
||||
|
||||
|
||||
def from_file_to_dataclass(dataclass_type_: Type[T], path: Union[str, Path]) -> T:
|
||||
"""Deserialize file as dataclass type."""
|
||||
with open(path, "r") as json_file:
|
||||
data = json.load(json_file)
|
||||
data = custom_decoder(data)
|
||||
return parse_obj_as(dataclass_type_, data)
|
||||
|
||||
|
||||
def from_list_in_file_to_dataclass(
|
||||
dataclass_type_: Type[T], path: Union[str, Path]
|
||||
) -> T:
|
||||
"""Deserialize list of objects in file as dataclass type."""
|
||||
with open(path, "r") as json_file:
|
||||
data = json.load(json_file)
|
||||
data = custom_decoder(data)
|
||||
return cast(dataclass_type_, parse_obj_as(List[dataclass_type_], data))
|
||||
|
||||
|
||||
def to_file(
|
||||
data: Any,
|
||||
target_file_name: str,
|
||||
target_path: Optional[Path] = None,
|
||||
target_file_extension="json",
|
||||
) -> None:
|
||||
"""Serialize object to file (defaultly json)."""
|
||||
if not os.path.exists(target_path):
|
||||
os.makedirs(target_path)
|
||||
|
||||
with open(
|
||||
construct_path(target_file_name, target_path, target_file_extension), "w"
|
||||
) as outfile:
|
||||
json.dump(data, outfile, indent=4, default=custom_encoder)
|
||||
|
||||
|
||||
# Color and abbreviation can both be of type hex but should not be converted
|
||||
banlist = ["color", "abbreviation", "is_write_in"]
|
||||
|
||||
|
||||
def _recursive_replace(object, type_: Type, replace: Callable[[Any], Any]):
|
||||
"""Iterate through object to replace."""
|
||||
if isinstance(object, dict):
|
||||
for key, item in object.items():
|
||||
if isinstance(item, (dict, list)):
|
||||
object[key] = _recursive_replace(item, type_, replace)
|
||||
if isinstance(item, type_) and key not in banlist:
|
||||
object[key] = replace(item)
|
||||
if isinstance(object, list):
|
||||
for index, item in enumerate(object):
|
||||
if isinstance(item, (dict, list)):
|
||||
object[index] = _recursive_replace(item, type_, replace)
|
||||
if isinstance(item, type_):
|
||||
object[index] = replace(item)
|
||||
return object
|
||||
|
||||
|
||||
class NumberEncodeOption(Enum):
|
||||
"""Option for encoding numbers."""
|
||||
|
||||
Int = "int"
|
||||
Hex = "hex"
|
||||
# Base64 = "base64"
|
||||
|
||||
|
||||
OPTION = NumberEncodeOption.Hex
|
||||
|
||||
|
||||
def _get_int_encoder() -> Callable[[Any], Any]:
|
||||
if OPTION is NumberEncodeOption.Hex:
|
||||
return int_to_hex
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def custom_encoder(obj: Any) -> Any:
|
||||
"""Integer encoder to convert int representations to type for json."""
|
||||
if is_dataclass(obj):
|
||||
new_dict = asdict(obj)
|
||||
obj = _recursive_replace(new_dict, int, _get_int_encoder())
|
||||
return obj
|
||||
return pydantic_encoder(obj)
|
||||
|
||||
|
||||
def _get_int_decoder() -> Callable[[Any], Any]:
|
||||
def safe_hex_to_int(input: str) -> Union[int, str]:
|
||||
try:
|
||||
return hex_to_int(input)
|
||||
except ValueError:
|
||||
return input
|
||||
|
||||
if OPTION is NumberEncodeOption.Hex:
|
||||
return safe_hex_to_int
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def custom_decoder(obj: Any) -> Any:
|
||||
"""Integer decoder to convert json stored int back to int representations."""
|
||||
return _recursive_replace(obj, str, _get_int_decoder())
|
|
@ -0,0 +1,107 @@
|
|||
import json
|
||||
from unittest import TestCase
|
||||
from shutil import rmtree
|
||||
|
||||
from electionguard.constants import ElectionConstants
|
||||
from electionguard.election import CiphertextElectionContext
|
||||
from electionguard.guardian import GuardianRecord
|
||||
from electionguard.manifest import Manifest
|
||||
from electionguard.ballot import (
|
||||
CiphertextBallot,
|
||||
PlaintextBallot,
|
||||
SubmittedBallot,
|
||||
)
|
||||
from electionguard.ballot_compact import CompactPlaintextBallot, CompactSubmittedBallot
|
||||
from electionguard.election_polynomial import LagrangeCoefficientsRecord
|
||||
from electionguard.encrypt import EncryptionDevice
|
||||
from electionguard.tally import (
|
||||
PublishedCiphertextTally,
|
||||
PlaintextTally,
|
||||
)
|
||||
from electionguard.serialize import construct_path, get_schema, to_file
|
||||
|
||||
|
||||
class TestCreateSchema(TestCase):
|
||||
"""Test creating schema."""
|
||||
|
||||
schema_dir = "schemas"
|
||||
remove_schema = False
|
||||
|
||||
# TODO Fix Pydantic errors with json schema
|
||||
resolve_pydantic_errors = False
|
||||
|
||||
def test_create_schema(self) -> None:
|
||||
|
||||
to_file(
|
||||
json.loads((get_schema(CiphertextElectionContext))),
|
||||
construct_path("context_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
|
||||
to_file(
|
||||
json.loads(get_schema(ElectionConstants)),
|
||||
construct_path("constants_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
|
||||
to_file(
|
||||
json.loads(get_schema(EncryptionDevice)),
|
||||
construct_path("device_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
|
||||
to_file(
|
||||
json.loads(get_schema(LagrangeCoefficientsRecord)),
|
||||
construct_path("coefficients_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
|
||||
if self.resolve_pydantic_errors:
|
||||
to_file(
|
||||
json.loads(get_schema(Manifest)),
|
||||
construct_path("manifest_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
to_file(
|
||||
json.loads(get_schema(GuardianRecord)),
|
||||
construct_path("guardian_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
to_file(
|
||||
json.loads(get_schema(PlaintextBallot)),
|
||||
construct_path("plaintext_ballot_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
to_file(
|
||||
json.loads(get_schema(CiphertextBallot)),
|
||||
construct_path("ciphertext_ballot_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
to_file(
|
||||
json.loads(get_schema(SubmittedBallot)),
|
||||
construct_path("submitted_ballot_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
to_file(
|
||||
json.loads(get_schema(CompactPlaintextBallot)),
|
||||
construct_path("compact_plaintext_ballot_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
to_file(
|
||||
json.loads(get_schema(CompactSubmittedBallot)),
|
||||
construct_path("compact_submitted_ballot_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
to_file(
|
||||
json.loads(get_schema(PlaintextTally)),
|
||||
construct_path("plaintext_tally_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
to_file(
|
||||
json.loads(get_schema(PublishedCiphertextTally)),
|
||||
construct_path("ciphertext_tally_schema"),
|
||||
self.schema_dir,
|
||||
)
|
||||
|
||||
if self.remove_schema:
|
||||
rmtree(self.schema_dir)
|
|
@ -46,6 +46,7 @@ from electionguard.decryption_mediator import DecryptionMediator
|
|||
from electionguard.election_polynomial import LagrangeCoefficientsRecord
|
||||
|
||||
# Step 5 - Publish and Verify
|
||||
from electionguard.serialize import from_file, construct_path
|
||||
from electionguard_tools.helpers.export import (
|
||||
COEFFICIENTS_FILE_NAME,
|
||||
DEVICES_DIR,
|
||||
|
@ -66,7 +67,6 @@ from electionguard_tools.helpers.export import (
|
|||
TALLY_FILE_NAME,
|
||||
export_private_data,
|
||||
)
|
||||
from electionguard_tools.helpers.serialize import from_file_to_dataclass, construct_path
|
||||
|
||||
from electionguard_tools.factories.ballot_factory import BallotFactory
|
||||
from electionguard_tools.factories.election_factory import (
|
||||
|
@ -554,30 +554,30 @@ class TestEndToEndElection(BaseTestCase):
|
|||
"""Ensure published data can be deserialized."""
|
||||
|
||||
# Deserialize
|
||||
manifest_from_file = from_file_to_dataclass(
|
||||
manifest_from_file = from_file(
|
||||
Manifest,
|
||||
construct_path(MANIFEST_FILE_NAME, ELECTION_RECORD_DIR),
|
||||
)
|
||||
self.assertEqualAsDicts(self.manifest, manifest_from_file)
|
||||
|
||||
context_from_file = from_file_to_dataclass(
|
||||
context_from_file = from_file(
|
||||
CiphertextElectionContext,
|
||||
construct_path(CONTEXT_FILE_NAME, ELECTION_RECORD_DIR),
|
||||
)
|
||||
self.assertEqualAsDicts(self.context, context_from_file)
|
||||
|
||||
constants_from_file = from_file_to_dataclass(
|
||||
constants_from_file = from_file(
|
||||
ElectionConstants, construct_path(CONSTANTS_FILE_NAME, ELECTION_RECORD_DIR)
|
||||
)
|
||||
self.assertEqualAsDicts(self.constants, constants_from_file)
|
||||
|
||||
coefficients_from_file = from_file_to_dataclass(
|
||||
coefficients_from_file = from_file(
|
||||
LagrangeCoefficientsRecord,
|
||||
construct_path(COEFFICIENTS_FILE_NAME, ELECTION_RECORD_DIR),
|
||||
)
|
||||
self.assertEqualAsDicts(self.lagrange_coefficients, coefficients_from_file)
|
||||
|
||||
device_from_file = from_file_to_dataclass(
|
||||
device_from_file = from_file(
|
||||
EncryptionDevice,
|
||||
construct_path(
|
||||
DEVICE_PREFIX + str(self.device.device_id), devices_directory
|
||||
|
@ -586,7 +586,7 @@ class TestEndToEndElection(BaseTestCase):
|
|||
self.assertEqualAsDicts(self.device, device_from_file)
|
||||
|
||||
for ballot in self.ballot_store.all():
|
||||
ballot_from_file = from_file_to_dataclass(
|
||||
ballot_from_file = from_file(
|
||||
SubmittedBallot,
|
||||
construct_path(
|
||||
SUBMITTED_BALLOT_PREFIX + ballot.object_id,
|
||||
|
@ -596,7 +596,7 @@ class TestEndToEndElection(BaseTestCase):
|
|||
self.assertEqualAsDicts(ballot, ballot_from_file)
|
||||
|
||||
for spoiled_ballot in self.plaintext_spoiled_ballots.values():
|
||||
spoiled_ballot_from_file = from_file_to_dataclass(
|
||||
spoiled_ballot_from_file = from_file(
|
||||
PlaintextTally,
|
||||
construct_path(
|
||||
SPOILED_BALLOT_PREFIX + spoiled_ballot.object_id,
|
||||
|
@ -605,7 +605,7 @@ class TestEndToEndElection(BaseTestCase):
|
|||
)
|
||||
self.assertEqualAsDicts(spoiled_ballot, spoiled_ballot_from_file)
|
||||
|
||||
published_ciphertext_tally_from_file = from_file_to_dataclass(
|
||||
published_ciphertext_tally_from_file = from_file(
|
||||
PublishedCiphertextTally,
|
||||
construct_path(ENCRYPTED_TALLY_FILE_NAME, ELECTION_RECORD_DIR),
|
||||
)
|
||||
|
@ -614,13 +614,13 @@ class TestEndToEndElection(BaseTestCase):
|
|||
published_ciphertext_tally_from_file,
|
||||
)
|
||||
|
||||
plainttext_tally_from_file = from_file_to_dataclass(
|
||||
plainttext_tally_from_file = from_file(
|
||||
PlaintextTally, construct_path(TALLY_FILE_NAME, ELECTION_RECORD_DIR)
|
||||
)
|
||||
self.assertEqualAsDicts(self.plaintext_tally, plainttext_tally_from_file)
|
||||
|
||||
for guardian_record in self.guardian_records:
|
||||
guardian_record_from_file = from_file_to_dataclass(
|
||||
guardian_record_from_file = from_file(
|
||||
GuardianRecord,
|
||||
construct_path(
|
||||
GUARDIAN_PREFIX + guardian_record.guardian_id, guardians_directory
|
||||
|
@ -639,7 +639,7 @@ class TestEndToEndElection(BaseTestCase):
|
|||
print(f"{name}: {message}: {result}")
|
||||
self.assertTrue(result)
|
||||
|
||||
def assertEqualAsDicts(self, first: object, second: object):
|
||||
def assertEqualAsDicts(self, first: object, second: object) -> None:
|
||||
"""
|
||||
Specialty assertEqual to compare dataclasses as dictionaries.
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from datetime import timedelta
|
||||
from hypothesis import given, settings, HealthCheck
|
||||
from hypothesis import given, settings, HealthCheck, Phase
|
||||
from hypothesis.strategies import integers
|
||||
|
||||
|
||||
|
@ -155,6 +155,7 @@ class TestChaumPedersen(BaseTestCase):
|
|||
deadline=timedelta(milliseconds=2000),
|
||||
suppress_health_check=[HealthCheck.too_slow],
|
||||
max_examples=10,
|
||||
phases=[Phase.explicit, Phase.reuse, Phase.generate, Phase.target],
|
||||
)
|
||||
@given(
|
||||
elgamal_keypairs(),
|
||||
|
|
|
@ -37,7 +37,7 @@ from electionguard_tools.strategies.group import elements_mod_q_no_zero
|
|||
class TestElGamal(BaseTestCase):
|
||||
"""ElGamal tests"""
|
||||
|
||||
def test_simple_elgamal_encryption_decryption(self):
|
||||
def test_simple_elgamal_encryption_decryption(self) -> None:
|
||||
nonce = ONE_MOD_Q
|
||||
secret_key = TWO_MOD_Q
|
||||
keypair = get_optional(elgamal_keypair_from_secret(secret_key))
|
||||
|
@ -50,12 +50,12 @@ class TestElGamal(BaseTestCase):
|
|||
ciphertext = get_optional(elgamal_encrypt(0, nonce, keypair.public_key))
|
||||
self.assertEqual(get_generator(), ciphertext.pad)
|
||||
self.assertEqual(
|
||||
pow(ciphertext.pad, secret_key, get_large_prime()),
|
||||
pow(public_key, nonce, get_large_prime()),
|
||||
pow(ciphertext.pad.get_value(), secret_key.get_value(), get_large_prime()),
|
||||
pow(public_key.get_value(), nonce.get_value(), get_large_prime()),
|
||||
)
|
||||
self.assertEqual(
|
||||
ciphertext.data,
|
||||
pow(public_key, nonce, get_large_prime()),
|
||||
ciphertext.data.get_value(),
|
||||
pow(public_key.get_value(), nonce.get_value(), get_large_prime()),
|
||||
)
|
||||
|
||||
plaintext = ciphertext.decrypt(keypair.secret_key)
|
||||
|
@ -65,17 +65,17 @@ class TestElGamal(BaseTestCase):
|
|||
@given(integers(0, 100), elgamal_keypairs())
|
||||
def test_elgamal_encrypt_requires_nonzero_nonce(
|
||||
self, message: int, keypair: ElGamalKeyPair
|
||||
):
|
||||
) -> None:
|
||||
self.assertEqual(None, elgamal_encrypt(message, ZERO_MOD_Q, keypair.public_key))
|
||||
|
||||
def test_elgamal_keypair_from_secret_requires_key_greater_than_one(self):
|
||||
def test_elgamal_keypair_from_secret_requires_key_greater_than_one(self) -> None:
|
||||
self.assertEqual(None, elgamal_keypair_from_secret(ZERO_MOD_Q))
|
||||
self.assertEqual(None, elgamal_keypair_from_secret(ONE_MOD_Q))
|
||||
|
||||
@given(integers(0, 100), elements_mod_q_no_zero(), elgamal_keypairs())
|
||||
def test_elgamal_encryption_decryption_inverses(
|
||||
self, message: int, nonce: ElementModQ, keypair: ElGamalKeyPair
|
||||
):
|
||||
) -> None:
|
||||
ciphertext = get_optional(elgamal_encrypt(message, nonce, keypair.public_key))
|
||||
plaintext = ciphertext.decrypt(keypair.secret_key)
|
||||
|
||||
|
@ -84,14 +84,16 @@ class TestElGamal(BaseTestCase):
|
|||
@given(integers(0, 100), elements_mod_q_no_zero(), elgamal_keypairs())
|
||||
def test_elgamal_encryption_decryption_with_known_nonce_inverses(
|
||||
self, message: int, nonce: ElementModQ, keypair: ElGamalKeyPair
|
||||
):
|
||||
) -> None:
|
||||
ciphertext = get_optional(elgamal_encrypt(message, nonce, keypair.public_key))
|
||||
plaintext = ciphertext.decrypt_known_nonce(keypair.public_key, nonce)
|
||||
|
||||
self.assertEqual(message, plaintext)
|
||||
|
||||
@given(elgamal_keypairs())
|
||||
def test_elgamal_generated_keypairs_are_within_range(self, keypair: ElGamalKeyPair):
|
||||
def test_elgamal_generated_keypairs_are_within_range(
|
||||
self, keypair: ElGamalKeyPair
|
||||
) -> None:
|
||||
self.assertLess(keypair.public_key, get_large_prime())
|
||||
self.assertLess(keypair.secret_key, get_small_prime())
|
||||
self.assertEqual(g_pow_p(keypair.secret_key), keypair.public_key)
|
||||
|
@ -110,7 +112,7 @@ class TestElGamal(BaseTestCase):
|
|||
r1: ElementModQ,
|
||||
m2: int,
|
||||
r2: ElementModQ,
|
||||
):
|
||||
) -> None:
|
||||
c1 = get_optional(elgamal_encrypt(m1, r1, keypair.public_key))
|
||||
c2 = get_optional(elgamal_encrypt(m2, r2, keypair.public_key))
|
||||
c_sum = elgamal_add(c1, c2)
|
||||
|
@ -118,14 +120,14 @@ class TestElGamal(BaseTestCase):
|
|||
|
||||
self.assertEqual(total, m1 + m2)
|
||||
|
||||
def test_elgamal_add_requires_args(self):
|
||||
def test_elgamal_add_requires_args(self) -> None:
|
||||
self.assertRaises(Exception, elgamal_add)
|
||||
|
||||
@given(elgamal_keypairs())
|
||||
def test_elgamal_keypair_produces_valid_residue(self, keypair):
|
||||
def test_elgamal_keypair_produces_valid_residue(self, keypair) -> None:
|
||||
self.assertTrue(keypair.public_key.is_valid_residue())
|
||||
|
||||
def test_elgamal_keypair_random(self):
|
||||
def test_elgamal_keypair_random(self) -> None:
|
||||
# Act
|
||||
random_keypair = elgamal_keypair_random()
|
||||
random_keypair_two = elgamal_keypair_random()
|
||||
|
@ -136,7 +138,7 @@ class TestElGamal(BaseTestCase):
|
|||
self.assertIsNotNone(random_keypair.secret_key)
|
||||
self.assertNotEqual(random_keypair, random_keypair_two)
|
||||
|
||||
def test_elgamal_combine_public_keys(self):
|
||||
def test_elgamal_combine_public_keys(self) -> None:
|
||||
# Arrange
|
||||
random_keypair = elgamal_keypair_random()
|
||||
random_keypair_two = elgamal_keypair_random()
|
||||
|
@ -150,7 +152,7 @@ class TestElGamal(BaseTestCase):
|
|||
self.assertNotEqual(joint_key, random_keypair.public_key)
|
||||
self.assertNotEqual(joint_key, random_keypair_two.public_key)
|
||||
|
||||
def test_gmpy2_parallelism_is_safe(self):
|
||||
def test_gmpy2_parallelism_is_safe(self) -> None:
|
||||
"""
|
||||
Ensures running lots of parallel exponentiations still yields the correct answer.
|
||||
This verifies that nothing incorrect is happening in the GMPY2 library
|
||||
|
|
|
@ -46,11 +46,11 @@ class TestEquality(BaseTestCase):
|
|||
"""Math equality tests"""
|
||||
|
||||
@given(elements_mod_q(), elements_mod_q())
|
||||
def test_p_not_equal_to_q(self, q: ElementModQ, q2: ElementModQ):
|
||||
def test_p_not_equal_to_q(self, q: ElementModQ, q2: ElementModQ) -> None:
|
||||
i = int(q)
|
||||
i2 = int(q2)
|
||||
p = ElementModP(q)
|
||||
p2 = ElementModP(q2)
|
||||
p = ElementModP(q.get_value())
|
||||
p2 = ElementModP(q2.get_value())
|
||||
|
||||
# same value should imply they're equal
|
||||
self.assertEqual(p, q)
|
||||
|
@ -78,54 +78,54 @@ class TestModularArithmetic(BaseTestCase):
|
|||
"""Math Modular Arithmetic tests"""
|
||||
|
||||
@given(elements_mod_q())
|
||||
def test_add_q(self, q: ElementModQ):
|
||||
def test_add_q(self, q: ElementModQ) -> None:
|
||||
as_int = add_q(q, 1)
|
||||
as_elem = add_q(q, ElementModQ(1))
|
||||
self.assertEqual(as_int, as_elem)
|
||||
|
||||
@given(elements_mod_q())
|
||||
def test_a_plus_bc_q(self, q: ElementModQ):
|
||||
def test_a_plus_bc_q(self, q: ElementModQ) -> None:
|
||||
as_int = a_plus_bc_q(q, 1, 1)
|
||||
as_elem = a_plus_bc_q(q, ElementModQ(1), ElementModQ(1))
|
||||
self.assertEqual(as_int, as_elem)
|
||||
|
||||
@given(elements_mod_q())
|
||||
def test_a_minus_b_q(self, q: ElementModQ):
|
||||
def test_a_minus_b_q(self, q: ElementModQ) -> None:
|
||||
as_int = a_minus_b_q(q, 1)
|
||||
as_elem = a_minus_b_q(q, ElementModQ(1))
|
||||
self.assertEqual(as_int, as_elem)
|
||||
|
||||
@given(elements_mod_q())
|
||||
def test_div_q(self, q: ElementModQ):
|
||||
def test_div_q(self, q: ElementModQ) -> None:
|
||||
as_int = div_q(q, 1)
|
||||
as_elem = div_q(q, ElementModQ(1))
|
||||
self.assertEqual(as_int, as_elem)
|
||||
|
||||
@given(elements_mod_p())
|
||||
def test_div_p(self, p: ElementModQ):
|
||||
def test_div_p(self, p: ElementModQ) -> None:
|
||||
as_int = div_p(p, 1)
|
||||
as_elem = div_p(p, ElementModP(1))
|
||||
self.assertEqual(as_int, as_elem)
|
||||
|
||||
def test_no_mult_inv_of_zero(self):
|
||||
def test_no_mult_inv_of_zero(self) -> None:
|
||||
self.assertRaises(Exception, mult_inv_p, ZERO_MOD_P)
|
||||
|
||||
@given(elements_mod_p_no_zero())
|
||||
def test_mult_inverses(self, elem: ElementModP):
|
||||
def test_mult_inverses(self, elem: ElementModP) -> None:
|
||||
inv = mult_inv_p(elem)
|
||||
self.assertEqual(mult_p(elem, inv), ONE_MOD_P)
|
||||
|
||||
@given(elements_mod_p())
|
||||
def test_mult_identity(self, elem: ElementModP):
|
||||
def test_mult_identity(self, elem: ElementModP) -> None:
|
||||
self.assertEqual(elem, mult_p(elem))
|
||||
|
||||
def test_mult_noargs(self):
|
||||
def test_mult_noargs(self) -> None:
|
||||
self.assertEqual(ONE_MOD_P, mult_p())
|
||||
|
||||
def test_add_noargs(self):
|
||||
def test_add_noargs(self) -> None:
|
||||
self.assertEqual(ZERO_MOD_Q, add_q())
|
||||
|
||||
def test_properties_for_constants(self):
|
||||
def test_properties_for_constants(self) -> None:
|
||||
self.assertNotEqual(get_generator(), 1)
|
||||
self.assertEqual(
|
||||
(get_cofactor() * get_small_prime()) % get_large_prime(),
|
||||
|
@ -135,13 +135,13 @@ class TestModularArithmetic(BaseTestCase):
|
|||
self.assertLess(get_generator(), get_large_prime())
|
||||
self.assertLess(get_cofactor(), get_large_prime())
|
||||
|
||||
def test_simple_powers(self):
|
||||
def test_simple_powers(self) -> None:
|
||||
gp = int_to_p(get_generator())
|
||||
self.assertEqual(gp, g_pow_p(ONE_MOD_Q))
|
||||
self.assertEqual(ONE_MOD_P, g_pow_p(ZERO_MOD_Q))
|
||||
|
||||
@given(elements_mod_q())
|
||||
def test_in_bounds_q(self, q: ElementModQ):
|
||||
def test_in_bounds_q(self, q: ElementModQ) -> None:
|
||||
self.assertTrue(q.is_in_bounds())
|
||||
too_big = q + get_small_prime()
|
||||
too_small = q - get_small_prime()
|
||||
|
@ -155,7 +155,7 @@ class TestModularArithmetic(BaseTestCase):
|
|||
ElementModQ(too_small)
|
||||
|
||||
@given(elements_mod_p())
|
||||
def test_in_bounds_p(self, p: ElementModP):
|
||||
def test_in_bounds_p(self, p: ElementModP) -> None:
|
||||
self.assertTrue(p.is_in_bounds())
|
||||
too_big = p + get_large_prime()
|
||||
too_small = p - get_large_prime()
|
||||
|
@ -180,7 +180,7 @@ class TestModularArithmetic(BaseTestCase):
|
|||
)
|
||||
|
||||
@given(elements_mod_p_no_zero())
|
||||
def test_in_bounds_p_no_zero(self, p: ElementModP):
|
||||
def test_in_bounds_p_no_zero(self, p: ElementModP) -> None:
|
||||
self.assertTrue(p.is_in_bounds_no_zero())
|
||||
self.assertFalse(ZERO_MOD_P.is_in_bounds_no_zero())
|
||||
self.assertFalse(
|
||||
|
@ -191,7 +191,7 @@ class TestModularArithmetic(BaseTestCase):
|
|||
)
|
||||
|
||||
@given(elements_mod_q())
|
||||
def test_large_values_rejected_by_int_to_q(self, q: ElementModQ):
|
||||
def test_large_values_rejected_by_int_to_q(self, q: ElementModQ) -> None:
|
||||
oversize = q + get_small_prime()
|
||||
self.assertEqual(None, int_to_q(oversize))
|
||||
|
||||
|
@ -199,28 +199,28 @@ class TestModularArithmetic(BaseTestCase):
|
|||
class TestOptionalFunctions(BaseTestCase):
|
||||
"""Math Optional Functions tests"""
|
||||
|
||||
def test_unwrap(self):
|
||||
def test_unwrap(self) -> None:
|
||||
good: Optional[int] = 3
|
||||
bad: Optional[int] = None
|
||||
|
||||
self.assertEqual(get_optional(good), 3)
|
||||
self.assertRaises(Exception, get_optional, bad)
|
||||
|
||||
def test_match(self):
|
||||
def test_match(self) -> None:
|
||||
good: Optional[int] = 3
|
||||
bad: Optional[int] = None
|
||||
|
||||
self.assertEqual(5, match_optional(good, lambda: 1, lambda x: x + 2))
|
||||
self.assertEqual(1, match_optional(bad, lambda: 1, lambda x: x + 2))
|
||||
|
||||
def test_get_or_else(self):
|
||||
def test_get_or_else(self) -> None:
|
||||
good: Optional[int] = 3
|
||||
bad: Optional[int] = None
|
||||
|
||||
self.assertEqual(3, get_or_else_optional(good, 5))
|
||||
self.assertEqual(5, get_or_else_optional(bad, 5))
|
||||
|
||||
def test_flatmap(self):
|
||||
def test_flatmap(self) -> None:
|
||||
good: Optional[int] = 3
|
||||
bad: Optional[int] = None
|
||||
|
||||
|
|
|
@ -9,9 +9,10 @@ from electionguard.manifest import (
|
|||
SelectionDescription,
|
||||
VoteVariationType,
|
||||
)
|
||||
from electionguard.serialize import from_raw, to_raw
|
||||
import electionguard_tools.factories.election_factory as ElectionFactory
|
||||
import electionguard_tools.factories.ballot_factory as BallotFactory
|
||||
from electionguard_tools.helpers.serialize import from_raw, to_raw
|
||||
|
||||
|
||||
election_factory = ElectionFactory.ElectionFactory()
|
||||
ballot_factory = BallotFactory.BallotFactory()
|
||||
|
@ -20,7 +21,7 @@ ballot_factory = BallotFactory.BallotFactory()
|
|||
class TestManifest(BaseTestCase):
|
||||
"""Manifest tests"""
|
||||
|
||||
def test_simple_manifest_is_valid(self):
|
||||
def test_simple_manifest_is_valid(self) -> None:
|
||||
|
||||
# Act
|
||||
subject = election_factory.get_simple_manifest_from_file()
|
||||
|
@ -30,7 +31,7 @@ class TestManifest(BaseTestCase):
|
|||
self.assertEqual(subject.election_scope_id, "jefferson-county-primary")
|
||||
self.assertTrue(subject.is_valid())
|
||||
|
||||
def test_simple_manifest_can_serialize(self):
|
||||
def test_simple_manifest_can_serialize(self) -> None:
|
||||
# Arrange
|
||||
subject = election_factory.get_simple_manifest_from_file()
|
||||
intermediate = to_raw(subject)
|
||||
|
@ -42,7 +43,7 @@ class TestManifest(BaseTestCase):
|
|||
self.assertIsNotNone(result.election_scope_id)
|
||||
self.assertEqual(result.election_scope_id, "jefferson-county-primary")
|
||||
|
||||
def test_manifest_has_deterministic_hash(self):
|
||||
def test_manifest_has_deterministic_hash(self) -> None:
|
||||
|
||||
# Act
|
||||
subject1 = election_factory.get_simple_manifest_from_file()
|
||||
|
@ -51,7 +52,7 @@ class TestManifest(BaseTestCase):
|
|||
# Assert
|
||||
self.assertEqual(subject1.crypto_hash(), subject2.crypto_hash())
|
||||
|
||||
def test_manifest_hash_is_consistent_regardless_of_format(self):
|
||||
def test_manifest_hash_is_consistent_regardless_of_format(self) -> None:
|
||||
|
||||
# Act
|
||||
subject1 = election_factory.get_simple_manifest_from_file()
|
||||
|
@ -72,7 +73,7 @@ class TestManifest(BaseTestCase):
|
|||
|
||||
def test_manifest_from_file_generates_consistent_internal_description_contest_hashes(
|
||||
self,
|
||||
):
|
||||
) -> None:
|
||||
# Arrange
|
||||
comparator = election_factory.get_simple_manifest_from_file()
|
||||
subject = InternalManifest(comparator)
|
||||
|
@ -84,7 +85,7 @@ class TestManifest(BaseTestCase):
|
|||
if expected.object_id == actual.object_id:
|
||||
self.assertEqual(expected.crypto_hash(), actual.crypto_hash())
|
||||
|
||||
def test_contest_description_valid_input_succeeds(self):
|
||||
def test_contest_description_valid_input_succeeds(self) -> None:
|
||||
description = ContestDescriptionWithPlaceholders(
|
||||
object_id="0@A.com-contest",
|
||||
electoral_district_id="0@A.com-gp-unit",
|
||||
|
@ -118,7 +119,7 @@ class TestManifest(BaseTestCase):
|
|||
|
||||
self.assertTrue(description.is_valid())
|
||||
|
||||
def test_contest_description_invalid_input_fails(self):
|
||||
def test_contest_description_invalid_input_fails(self) -> None:
|
||||
|
||||
description = ContestDescriptionWithPlaceholders(
|
||||
object_id="0@A.com-contest",
|
||||
|
|
Загрузка…
Ссылка в новой задаче