Add sshbarrier to ssh plugin (#3587)
* Add ssh barrier - set sshbarrier to true to enable ssh barrier - set sshbarriertaskroles to identify specific task roles that barrier wait for ssh setup. If no sshbarriertaskroles, all the barrier will wait for all task roles. * Add error spec for SSH barrier.
This commit is contained in:
Родитель
c51a017e48
Коммит
d10405b277
|
@ -10,9 +10,14 @@ extras:
|
|||
- plugin: ssh
|
||||
parameters:
|
||||
jobssh: boolean
|
||||
sshbarrier: boolean
|
||||
sshbarriertaskroles:
|
||||
- taskrole
|
||||
userssh:
|
||||
type: string
|
||||
value: string
|
||||
```
|
||||
- jobssh: true to enable job container wise ssh, false to disable.
|
||||
- userssh: currently the userssh type should be system|custom. Type system means the value is a key stored in PAI, and type custom means the value is the string defined in job config.
|
||||
- sshbarrier: if set to true, wait until can ssh to all corresponding job containers. If not set, the defalut value is false.
|
||||
- sshbarriertaskroles: only valid if sshbarrier set to true. Defines the task roles that the barrier will test ssh to. If not defind, all taskroles will be included.
|
||||
- userssh: currently the userssh type should be ```custom```. Type ```custom``` means use the userssh value as the SSH public key to run job. User can use the corresponding SSH private key to connect to job container.
|
|
@ -33,12 +33,12 @@ logger = logging.getLogger(__name__)
|
|||
if __name__ == "__main__":
|
||||
[parameters, pre_script, post_script] = plugin_init()
|
||||
|
||||
cmdParams = []
|
||||
if parameters is not None:
|
||||
if "jobssh" in parameters:
|
||||
cmdParams.append(str(parameters["jobssh"]).lower())
|
||||
jobssh = str(parameters["jobssh"]).lower()
|
||||
else:
|
||||
cmdParams.append("false")
|
||||
jobssh = "false"
|
||||
cmdParams = [jobssh]
|
||||
|
||||
if "userssh" in parameters:
|
||||
if "type" in parameters["userssh"] and "value" in parameters["userssh"]:
|
||||
|
@ -46,5 +46,14 @@ if __name__ == "__main__":
|
|||
cmdParams.append("\'{}\'".format(parameters["userssh"]["value"]))
|
||||
|
||||
# write call to real executable script
|
||||
command = "{}/sshd.sh {}\n".format(os.path.dirname(os.path.abspath(__file__)), " ".join(cmdParams))
|
||||
inject_commands([command], pre_script)
|
||||
command = ["{}/sshd.sh {}\n".format(os.path.dirname(os.path.abspath(__file__)), " ".join(cmdParams))]
|
||||
|
||||
# ssh barrier
|
||||
if jobssh == "true" and "sshbarrier" in parameters and str(parameters["sshbarrier"]).lower() == "true":
|
||||
if "sshbarriertaskroles" in parameters:
|
||||
barrierParams = " ".join('"{}"'.format(tr) for tr in parameters["sshbarriertaskroles"])
|
||||
else:
|
||||
barrierParams = ""
|
||||
command.append("{}/sshbarrier.sh {}\n".format(os.path.dirname(os.path.abspath(__file__)), barrierParams))
|
||||
|
||||
inject_commands(command, pre_script)
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
#!/bin/bash
|
||||
# Copyright (c) Microsoft Corporation
|
||||
# All rights reserved.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
|
||||
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
|
||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
|
||||
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
|
||||
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
# no set -o errexit because use exitcode to judge ssh connectivity
|
||||
# no set -o nounset because use empty array to judge end
|
||||
set -o pipefail
|
||||
|
||||
readonly MAX_RETRY_COUNT=20
|
||||
readonly RETRY_INTERVAL=1
|
||||
|
||||
function check_ssh_connection()
|
||||
{
|
||||
ssh -q -o BatchMode=yes -o StrictHostKeyChecking=no $1 "exit 0"
|
||||
_RCODE=$?
|
||||
return $_RCODE
|
||||
}
|
||||
|
||||
taskRolesToCheck=()
|
||||
for barrierTaskRole in $@; do
|
||||
taskRolesToCheck+=($barrierTaskRole)
|
||||
done
|
||||
|
||||
instancesToCheck=()
|
||||
# Set ssh config for all task role instances
|
||||
taskRoleInstances=(${PAI_TASK_ROLE_INSTANCES//,/ })
|
||||
for i in "${taskRoleInstances[@]}"; do
|
||||
instancePair=(${i//:/ })
|
||||
taskRole=${instancePair[0]}
|
||||
index=${instancePair[1]}
|
||||
|
||||
if [[ $taskRole = $FC_TASKROLE_NAME ]] && [[ $index = $FC_TASK_INDEX ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# If barrier task roles defined, then only check instances for defined task roles. Otherwise check all instances.
|
||||
if [[ ${#taskRolesToCheck[@]} != 0 ]]; then
|
||||
if [[ ${taskRolesToCheck[@]} =~ ${taskRole} ]]; then
|
||||
instancesToCheck+=("${taskRole}-${index}")
|
||||
fi
|
||||
else
|
||||
instancesToCheck+=("${taskRole}-${index}")
|
||||
fi
|
||||
done
|
||||
|
||||
retryCount=0
|
||||
while true
|
||||
do
|
||||
echo "Trying to SSH to instances: ${instancesToCheck[*]}"
|
||||
|
||||
instanceFailed=()
|
||||
for instance in "${instancesToCheck[@]}"; do
|
||||
check_ssh_connection "$instance"
|
||||
if [[ $? != 0 ]]; then
|
||||
instanceFailed+=("$instance")
|
||||
fi
|
||||
done
|
||||
|
||||
[[ ${#instanceFailed[@]} = 0 ]] && break
|
||||
|
||||
if (( $retryCount >= $MAX_RETRY_COUNT )); then
|
||||
echo "SSH barrier reaches max retry count. Failed instances: ${instancesToCheck[*]} Exit..." >&2
|
||||
exit 10
|
||||
fi
|
||||
|
||||
instancesToCheck=(${instanceFailed[*]})
|
||||
((retryCount++))
|
||||
|
||||
sleep $RETRY_INTERVAL
|
||||
done
|
||||
|
||||
echo "All ssh connections are established, continue..."
|
|
@ -16,6 +16,10 @@
|
|||
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
set -o errexit
|
||||
set -o nounset
|
||||
set -o pipefail
|
||||
|
||||
PAI_WORK_DIR=/usr/local/pai
|
||||
SSH_DIR=/root/.ssh
|
||||
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
protocolVersion: 2
|
||||
name: sshbarrier_test_job
|
||||
type: job
|
||||
version: horovod0.16.4-tf1.12.0-torch1.1.0-mxnet1.4.1-py3.5
|
||||
contributor: OpenPAI
|
||||
description: |
|
||||
This is a distributed synthetic benchmark for Horovod with PyTorch backend running on OpenPAI.
|
||||
It runs [Horovod with Open MPI](https://github.com/horovod/horovod/blob/master/docs/mpirun.rst).
|
||||
parameters:
|
||||
model: resnet50
|
||||
batchsize: 64
|
||||
|
||||
prerequisites:
|
||||
- protocolVersion: 2
|
||||
name: horovod_official
|
||||
type: dockerimage
|
||||
contributor : Horovod
|
||||
uri : horovod/horovod:0.16.4-tf1.12.0-torch1.1.0-mxnet1.4.1-py3.5
|
||||
|
||||
taskRoles:
|
||||
master:
|
||||
instances: 1
|
||||
completion:
|
||||
minSucceededInstances: 1
|
||||
dockerImage: horovod_official
|
||||
resourcePerInstance:
|
||||
cpu: 8
|
||||
memoryMB: 16384
|
||||
gpu: 2
|
||||
commands:
|
||||
- sleep 10
|
||||
- >
|
||||
horovodrun -np 4 -H master-0:2,worker-0:2
|
||||
python pytorch_synthetic_benchmark.py
|
||||
--model <% $parameters.model %>
|
||||
--batch-size <% $parameters.batchsize %>
|
||||
worker:
|
||||
instances: 1
|
||||
dockerImage: horovod_official
|
||||
resourcePerInstance:
|
||||
cpu: 8
|
||||
memoryMB: 16384
|
||||
gpu: 2
|
||||
commands:
|
||||
- sleep infinity
|
||||
|
||||
extras:
|
||||
com.microsoft.pai.runtimeplugin:
|
||||
- plugin: ssh
|
||||
taskroles:
|
||||
- master
|
||||
parameters:
|
||||
jobssh: true
|
||||
sshbarrier: true
|
||||
- plugin: ssh
|
||||
taskroles:
|
||||
- worker
|
||||
parameters:
|
||||
jobssh: true
|
|
@ -52,6 +52,16 @@ class TestRuntimeInitializer(unittest.TestCase):
|
|||
commands = [[],[]]
|
||||
init_plugins(jobconfig, commands, "../src/plugins", ".", "worker")
|
||||
|
||||
def test_ssh_plugin_barrier(self):
|
||||
job_path = "sshbarrier_test_job.yaml"
|
||||
if os.path.exists(job_path):
|
||||
with open(job_path, 'rt') as f:
|
||||
jobconfig = yaml.load(f)
|
||||
commands = [[],[]]
|
||||
init_plugins(jobconfig, commands, "../src/plugins", ".", "master")
|
||||
commands = [[],[]]
|
||||
init_plugins(jobconfig, commands, "../src/plugins", ".", "worker")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче