From bc5323d2fd03eba4c368cc03f0198eabceca202b Mon Sep 17 00:00:00 2001 From: Miltos Allamanis Date: Mon, 25 Jan 2021 08:32:49 +0000 Subject: [PATCH] Minor AML updates. --- Dockerfile | 18 ++++++++++++++---- ptgnn/implementations/typilus/train.py | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index 6951681..9744947 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,16 @@ FROM mcr.microsoft.com/azureml/base-gpu:openmpi3.1.2-cuda10.1-cudnn7-ubuntu18.04 -RUN python3 -m pip install --upgrade --no-cache-dir numpy scipy docopt dpu-utils more-itertools typing_extensions sentencepiece azureml-sdk pyyaml dill jellyfish -RUN pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html -# Install torch scatter -RUN pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.6.0.html +RUN apt update && apt install -y python3.8-dev python3.8 python3.8-venv python3-pip python3-cffi +RUN python3.8 -m pip install --no-cache-dir torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + + +RUN python3.8 -m pip install --upgrade wheel pip cffi +RUN python3.8 -m pip install --no-cache-dir sentencepiece==0.1.90 +RUN python3.8 -m pip install --no-cache-dir azureml-sdk annoy chardet datasketch docopt jedi libcst msgpack opentelemetry-api opentelemetry-exporter-jaeger opentelemetry-exporter-prometheus opentelemetry-sdk prometheus-client pystache pyzmq tqdm typing_extensions dpu-utils + + +# ReInstall torch scatter +RUN python3.8 -m pip install --no-cache-dir --upgrade torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html + +# Test installation +RUN python3.8 -c "import torch_scatter" diff --git a/ptgnn/implementations/typilus/train.py b/ptgnn/implementations/typilus/train.py index c3c64f3..ac5cdb5 100644 --- a/ptgnn/implementations/typilus/train.py +++ b/ptgnn/implementations/typilus/train.py @@ -149,7 +149,7 @@ def run(arguments): initialize_metadata = True restore_path = arguments.get("--restore-path", None) - if restore_path: + if restore_path or (arguments["--aml"] and model_path.exists()): initialize_metadata = False model, nn = Graph2Class.restore_model(Path(restore_path)) else: