for MPP-3802: refactor assert statements into if/raise checks

This commit is contained in:
groovecoder 2024-05-07 13:29:02 -05:00
Родитель f27a5afd46
Коммит c6a71b5dd7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4825AB58E974B712
23 изменённых файлов: 182 добавлений и 77 удалений

Просмотреть файл

@ -58,13 +58,17 @@ class RelayAPIException(APIException):
self, detail: _APIExceptionInput = None, code: str | None = None
) -> None:
"""Check that derived classes have set the required data."""
assert isinstance(self.default_code, str)
assert isinstance(self.status_code, int)
if not isinstance(self.default_code, str):
raise TypeError("default_code must be type str")
if not isinstance(self.status_code, int):
raise TypeError("self.status_code must be type int")
if hasattr(self, "default_detail_template"):
context = self.error_context()
assert context
if not context:
raise ValueError("error_context is required")
self.default_detail = self.default_detail_template.format(**context)
assert isinstance(self.default_detail, str)
if not isinstance(self.default_detail, str):
raise TypeError("self.default_detail must be type str")
super().__init__(detail, code)
def error_context(self) -> ErrorContextType:
@ -76,7 +80,8 @@ class RelayAPIException(APIException):
# For RelayAPIException classes, this is the default_code and is a string
error_code = self.get_codes()
assert isinstance(error_code, str)
if not isinstance(error_code, str):
raise TypeError("error_code must be type str")
# Build the Fluent error ID
ftl_id_sub = "api-error-"

Просмотреть файл

@ -12,7 +12,8 @@ class PremiumValidatorsMixin:
def validate_block_list_emails(self, value):
if not value:
return value
assert hasattr(self, "context")
if not hasattr(self, "context"):
raise AttributeError("self must have attribute context")
user = self.context["request"].user
prefetch_related_objects([user], "socialaccount_set", "profile")
if not user.profile.has_premium:

Просмотреть файл

@ -12,8 +12,10 @@ class SaveToRequestUser:
"""ModelViewSet mixin for creating object for the authenticated user."""
def perform_create(self, serializer):
assert hasattr(self, "request")
assert hasattr(self.request, "user")
if not hasattr(self, "request"):
raise AttributeError("self must have request attribute.")
if not hasattr(self.request, "user"):
raise AttributeError("self.request must have user attribute.")
serializer.save(user=self.request.user)

Просмотреть файл

