get user ssh keys from user extension secret instead of rest server (#36)
This commit is contained in:
Родитель
c77c56291c
Коммит
c5c6dbc340
2
src/init
2
src/init
|
@ -166,7 +166,7 @@ python ${PAI_INIT_DIR}/framework_parser.py genconf framework.json > ${PAI_RUNTIM
|
|||
# Init plugins
|
||||
# priority=12
|
||||
CHILD_PROCESS="PLUGIN_INITIALIZER"
|
||||
python ${PAI_INIT_DIR}/initializer.py ${PAI_RUNTIME_DIR}/job_config.yaml ${PAI_SECRET_DIR}/secrets.yaml ${PAI_TOKEN_SECRET_DIR}/token ${PAI_WORK_DIR}/plugins ${PAI_RUNTIME_DIR} ${FC_TASKROLE_NAME}
|
||||
python ${PAI_INIT_DIR}/initializer.py ${PAI_RUNTIME_DIR}/job_config.yaml ${PAI_SECRET_DIR}/secrets.yaml ${PAI_SECRET_DIR}/userExtensionSecrets.yaml ${PAI_TOKEN_SECRET_DIR}/token ${PAI_WORK_DIR}/plugins ${PAI_RUNTIME_DIR} ${FC_TASKROLE_NAME}
|
||||
|
||||
# Init plugins
|
||||
# check port conflict
|
||||
|
|
|
@ -108,7 +108,7 @@ def collect_plugin_configs(jobconfig, taskrole):
|
|||
for prerequisite_name in jobconfig['taskRoles'][taskrole]['prerequisites']:
|
||||
prerequisite_config = copy.deepcopy(prerequisites_name2config[prerequisite_name])
|
||||
if 'plugin' in prerequisite_config and prerequisite_config['plugin'].startswith(RUNTIME_PLUGIN_PLACE_HOLDER):
|
||||
# convert prerequisite to runtime plugin config
|
||||
# convert prerequisite to runtime plugin config
|
||||
plugin_config = {
|
||||
# plugin name follows the format "com.microsoft.pai.runtimeplugin.<plugin name>"
|
||||
'plugin': prerequisite_config.pop('plugin')[len(RUNTIME_PLUGIN_PLACE_HOLDER) + 1:]
|
||||
|
@ -119,7 +119,7 @@ def collect_plugin_configs(jobconfig, taskrole):
|
|||
# the remaining keys (other than plugin, failurePolicy, and type) will be treated as parameters
|
||||
plugin_config['parameters'] = copy.deepcopy(prerequisite_config)
|
||||
plugin_configs.append(plugin_config)
|
||||
|
||||
|
||||
# collect plugins from jobconfig["extras"]
|
||||
if "extras" in jobconfig and RUNTIME_PLUGIN_PLACE_HOLDER in jobconfig["extras"]:
|
||||
for plugin_config in jobconfig["extras"][RUNTIME_PLUGIN_PLACE_HOLDER]:
|
||||
|
@ -130,13 +130,14 @@ def collect_plugin_configs(jobconfig, taskrole):
|
|||
return plugin_configs
|
||||
|
||||
|
||||
def init_plugins(jobconfig, secrets, application_token, commands, plugins_path, runtime_path,
|
||||
def init_plugins(jobconfig, secrets, user_extension, application_token, commands, plugins_path, runtime_path,
|
||||
taskrole):
|
||||
"""Init plugins from jobconfig.
|
||||
|
||||
Args:
|
||||
jobconfig: Jobconfig object generated by parser.py from framework.json.
|
||||
secrets: user secrests passed to runtime.
|
||||
secrets: config secrests passed to runtime.
|
||||
user_extension: user extension passed to runtime.
|
||||
application_token: application token path passed to runtime.
|
||||
commands: Commands to call in precommands.sh and postcommands.sh.
|
||||
plugins_path: The base path for all plugins.
|
||||
|
@ -155,6 +156,8 @@ def init_plugins(jobconfig, secrets, application_token, commands, plugins_path,
|
|||
taskrole))
|
||||
plugin_config["parameters"] = parameters
|
||||
|
||||
plugin_config["user_extension"] = user_extension
|
||||
|
||||
if os.path.exists(application_token):
|
||||
with open(application_token, "r") as f:
|
||||
plugin_config["application_token"] = yaml.safe_load(f)
|
||||
|
@ -219,7 +222,9 @@ def main():
|
|||
"jobconfig_yaml",
|
||||
help="jobConfig.yaml generated by parser.py from framework.json")
|
||||
parser.add_argument("secret_file",
|
||||
help="secrets.yaml user secrets passed to runtime")
|
||||
help="secrets.yaml config secrets passed to runtime")
|
||||
parser.add_argument("user_extension_secrets_file",
|
||||
help="userExtensionSecrets.yaml user extension secrets passed to runtime")
|
||||
parser.add_argument("application_token",
|
||||
help="application token passed to runtime")
|
||||
parser.add_argument("plugins_path", help="Plugins path")
|
||||
|
@ -237,8 +242,14 @@ def main():
|
|||
with open(args.secret_file) as f:
|
||||
secrets = yaml.safe_load(f.read())
|
||||
|
||||
if not os.path.isfile(args.user_extension_secrets_file):
|
||||
user_extension = None
|
||||
else:
|
||||
with open(args.user_extension_secrets_file) as f:
|
||||
user_extension = yaml.safe_load(f.read())
|
||||
|
||||
commands = [[], []]
|
||||
init_plugins(job_config, secrets, args.application_token, commands, args.plugins_path,
|
||||
init_plugins(job_config, secrets, user_extension, args.application_token, commands, args.plugins_path,
|
||||
args.runtime_path, args.task_role)
|
||||
|
||||
# pre-commands and post-commands already handled by rest-server.
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import requests
|
||||
|
||||
sys.path.append(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.."))
|
||||
|
@ -28,22 +27,19 @@ from plugins.plugin_utils import plugin_init, PluginHelper, try_to_install_by_ca
|
|||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_user_public_keys(application_token, username):
|
||||
def get_user_public_keys(user_extension):
|
||||
"""
|
||||
get user public keys from rest-server
|
||||
get user public keys from user extension
|
||||
|
||||
Format of API `REST_SERVER_URI/api/v2/users/<username>` response:
|
||||
Format of user extension:
|
||||
{
|
||||
"xxx": "xxx",
|
||||
"extensions": {
|
||||
"sshKeys": [
|
||||
{
|
||||
"title": "title-of-the-public-key",
|
||||
"value": "ssh-rsa xxxx"
|
||||
"time": "xxx"
|
||||
}
|
||||
]
|
||||
}
|
||||
"sshKeys": [
|
||||
{
|
||||
"title": "title-of-the-public-key",
|
||||
"value": "ssh-rsa xxxx"
|
||||
"time": "xxx"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Returns:
|
||||
|
@ -51,15 +47,7 @@ def get_user_public_keys(application_token, username):
|
|||
list
|
||||
a list of public keys
|
||||
"""
|
||||
url = "{}/api/v2/users/{}".format(os.environ.get('REST_SERVER_URI'), username)
|
||||
headers={
|
||||
'Authorization': "Bearer {}".format(application_token),
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers, data={})
|
||||
response.raise_for_status()
|
||||
|
||||
public_keys = [item["value"] for item in response.json()["extension"]["sshKeys"]]
|
||||
public_keys = [item["value"] for item in user_extension["sshKeys"]]
|
||||
|
||||
return public_keys
|
||||
|
||||
|
@ -69,6 +57,7 @@ def main():
|
|||
[plugin_config, pre_script, _] = plugin_init()
|
||||
plugin_helper = PluginHelper(plugin_config)
|
||||
parameters = plugin_config.get("parameters")
|
||||
user_extension = plugin_config.get("user_extension")
|
||||
|
||||
if not parameters:
|
||||
LOGGER.info("Ssh plugin parameters is empty, ignore this")
|
||||
|
@ -86,16 +75,10 @@ def main():
|
|||
cmd_params = [jobssh]
|
||||
|
||||
if "userssh" in parameters:
|
||||
# get user public keys from rest server
|
||||
application_token = plugin_config.get("application_token")
|
||||
username = os.environ.get("PAI_USER_NAME")
|
||||
# get user public keys from user extension secret
|
||||
public_keys = []
|
||||
if application_token:
|
||||
try:
|
||||
public_keys = get_user_public_keys(application_token, username)
|
||||
except Exception: #pylint: disable=broad-except
|
||||
LOGGER.error("Failed to get user public keys", exc_info=True)
|
||||
sys.exit(1)
|
||||
if user_extension and "sshKeys" in user_extension:
|
||||
public_keys = get_user_public_keys(user_extension)
|
||||
|
||||
if "value" in parameters["userssh"] and parameters["userssh"]["value"] != "":
|
||||
public_keys.append(parameters["userssh"]["value"])
|
||||
|
|
|
@ -60,7 +60,7 @@ class TestRuntime(unittest.TestCase):
|
|||
with open(job_path, 'rt') as f:
|
||||
jobconfig = yaml.safe_load(f)
|
||||
commands = [[], []]
|
||||
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
|
||||
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
|
||||
".", "worker")
|
||||
|
||||
def test_cmd_plugin_with_callbacks(self):
|
||||
|
@ -69,7 +69,7 @@ class TestRuntime(unittest.TestCase):
|
|||
with open(job_path, 'rt') as f:
|
||||
jobconfig = yaml.safe_load(f)
|
||||
commands = [[], []]
|
||||
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
|
||||
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
|
||||
".", "worker")
|
||||
|
||||
def test_cmd_plugin_with_prerequisites(self):
|
||||
|
@ -78,7 +78,7 @@ class TestRuntime(unittest.TestCase):
|
|||
with open(job_path, 'rt') as f:
|
||||
jobconfig = yaml.safe_load(f)
|
||||
commands = [[], []]
|
||||
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
|
||||
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
|
||||
".", "worker")
|
||||
|
||||
def test_ssh_plugin(self):
|
||||
|
@ -88,7 +88,7 @@ class TestRuntime(unittest.TestCase):
|
|||
jobconfig = yaml.safe_load(f)
|
||||
commands = [[], []]
|
||||
initializer.init_plugins(
|
||||
jobconfig, {"userssh": "ssh-rsa AAAAB3N/cTbWGQZtN1pai-ssh"}, "",
|
||||
jobconfig, {"userssh": "ssh-rsa AAAAB3N/cTbWGQZtN1pai-ssh"}, {}, "",
|
||||
commands, "../src/plugins", ".", "worker")
|
||||
|
||||
def test_ssh_plugin_barrier(self):
|
||||
|
@ -97,10 +97,10 @@ class TestRuntime(unittest.TestCase):
|
|||
with open(job_path, 'rt') as f:
|
||||
jobconfig = yaml.safe_load(f)
|
||||
commands = [[], []]
|
||||
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
|
||||
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
|
||||
".", "master")
|
||||
commands = [[], []]
|
||||
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
|
||||
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
|
||||
".", "worker")
|
||||
|
||||
def test_git_plugin(self):
|
||||
|
@ -109,7 +109,7 @@ class TestRuntime(unittest.TestCase):
|
|||
with open(job_path, 'rt') as f:
|
||||
jobconfig = yaml.safe_load(f)
|
||||
commands = [[], []]
|
||||
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins", ".",
|
||||
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins", ".",
|
||||
"worker")
|
||||
repo_local_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../src/code")
|
||||
self.assertTrue(os.path.exists(repo_local_path))
|
||||
|
|
Загрузка…
Ссылка в новой задаче