Π·Π΅ΡΠΊΠ°Π»ΠΎ ΠΈΠ· https://github.com/microsoft/electionguard-python.git
π. Big Integer and Base Element (#599)
* π§. Create BigInteger class A representation of BigInteger in python to specifically designate what is happening. Left as a str due to the hex representation being the underpinning. * π§. Refactor BaseElement to use BigInteger - Refactor BaseElement to use BigInteger and fix tests to use value instead of get_value - Update Tests * π§ Remove now unused Serializable * π¦ Add dacite to handle dictionary casting * β¨ Add custom deserialization to avoid pydantic issues * π₯ Remove add and subtract functions from big integer * β Better indicate padding purpose
This commit is contained in:
Π ΠΎΠ΄ΠΈΡΠ΅Π»Ρ
b7c8782e81
ΠΠΎΠΌΠΌΠΈΡ
392586832d
|
@ -189,6 +189,17 @@ sdist = ["setuptools_rust (>=0.11.4)"]
|
||||||
ssh = ["bcrypt (>=3.1.5)"]
|
ssh = ["bcrypt (>=3.1.5)"]
|
||||||
test = ["pytest (>=6.2.0)", "pytest-cov", "pytest-subtests", "pytest-xdist", "pretend", "iso8601", "pytz", "hypothesis (>=1.11.4,!=3.79.2)"]
|
test = ["pytest (>=6.2.0)", "pytest-cov", "pytest-subtests", "pytest-xdist", "pretend", "iso8601", "pytz", "hypothesis (>=1.11.4,!=3.79.2)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dacite"
|
||||||
|
version = "1.6.0"
|
||||||
|
description = "Simple creation of data classes from dictionaries."
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["pytest (>=5)", "pytest-cov", "coveralls", "black", "mypy", "pylint"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "debugpy"
|
name = "debugpy"
|
||||||
version = "1.6.0"
|
version = "1.6.0"
|
||||||
|
@ -1610,7 +1621,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.9.5"
|
python-versions = "^3.9.5"
|
||||||
content-hash = "d5dd120c9f2fca28b2e16b793b8e9b4c397d8d4a61304b96696e976db06c4d47"
|
content-hash = "9bd2777ce600469494106f880daad6ce7feb021f972be4dc84882cd6afbab16e"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
appnope = [
|
appnope = [
|
||||||
|
@ -1803,6 +1814,10 @@ cryptography = [
|
||||||
{file = "cryptography-36.0.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e167b6b710c7f7bc54e67ef593f8731e1f45aa35f8a8a7b72d6e42ec76afd4b3"},
|
{file = "cryptography-36.0.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e167b6b710c7f7bc54e67ef593f8731e1f45aa35f8a8a7b72d6e42ec76afd4b3"},
|
||||||
{file = "cryptography-36.0.2.tar.gz", hash = "sha256:70f8f4f7bb2ac9f340655cbac89d68c527af5bb4387522a8413e841e3e6628c9"},
|
{file = "cryptography-36.0.2.tar.gz", hash = "sha256:70f8f4f7bb2ac9f340655cbac89d68c527af5bb4387522a8413e841e3e6628c9"},
|
||||||
]
|
]
|
||||||
|
dacite = [
|
||||||
|
{file = "dacite-1.6.0-py3-none-any.whl", hash = "sha256:4331535f7aabb505c732fa4c3c094313fc0a1d5ea19907bf4726a7819a68b93f"},
|
||||||
|
{file = "dacite-1.6.0.tar.gz", hash = "sha256:d48125ed0a0352d3de9f493bf980038088f45f3f9d7498f090b50a847daaa6df"},
|
||||||
|
]
|
||||||
debugpy = [
|
debugpy = [
|
||||||
{file = "debugpy-1.6.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:eb1946efac0c0c3d411cea0b5ac772fbde744109fd9520fb0c5a51979faf05ad"},
|
{file = "debugpy-1.6.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:eb1946efac0c0c3d411cea0b5ac772fbde744109fd9520fb0c5a51979faf05ad"},
|
||||||
{file = "debugpy-1.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e3513399177dd37af4c1332df52da5da1d0c387e5927dc4c0709e26ee7302e8f"},
|
{file = "debugpy-1.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e3513399177dd37af4c1332df52da5da1d0c387e5927dc4c0709e26ee7302e8f"},
|
||||||
|
|
|
@ -39,6 +39,7 @@ gmpy2 = "^2.0.8"
|
||||||
psutil = ">=5.7.2"
|
psutil = ">=5.7.2"
|
||||||
pydantic = "1.9.0"
|
pydantic = "1.9.0"
|
||||||
click = "^8.1.0"
|
click = "^8.1.0"
|
||||||
|
dacite = "^1.6.0"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
atomicwrites = "*"
|
atomicwrites = "*"
|
||||||
|
|
|
@ -6,6 +6,7 @@ from electionguard import ballot_box
|
||||||
from electionguard import ballot_code
|
from electionguard import ballot_code
|
||||||
from electionguard import ballot_compact
|
from electionguard import ballot_compact
|
||||||
from electionguard import ballot_validator
|
from electionguard import ballot_validator
|
||||||
|
from electionguard import big_integer
|
||||||
from electionguard import chaum_pedersen
|
from electionguard import chaum_pedersen
|
||||||
from electionguard import constants
|
from electionguard import constants
|
||||||
from electionguard import data_store
|
from electionguard import data_store
|
||||||
|
@ -83,6 +84,9 @@ from electionguard.ballot_validator import (
|
||||||
contest_is_valid_for_style,
|
contest_is_valid_for_style,
|
||||||
selection_is_valid_for_style,
|
selection_is_valid_for_style,
|
||||||
)
|
)
|
||||||
|
from electionguard.big_integer import (
|
||||||
|
BigInteger,
|
||||||
|
)
|
||||||
from electionguard.chaum_pedersen import (
|
from electionguard.chaum_pedersen import (
|
||||||
ChaumPedersenProof,
|
ChaumPedersenProof,
|
||||||
ConstantChaumPedersenProof,
|
ConstantChaumPedersenProof,
|
||||||
|
@ -231,10 +235,8 @@ from electionguard.group import (
|
||||||
div_p,
|
div_p,
|
||||||
div_q,
|
div_q,
|
||||||
g_pow_p,
|
g_pow_p,
|
||||||
hex_to_int,
|
|
||||||
hex_to_p,
|
hex_to_p,
|
||||||
hex_to_q,
|
hex_to_q,
|
||||||
int_to_hex,
|
|
||||||
int_to_p,
|
int_to_p,
|
||||||
int_to_q,
|
int_to_q,
|
||||||
mult_inv_p,
|
mult_inv_p,
|
||||||
|
@ -336,8 +338,6 @@ from electionguard.schnorr import (
|
||||||
make_schnorr_proof,
|
make_schnorr_proof,
|
||||||
)
|
)
|
||||||
from electionguard.serialize import (
|
from electionguard.serialize import (
|
||||||
Private,
|
|
||||||
Serializable,
|
|
||||||
construct_path,
|
construct_path,
|
||||||
from_file,
|
from_file,
|
||||||
from_file_wrapper,
|
from_file_wrapper,
|
||||||
|
@ -390,6 +390,7 @@ __all__ = [
|
||||||
"BallotId",
|
"BallotId",
|
||||||
"BallotStyle",
|
"BallotStyle",
|
||||||
"BaseElement",
|
"BaseElement",
|
||||||
|
"BigInteger",
|
||||||
"Candidate",
|
"Candidate",
|
||||||
"CandidateContestDescription",
|
"CandidateContestDescription",
|
||||||
"CeremonyDetails",
|
"CeremonyDetails",
|
||||||
|
@ -483,7 +484,6 @@ __all__ = [
|
||||||
"PlaintextTallyContest",
|
"PlaintextTallyContest",
|
||||||
"PlaintextTallySelection",
|
"PlaintextTallySelection",
|
||||||
"PrimeOption",
|
"PrimeOption",
|
||||||
"Private",
|
|
||||||
"PrivateGuardianRecord",
|
"PrivateGuardianRecord",
|
||||||
"Proof",
|
"Proof",
|
||||||
"ProofOrRecovery",
|
"ProofOrRecovery",
|
||||||
|
@ -501,7 +501,6 @@ __all__ = [
|
||||||
"SecretCoefficient",
|
"SecretCoefficient",
|
||||||
"SelectionDescription",
|
"SelectionDescription",
|
||||||
"SelectionId",
|
"SelectionId",
|
||||||
"Serializable",
|
|
||||||
"Singleton",
|
"Singleton",
|
||||||
"SubmittedBallot",
|
"SubmittedBallot",
|
||||||
"VerifierId",
|
"VerifierId",
|
||||||
|
@ -518,6 +517,7 @@ __all__ = [
|
||||||
"ballot_is_valid_for_election",
|
"ballot_is_valid_for_election",
|
||||||
"ballot_is_valid_for_style",
|
"ballot_is_valid_for_style",
|
||||||
"ballot_validator",
|
"ballot_validator",
|
||||||
|
"big_integer",
|
||||||
"chaum_pedersen",
|
"chaum_pedersen",
|
||||||
"combine_election_public_keys",
|
"combine_election_public_keys",
|
||||||
"compensate_decrypt",
|
"compensate_decrypt",
|
||||||
|
@ -620,11 +620,9 @@ __all__ = [
|
||||||
"hash",
|
"hash",
|
||||||
"hash_elems",
|
"hash_elems",
|
||||||
"hashed_elgamal_encrypt",
|
"hashed_elgamal_encrypt",
|
||||||
"hex_to_int",
|
|
||||||
"hex_to_p",
|
"hex_to_p",
|
||||||
"hex_to_q",
|
"hex_to_q",
|
||||||
"hmac",
|
"hmac",
|
||||||
"int_to_hex",
|
|
||||||
"int_to_p",
|
"int_to_p",
|
||||||
"int_to_q",
|
"int_to_q",
|
||||||
"key_ceremony",
|
"key_ceremony",
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
from typing import Any, Tuple, Union
|
||||||
|
from base64 import b16decode
|
||||||
|
|
||||||
|
# pylint: disable=no-name-in-module
|
||||||
|
from gmpy2 import mpz
|
||||||
|
|
||||||
|
|
||||||
|
def _hex_to_int(input: str) -> int:
|
||||||
|
"""Given a hex string representing bytes, returns an int."""
|
||||||
|
return int(input, 16)
|
||||||
|
|
||||||
|
|
||||||
|
def _int_to_hex(input: int) -> str:
|
||||||
|
"""Given an int, returns a hex string representing bytes."""
|
||||||
|
|
||||||
|
def pad_hex(hex: str) -> str:
|
||||||
|
"""Pad hex to ensure 2 digit hexadecimal format maintained."""
|
||||||
|
return "0" + hex if len(hex) % 2 else hex
|
||||||
|
|
||||||
|
hex = format(input, "02X")
|
||||||
|
return pad_hex(hex)
|
||||||
|
|
||||||
|
|
||||||
|
_zero = mpz(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_element(data: Union[int, str]) -> Tuple[str, int]:
|
||||||
|
"""Convert element to consistent types"""
|
||||||
|
if isinstance(data, str):
|
||||||
|
hex = data
|
||||||
|
integer = _hex_to_int(data)
|
||||||
|
else:
|
||||||
|
hex = _int_to_hex(data)
|
||||||
|
integer = data
|
||||||
|
return (hex, integer)
|
||||||
|
|
||||||
|
|
||||||
|
class BigInteger(str):
|
||||||
|
"""A specialized representation of a big integer in python"""
|
||||||
|
|
||||||
|
_value: mpz = _zero
|
||||||
|
|
||||||
|
def __new__(cls, data: Union[int, str]): # type: ignore
|
||||||
|
(hex, integer) = _convert_to_element(data)
|
||||||
|
big_int = super(BigInteger, cls).__new__(cls, hex)
|
||||||
|
big_int._value = mpz(integer)
|
||||||
|
return big_int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self) -> mpz:
|
||||||
|
"""Get internal value for math calculations"""
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def __int__(self) -> int:
|
||||||
|
"""Overload int conversion."""
|
||||||
|
return int(self.value)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Overload == (equal to) operator."""
|
||||||
|
return (
|
||||||
|
isinstance(other, BigInteger) and int(self.value) == int(other.value)
|
||||||
|
) or (isinstance(other, int) and int(self.value) == other)
|
||||||
|
|
||||||
|
def __ne__(self, other: Any) -> bool:
|
||||||
|
"""Overload != (not equal to) operator."""
|
||||||
|
return not self == other
|
||||||
|
|
||||||
|
def __lt__(self, other: Any) -> bool:
|
||||||
|
"""Overload <= (less than) operator."""
|
||||||
|
return (
|
||||||
|
isinstance(other, BigInteger) and int(self.value) < int(other.value)
|
||||||
|
) or (isinstance(other, int) and int(self.value) < other)
|
||||||
|
|
||||||
|
def __le__(self, other: Any) -> bool:
|
||||||
|
"""Overload <= (less than or equal) operator."""
|
||||||
|
return self.__lt__(other) or self.__eq__(other)
|
||||||
|
|
||||||
|
def __gt__(self, other: Any) -> bool:
|
||||||
|
"""Overload > (greater than) operator."""
|
||||||
|
return (
|
||||||
|
isinstance(other, BigInteger) and int(self.value) > int(other.value)
|
||||||
|
) or (isinstance(other, int) and int(self.value) > other)
|
||||||
|
|
||||||
|
def __ge__(self, other: Any) -> bool:
|
||||||
|
"""Overload >= (greater than or equal) operator."""
|
||||||
|
return self.__gt__(other) or self.__eq__(other)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Overload the hashing function."""
|
||||||
|
return hash(self.value)
|
||||||
|
|
||||||
|
def to_hex(self) -> str:
|
||||||
|
"""
|
||||||
|
Convert from the element to the hex representation of bytes.
|
||||||
|
"""
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
def to_hex_bytes(self) -> bytes:
|
||||||
|
"""
|
||||||
|
Convert from the element to the representation of bytes by first going through hex.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return b16decode(self)
|
|
@ -5,154 +5,41 @@ in the sense that performance may be less than hand-optimized C code, and no gua
|
||||||
made about timing or other side-channels.
|
made about timing or other side-channels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC
|
||||||
from typing import Any, Final, Optional, Tuple, Union
|
from typing import Final, Optional, Union
|
||||||
from base64 import b16decode
|
|
||||||
from secrets import randbelow
|
from secrets import randbelow
|
||||||
from sys import maxsize
|
from sys import maxsize
|
||||||
|
|
||||||
# pylint: disable=no-name-in-module
|
# pylint: disable=no-name-in-module
|
||||||
from gmpy2 import mpz, powmod, invert
|
from gmpy2 import mpz, powmod, invert
|
||||||
|
|
||||||
from .serialize import Serializable, Private
|
from .big_integer import BigInteger
|
||||||
from .constants import get_large_prime, get_small_prime, get_generator
|
from .constants import get_large_prime, get_small_prime, get_generator
|
||||||
|
|
||||||
|
|
||||||
def hex_to_int(input: str) -> int:
|
class BaseElement(BigInteger, ABC):
|
||||||
"""Given a hex string representing bytes, returns an int."""
|
|
||||||
return int(input, 16)
|
|
||||||
|
|
||||||
|
|
||||||
def int_to_hex(input: int) -> str:
|
|
||||||
"""Given an int, returns a hex string representing bytes."""
|
|
||||||
hex = format(input, "02X")
|
|
||||||
if len(hex) % 2:
|
|
||||||
hex = "0" + hex
|
|
||||||
return hex
|
|
||||||
|
|
||||||
|
|
||||||
_zero = mpz(0)
|
|
||||||
|
|
||||||
|
|
||||||
def _mpz_zero() -> mpz:
|
|
||||||
return _zero
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_element(data: Union[int, str]) -> Tuple[str, int]:
|
|
||||||
"""Convert element to consistent types"""
|
|
||||||
if isinstance(data, str):
|
|
||||||
hex = data
|
|
||||||
integer = hex_to_int(data)
|
|
||||||
else:
|
|
||||||
hex = int_to_hex(data)
|
|
||||||
integer = data
|
|
||||||
return (hex, integer)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseElement(Serializable, ABC):
|
|
||||||
"""An element limited by mod T within [0, T) where T is determined by an upper_bound function."""
|
"""An element limited by mod T within [0, T) where T is determined by an upper_bound function."""
|
||||||
|
|
||||||
data: str
|
def __new__(cls, data: Union[int, str], check_within_bounds: bool = True): # type: ignore
|
||||||
|
|
||||||
_value: mpz = Private(default_factory=_mpz_zero)
|
|
||||||
"""Internal math representation of element"""
|
|
||||||
|
|
||||||
def __init__(self, data: Union[int, str], check_within_bounds: bool = True) -> None:
|
|
||||||
"""Instantiate element mod T where element is an int or its hex representation."""
|
"""Instantiate element mod T where element is an int or its hex representation."""
|
||||||
(hex, integer) = _convert_to_element(data)
|
element = super(BaseElement, cls).__new__(cls, data)
|
||||||
super().__init__(data=hex)
|
|
||||||
self._value = mpz(integer)
|
|
||||||
|
|
||||||
if check_within_bounds:
|
if check_within_bounds:
|
||||||
if not self.is_in_bounds():
|
if not 0 <= element.value < cls.get_upper_bound():
|
||||||
raise OverflowError
|
raise OverflowError
|
||||||
|
return element
|
||||||
|
|
||||||
def __str__(self) -> str:
|
@classmethod
|
||||||
"""Overload string representation"""
|
def get_upper_bound(cls) -> int:
|
||||||
return self.data
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
"""Overload object representation"""
|
|
||||||
return self.data
|
|
||||||
|
|
||||||
def __int__(self) -> int:
|
|
||||||
"""Overload int conversion."""
|
|
||||||
return int(self.get_value())
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Overload == (equal to) operator."""
|
|
||||||
return (
|
|
||||||
isinstance(other, BaseElement)
|
|
||||||
and int(self.get_value()) == int(other.get_value())
|
|
||||||
) or (isinstance(other, int) and int(self.get_value()) == other)
|
|
||||||
|
|
||||||
def __ne__(self, other: Any) -> bool:
|
|
||||||
"""Overload != (not equal to) operator."""
|
|
||||||
return not self == other
|
|
||||||
|
|
||||||
def __lt__(self, other: Any) -> bool:
|
|
||||||
"""Overload <= (less than) operator."""
|
|
||||||
return (
|
|
||||||
isinstance(other, BaseElement)
|
|
||||||
and int(self.get_value()) < int(other.get_value())
|
|
||||||
) or (isinstance(other, int) and int(self.get_value()) < other)
|
|
||||||
|
|
||||||
def __le__(self, other: Any) -> bool:
|
|
||||||
"""Overload <= (less than or equal) operator."""
|
|
||||||
return self.__lt__(other) or self.__eq__(other)
|
|
||||||
|
|
||||||
def __gt__(self, other: Any) -> bool:
|
|
||||||
"""Overload > (greater than) operator."""
|
|
||||||
return (
|
|
||||||
isinstance(other, BaseElement)
|
|
||||||
and int(self.get_value()) > int(other.get_value())
|
|
||||||
) or (isinstance(other, int) and int(self.get_value()) > other)
|
|
||||||
|
|
||||||
def __ge__(self, other: Any) -> bool:
|
|
||||||
"""Overload >= (greater than or equal) operator."""
|
|
||||||
return self.__gt__(other) or self.__eq__(other)
|
|
||||||
|
|
||||||
def __add__(self, other: Any) -> Any:
|
|
||||||
"""Overload addition operator."""
|
|
||||||
return self.get_value() + other
|
|
||||||
|
|
||||||
def __sub__(self, other: Any) -> Any:
|
|
||||||
"""Overload subtraction operator."""
|
|
||||||
return self.get_value() - other
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
"""Overload the hashing function."""
|
|
||||||
return hash(self.get_value())
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_upper_bound(self) -> int:
|
|
||||||
"""Get the upper bound for the element."""
|
"""Get the upper bound for the element."""
|
||||||
return maxsize
|
return maxsize
|
||||||
|
|
||||||
def get_value(self) -> mpz:
|
|
||||||
"""Get internal value for math calculations"""
|
|
||||||
return self._value
|
|
||||||
|
|
||||||
def to_hex(self) -> str:
|
|
||||||
"""
|
|
||||||
Convert from the element to the hex representation of bytes.
|
|
||||||
"""
|
|
||||||
return self.data
|
|
||||||
|
|
||||||
def to_hex_bytes(self) -> bytes:
|
|
||||||
"""
|
|
||||||
Convert from the element to the representation of bytes by first going through hex.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return b16decode(self.data)
|
|
||||||
|
|
||||||
def is_in_bounds(self) -> bool:
|
def is_in_bounds(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Validate that the element is actually within the bounds of [0,Q).
|
Validate that the element is actually within the bounds of [0,Q).
|
||||||
|
|
||||||
Returns true if all is good, false if something's wrong.
|
Returns true if all is good, false if something's wrong.
|
||||||
"""
|
"""
|
||||||
return 0 <= self.get_value() < self.get_upper_bound()
|
return 0 <= self.value < self.get_upper_bound()
|
||||||
|
|
||||||
def is_in_bounds_no_zero(self) -> bool:
|
def is_in_bounds_no_zero(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -160,13 +47,14 @@ class BaseElement(Serializable, ABC):
|
||||||
|
|
||||||
Returns true if all is good, false if something's wrong.
|
Returns true if all is good, false if something's wrong.
|
||||||
"""
|
"""
|
||||||
return 1 <= self.get_value() < self.get_upper_bound()
|
return 1 <= self.value < self.get_upper_bound()
|
||||||
|
|
||||||
|
|
||||||
class ElementModQ(BaseElement):
|
class ElementModQ(BaseElement):
|
||||||
"""An element of the smaller `mod q` space, i.e., in [0, Q), where Q is a 256-bit prime."""
|
"""An element of the smaller `mod q` space, i.e., in [0, Q), where Q is a 256-bit prime."""
|
||||||
|
|
||||||
def get_upper_bound(self) -> int:
|
@classmethod
|
||||||
|
def get_upper_bound(cls) -> int:
|
||||||
"""Get the upper bound for the element."""
|
"""Get the upper bound for the element."""
|
||||||
return get_small_prime()
|
return get_small_prime()
|
||||||
|
|
||||||
|
@ -174,7 +62,8 @@ class ElementModQ(BaseElement):
|
||||||
class ElementModP(BaseElement):
|
class ElementModP(BaseElement):
|
||||||
"""An element of the larger `mod p` space, i.e., in [0, P), where P is a 4096-bit prime."""
|
"""An element of the larger `mod p` space, i.e., in [0, P), where P is a 4096-bit prime."""
|
||||||
|
|
||||||
def get_upper_bound(self) -> int:
|
@classmethod
|
||||||
|
def get_upper_bound(cls) -> int:
|
||||||
"""Get the upper bound for the element."""
|
"""Get the upper bound for the element."""
|
||||||
return get_large_prime()
|
return get_large_prime()
|
||||||
|
|
||||||
|
@ -202,7 +91,7 @@ ElementModPorInt = Union[ElementModP, int]
|
||||||
def _get_mpz(input: Union[BaseElement, int]) -> mpz:
|
def _get_mpz(input: Union[BaseElement, int]) -> mpz:
|
||||||
"""Get BaseElement or integer as mpz."""
|
"""Get BaseElement or integer as mpz."""
|
||||||
if isinstance(input, BaseElement):
|
if isinstance(input, BaseElement):
|
||||||
return input.get_value()
|
return input.value
|
||||||
return mpz(input)
|
return mpz(input)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,18 @@
|
||||||
|
from datetime import datetime
|
||||||
from io import TextIOWrapper
|
from io import TextIOWrapper
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Type, TypeVar, Union
|
from typing import Any, List, Type, TypeVar, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, PrivateAttr
|
from dacite import Config, from_dict
|
||||||
from pydantic.json import pydantic_encoder
|
from pydantic.json import pydantic_encoder
|
||||||
from pydantic.tools import parse_raw_as, parse_obj_as, schema_json_of
|
from pydantic.tools import parse_raw_as, schema_json_of
|
||||||
|
|
||||||
Private = PrivateAttr
|
|
||||||
|
|
||||||
|
|
||||||
class Serializable(BaseModel):
|
|
||||||
"""Serializable data object intended for exporting and importing"""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Model config to handle private properties"""
|
|
||||||
|
|
||||||
underscore_attrs_are_private = True
|
|
||||||
|
|
||||||
|
from .ballot_box import BallotBoxState
|
||||||
|
from .manifest import ElectionType, ReportingUnitType, VoteVariationType
|
||||||
|
from .group import ElementModP, ElementModQ
|
||||||
|
from .proof import ProofUsage
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
@ -26,6 +20,20 @@ _indent = 2
|
||||||
_encoding = "utf-8"
|
_encoding = "utf-8"
|
||||||
_file_extension = "json"
|
_file_extension = "json"
|
||||||
|
|
||||||
|
_config = Config(
|
||||||
|
cast=[
|
||||||
|
datetime,
|
||||||
|
ElementModP,
|
||||||
|
ElementModQ,
|
||||||
|
BallotBoxState,
|
||||||
|
ElectionType,
|
||||||
|
ReportingUnitType,
|
||||||
|
VoteVariationType,
|
||||||
|
ProofUsage,
|
||||||
|
],
|
||||||
|
type_hooks={datetime: datetime.fromisoformat},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def construct_path(
|
def construct_path(
|
||||||
target_file_name: str,
|
target_file_name: str,
|
||||||
|
@ -54,7 +62,7 @@ def from_file_wrapper(type_: Type[_T], file: TextIOWrapper) -> _T:
|
||||||
"""Deserialize json file as type."""
|
"""Deserialize json file as type."""
|
||||||
|
|
||||||
data = json.load(file)
|
data = json.load(file)
|
||||||
return parse_obj_as(type_, data)
|
return from_dict(type_, data, _config)
|
||||||
|
|
||||||
|
|
||||||
def from_file(type_: Type[_T], path: Union[str, Path]) -> _T:
|
def from_file(type_: Type[_T], path: Union[str, Path]) -> _T:
|
||||||
|
@ -62,7 +70,7 @@ def from_file(type_: Type[_T], path: Union[str, Path]) -> _T:
|
||||||
|
|
||||||
with open(path, "r", encoding=_encoding) as json_file:
|
with open(path, "r", encoding=_encoding) as json_file:
|
||||||
data = json.load(json_file)
|
data = json.load(json_file)
|
||||||
return parse_obj_as(type_, data)
|
return from_dict(type_, data, _config)
|
||||||
|
|
||||||
|
|
||||||
def from_list_in_file(type_: Type[_T], path: Union[str, Path]) -> List[_T]:
|
def from_list_in_file(type_: Type[_T], path: Union[str, Path]) -> List[_T]:
|
||||||
|
@ -72,7 +80,7 @@ def from_list_in_file(type_: Type[_T], path: Union[str, Path]) -> List[_T]:
|
||||||
data = json.load(json_file)
|
data = json.load(json_file)
|
||||||
ls: List[_T] = []
|
ls: List[_T] = []
|
||||||
for item in data:
|
for item in data:
|
||||||
ls.append(parse_obj_as(type_, item))
|
ls.append(from_dict(type_, item, _config))
|
||||||
return ls
|
return ls
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,7 +90,7 @@ def from_list_in_file_wrapper(type_: Type[_T], file: TextIOWrapper) -> List[_T]:
|
||||||
data = json.load(file)
|
data = json.load(file)
|
||||||
ls: List[_T] = []
|
ls: List[_T] = []
|
||||||
for item in data:
|
for item in data:
|
||||||
ls.append(parse_obj_as(type_, item))
|
ls.append(from_dict(type_, item, _config))
|
||||||
return ls
|
return ls
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -54,12 +54,12 @@ class TestElGamal(BaseTestCase):
|
||||||
ciphertext = get_optional(elgamal_encrypt(0, nonce, keypair.public_key))
|
ciphertext = get_optional(elgamal_encrypt(0, nonce, keypair.public_key))
|
||||||
self.assertEqual(get_generator(), ciphertext.pad)
|
self.assertEqual(get_generator(), ciphertext.pad)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
pow(ciphertext.pad.get_value(), secret_key.get_value(), get_large_prime()),
|
pow(ciphertext.pad.value, secret_key.value, get_large_prime()),
|
||||||
pow(public_key.get_value(), nonce.get_value(), get_large_prime()),
|
pow(public_key.value, nonce.value, get_large_prime()),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
ciphertext.data.get_value(),
|
ciphertext.data.value,
|
||||||
pow(public_key.get_value(), nonce.get_value(), get_large_prime()),
|
pow(public_key.value, nonce.value, get_large_prime()),
|
||||||
)
|
)
|
||||||
|
|
||||||
plaintext = ciphertext.decrypt(keypair.secret_key)
|
plaintext = ciphertext.decrypt(keypair.secret_key)
|
||||||
|
|
|
@ -49,8 +49,8 @@ class TestEquality(BaseTestCase):
|
||||||
def test_p_not_equal_to_q(self, q: ElementModQ, q2: ElementModQ) -> None:
|
def test_p_not_equal_to_q(self, q: ElementModQ, q2: ElementModQ) -> None:
|
||||||
i = int(q)
|
i = int(q)
|
||||||
i2 = int(q2)
|
i2 = int(q2)
|
||||||
p = ElementModP(q.get_value())
|
p = ElementModP(q)
|
||||||
p2 = ElementModP(q2.get_value())
|
p2 = ElementModP(q2)
|
||||||
|
|
||||||
# same value should imply they're equal
|
# same value should imply they're equal
|
||||||
self.assertEqual(p, q)
|
self.assertEqual(p, q)
|
||||||
|
@ -143,8 +143,8 @@ class TestModularArithmetic(BaseTestCase):
|
||||||
@given(elements_mod_q())
|
@given(elements_mod_q())
|
||||||
def test_in_bounds_q(self, q: ElementModQ) -> None:
|
def test_in_bounds_q(self, q: ElementModQ) -> None:
|
||||||
self.assertTrue(q.is_in_bounds())
|
self.assertTrue(q.is_in_bounds())
|
||||||
too_big = q + get_small_prime()
|
too_big = q.value + get_small_prime()
|
||||||
too_small = q - get_small_prime()
|
too_small = q.value - get_small_prime()
|
||||||
self.assertFalse(ElementModQ(too_big, False).is_in_bounds())
|
self.assertFalse(ElementModQ(too_big, False).is_in_bounds())
|
||||||
self.assertFalse(ElementModQ(too_small, False).is_in_bounds())
|
self.assertFalse(ElementModQ(too_small, False).is_in_bounds())
|
||||||
self.assertEqual(None, int_to_q(too_big))
|
self.assertEqual(None, int_to_q(too_big))
|
||||||
|
@ -157,8 +157,8 @@ class TestModularArithmetic(BaseTestCase):
|
||||||
@given(elements_mod_p())
|
@given(elements_mod_p())
|
||||||
def test_in_bounds_p(self, p: ElementModP) -> None:
|
def test_in_bounds_p(self, p: ElementModP) -> None:
|
||||||
self.assertTrue(p.is_in_bounds())
|
self.assertTrue(p.is_in_bounds())
|
||||||
too_big = p + get_large_prime()
|
too_big = p.value + get_large_prime()
|
||||||
too_small = p - get_large_prime()
|
too_small = p.value - get_large_prime()
|
||||||
self.assertFalse(ElementModP(too_big, False).is_in_bounds())
|
self.assertFalse(ElementModP(too_big, False).is_in_bounds())
|
||||||
self.assertFalse(ElementModP(too_small, False).is_in_bounds())
|
self.assertFalse(ElementModP(too_small, False).is_in_bounds())
|
||||||
self.assertEqual(None, int_to_p(too_big))
|
self.assertEqual(None, int_to_p(too_big))
|
||||||
|
@ -173,10 +173,10 @@ class TestModularArithmetic(BaseTestCase):
|
||||||
self.assertTrue(q.is_in_bounds_no_zero())
|
self.assertTrue(q.is_in_bounds_no_zero())
|
||||||
self.assertFalse(ZERO_MOD_Q.is_in_bounds_no_zero())
|
self.assertFalse(ZERO_MOD_Q.is_in_bounds_no_zero())
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
ElementModQ(q + get_small_prime(), False).is_in_bounds_no_zero()
|
ElementModQ(q.value + get_small_prime(), False).is_in_bounds_no_zero()
|
||||||
)
|
)
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
ElementModQ(q - get_small_prime(), False).is_in_bounds_no_zero()
|
ElementModQ(q.value - get_small_prime(), False).is_in_bounds_no_zero()
|
||||||
)
|
)
|
||||||
|
|
||||||
@given(elements_mod_p_no_zero())
|
@given(elements_mod_p_no_zero())
|
||||||
|
@ -184,15 +184,15 @@ class TestModularArithmetic(BaseTestCase):
|
||||||
self.assertTrue(p.is_in_bounds_no_zero())
|
self.assertTrue(p.is_in_bounds_no_zero())
|
||||||
self.assertFalse(ZERO_MOD_P.is_in_bounds_no_zero())
|
self.assertFalse(ZERO_MOD_P.is_in_bounds_no_zero())
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
ElementModP(p + get_large_prime(), False).is_in_bounds_no_zero()
|
ElementModP(p.value + get_large_prime(), False).is_in_bounds_no_zero()
|
||||||
)
|
)
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
ElementModP(p - get_large_prime(), False).is_in_bounds_no_zero()
|
ElementModP(p.value - get_large_prime(), False).is_in_bounds_no_zero()
|
||||||
)
|
)
|
||||||
|
|
||||||
@given(elements_mod_q())
|
@given(elements_mod_q())
|
||||||
def test_large_values_rejected_by_int_to_q(self, q: ElementModQ) -> None:
|
def test_large_values_rejected_by_int_to_q(self, q: ElementModQ) -> None:
|
||||||
oversize = q + get_small_prime()
|
oversize = q.value + get_small_prime()
|
||||||
self.assertEqual(None, int_to_q(oversize))
|
self.assertEqual(None, int_to_q(oversize))
|
||||||
|
|
||||||
|
|
||||||
|
|
ΠΠ°Π³ΡΡΠ·ΠΊΠ°β¦
Π‘ΡΡΠ»ΠΊΠ° Π² Π½ΠΎΠ²ΠΎΠΉ Π·Π°Π΄Π°ΡΠ΅