@ -74,7 +74,8 @@ _Address = TypeVar("_Address", RelayAddress, DomainAddress)
class AddressViewSet(Generic[_Address], SaveToRequestUser, ModelViewSet):
def perform_create(self, serializer: BaseSerializer[_Address]) -> None:
super().perform_create(serializer)
assert serializer.instance
if not serializer.instance:
raise ValueError("serializer.instance must be truthy value.")
glean_logger().log_email_mask_created(
request=self.request,
mask=serializer.instance,
@ -82,7 +83,8 @@ class AddressViewSet(Generic[_Address], SaveToRequestUser, ModelViewSet):
)
def perform_update(self, serializer: BaseSerializer[_Address]) -> None:
assert serializer.instance is not None
if not serializer.instance:
raise ValueError("serializer.instance must be truthy value.")
old_description = serializer.instance.description
super().perform_update(serializer)
new_description = serializer.instance.description
@ -200,10 +202,13 @@ def first_forwarded_email(request):
return Response(f"{mask} does not exist for user.", status=HTTP_404_NOT_FOUND)
profile = user.profile
app_config = apps.get_app_config("emails")
assert isinstance(app_config, EmailsConfig)
if not isinstance(app_config, EmailsConfig):
raise TypeError("app_config must be type EmailsConfig")
ses_client = app_config.ses_client
assert ses_client
assert settings.RELAY_FROM_ADDRESS
if not ses_client:
raise ValueError("ses_client must be truthy value.")
if not settings.RELAY_FROM_ADDRESS:
raise ValueError("settings.RELAY_FROM_ADDRESS must have a value.")
with django_ftl.override(profile.language):
translated_subject = ftl_bundle.format("forwarded-email-hero-header")
first_forwarded_email_html = render_to_string(

Просмотреть файл

@ -1316,9 +1316,15 @@ class RelaySMSException(Exception):
def __init__(self, critical=False, *args, **kwargs):
self.critical = critical
assert (
if not (
self.default_detail is not None and self.default_detail_template is None
) or (self.default_detail is None and self.default_detail_template is not None)
) and not (
self.default_detail is None and self.default_detail_template is not None
):
raise ValueError(
"One and only one of default_detail or "
"default_detail_template must be None."
)
super().__init__(*args, **kwargs)
@property
@ -1326,7 +1332,8 @@ class RelaySMSException(Exception):
if self.default_detail:
return self.default_detail
else:
assert self.default_detail_template is not None
if self.default_detail_template is None:
raise ValueError("self.default_detail_template must not be None.")
return self.default_detail_template.format(**self.error_context())
def get_codes(self):
@ -1438,14 +1445,16 @@ def _prepare_sms_reply(
if match and not match.contacts and match.match_type == "full":
raise FullNumberMatchesNoSenders(full_number=match.detected)
if match and len(match.contacts) > 1:
assert match.match_type == "short"
if not match.match_type == "short":
raise ValueError("match.match_type must be 'short'.")
raise MultipleNumberMatches(short_prefix=match.detected)
# Determine the destination number
destination_number: str | None = None
if match:
# Use the sender matched by the prefix
assert len(match.contacts) == 1
if not len(match.contacts) == 1:
raise ValueError("len(match.contacts) must be 1.")
destination_number = match.contacts[0].inbound_number
else:
# No prefix, default to last sender if any

Просмотреть файл

@ -65,7 +65,8 @@ class EmailsConfig(AppConfig):
def emails_config() -> EmailsConfig:
emails_config = apps.get_app_config("emails")
assert isinstance(emails_config, EmailsConfig)
if not isinstance(emails_config, EmailsConfig):
raise TypeError("emails_config must be type EmailsConfig")
return emails_config

Просмотреть файл

@ -37,7 +37,8 @@ class CommandFromDjangoSettings(BaseCommand):
* Add the Django settings and their values to the command help
* Override the verbosity from an environment variable
"""
assert self.settings_to_locals
if not self.settings_to_locals:
raise ValueError("self.settings_to_locals must be truthy value.")
epilog_lines = [
(
"Parameters are read from Django settings and the related environment"
@ -65,15 +66,18 @@ class CommandFromDjangoSettings(BaseCommand):
epilog = "\n".join(epilog_lines)
parser = super().create_parser(prog_name, subcommand, epilog=epilog, **kwargs)
assert parser.formatter_class == DjangoHelpFormatter
if parser.formatter_class != DjangoHelpFormatter:
raise TypeError("parser.formatter_class must be DjangoHelpFormatter")
parser.formatter_class = RawDescriptionDjangoHelpFormatter
assert verbosity_override is not None
if verbosity_override is None:
raise ValueError("verbosity_override must not be None.")
parser.set_defaults(verbosity=verbosity_override)
return parser
def init_from_settings(self, verbosity):
"""Initialize local variables from settings"""
assert self.settings_to_locals
if not self.settings_to_locals:
raise ValueError("self.settings_to_locals must be truthy value.")
for setting_key, local_name, help_str, validator in self.settings_to_locals:
value = getattr(settings, setting_key)
if not validator(value):

Просмотреть файл

@ -152,8 +152,10 @@ class Command(CommandFromDjangoSettings):
def create_client(self):
"""Create the SQS client."""
assert self.aws_region
assert self.sqs_url
if not self.aws_region:
raise ValueError("self.aws_region must be truthy value.")
if not self.sqs_url:
raise ValueError("self.sqs_url must be truthy value.")
sqs_client = boto3.resource("sqs", region_name=self.aws_region)
return sqs_client.Queue(self.sqs_url)

Просмотреть файл

@ -40,10 +40,13 @@ def _ses_message_props(data: str) -> ContentTypeDef:
def send_welcome_email(profile: Profile) -> None:
user = profile.user
app_config = apps.get_app_config("emails")
assert isinstance(app_config, EmailsConfig)
if not isinstance(app_config, EmailsConfig):
raise TypeError("app_config must be type EmailsConfig")
ses_client = app_config.ses_client
assert ses_client
assert settings.RELAY_FROM_ADDRESS
if not ses_client:
raise ValueError("ses_client must be truthy value")
if not settings.RELAY_FROM_ADDRESS:
raise ValueError("settings.RELAY_FROM_ADDRESS must be truthy value.")
with django_ftl.override(profile.language):
translated_subject = ftl_bundle.format("first-time-user-email-welcome")
try:

Просмотреть файл

@ -266,13 +266,16 @@ class Profile(models.Model):
return datetime.now(UTC)
if bounce_type == "soft":
assert self.last_soft_bounce
if not self.last_soft_bounce:
raise ValueError("self.last_soft_bounce must be truthy value.")
return self.last_soft_bounce + timedelta(
days=settings.SOFT_BOUNCE_ALLOWED_DAYS
)
assert bounce_type == "hard"
assert self.last_hard_bounce
if bounce_type != "hard":
raise ValueError("bounce_type must be either 'soft' or 'hard'")
if not self.last_hard_bounce:
raise ValueError("self.last_hard_bounce must be truthy value.")
return self.last_hard_bounce + timedelta(days=settings.HARD_BOUNCE_ALLOWED_DAYS)
@property
@ -293,7 +296,8 @@ class Profile(models.Model):
# Note: we are NOT using .filter() here because it invalidates
# any profile instances that were queried with prefetch_related, which
# we use in at least the profile view to minimize queries
assert hasattr(self.user, "socialaccount_set")
if not hasattr(self.user, "socialaccount_set"):
raise AttributeError("self.user must have socialaccount_set attribute")
for sa in self.user.socialaccount_set.all():
if sa.provider == "fxa":
return sa
@ -310,7 +314,8 @@ class Profile(models.Model):
@property
def custom_domain(self) -> str:
assert self.subdomain
if not self.subdomain:
raise ValueError("self.subdomain must be truthy value.")
return f"@{self.subdomain}.{settings.MOZMAIL_DOMAIN}"
@property
@ -844,7 +849,8 @@ class RelayAddress(models.Model):
@property
def metrics_id(self) -> str:
assert self.id
if not self.id:
raise ValueError("self.id must be truthy value.")
# Prefix with 'R' for RelayAddress, since there may be a DomainAddress with the
# same row ID
return f"R{self.id}"
@ -1002,7 +1008,8 @@ class DomainAddress(models.Model):
# DomainAlias will be a feature
address = address_default()
# Only check for bad words if randomly generated
assert isinstance(address, str)
if not isinstance(address, str):
raise TypeError("address must be type str")
first_emailed_at = datetime.now(UTC) if made_via_email else None
domain_address = DomainAddress.objects.create(
@ -1052,7 +1059,8 @@ class DomainAddress(models.Model):
@property
def metrics_id(self) -> str:
assert self.id
if not self.id:
raise ValueError("self.id must be truthy value.")
# Prefix with 'D' for DomainAddress, since there may be a RelayAddress with the
# same row ID
return f"D{self.id}"
@ -1083,7 +1091,8 @@ class Reply(models.Model):
def increment_num_replied(self):
address = self.relay_address or self.domain_address
assert address
if not address:
raise ValueError("address must be truthy value")
address.num_replied += 1
address.last_used_at = datetime.now(UTC)
address.save(update_fields=["num_replied", "last_used_at"])

Просмотреть файл

@ -198,7 +198,8 @@ def _get_hero_img_src(lang_code):
if major_lang in avail_l10n_image_codes:
img_locale = major_lang
assert settings.SITE_ORIGIN
if not settings.SITE_ORIGIN:
raise ValueError("settings.SITE_ORIGIN must have a value")
return (
settings.SITE_ORIGIN
+ f"/static/images/email-images/first-time-user/hero-image-{img_locale}.png"
@ -229,8 +230,11 @@ def ses_send_raw_email(
destination_address: str,
message: EmailMessage,
) -> SendRawEmailResponseTypeDef:
assert (client := ses_client()) is not None
assert settings.AWS_SES_CONFIGSET
client = ses_client()
if client is None:
raise ValueError("client must have a value")
if not settings.AWS_SES_CONFIGSET:
raise ValueError("settings.AWS_SES_CONFIGSET must have a value")
data = message.as_string()
try:
@ -393,7 +397,9 @@ def _get_bucket_and_key_from_s3_json(message_json):
@time_if_enabled("s3_get_message_content")
def get_message_content_from_s3(bucket, object_key):
if bucket and object_key:
assert (client := s3_client()) is not None
client = s3_client()
if client is None:
raise ValueError("client must not be None")
streamed_s3_object = client.get_object(Bucket=bucket, Key=object_key).get(
"Body"
)
@ -405,7 +411,9 @@ def remove_message_from_s3(bucket, object_key):
if bucket is None or object_key is None:
return False
try:
assert (client := s3_client()) is not None
client = s3_client()
if client is None:
raise ValueError("client must not be None")
response = client.delete_object(Bucket=bucket, Key=object_key)
return response.get("DeleteMarker")
except ClientError as e:

Просмотреть файл

@ -202,13 +202,15 @@ def wrapped_email_test(request: HttpRequest) -> HttpResponse:
if "language" in request.GET:
language = request.GET["language"]
else:
assert user_profile is not None
if user_profile is None:
raise ValueError("user_profile must not be None")
language = user_profile.language
if "has_premium" in request.GET:
has_premium = strtobool(request.GET["has_premium"])
else:
assert user_profile is not None
if user_profile is None:
raise ValueError("user_profile must not be None")
has_premium = user_profile.has_premium
if "num_level_one_email_trackers_removed" in request.GET:
@ -335,7 +337,8 @@ def _store_reply_record(
if isinstance(address, DomainAddress):
reply_create_args["domain_address"] = address
else:
assert isinstance(address, RelayAddress)
if not isinstance(address, RelayAddress):
raise TypeError("address must be type RelayAddress")
reply_create_args["relay_address"] = address
Reply.objects.create(**reply_create_args)
return mail
@ -492,7 +495,10 @@ def _sns_message(message_json: AWS_SNSMessageJSON) -> HttpResponse:
return _handle_bounce(message_json)
if notification_type == "Complaint" or event_type == "Complaint":
return _handle_complaint(message_json)
assert notification_type == "Received" and event_type is None
if notification_type != "Received":
raise ValueError('notification_type must be "Received"')
if event_type is not None:
raise ValueError("event_type must be None")
return _handle_received(message_json)
@ -928,7 +934,8 @@ def _convert_to_forwarded_email(
# python/typeshed issue 2418
# The Python 3.2 default was Message, 3.6 uses policy.message_factory, and
# policy.default.message_factory is EmailMessage
assert isinstance(email, EmailMessage)
if not isinstance(email, EmailMessage):
raise TypeError("email must be type EmailMessage")
# Replace headers in the original email
header_issues = _replace_headers(email, headers)
@ -939,7 +946,8 @@ def _convert_to_forwarded_email(
has_text = False
if text_body:
has_text = True
assert isinstance(text_body, EmailMessage)
if not isinstance(text_body, EmailMessage):
raise TypeError("text_body must be type EmailMessage")
text_content = text_body.get_content()
new_text_content = _convert_text_content(text_content, to_address)
text_body.set_content(new_text_content)
@ -950,7 +958,8 @@ def _convert_to_forwarded_email(
has_html = False
if html_body:
has_html = True
assert isinstance(html_body, EmailMessage)
if not isinstance(html_body, EmailMessage):
raise TypeError("html_body must be type EmailMessage")
html_content = html_body.get_content()
new_content, level_one_trackers_removed = _convert_html_content(
html_content,
@ -974,7 +983,8 @@ def _convert_to_forwarded_email(
sample_trackers,
remove_level_one_trackers,
)
assert isinstance(text_body, EmailMessage)
if not isinstance(text_body, EmailMessage):
raise TypeError("text_body must be type EmailMessage")
try:
text_body.add_alternative(new_content, subtype="html")
except TypeError as e:
@ -1303,7 +1313,8 @@ def _handle_reply(
return HttpResponse("Cannot fetch the message content from S3", status=503)
email = message_from_bytes(email_bytes, policy=relay_policy)
assert isinstance(email, EmailMessage)
if not isinstance(email, EmailMessage):
raise TypeError("email must be type EmailMessage")
# Convert to a reply email
# TODO: Issue #1747 - Remove wrapper / prefix in replies

Просмотреть файл

@ -27,7 +27,8 @@ class PhonesConfig(AppConfig):
instance = self.twilio_client.applications(
settings.TWILIO_SMS_APPLICATION_SID
).fetch()
assert isinstance(instance, InstanceResource)
if not isinstance(instance, InstanceResource):
raise TypeError("instance must be type InstanceResource")
return instance
@cached_property
@ -47,10 +48,15 @@ class PhonesConfig(AppConfig):
def phones_config() -> PhonesConfig:
phones_config = apps.get_app_config("phones")
assert isinstance(phones_config, PhonesConfig)
if not isinstance(phones_config, PhonesConfig):
raise TypeError("phones_config must be type PhonesConfig")
return phones_config
def twilio_client() -> Client:
assert not settings.PHONES_NO_CLIENT_CALLS_IN_TEST
if settings.PHONES_NO_CLIENT_CALLS_IN_TEST:
raise ValueError(
"settings.PHONES_NO_CLIENT_CALLS_IN_TEST must be False when "
"calling twilio_client()"
)
return phones_config().twilio_client

Просмотреть файл

@ -328,7 +328,11 @@ class CachedList:
def register_with_messaging_service(client: Client, number_sid: str) -> None:
"""Register a Twilio US phone number with a Messaging Service."""
assert settings.TWILIO_MESSAGING_SERVICE_SID
if not settings.TWILIO_MESSAGING_SERVICE_SID:
raise ValueError(
"settings.TWILIO_MESSAGING_SERVICE_SID must contain a value when calling "
"register_with_messaging_service"
)
closed_sids = CachedList("twilio_messaging_service_closed")
@ -386,7 +390,11 @@ def relaynumber_post_save(sender, instance, created, **kwargs):
def send_welcome_message(user, relay_number):
real_phone = RealPhone.objects.get(user=user)
assert settings.SITE_ORIGIN
if not settings.SITE_ORIGIN:
raise ValueError(
"settings.SITE_ORIGIN must contain a value when calling "
"send_welcome_message"
)
media_url = settings.SITE_ORIGIN + reverse(
"vCard", kwargs={"lookup_key": relay_number.vcard_lookup_key}
)

Просмотреть файл

@ -17,7 +17,10 @@ class AccountAdapter(DefaultAccountAdapter):
"""
Redirect to dashboard, preserving utm params from FXA.
"""
assert request.user.is_authenticated
if not request.user.is_authenticated:
raise ValueError(
"request.user must be authenticated when calling get_login_redirect_url"
)
url = "/accounts/profile/?"
utm_params = {k: v for k, v in request.GET.items() if k.startswith("utm")}
url += urlencode(utm_params)

Просмотреть файл

@ -81,7 +81,9 @@ class PrivateRelayConfig(AppConfig):
import privaterelay.signals
assert privaterelay.signals # Suppress "imported but unused" warnings
assert ( # noqa S101
privaterelay.signals
) # Suppress "imported but unused" warnings
try:
del self.fxa_verifying_keys # Clear cache

Просмотреть файл

@ -29,15 +29,22 @@ class DataIssueTask:
def counts(self) -> Counts:
"""Get relevant counts for data issues and prepare to clean if possible."""
if self._counts is None:
assert self._cleanup_data is None
if self._cleanup_data is not None:
raise ValueError(
"self.cleanup_data should be None when self._counts is None"
)
self._counts, self._cleanup_data = self._get_counts_and_data()
return self._counts
@property
def cleanup_data(self) -> CleanupData:
"""Get data needed to clean data issues."""
assert self.counts # Populate self._cleanup_data if not populated
assert self._cleanup_data
if not self.counts:
raise ValueError("self.counts must have a value when calling cleanup_data.")
if not self._cleanup_data:
raise ValueError(
"self._cleanup_data must have a value when calling cleanup_data."
)
return self._cleanup_data
def issues(self) -> int:
@ -72,7 +79,8 @@ class DataIssueTask:
@staticmethod
def _as_percent(part: int, whole: int) -> str:
"""Return value followed by percent of whole, like '5 ( 30.0%)'"""
assert whole > 0
if not whole > 0:
raise ValueError("whole must be greater than 0 when calling _as_percent")
len_whole = len(str(whole))
return f"{part:{len_whole}d} ({part / whole:6.1%})"

Просмотреть файл

@ -137,7 +137,10 @@ class RelayGleanLogger(EventsServerEventLogger):
app_display_version: str,
channel: RELAY_CHANNEL_NAME,
):
assert settings.GLEAN_EVENT_MOZLOG_TYPE == GLEAN_EVENT_MOZLOG_TYPE
if not settings.GLEAN_EVENT_MOZLOG_TYPE == GLEAN_EVENT_MOZLOG_TYPE:
raise ValueError(
"settings.GLEAN_EVENT_MOZLOG_TYPE must equal GLEAN_EVENT_MOZLOG_TYPE"
)
self._logger = getLogger(GLEAN_EVENT_MOZLOG_TYPE)
super().__init__(
application_id=application_id,

Просмотреть файл

@ -57,7 +57,8 @@ class Command(BaseCommand):
epilog = "\n".join(epilog_lines)
parser = super().create_parser(prog_name, subcommand, epilog=epilog, **kwargs)
assert parser.formatter_class == DjangoHelpFormatter
if not parser.formatter_class == DjangoHelpFormatter:
raise TypeError("parser.formatter_class must be type DjangoHelpFormatter")
parser.formatter_class = RawDescriptionDjangoHelpFormatter
return parser

Просмотреть файл

@ -557,15 +557,18 @@ def _cached_country_language_mapping(
mapping: PlanCountryLangMapping = {}
for relay_country in relay_maps.get("by_country", []):
assert relay_country not in mapping
if relay_country in mapping:
raise ValueError("relay_country should not be in mapping.")
mapping[relay_country] = {"*": _get_stripe_prices(relay_country, stripe_data)}
for relay_country, override in relay_maps.get("by_country_override", {}).items():
assert relay_country not in mapping
if relay_country in mapping:
raise ValueError("relay_country should not be in mapping.")
mapping[relay_country] = {"*": _get_stripe_prices(override, stripe_data)}
for relay_country, languages in relay_maps.get("by_country_and_lang", {}).items():
assert relay_country not in mapping
if relay_country in mapping:
raise ValueError("relay_country should not be in mapping.")
mapping[relay_country] = {
lang: _get_stripe_prices(stripe_country, stripe_data)
for lang, stripe_country in languages.items()
@ -586,16 +589,19 @@ def _get_stripe_prices(
# mypy thinks stripe_details _could_ be _StripeYearlyPriceDetails,
# so extra asserts are needed to make mypy happy.
monthly_id = str(stripe_details.get("monthly_id"))
assert monthly_id.startswith("price_")
if not monthly_id.startswith("price_"):
raise ValueError("monthly_id must start with 'price_'")
price = prices.get("monthly", 0.0)
assert price and isinstance(price, float)
if not isinstance(price, float):
raise TypeError("price must be of type float.")
period_to_details["monthly"] = {
"id": monthly_id,
"currency": currency,
"price": price,
}
yearly_id = stripe_details["yearly_id"]
assert yearly_id.startswith("price_")
if not yearly_id.startswith("price_"):
raise ValueError("yearly_id must start with 'price_'")
period_to_details["yearly"] = {
"id": yearly_id,
"currency": currency,

Просмотреть файл

@ -40,7 +40,7 @@ try:
# https://github.com/jazzband/django-silk
import silk
assert silk # Suppress "imported but unused" warning
assert silk # Suppress "imported but unused" warning # noqa S101
HAS_SILK = True
except ImportError:
@ -119,7 +119,10 @@ if FXA_BASE_ORIGIN == "https://accounts.firefox.com":
]
_ACCOUNT_CONNECT_SRC = [FXA_BASE_ORIGIN]
else:
assert FXA_BASE_ORIGIN == "https://accounts.stage.mozaws.net"
if not FXA_BASE_ORIGIN == "https://accounts.stage.mozaws.net":
raise ValueError(
"FXA_BASE_ORIGIN must be either https://accounts.firefox.com or https://accounts.stage.mozaws.net"
)
_AVATAR_IMG_SRC = [
"mozillausercontent.com",
"https://profile.stage.mozaws.net",

Просмотреть файл

@ -37,7 +37,8 @@ class RelayStaticFilesStorage(CompressedManifestStaticFilesStorage):
return name
else:
new_name = super().hashed_name(name, content, filename)
assert isinstance(new_name, str)
if not isinstance(new_name, str):
raise TypeError("new_name must be type str")
return new_name
def url_converter(

Просмотреть файл

@ -168,7 +168,8 @@ def _parse_jwt_from_request(request: HttpRequest) -> str:
def fxa_verifying_keys(reload: bool = False) -> list[dict[str, Any]]:
"""Get list of FxA verifying (public) keys."""
private_relay_config = apps.get_app_config("privaterelay")
assert isinstance(private_relay_config, PrivateRelayConfig)
if not isinstance(private_relay_config, PrivateRelayConfig):
raise TypeError("private_relay_config must be PrivateRelayConfig")
if reload:
private_relay_config.ready()
return private_relay_config.fxa_verifying_keys
@ -177,7 +178,8 @@ def fxa_verifying_keys(reload: bool = False) -> list[dict[str, Any]]:
def fxa_social_app(reload: bool = False) -> SocialApp:
"""Get FxA SocialApp from app config or DB."""
private_relay_config = apps.get_app_config("privaterelay")
assert isinstance(private_relay_config, PrivateRelayConfig)
if not isinstance(private_relay_config, PrivateRelayConfig):
raise TypeError("private_relay_config must be PrivateRelayConfig")
if reload:
private_relay_config.ready()
return private_relay_config.fxa_social_app
@ -222,11 +224,13 @@ def _verify_jwt_with_fxa_key(
social_app = fxa_social_app()
if not social_app:
raise Exception("FXA SocialApp is not available.")
assert isinstance(social_app, SocialApp)
if not isinstance(social_app, SocialApp):
raise TypeError("social_app must be SocialApp")
for verifying_key in verifying_keys:
if verifying_key["alg"] == "RS256":
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(verifying_key))
assert isinstance(public_key, RSAPublicKey)
if not isinstance(public_key, RSAPublicKey):
raise TypeError("public_key must be RSAPublicKey")
try:
security_event = jwt.decode(
req_jwt,