379 строки
12 KiB
Python
379 строки
12 KiB
Python
import os, sys, mimetypes, requests, uuid, json
|
|
|
|
from kinto_http import Client, BearerTokenAuth
|
|
from packaging import version
|
|
|
|
from remote_settings.format import print_error, print_help
|
|
|
|
REMOTE_SETTINGS_BEARER_TOKEN = "REMOTE_SETTINGS_BEARER_TOKEN"
|
|
BEARER_TOKEN_HELP_MESSAGE = f"""\
|
|
Export the token as an environment variable called {REMOTE_SETTINGS_BEARER_TOKEN}.
|
|
|
|
You can retrieve a bearer token from the Remote Settings admin dashboards.
|
|
|
|
Dev: https://settings.dev.mozaws.net/v1/admin
|
|
Stage: https://remote-settings.allizom.org/v1/admin
|
|
Prod: https://remote-settings.mozilla.org/v1/admin
|
|
|
|
On the top right corner, use the 📋 icon to copy the authentication string
|
|
"""
|
|
BUCKET = "main-workspace"
|
|
COLLECTION = "translations-models"
|
|
SERVER_URLS = {
|
|
"dev": "https://remote-settings-dev.allizom.org/v1",
|
|
"stage": "https://remote-settings.allizom.org/v1",
|
|
"prod": "https://remote-settings.mozilla.org/v1",
|
|
}
|
|
|
|
|
|
class MockedClient:
|
|
def __init__(self, args):
|
|
self._server = args.server
|
|
|
|
def server_info(self):
|
|
return {
|
|
"url": SERVER_URLS.get(self._server),
|
|
"user": {
|
|
"id": "mocked_user",
|
|
},
|
|
}
|
|
|
|
|
|
class RemoteSettingsClient:
|
|
def __init__(self, args):
|
|
"""Initializes the RemoteSettingsClient by authenticating with the server.
|
|
|
|
The client may be mocked for testing if the --mock-connection flag was passed.
|
|
|
|
Args:
|
|
args (argparse.Namespace): The arguments passed through the CLI
|
|
"""
|
|
if args.mock_connection:
|
|
self._client = MockedClient(args)
|
|
return
|
|
|
|
self._auth_token = RemoteSettingsClient._retrieve_remote_settings_bearer_token()
|
|
self._client = Client(
|
|
server_url=SERVER_URLS.get(args.server),
|
|
bucket=BUCKET,
|
|
collection=COLLECTION,
|
|
auth=BearerTokenAuth(self._auth_token),
|
|
)
|
|
self._new_records = None
|
|
|
|
@classmethod
|
|
def init_for_create(cls, args):
|
|
"""Initializes the RemoteSettingsClient for the create subcommand
|
|
This expects the CLI args to have information regarding creating a
|
|
new record, which populates the _new_records data member.
|
|
|
|
Args:
|
|
args (argparse.Namespace): The arguments passed through the CLI
|
|
|
|
Returns:
|
|
RemoteSettingsClient: A RemoteSettingsClient that can create new records
|
|
"""
|
|
this = cls(args)
|
|
if args.path is not None:
|
|
new_record_info = RemoteSettingsClient._create_record_info(args.path, args.version)
|
|
this._new_records = [new_record_info]
|
|
else:
|
|
paths = this._paths_for_lang_pair(args)
|
|
this._new_records = [
|
|
RemoteSettingsClient._create_record_info(path, args.version) for path in paths
|
|
]
|
|
|
|
return this
|
|
|
|
@staticmethod
|
|
def _paths_for_lang_pair(args):
|
|
"""Retrieves all of the file paths for the given language pair and version in args.
|
|
|
|
Args:
|
|
args (argparse.Namespace): The arguments passed through the CLI
|
|
|
|
Returns:
|
|
List[str]: A list of file paths in the specified language-pair directory.
|
|
"""
|
|
parsed_version = version.parse(args.version)
|
|
|
|
if parsed_version.is_prerelease:
|
|
directory = os.path.join(RemoteSettingsClient._base_dir(args), "dev")
|
|
else:
|
|
directory = os.path.join(RemoteSettingsClient._base_dir(args), "prod")
|
|
|
|
full_path = os.path.join(directory, args.lang_pair)
|
|
|
|
if not os.path.exists(full_path):
|
|
print_error(f"Path does not exist: {full_path}")
|
|
exit(1)
|
|
|
|
return [os.path.join(full_path, f) for f in os.listdir(full_path) if not f.endswith(".gz")]
|
|
|
|
@staticmethod
|
|
def _create_record_info(path, version):
|
|
"""Creates a record-info dictionary for a file at the given path.
|
|
|
|
Args:
|
|
path (str): The path to the file
|
|
version (str): The version of the record attachment
|
|
|
|
Returns:
|
|
dict: A dictionary containing the record metadata
|
|
"""
|
|
name = os.path.basename(path)
|
|
file_type = RemoteSettingsClient._determine_file_type(name)
|
|
from_lang, to_lang = RemoteSettingsClient._determine_language_pair(name)
|
|
filter_expression = RemoteSettingsClient._determine_filter_expression(version)
|
|
mimetype, _ = mimetypes.guess_type(path)
|
|
return {
|
|
"id": str(uuid.uuid4()),
|
|
"data": {
|
|
"name": os.path.basename(path),
|
|
"fromLang": from_lang,
|
|
"toLang": to_lang,
|
|
"version": version,
|
|
"fileType": file_type,
|
|
"filter_expression": filter_expression,
|
|
},
|
|
"attachment": {
|
|
"path": path,
|
|
"mimeType": mimetype,
|
|
},
|
|
}
|
|
|
|
@staticmethod
|
|
def _retrieve_remote_settings_bearer_token():
|
|
"""
|
|
Attempts to retrieve a Remote Settings bearer token exported to an environment
|
|
variable called REMOTE_SETTINGS_BEARER_TOKEN.
|
|
|
|
Exits with failure if the token cannot be retrieved.
|
|
|
|
Returns:
|
|
String: The bearer token.
|
|
"""
|
|
|
|
token = os.environ.get(REMOTE_SETTINGS_BEARER_TOKEN)
|
|
if token is None:
|
|
print_error(f"Failed to retrieve {REMOTE_SETTINGS_BEARER_TOKEN}")
|
|
print_help(BEARER_TOKEN_HELP_MESSAGE)
|
|
sys.exit(1)
|
|
|
|
# When copying the Remote Settings token from the UI, it copies in the format
|
|
# "Bearer <token>". We want to strip just the token if the user did not strip
|
|
# it already themselves.
|
|
if token.startswith("Bearer "):
|
|
return token[len("Bearer ") :]
|
|
|
|
return token
|
|
|
|
@staticmethod
|
|
def _determine_filter_expression(semantic_version):
|
|
"""Determines the appropriate Remote Settings filter expression based on the version.
|
|
|
|
Alpha versions are available in local builds and nightly.
|
|
Beta versions are available in all builds except release.
|
|
Release versions are available in all builds.
|
|
|
|
Args:
|
|
semantic_version str: A semantic version string
|
|
|
|
Returns:
|
|
str: The appropriate Remote Settings filter expression based on the version
|
|
"""
|
|
record_version = version.parse(semantic_version)
|
|
base_version = record_version.base_version
|
|
|
|
if record_version < version.parse(f"{base_version}b"):
|
|
return "env.channel == 'default' || env.channel == 'nightly'"
|
|
elif record_version < version.parse(f"{base_version}"):
|
|
return "env.channel != 'release'"
|
|
else:
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _determine_language_pair(name):
|
|
"""Determines the language pair based on the name of the file.
|
|
|
|
Args:
|
|
name str: The name of a file to attach to a record
|
|
|
|
Returns:
|
|
Tuple[str, str]: The (fromLang, toLang) pair for this file
|
|
"""
|
|
segments = name.split(".")
|
|
|
|
# File names are of the following formats:
|
|
# - model.{lang_pair}.intgemm8.bin.gz
|
|
# - lex.{lang_pair}.s2t.bin.gz
|
|
# - lex.50.50.{lang_pair}.s2t.bin.gz
|
|
# - trgvocab.{lang_pair}.spm.gz
|
|
# - srcvocab.{lang_pair}.spm.gz
|
|
# - qualityModel.{lang_pair}.bin.gz
|
|
# - vocab.{lang_pair}.spm.gz
|
|
#
|
|
# The lang_pair will always be in the one-index, except for
|
|
# the lex.50.50... file, in which case it is in the three-index segment.
|
|
lang_pair_segment = segments[1]
|
|
if len(lang_pair_segment) < 4:
|
|
lang_pair_segment = segments[3]
|
|
return (lang_pair_segment[:2], lang_pair_segment[-2:])
|
|
|
|
@staticmethod
|
|
def _determine_file_type(name):
|
|
"""Returns the type of the file based on the file name.
|
|
Note that this is different than the file extension.
|
|
The resulting type will be one of the following strings:
|
|
|
|
{"model", "lex", "vocab", "trgvocab", "srcvocab", "qualityModel"}
|
|
|
|
Args:
|
|
name str: The name of a file to attach to a record
|
|
|
|
Returns:
|
|
str: The type of the file
|
|
"""
|
|
segments = name.split(".")
|
|
file_type_segment = segments[0]
|
|
return file_type_segment
|
|
|
|
@staticmethod
|
|
def _base_dir(args):
|
|
"""Get the base directory in which to search for record attachments.
|
|
|
|
Args:
|
|
args (argparse.Namespace): The arguments passed through the CLI
|
|
|
|
Returns:
|
|
str: The base directory for record attachments.
|
|
"""
|
|
if args.test:
|
|
return os.path.join("tests", "remote_settings", "attachments")
|
|
else:
|
|
return "models"
|
|
|
|
def server_url(self):
|
|
"""Retrieves the url of the server that this client is connected to.
|
|
|
|
Returns:
|
|
str: The server url
|
|
"""
|
|
return self._client.server_info()["url"]
|
|
|
|
def authenticated_user(self):
|
|
"""Retrieves the user who is authenticated through this client.
|
|
|
|
Returns:
|
|
str: The authenticated user
|
|
"""
|
|
return self._client.server_info()["user"]["id"]
|
|
|
|
def attachment_path(self, index):
|
|
"""Retrieves the path of the attachment that will be attached to a newly created record.
|
|
|
|
Args:
|
|
index (int): The index of the record.
|
|
|
|
Returns:
|
|
str: The attachment path
|
|
"""
|
|
return self._new_records[index]["attachment"]["path"]
|
|
|
|
def attachment_name(self, index):
|
|
"""Retrieves the name of the attachment that will be attached to a newly created record.
|
|
|
|
Args:
|
|
index (int): The index of the record.
|
|
|
|
Returns:
|
|
str: The attachment name
|
|
"""
|
|
return os.path.basename(self.attachment_path(index))
|
|
|
|
def attachment_mimetype(self, index):
|
|
"""Retrieves the determined mimetype of the attachment that will be attached to a newly created record.
|
|
|
|
Args:
|
|
index (int): The index of the record.
|
|
|
|
Returns:
|
|
Union[None | str]: The determined mimetype
|
|
"""
|
|
return self._new_records[index]["attachment"]["mimeType"]
|
|
|
|
def attachment_content(self, index):
|
|
"""Retrieves the file content of the attachment that will be attached to a newly created record.
|
|
|
|
Args:
|
|
index (int): The index of the record.
|
|
|
|
Returns:
|
|
bytes: The content of the attachment
|
|
"""
|
|
with open(self.attachment_path(index), "rb") as f:
|
|
attachment_content = f.read()
|
|
return attachment_content
|
|
|
|
def record_count(self):
|
|
"""Returns the count of new records to be created"""
|
|
return len(self._new_records)
|
|
|
|
def record_info_json(self, index):
|
|
"""Returns the information of the record to be created as JSON data.
|
|
|
|
Args:
|
|
index (int): The index of the record.
|
|
|
|
Returns:
|
|
str: The JSON-formatted string containing the record info
|
|
"""
|
|
return json.dumps(self._new_records[index], indent=2)
|
|
|
|
def create_new_record(self, index):
|
|
"""Creates a new record in the Remote Settings server along with its file attachment.
|
|
|
|
Args:
|
|
index (int): The index of the record.
|
|
"""
|
|
id = self._new_records[index]["id"]
|
|
data = self._new_records[index]["data"]
|
|
self._client.create_record(id=id, data=data)
|
|
self.attach_file_to_record(index)
|
|
|
|
def attach_file_to_record(self, index):
|
|
"""Attaches the file attachment to the record of the matching id.
|
|
|
|
Args:
|
|
index (int): The index of the record.
|
|
|
|
Raises:
|
|
KintoException: An exception if the record was not able to be uploaded.
|
|
"""
|
|
headers = {"Authorization": f"Bearer {self._auth_token}"}
|
|
|
|
attachment_endpoint = "buckets/{}/collections/{}/records/{}/attachment".format(
|
|
BUCKET, COLLECTION, self._new_records[index]["id"]
|
|
)
|
|
|
|
response = requests.post(
|
|
f"{self.server_url()}{attachment_endpoint}",
|
|
files=[
|
|
(
|
|
"attachment",
|
|
(
|
|
self.attachment_name(index),
|
|
self.attachment_content(index),
|
|
self.attachment_mimetype(index),
|
|
),
|
|
)
|
|
],
|
|
headers=headers,
|
|
)
|
|
|
|
if response.status_code > 200:
|
|
raise KintoException(
|
|
f"Couldn't attach file at endpoint {self.sever_url()}{attachment_endpoint}: "
|
|
+ f"{response.content.decode('utf-8')}"
|
|
)
|