* 🚧. Create BigInteger class

A representation of BigInteger in python to specifically designate what is happening. Left as a str due to the hex representation being the underpinning.

* 🚧. Refactor BaseElement to use BigInteger

- Refactor BaseElement to use BigInteger and fix tests to use value instead of get_value
- Update Tests

* 🚧 Remove now unused Serializable

* πŸ“¦ Add dacite to handle dictionary casting

* ✨ Add custom deserialization to avoid pydantic issues

* πŸ”₯ Remove add and subtract functions from big integer

* βœ… Better indicate padding purpose
This commit is contained in:
Keith Fung 2022-04-19 10:06:18 -04:00 ΠΊΠΎΠΌΠΌΠΈΡ‚ ΠΏΡ€ΠΎΠΈΠ·Π²Ρ‘Π» GitHub
Π ΠΎΠ΄ΠΈΡ‚Π΅Π»ΡŒ b7c8782e81
ΠšΠΎΠΌΠΌΠΈΡ‚ 392586832d
НС Π½Π°ΠΉΠ΄Π΅Π½ ΠΊΠ»ΡŽΡ‡, ΡΠΎΠΎΡ‚Π²Π΅Ρ‚ΡΡ‚Π²ΡƒΡŽΡ‰ΠΈΠΉ Π΄Π°Π½Π½ΠΎΠΉ подписи
Π˜Π΄Π΅Π½Ρ‚ΠΈΡ„ΠΈΠΊΠ°Ρ‚ΠΎΡ€ ΠΊΠ»ΡŽΡ‡Π° GPG: 4AEE18F83AFDEB23
8 ΠΈΠ·ΠΌΠ΅Π½Ρ‘Π½Π½Ρ‹Ρ… Ρ„Π°ΠΉΠ»ΠΎΠ²: 183 Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠΉ ΠΈ 169 ΡƒΠ΄Π°Π»Π΅Π½ΠΈΠΉ

17
poetry.lock сгСнСрированный
ΠŸΡ€ΠΎΡΠΌΠΎΡ‚Ρ€Π΅Ρ‚ΡŒ Ρ„Π°ΠΉΠ»

