get user ssh keys from user extension secret instead of rest server (#36)

This commit is contained in:
Guoxin 2021-02-23 09:33:09 +08:00 коммит произвёл GitHub
Родитель c77c56291c
Коммит c5c6dbc340
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 40 добавлений и 46 удалений

Просмотреть файл

@ -166,7 +166,7 @@ python ${PAI_INIT_DIR}/framework_parser.py genconf framework.json > ${PAI_RUNTIM
# Init plugins
# priority=12
CHILD_PROCESS="PLUGIN_INITIALIZER"
python ${PAI_INIT_DIR}/initializer.py ${PAI_RUNTIME_DIR}/job_config.yaml ${PAI_SECRET_DIR}/secrets.yaml ${PAI_TOKEN_SECRET_DIR}/token ${PAI_WORK_DIR}/plugins ${PAI_RUNTIME_DIR} ${FC_TASKROLE_NAME}
python ${PAI_INIT_DIR}/initializer.py ${PAI_RUNTIME_DIR}/job_config.yaml ${PAI_SECRET_DIR}/secrets.yaml ${PAI_SECRET_DIR}/userExtensionSecrets.yaml ${PAI_TOKEN_SECRET_DIR}/token ${PAI_WORK_DIR}/plugins ${PAI_RUNTIME_DIR} ${FC_TASKROLE_NAME}
# Init plugins
# check port conflict

Просмотреть файл

@ -108,7 +108,7 @@ def collect_plugin_configs(jobconfig, taskrole):
for prerequisite_name in jobconfig['taskRoles'][taskrole]['prerequisites']:
prerequisite_config = copy.deepcopy(prerequisites_name2config[prerequisite_name])
if 'plugin' in prerequisite_config and prerequisite_config['plugin'].startswith(RUNTIME_PLUGIN_PLACE_HOLDER):
# convert prerequisite to runtime plugin config
# convert prerequisite to runtime plugin config
plugin_config = {
# plugin name follows the format "com.microsoft.pai.runtimeplugin.<plugin name>"
'plugin': prerequisite_config.pop('plugin')[len(RUNTIME_PLUGIN_PLACE_HOLDER) + 1:]
@ -119,7 +119,7 @@ def collect_plugin_configs(jobconfig, taskrole):
# the remaining keys (other than plugin, failurePolicy, and type) will be treated as parameters
plugin_config['parameters'] = copy.deepcopy(prerequisite_config)
plugin_configs.append(plugin_config)
# collect plugins from jobconfig["extras"]
if "extras" in jobconfig and RUNTIME_PLUGIN_PLACE_HOLDER in jobconfig["extras"]:
for plugin_config in jobconfig["extras"][RUNTIME_PLUGIN_PLACE_HOLDER]:
@ -130,13 +130,14 @@ def collect_plugin_configs(jobconfig, taskrole):
return plugin_configs
def init_plugins(jobconfig, secrets, application_token, commands, plugins_path, runtime_path,
def init_plugins(jobconfig, secrets, user_extension, application_token, commands, plugins_path, runtime_path,
taskrole):
"""Init plugins from jobconfig.
Args:
jobconfig: Jobconfig object generated by parser.py from framework.json.
secrets: user secrests passed to runtime.
secrets: config secrests passed to runtime.
user_extension: user extension passed to runtime.
application_token: application token path passed to runtime.
commands: Commands to call in precommands.sh and postcommands.sh.
plugins_path: The base path for all plugins.
@ -155,6 +156,8 @@ def init_plugins(jobconfig, secrets, application_token, commands, plugins_path,
taskrole))
plugin_config["parameters"] = parameters
plugin_config["user_extension"] = user_extension
if os.path.exists(application_token):
with open(application_token, "r") as f:
plugin_config["application_token"] = yaml.safe_load(f)
@ -219,7 +222,9 @@ def main():
"jobconfig_yaml",
help="jobConfig.yaml generated by parser.py from framework.json")
parser.add_argument("secret_file",
help="secrets.yaml user secrets passed to runtime")
help="secrets.yaml config secrets passed to runtime")
parser.add_argument("user_extension_secrets_file",
help="userExtensionSecrets.yaml user extension secrets passed to runtime")
parser.add_argument("application_token",
help="application token passed to runtime")
parser.add_argument("plugins_path", help="Plugins path")
@ -237,8 +242,14 @@ def main():
with open(args.secret_file) as f:
secrets = yaml.safe_load(f.read())
if not os.path.isfile(args.user_extension_secrets_file):
user_extension = None
else:
with open(args.user_extension_secrets_file) as f:
user_extension = yaml.safe_load(f.read())
commands = [[], []]
init_plugins(job_config, secrets, args.application_token, commands, args.plugins_path,
init_plugins(job_config, secrets, user_extension, args.application_token, commands, args.plugins_path,
args.runtime_path, args.task_role)
# pre-commands and post-commands already handled by rest-server.

Просмотреть файл

@ -19,7 +19,6 @@
import logging
import os
import sys
import requests
sys.path.append(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.."))
@ -28,22 +27,19 @@ from plugins.plugin_utils import plugin_init, PluginHelper, try_to_install_by_ca
LOGGER = logging.getLogger(__name__)
def get_user_public_keys(application_token, username):
def get_user_public_keys(user_extension):
"""
get user public keys from rest-server
get user public keys from user extension
Format of API `REST_SERVER_URI/api/v2/users/<username>` response:
Format of user extension:
{
"xxx": "xxx",
"extensions": {
"sshKeys": [
{
"title": "title-of-the-public-key",
"value": "ssh-rsa xxxx"
"time": "xxx"
}
]
}
"sshKeys": [
{
"title": "title-of-the-public-key",
"value": "ssh-rsa xxxx"
"time": "xxx"
}
]
}
Returns:
@ -51,15 +47,7 @@ def get_user_public_keys(application_token, username):
list
a list of public keys
"""
url = "{}/api/v2/users/{}".format(os.environ.get('REST_SERVER_URI'), username)
headers={
'Authorization': "Bearer {}".format(application_token),
}
response = requests.get(url, headers=headers, data={})
response.raise_for_status()
public_keys = [item["value"] for item in response.json()["extension"]["sshKeys"]]
public_keys = [item["value"] for item in user_extension["sshKeys"]]
return public_keys
@ -69,6 +57,7 @@ def main():
[plugin_config, pre_script, _] = plugin_init()
plugin_helper = PluginHelper(plugin_config)
parameters = plugin_config.get("parameters")
user_extension = plugin_config.get("user_extension")
if not parameters:
LOGGER.info("Ssh plugin parameters is empty, ignore this")
@ -86,16 +75,10 @@ def main():
cmd_params = [jobssh]
if "userssh" in parameters:
# get user public keys from rest server
application_token = plugin_config.get("application_token")
username = os.environ.get("PAI_USER_NAME")
# get user public keys from user extension secret
public_keys = []
if application_token:
try:
public_keys = get_user_public_keys(application_token, username)
except Exception: #pylint: disable=broad-except
LOGGER.error("Failed to get user public keys", exc_info=True)
sys.exit(1)
if user_extension and "sshKeys" in user_extension:
public_keys = get_user_public_keys(user_extension)
if "value" in parameters["userssh"] and parameters["userssh"]["value"] != "":
public_keys.append(parameters["userssh"]["value"])

Просмотреть файл

@ -60,7 +60,7 @@ class TestRuntime(unittest.TestCase):
with open(job_path, 'rt') as f:
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
".", "worker")
def test_cmd_plugin_with_callbacks(self):
@ -69,7 +69,7 @@ class TestRuntime(unittest.TestCase):
with open(job_path, 'rt') as f:
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
".", "worker")
def test_cmd_plugin_with_prerequisites(self):
@ -78,7 +78,7 @@ class TestRuntime(unittest.TestCase):
with open(job_path, 'rt') as f:
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
".", "worker")
def test_ssh_plugin(self):
@ -88,7 +88,7 @@ class TestRuntime(unittest.TestCase):
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(
jobconfig, {"userssh": "ssh-rsa AAAAB3N/cTbWGQZtN1pai-ssh"}, "",
jobconfig, {"userssh": "ssh-rsa AAAAB3N/cTbWGQZtN1pai-ssh"}, {}, "",
commands, "../src/plugins", ".", "worker")
def test_ssh_plugin_barrier(self):
@ -97,10 +97,10 @@ class TestRuntime(unittest.TestCase):
with open(job_path, 'rt') as f:
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
".", "master")
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
".", "worker")
def test_git_plugin(self):
@ -109,7 +109,7 @@ class TestRuntime(unittest.TestCase):
with open(job_path, 'rt') as f:
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins", ".",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins", ".",
"worker")
repo_local_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../src/code")
self.assertTrue(os.path.exists(repo_local_path))