@ -189,6 +189,17 @@ sdist = ["setuptools_rust (>=0.11.4)"]
ssh = ["bcrypt (>=3.1.5)"] ssh = ["bcrypt (>=3.1.5)"]
test = ["pytest (>=6.2.0)", "pytest-cov", "pytest-subtests", "pytest-xdist", "pretend", "iso8601", "pytz", "hypothesis (>=1.11.4,!=3.79.2)"] test = ["pytest (>=6.2.0)", "pytest-cov", "pytest-subtests", "pytest-xdist", "pretend", "iso8601", "pytz", "hypothesis (>=1.11.4,!=3.79.2)"]
[[package]]
name = "dacite"
version = "1.6.0"
description = "Simple creation of data classes from dictionaries."
category = "main"
optional = false
python-versions = ">=3.6"
[package.extras]
dev = ["pytest (>=5)", "pytest-cov", "coveralls", "black", "mypy", "pylint"]
[[package]] [[package]]
name = "debugpy" name = "debugpy"
version = "1.6.0" version = "1.6.0"
@ -1610,7 +1621,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9.5" python-versions = "^3.9.5"
content-hash = "d5dd120c9f2fca28b2e16b793b8e9b4c397d8d4a61304b96696e976db06c4d47" content-hash = "9bd2777ce600469494106f880daad6ce7feb021f972be4dc84882cd6afbab16e"
[metadata.files] [metadata.files]
appnope = [ appnope = [
@ -1803,6 +1814,10 @@ cryptography = [
{file = "cryptography-36.0.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e167b6b710c7f7bc54e67ef593f8731e1f45aa35f8a8a7b72d6e42ec76afd4b3"}, {file = "cryptography-36.0.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e167b6b710c7f7bc54e67ef593f8731e1f45aa35f8a8a7b72d6e42ec76afd4b3"},
{file = "cryptography-36.0.2.tar.gz", hash = "sha256:70f8f4f7bb2ac9f340655cbac89d68c527af5bb4387522a8413e841e3e6628c9"}, {file = "cryptography-36.0.2.tar.gz", hash = "sha256:70f8f4f7bb2ac9f340655cbac89d68c527af5bb4387522a8413e841e3e6628c9"},
] ]
dacite = [
{file = "dacite-1.6.0-py3-none-any.whl", hash = "sha256:4331535f7aabb505c732fa4c3c094313fc0a1d5ea19907bf4726a7819a68b93f"},
{file = "dacite-1.6.0.tar.gz", hash = "sha256:d48125ed0a0352d3de9f493bf980038088f45f3f9d7498f090b50a847daaa6df"},
]
debugpy = [ debugpy = [
{file = "debugpy-1.6.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:eb1946efac0c0c3d411cea0b5ac772fbde744109fd9520fb0c5a51979faf05ad"}, {file = "debugpy-1.6.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:eb1946efac0c0c3d411cea0b5ac772fbde744109fd9520fb0c5a51979faf05ad"},
{file = "debugpy-1.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e3513399177dd37af4c1332df52da5da1d0c387e5927dc4c0709e26ee7302e8f"}, {file = "debugpy-1.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e3513399177dd37af4c1332df52da5da1d0c387e5927dc4c0709e26ee7302e8f"},

ΠŸΡ€ΠΎΡΠΌΠΎΡ‚Ρ€Π΅Ρ‚ΡŒ Ρ„Π°ΠΉΠ»

@ -39,6 +39,7 @@ gmpy2 = "^2.0.8"
psutil = ">=5.7.2" psutil = ">=5.7.2"
pydantic = "1.9.0" pydantic = "1.9.0"
click = "^8.1.0" click = "^8.1.0"
dacite = "^1.6.0"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
atomicwrites = "*" atomicwrites = "*"

ΠŸΡ€ΠΎΡΠΌΠΎΡ‚Ρ€Π΅Ρ‚ΡŒ Ρ„Π°ΠΉΠ»

@ -6,6 +6,7 @@ from electionguard import ballot_box
from electionguard import ballot_code from electionguard import ballot_code
from electionguard import ballot_compact from electionguard import ballot_compact
from electionguard import ballot_validator from electionguard import ballot_validator
from electionguard import big_integer
from electionguard import chaum_pedersen from electionguard import chaum_pedersen
from electionguard import constants from electionguard import constants
from electionguard import data_store from electionguard import data_store
@ -83,6 +84,9 @@ from electionguard.ballot_validator import (
contest_is_valid_for_style, contest_is_valid_for_style,
selection_is_valid_for_style, selection_is_valid_for_style,
) )
from electionguard.big_integer import (
BigInteger,
)
from electionguard.chaum_pedersen import ( from electionguard.chaum_pedersen import (
ChaumPedersenProof, ChaumPedersenProof,
ConstantChaumPedersenProof, ConstantChaumPedersenProof,
@ -231,10 +235,8 @@ from electionguard.group import (
div_p, div_p,
div_q, div_q,
g_pow_p, g_pow_p,
hex_to_int,
hex_to_p, hex_to_p,
hex_to_q, hex_to_q,
int_to_hex,
int_to_p, int_to_p,
int_to_q, int_to_q,
mult_inv_p, mult_inv_p,
@ -336,8 +338,6 @@ from electionguard.schnorr import (
make_schnorr_proof, make_schnorr_proof,
) )
from electionguard.serialize import ( from electionguard.serialize import (
Private,
Serializable,
construct_path, construct_path,
from_file, from_file,
from_file_wrapper, from_file_wrapper,
@ -390,6 +390,7 @@ __all__ = [
"BallotId", "BallotId",
"BallotStyle", "BallotStyle",
"BaseElement", "BaseElement",
"BigInteger",
"Candidate", "Candidate",
"CandidateContestDescription", "CandidateContestDescription",
"CeremonyDetails", "CeremonyDetails",
@ -483,7 +484,6 @@ __all__ = [
"PlaintextTallyContest", "PlaintextTallyContest",
"PlaintextTallySelection", "PlaintextTallySelection",
"PrimeOption", "PrimeOption",
"Private",
"PrivateGuardianRecord", "PrivateGuardianRecord",
"Proof", "Proof",
"ProofOrRecovery", "ProofOrRecovery",
@ -501,7 +501,6 @@ __all__ = [
"SecretCoefficient", "SecretCoefficient",
"SelectionDescription", "SelectionDescription",
"SelectionId", "SelectionId",
"Serializable",
"Singleton", "Singleton",
"SubmittedBallot", "SubmittedBallot",
"VerifierId", "VerifierId",
@ -518,6 +517,7 @@ __all__ = [
"ballot_is_valid_for_election", "ballot_is_valid_for_election",
"ballot_is_valid_for_style", "ballot_is_valid_for_style",
"ballot_validator", "ballot_validator",
"big_integer",
"chaum_pedersen", "chaum_pedersen",
"combine_election_public_keys", "combine_election_public_keys",
"compensate_decrypt", "compensate_decrypt",
@ -620,11 +620,9 @@ __all__ = [
"hash", "hash",
"hash_elems", "hash_elems",
"hashed_elgamal_encrypt", "hashed_elgamal_encrypt",
"hex_to_int",
"hex_to_p", "hex_to_p",
"hex_to_q", "hex_to_q",
"hmac", "hmac",
"int_to_hex",
"int_to_p", "int_to_p",
"int_to_q", "int_to_q",
"key_ceremony", "key_ceremony",

ΠŸΡ€ΠΎΡΠΌΠΎΡ‚Ρ€Π΅Ρ‚ΡŒ Ρ„Π°ΠΉΠ»

@ -0,0 +1,103 @@
from typing import Any, Tuple, Union
from base64 import b16decode
# pylint: disable=no-name-in-module
from gmpy2 import mpz
def _hex_to_int(input: str) -> int:
"""Given a hex string representing bytes, returns an int."""
return int(input, 16)
def _int_to_hex(input: int) -> str:
"""Given an int, returns a hex string representing bytes."""
def pad_hex(hex: str) -> str:
"""Pad hex to ensure 2 digit hexadecimal format maintained."""
return "0" + hex if len(hex) % 2 else hex
hex = format(input, "02X")
return pad_hex(hex)
_zero = mpz(0)
def _convert_to_element(data: Union[int, str]) -> Tuple[str, int]:
"""Convert element to consistent types"""
if isinstance(data, str):
hex = data
integer = _hex_to_int(data)
else:
hex = _int_to_hex(data)
integer = data
return (hex, integer)
class BigInteger(str):
"""A specialized representation of a big integer in python"""
_value: mpz = _zero
def __new__(cls, data: Union[int, str]): # type: ignore
(hex, integer) = _convert_to_element(data)
big_int = super(BigInteger, cls).__new__(cls, hex)
big_int._value = mpz(integer)
return big_int
@property
def value(self) -> mpz:
"""Get internal value for math calculations"""
return self._value
def __int__(self) -> int:
"""Overload int conversion."""
return int(self.value)
def __eq__(self, other: Any) -> bool:
"""Overload == (equal to) operator."""
return (
isinstance(other, BigInteger) and int(self.value) == int(other.value)
) or (isinstance(other, int) and int(self.value) == other)
def __ne__(self, other: Any) -> bool:
"""Overload != (not equal to) operator."""
return not self == other
def __lt__(self, other: Any) -> bool:
"""Overload <= (less than) operator."""
return (
isinstance(other, BigInteger) and int(self.value) < int(other.value)
) or (isinstance(other, int) and int(self.value) < other)
def __le__(self, other: Any) -> bool:
"""Overload <= (less than or equal) operator."""
return self.__lt__(other) or self.__eq__(other)
def __gt__(self, other: Any) -> bool:
"""Overload > (greater than) operator."""
return (
isinstance(other, BigInteger) and int(self.value) > int(other.value)
) or (isinstance(other, int) and int(self.value) > other)
def __ge__(self, other: Any) -> bool:
"""Overload >= (greater than or equal) operator."""
return self.__gt__(other) or self.__eq__(other)
def __hash__(self) -> int:
"""Overload the hashing function."""
return hash(self.value)
def to_hex(self) -> str:
"""
Convert from the element to the hex representation of bytes.
"""
return str(self)
def to_hex_bytes(self) -> bytes:
"""
Convert from the element to the representation of bytes by first going through hex.
"""
return b16decode(self)

ΠŸΡ€ΠΎΡΠΌΠΎΡ‚Ρ€Π΅Ρ‚ΡŒ Ρ„Π°ΠΉΠ»

@ -5,154 +5,41 @@ in the sense that performance may be less than hand-optimized C code, and no gua
made about timing or other side-channels. made about timing or other side-channels.
""" """
from abc import ABC, abstractmethod from abc import ABC
from typing import Any, Final, Optional, Tuple, Union from typing import Final, Optional, Union
from base64 import b16decode
from secrets import randbelow from secrets import randbelow
from sys import maxsize from sys import maxsize
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from gmpy2 import mpz, powmod, invert from gmpy2 import mpz, powmod, invert
from .serialize import Serializable, Private from .big_integer import BigInteger
from .constants import get_large_prime, get_small_prime, get_generator from .constants import get_large_prime, get_small_prime, get_generator
def hex_to_int(input: str) -> int: class BaseElement(BigInteger, ABC):
"""Given a hex string representing bytes, returns an int."""
return int(input, 16)
def int_to_hex(input: int) -> str:
"""Given an int, returns a hex string representing bytes."""
hex = format(input, "02X")
if len(hex) % 2:
hex = "0" + hex
return hex
_zero = mpz(0)
def _mpz_zero() -> mpz:
return _zero
def _convert_to_element(data: Union[int, str]) -> Tuple[str, int]:
"""Convert element to consistent types"""
if isinstance(data, str):
hex = data
integer = hex_to_int(data)
else:
hex = int_to_hex(data)
integer = data
return (hex, integer)
class BaseElement(Serializable, ABC):
"""An element limited by mod T within [0, T) where T is determined by an upper_bound function.""" """An element limited by mod T within [0, T) where T is determined by an upper_bound function."""
data: str def __new__(cls, data: Union[int, str], check_within_bounds: bool = True): # type: ignore
_value: mpz = Private(default_factory=_mpz_zero)
"""Internal math representation of element"""
def __init__(self, data: Union[int, str], check_within_bounds: bool = True) -> None:
"""Instantiate element mod T where element is an int or its hex representation.""" """Instantiate element mod T where element is an int or its hex representation."""
(hex, integer) = _convert_to_element(data) element = super(BaseElement, cls).__new__(cls, data)
super().__init__(data=hex)
self._value = mpz(integer)
if check_within_bounds: if check_within_bounds:
if not self.is_in_bounds(): if not 0 <= element.value < cls.get_upper_bound():
raise OverflowError raise OverflowError
return element
def __str__(self) -> str: @classmethod
"""Overload string representation""" def get_upper_bound(cls) -> int:
return self.data
def __repr__(self) -> str:
"""Overload object representation"""
return self.data
def __int__(self) -> int:
"""Overload int conversion."""
return int(self.get_value())
def __eq__(self, other: Any) -> bool:
"""Overload == (equal to) operator."""
return (
isinstance(other, BaseElement)
and int(self.get_value()) == int(other.get_value())
) or (isinstance(other, int) and int(self.get_value()) == other)
def __ne__(self, other: Any) -> bool:
"""Overload != (not equal to) operator."""
return not self == other
def __lt__(self, other: Any) -> bool:
"""Overload <= (less than) operator."""
return (
isinstance(other, BaseElement)
and int(self.get_value()) < int(other.get_value())
) or (isinstance(other, int) and int(self.get_value()) < other)
def __le__(self, other: Any) -> bool:
"""Overload <= (less than or equal) operator."""
return self.__lt__(other) or self.__eq__(other)
def __gt__(self, other: Any) -> bool:
"""Overload > (greater than) operator."""
return (
isinstance(other, BaseElement)
and int(self.get_value()) > int(other.get_value())
) or (isinstance(other, int) and int(self.get_value()) > other)
def __ge__(self, other: Any) -> bool:
"""Overload >= (greater than or equal) operator."""
return self.__gt__(other) or self.__eq__(other)
def __add__(self, other: Any) -> Any:
"""Overload addition operator."""
return self.get_value() + other
def __sub__(self, other: Any) -> Any:
"""Overload subtraction operator."""
return self.get_value() - other
def __hash__(self) -> int:
"""Overload the hashing function."""
return hash(self.get_value())
@abstractmethod
def get_upper_bound(self) -> int:
"""Get the upper bound for the element.""" """Get the upper bound for the element."""
return maxsize return maxsize
def get_value(self) -> mpz:
"""Get internal value for math calculations"""
return self._value
def to_hex(self) -> str:
"""
Convert from the element to the hex representation of bytes.
"""
return self.data
def to_hex_bytes(self) -> bytes:
"""
Convert from the element to the representation of bytes by first going through hex.
"""
return b16decode(self.data)
def is_in_bounds(self) -> bool: def is_in_bounds(self) -> bool:
""" """
Validate that the element is actually within the bounds of [0,Q). Validate that the element is actually within the bounds of [0,Q).
Returns true if all is good, false if something's wrong. Returns true if all is good, false if something's wrong.
""" """
return 0 <= self.get_value() < self.get_upper_bound() return 0 <= self.value < self.get_upper_bound()
def is_in_bounds_no_zero(self) -> bool: def is_in_bounds_no_zero(self) -> bool:
""" """
@ -160,13 +47,14 @@ class BaseElement(Serializable, ABC):
Returns true if all is good, false if something's wrong. Returns true if all is good, false if something's wrong.
""" """
return 1 <= self.get_value() < self.get_upper_bound() return 1 <= self.value < self.get_upper_bound()
class ElementModQ(BaseElement): class ElementModQ(BaseElement):
"""An element of the smaller `mod q` space, i.e., in [0, Q), where Q is a 256-bit prime.""" """An element of the smaller `mod q` space, i.e., in [0, Q), where Q is a 256-bit prime."""
def get_upper_bound(self) -> int: @classmethod
def get_upper_bound(cls) -> int:
"""Get the upper bound for the element.""" """Get the upper bound for the element."""
return get_small_prime() return get_small_prime()
@ -174,7 +62,8 @@ class ElementModQ(BaseElement):
class ElementModP(BaseElement): class ElementModP(BaseElement):
"""An element of the larger `mod p` space, i.e., in [0, P), where P is a 4096-bit prime.""" """An element of the larger `mod p` space, i.e., in [0, P), where P is a 4096-bit prime."""
def get_upper_bound(self) -> int: @classmethod
def get_upper_bound(cls) -> int:
"""Get the upper bound for the element.""" """Get the upper bound for the element."""
return get_large_prime() return get_large_prime()
@ -202,7 +91,7 @@ ElementModPorInt = Union[ElementModP, int]
def _get_mpz(input: Union[BaseElement, int]) -> mpz: def _get_mpz(input: Union[BaseElement, int]) -> mpz:
"""Get BaseElement or integer as mpz.""" """Get BaseElement or integer as mpz."""
if isinstance(input, BaseElement): if isinstance(input, BaseElement):
return input.get_value() return input.value
return mpz(input) return mpz(input)

ΠŸΡ€ΠΎΡΠΌΠΎΡ‚Ρ€Π΅Ρ‚ΡŒ Ρ„Π°ΠΉΠ»

@ -1,24 +1,18 @@
from datetime import datetime
from io import TextIOWrapper from io import TextIOWrapper
import json import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, List, Type, TypeVar, Union from typing import Any, List, Type, TypeVar, Union
from pydantic import BaseModel, PrivateAttr from dacite import Config, from_dict
from pydantic.json import pydantic_encoder from pydantic.json import pydantic_encoder
from pydantic.tools import parse_raw_as, parse_obj_as, schema_json_of from pydantic.tools import parse_raw_as, schema_json_of
Private = PrivateAttr
class Serializable(BaseModel):
"""Serializable data object intended for exporting and importing"""
class Config:
"""Model config to handle private properties"""
underscore_attrs_are_private = True
from .ballot_box import BallotBoxState
from .manifest import ElectionType, ReportingUnitType, VoteVariationType
from .group import ElementModP, ElementModQ
from .proof import ProofUsage
_T = TypeVar("_T") _T = TypeVar("_T")
@ -26,6 +20,20 @@ _indent = 2
_encoding = "utf-8" _encoding = "utf-8"
_file_extension = "json" _file_extension = "json"
_config = Config(
cast=[
datetime,
ElementModP,
ElementModQ,
BallotBoxState,
ElectionType,
ReportingUnitType,
VoteVariationType,
ProofUsage,
],
type_hooks={datetime: datetime.fromisoformat},
)
def construct_path( def construct_path(
target_file_name: str, target_file_name: str,
@ -54,7 +62,7 @@ def from_file_wrapper(type_: Type[_T], file: TextIOWrapper) -> _T:
"""Deserialize json file as type.""" """Deserialize json file as type."""
data = json.load(file) data = json.load(file)
return parse_obj_as(type_, data) return from_dict(type_, data, _config)
def from_file(type_: Type[_T], path: Union[str, Path]) -> _T: def from_file(type_: Type[_T], path: Union[str, Path]) -> _T:
@ -62,7 +70,7 @@ def from_file(type_: Type[_T], path: Union[str, Path]) -> _T:
with open(path, "r", encoding=_encoding) as json_file: with open(path, "r", encoding=_encoding) as json_file:
data = json.load(json_file) data = json.load(json_file)
return parse_obj_as(type_, data) return from_dict(type_, data, _config)
def from_list_in_file(type_: Type[_T], path: Union[str, Path]) -> List[_T]: def from_list_in_file(type_: Type[_T], path: Union[str, Path]) -> List[_T]:
@ -72,7 +80,7 @@ def from_list_in_file(type_: Type[_T], path: Union[str, Path]) -> List[_T]:
data = json.load(json_file) data = json.load(json_file)
ls: List[_T] = [] ls: List[_T] = []
for item in data: for item in data:
ls.append(parse_obj_as(type_, item)) ls.append(from_dict(type_, item, _config))
return ls return ls
@ -82,7 +90,7 @@ def from_list_in_file_wrapper(type_: Type[_T], file: TextIOWrapper) -> List[_T]:
data = json.load(file) data = json.load(file)
ls: List[_T] = [] ls: List[_T] = []
for item in data: for item in data:
ls.append(parse_obj_as(type_, item)) ls.append(from_dict(type_, item, _config))
return ls return ls

ΠŸΡ€ΠΎΡΠΌΠΎΡ‚Ρ€Π΅Ρ‚ΡŒ Ρ„Π°ΠΉΠ»

@ -54,12 +54,12 @@ class TestElGamal(BaseTestCase):
ciphertext = get_optional(elgamal_encrypt(0, nonce, keypair.public_key)) ciphertext = get_optional(elgamal_encrypt(0, nonce, keypair.public_key))
self.assertEqual(get_generator(), ciphertext.pad) self.assertEqual(get_generator(), ciphertext.pad)
self.assertEqual( self.assertEqual(
pow(ciphertext.pad.get_value(), secret_key.get_value(), get_large_prime()), pow(ciphertext.pad.value, secret_key.value, get_large_prime()),
pow(public_key.get_value(), nonce.get_value(), get_large_prime()), pow(public_key.value, nonce.value, get_large_prime()),
) )
self.assertEqual( self.assertEqual(
ciphertext.data.get_value(), ciphertext.data.value,
pow(public_key.get_value(), nonce.get_value(), get_large_prime()), pow(public_key.value, nonce.value, get_large_prime()),
) )
plaintext = ciphertext.decrypt(keypair.secret_key) plaintext = ciphertext.decrypt(keypair.secret_key)

ΠŸΡ€ΠΎΡΠΌΠΎΡ‚Ρ€Π΅Ρ‚ΡŒ Ρ„Π°ΠΉΠ»

@ -49,8 +49,8 @@ class TestEquality(BaseTestCase):
def test_p_not_equal_to_q(self, q: ElementModQ, q2: ElementModQ) -> None: def test_p_not_equal_to_q(self, q: ElementModQ, q2: ElementModQ) -> None:
i = int(q) i = int(q)
i2 = int(q2) i2 = int(q2)
p = ElementModP(q.get_value()) p = ElementModP(q)
p2 = ElementModP(q2.get_value()) p2 = ElementModP(q2)
# same value should imply they're equal # same value should imply they're equal
self.assertEqual(p, q) self.assertEqual(p, q)
@ -143,8 +143,8 @@ class TestModularArithmetic(BaseTestCase):
@given(elements_mod_q()) @given(elements_mod_q())
def test_in_bounds_q(self, q: ElementModQ) -> None: def test_in_bounds_q(self, q: ElementModQ) -> None:
self.assertTrue(q.is_in_bounds()) self.assertTrue(q.is_in_bounds())
too_big = q + get_small_prime() too_big = q.value + get_small_prime()
too_small = q - get_small_prime() too_small = q.value - get_small_prime()
self.assertFalse(ElementModQ(too_big, False).is_in_bounds()) self.assertFalse(ElementModQ(too_big, False).is_in_bounds())
self.assertFalse(ElementModQ(too_small, False).is_in_bounds()) self.assertFalse(ElementModQ(too_small, False).is_in_bounds())
self.assertEqual(None, int_to_q(too_big)) self.assertEqual(None, int_to_q(too_big))
@ -157,8 +157,8 @@ class TestModularArithmetic(BaseTestCase):
@given(elements_mod_p()) @given(elements_mod_p())
def test_in_bounds_p(self, p: ElementModP) -> None: def test_in_bounds_p(self, p: ElementModP) -> None:
self.assertTrue(p.is_in_bounds()) self.assertTrue(p.is_in_bounds())
too_big = p + get_large_prime() too_big = p.value + get_large_prime()
too_small = p - get_large_prime() too_small = p.value - get_large_prime()
self.assertFalse(ElementModP(too_big, False).is_in_bounds()) self.assertFalse(ElementModP(too_big, False).is_in_bounds())
self.assertFalse(ElementModP(too_small, False).is_in_bounds()) self.assertFalse(ElementModP(too_small, False).is_in_bounds())
self.assertEqual(None, int_to_p(too_big)) self.assertEqual(None, int_to_p(too_big))
@ -173,10 +173,10 @@ class TestModularArithmetic(BaseTestCase):
self.assertTrue(q.is_in_bounds_no_zero()) self.assertTrue(q.is_in_bounds_no_zero())
self.assertFalse(ZERO_MOD_Q.is_in_bounds_no_zero()) self.assertFalse(ZERO_MOD_Q.is_in_bounds_no_zero())
self.assertFalse( self.assertFalse(
ElementModQ(q + get_small_prime(), False).is_in_bounds_no_zero() ElementModQ(q.value + get_small_prime(), False).is_in_bounds_no_zero()
) )
self.assertFalse( self.assertFalse(
ElementModQ(q - get_small_prime(), False).is_in_bounds_no_zero() ElementModQ(q.value - get_small_prime(), False).is_in_bounds_no_zero()
) )
@given(elements_mod_p_no_zero()) @given(elements_mod_p_no_zero())
@ -184,15 +184,15 @@ class TestModularArithmetic(BaseTestCase):
self.assertTrue(p.is_in_bounds_no_zero()) self.assertTrue(p.is_in_bounds_no_zero())
self.assertFalse(ZERO_MOD_P.is_in_bounds_no_zero()) self.assertFalse(ZERO_MOD_P.is_in_bounds_no_zero())
self.assertFalse( self.assertFalse(
ElementModP(p + get_large_prime(), False).is_in_bounds_no_zero() ElementModP(p.value + get_large_prime(), False).is_in_bounds_no_zero()
) )
self.assertFalse( self.assertFalse(
ElementModP(p - get_large_prime(), False).is_in_bounds_no_zero() ElementModP(p.value - get_large_prime(), False).is_in_bounds_no_zero()
) )
@given(elements_mod_q()) @given(elements_mod_q())
def test_large_values_rejected_by_int_to_q(self, q: ElementModQ) -> None: def test_large_values_rejected_by_int_to_q(self, q: ElementModQ) -> None:
oversize = q + get_small_prime() oversize = q.value + get_small_prime()
self.assertEqual(None, int_to_q(oversize)) self.assertEqual(None, int_to_q(oversize))