Add aiida-dynamic-workflows source
This commit is contained in:
Родитель
bf371387a0
Коммит
b9df2d8ccc
|
@ -0,0 +1 @@
|
||||||
|
aiida_dynamic_workflows/_static_version.py export-subst
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
from . import (
|
||||||
|
calculations,
|
||||||
|
common,
|
||||||
|
control,
|
||||||
|
data,
|
||||||
|
engine,
|
||||||
|
parsers,
|
||||||
|
query,
|
||||||
|
report,
|
||||||
|
utils,
|
||||||
|
workflow,
|
||||||
|
)
|
||||||
|
from ._version import __version__ # noqa: F401
|
||||||
|
from .samples import input_samples
|
||||||
|
from .step import step
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"calculations",
|
||||||
|
"common",
|
||||||
|
"control",
|
||||||
|
"data",
|
||||||
|
"engine",
|
||||||
|
"input_samples",
|
||||||
|
"parsers",
|
||||||
|
"report",
|
||||||
|
"query",
|
||||||
|
"step",
|
||||||
|
"utils",
|
||||||
|
"workflow",
|
||||||
|
"__version__",
|
||||||
|
]
|
|
@ -0,0 +1,13 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
# This file will be overwritten by setup.py when a source or binary
|
||||||
|
# distribution is made. The magic value "__use_git__" is interpreted by
|
||||||
|
# version.py.
|
||||||
|
|
||||||
|
version = "__use_git__"
|
||||||
|
|
||||||
|
# These values are only set if the distribution was created with 'git archive'
|
||||||
|
refnames = "$Format:%D$"
|
||||||
|
git_hash = "$Format:%h$"
|
|
@ -0,0 +1,209 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from setuptools.command.build_py import build_py as build_py_orig
|
||||||
|
from setuptools.command.sdist import sdist as sdist_orig
|
||||||
|
|
||||||
|
Version = namedtuple("Version", ("release", "dev", "labels"))
|
||||||
|
|
||||||
|
# No public API
|
||||||
|
__all__ = []
|
||||||
|
|
||||||
|
package_root = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
package_name = os.path.basename(package_root)
|
||||||
|
distr_root = os.path.dirname(package_root)
|
||||||
|
# If the package is inside a "src" directory the
|
||||||
|
# distribution root is 1 level up.
|
||||||
|
if os.path.split(distr_root)[1] == "src":
|
||||||
|
_package_root_inside_src = True
|
||||||
|
distr_root = os.path.dirname(distr_root)
|
||||||
|
else:
|
||||||
|
_package_root_inside_src = False
|
||||||
|
|
||||||
|
STATIC_VERSION_FILE = "_static_version.py"
|
||||||
|
|
||||||
|
|
||||||
|
def get_version(version_file=STATIC_VERSION_FILE):
|
||||||
|
version_info = get_static_version_info(version_file)
|
||||||
|
version = version_info["version"]
|
||||||
|
if version == "__use_git__":
|
||||||
|
version = get_version_from_git()
|
||||||
|
if not version:
|
||||||
|
version = get_version_from_git_archive(version_info)
|
||||||
|
if not version:
|
||||||
|
version = Version("unknown", None, None)
|
||||||
|
return pep440_format(version)
|
||||||
|
else:
|
||||||
|
return version
|
||||||
|
|
||||||
|
|
||||||
|
def get_static_version_info(version_file=STATIC_VERSION_FILE):
|
||||||
|
version_info = {}
|
||||||
|
with open(os.path.join(package_root, version_file), "rb") as f:
|
||||||
|
exec(f.read(), {}, version_info)
|
||||||
|
return version_info
|
||||||
|
|
||||||
|
|
||||||
|
def version_is_from_git(version_file=STATIC_VERSION_FILE):
|
||||||
|
return get_static_version_info(version_file)["version"] == "__use_git__"
|
||||||
|
|
||||||
|
|
||||||
|
def pep440_format(version_info):
|
||||||
|
release, dev, labels = version_info
|
||||||
|
|
||||||
|
version_parts = [release]
|
||||||
|
if dev:
|
||||||
|
if release.endswith("-dev") or release.endswith(".dev"):
|
||||||
|
version_parts.append(dev)
|
||||||
|
else: # prefer PEP440 over strict adhesion to semver
|
||||||
|
version_parts.append(".dev{}".format(dev))
|
||||||
|
|
||||||
|
if labels:
|
||||||
|
version_parts.append("+")
|
||||||
|
version_parts.append(".".join(labels))
|
||||||
|
|
||||||
|
return "".join(version_parts)
|
||||||
|
|
||||||
|
|
||||||
|
def get_version_from_git():
|
||||||
|
try:
|
||||||
|
p = subprocess.Popen(
|
||||||
|
["git", "rev-parse", "--show-toplevel"],
|
||||||
|
cwd=distr_root,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
if p.wait() != 0:
|
||||||
|
return
|
||||||
|
if not os.path.samefile(p.communicate()[0].decode().rstrip("\n"), distr_root):
|
||||||
|
# The top-level directory of the current Git repository is not the same
|
||||||
|
# as the root directory of the distribution: do not extract the
|
||||||
|
# version from Git.
|
||||||
|
return
|
||||||
|
|
||||||
|
# git describe --first-parent does not take into account tags from branches
|
||||||
|
# that were merged-in. The '--long' flag gets us the 'dev' version and
|
||||||
|
# git hash, '--always' returns the git hash even if there are no tags.
|
||||||
|
for opts in [["--first-parent"], []]:
|
||||||
|
try:
|
||||||
|
p = subprocess.Popen(
|
||||||
|
["git", "describe", "--long", "--always"] + opts,
|
||||||
|
cwd=distr_root,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
if p.wait() == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
description = (
|
||||||
|
p.communicate()[0]
|
||||||
|
.decode()
|
||||||
|
.strip("v") # Tags can have a leading 'v', but the version should not
|
||||||
|
.rstrip("\n")
|
||||||
|
.rsplit("-", 2) # Split the latest tag, commits since tag, and hash
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
release, dev, git = description
|
||||||
|
except ValueError: # No tags, only the git hash
|
||||||
|
# prepend 'g' to match with format returned by 'git describe'
|
||||||
|
git = "g{}".format(*description)
|
||||||
|
release = "unknown"
|
||||||
|
dev = None
|
||||||
|
|
||||||
|
labels = []
|
||||||
|
if dev == "0":
|
||||||
|
dev = None
|
||||||
|
else:
|
||||||
|
labels.append(git)
|
||||||
|
|
||||||
|
try:
|
||||||
|
p = subprocess.Popen(["git", "diff", "--quiet"], cwd=distr_root)
|
||||||
|
except OSError:
|
||||||
|
labels.append("confused") # This should never happen.
|
||||||
|
else:
|
||||||
|
if p.wait() == 1:
|
||||||
|
labels.append("dirty")
|
||||||
|
|
||||||
|
return Version(release, dev, labels)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: change this logic when there is a git pretty-format
|
||||||
|
# that gives the same output as 'git describe'.
|
||||||
|
# Currently we can only tell the tag the current commit is
|
||||||
|
# pointing to, or its hash (with no version info)
|
||||||
|
# if it is not tagged.
|
||||||
|
def get_version_from_git_archive(version_info):
|
||||||
|
try:
|
||||||
|
refnames = version_info["refnames"]
|
||||||
|
git_hash = version_info["git_hash"]
|
||||||
|
except KeyError:
|
||||||
|
# These fields are not present if we are running from an sdist.
|
||||||
|
# Execution should never reach here, though
|
||||||
|
return None
|
||||||
|
|
||||||
|
if git_hash.startswith("$Format") or refnames.startswith("$Format"):
|
||||||
|
# variables not expanded during 'git archive'
|
||||||
|
return None
|
||||||
|
|
||||||
|
VTAG = "tag: v"
|
||||||
|
refs = set(r.strip() for r in refnames.split(","))
|
||||||
|
version_tags = set(r[len(VTAG) :] for r in refs if r.startswith(VTAG))
|
||||||
|
if version_tags:
|
||||||
|
release, *_ = sorted(version_tags) # prefer e.g. "2.0" over "2.0rc1"
|
||||||
|
return Version(release, dev=None, labels=None)
|
||||||
|
else:
|
||||||
|
return Version("unknown", dev=None, labels=["g{}".format(git_hash)])
|
||||||
|
|
||||||
|
|
||||||
|
__version__ = get_version()
|
||||||
|
|
||||||
|
|
||||||
|
# The following section defines a module global 'cmdclass',
|
||||||
|
# which can be used from setup.py. The 'package_name' and
|
||||||
|
# '__version__' module globals are used (but not modified).
|
||||||
|
|
||||||
|
|
||||||
|
def _write_version(fname):
|
||||||
|
# This could be a hard link, so try to delete it first. Is there any way
|
||||||
|
# to do this atomically together with opening?
|
||||||
|
try:
|
||||||
|
os.remove(fname)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
with open(fname, "w") as f:
|
||||||
|
f.write(
|
||||||
|
"# This file has been created by setup.py.\n"
|
||||||
|
"version = '{}'\n".format(__version__)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _build_py(build_py_orig):
|
||||||
|
def run(self):
|
||||||
|
super().run()
|
||||||
|
_write_version(os.path.join(self.build_lib, package_name, STATIC_VERSION_FILE))
|
||||||
|
|
||||||
|
|
||||||
|
class _sdist(sdist_orig):
|
||||||
|
def make_release_tree(self, base_dir, files):
|
||||||
|
super().make_release_tree(base_dir, files)
|
||||||
|
if _package_root_inside_src:
|
||||||
|
p = os.path.join("src", package_name)
|
||||||
|
else:
|
||||||
|
p = package_name
|
||||||
|
_write_version(os.path.join(base_dir, p, STATIC_VERSION_FILE))
|
||||||
|
|
||||||
|
|
||||||
|
cmdclass = dict(sdist=_sdist, build_py=_build_py)
|
|
@ -0,0 +1,672 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""Aiida Calculations for running arbitrary Python functions."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import textwrap
|
||||||
|
from typing import Any, Dict, Sequence
|
||||||
|
|
||||||
|
import aiida.common
|
||||||
|
import aiida.engine
|
||||||
|
import numpy as np
|
||||||
|
import toolz
|
||||||
|
|
||||||
|
from . import common
|
||||||
|
from .data import (
|
||||||
|
Nil,
|
||||||
|
PyArray,
|
||||||
|
PyData,
|
||||||
|
PyFunction,
|
||||||
|
PyRemoteArray,
|
||||||
|
PyRemoteData,
|
||||||
|
array_mask,
|
||||||
|
array_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PyCalcJob(aiida.engine.CalcJob):
|
||||||
|
"""CalcJob that runs a single Python function."""
|
||||||
|
|
||||||
|
@aiida.common.lang.override
|
||||||
|
def out(self, output_port, value=None) -> None:
|
||||||
|
"""Attach output to output port."""
|
||||||
|
# This hack is necessary to work around a bug with output namespace naming.
|
||||||
|
# Some parts of Aiida consider the namespace/port separator to be '__',
|
||||||
|
# but others think it is '.'.
|
||||||
|
return super().out(output_port.replace("__", "."), value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define(cls, spec: aiida.engine.CalcJobProcessSpec): # noqa: D102
|
||||||
|
super().define(spec)
|
||||||
|
|
||||||
|
spec.input("func", valid_type=PyFunction, help="The function to execute")
|
||||||
|
spec.input_namespace(
|
||||||
|
"kwargs", dynamic=True, help="The (keyword) arguments to the function"
|
||||||
|
)
|
||||||
|
|
||||||
|
spec.output_namespace(
|
||||||
|
"return_values", dynamic=True, help="The return value(s) of the function"
|
||||||
|
)
|
||||||
|
spec.output(
|
||||||
|
"exception", required=False, help="The exception raised (if any)",
|
||||||
|
)
|
||||||
|
|
||||||
|
spec.inputs["metadata"]["options"][
|
||||||
|
"parser_name"
|
||||||
|
].default = "dynamic_workflows.PyCalcParser"
|
||||||
|
spec.inputs["metadata"]["options"]["resources"].default = dict(
|
||||||
|
num_machines=1, num_mpiprocs_per_machine=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: add error codes
|
||||||
|
spec.exit_code(
|
||||||
|
401,
|
||||||
|
"USER_CODE_RAISED",
|
||||||
|
invalidates_cache=True,
|
||||||
|
message="User code raised an Exception.",
|
||||||
|
)
|
||||||
|
spec.exit_code(
|
||||||
|
402,
|
||||||
|
"NONZERO_EXIT_CODE",
|
||||||
|
invalidates_cache=True,
|
||||||
|
message="Script returned non-zero exit code.",
|
||||||
|
)
|
||||||
|
spec.exit_code(
|
||||||
|
403,
|
||||||
|
"MISSING_OUTPUT",
|
||||||
|
invalidates_cache=True,
|
||||||
|
message="Script returned zero exit code, but no output generated.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: refactor this; it is a bit of a mess
|
||||||
|
def prepare_for_submission(
|
||||||
|
self, folder: aiida.common.folders.Folder,
|
||||||
|
) -> aiida.common.CalcInfo: # noqa: D102
|
||||||
|
|
||||||
|
# TODO: update "resources" given the resources specified on "py_func"
|
||||||
|
codeinfo = aiida.common.CodeInfo()
|
||||||
|
codeinfo.code_uuid = self.inputs.code.uuid
|
||||||
|
|
||||||
|
calcinfo = aiida.common.CalcInfo()
|
||||||
|
calcinfo.codes_info = [codeinfo]
|
||||||
|
calcinfo.remote_copy_list = []
|
||||||
|
calcinfo.remote_symlink_list = []
|
||||||
|
|
||||||
|
py_function = self.inputs.func
|
||||||
|
computer = self.inputs.code.computer
|
||||||
|
kwargs = getattr(self.inputs, "kwargs", dict())
|
||||||
|
|
||||||
|
remaining_kwargs_file = "__kwargs__/__remaining__.pickle"
|
||||||
|
kwargs_array_folder_template = "__kwargs__/{}"
|
||||||
|
kwargs_filename_template = "__kwargs__/{}.pickle"
|
||||||
|
function_file = "__func__.pickle"
|
||||||
|
exception_file = "__exception__.pickle"
|
||||||
|
return_value_files = [
|
||||||
|
f"__return_values__/{r}.pickle" for r in py_function.returns
|
||||||
|
]
|
||||||
|
|
||||||
|
folder.get_subfolder("__kwargs__", create=True)
|
||||||
|
folder.get_subfolder("__return_values__", create=True)
|
||||||
|
|
||||||
|
calcinfo.retrieve_list = [exception_file]
|
||||||
|
|
||||||
|
# TODO: figure out how to do this with "folder.copy_file" or whatever
|
||||||
|
with folder.open(function_file, "wb") as f:
|
||||||
|
f.write(py_function.pickle)
|
||||||
|
|
||||||
|
literal_kwargs = dict()
|
||||||
|
local_kwargs = dict()
|
||||||
|
remote_kwargs = dict()
|
||||||
|
remote_array_kwargs = dict()
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
# TODO: refactor this to allow more generic / customizable dispatch
|
||||||
|
if isinstance(v, aiida.orm.BaseType):
|
||||||
|
literal_kwargs[k] = v.value
|
||||||
|
elif isinstance(v, PyArray):
|
||||||
|
literal_kwargs[k] = v.get_array()
|
||||||
|
elif isinstance(v, PyRemoteData):
|
||||||
|
remote_kwargs[k] = v
|
||||||
|
elif isinstance(v, PyRemoteArray):
|
||||||
|
remote_array_kwargs[k] = v
|
||||||
|
elif isinstance(v, Nil):
|
||||||
|
literal_kwargs[k] = None
|
||||||
|
elif isinstance(v, PyData):
|
||||||
|
local_kwargs[k] = v
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsure how to treat '{k}' ({type(v)})")
|
||||||
|
|
||||||
|
for k, v in remote_kwargs.items():
|
||||||
|
# TODO: move the data as needed.
|
||||||
|
if v.computer.uuid != self.inputs.code.computer.uuid:
|
||||||
|
raise ValueError(
|
||||||
|
f"Data passed as '{k}' to '{py_function.name}' is stored "
|
||||||
|
f"on '{v.computer.label}', which is not directly accessible "
|
||||||
|
f"from '{computer.label}'."
|
||||||
|
)
|
||||||
|
calcinfo.remote_symlink_list.append(
|
||||||
|
(computer.uuid, v.pickle_path, kwargs_filename_template.format(k))
|
||||||
|
)
|
||||||
|
|
||||||
|
for k, v in remote_array_kwargs.items():
|
||||||
|
# TODO: move the data as needed.
|
||||||
|
if v.computer.uuid != self.inputs.code.computer.uuid:
|
||||||
|
raise ValueError(
|
||||||
|
f"Data passed as '{k}' to '{py_function.name}' is stored "
|
||||||
|
f"on '{v.computer.label}', which is not directly accessible "
|
||||||
|
f"from '{computer.label}'."
|
||||||
|
)
|
||||||
|
calcinfo.remote_symlink_list.append(
|
||||||
|
(computer.uuid, v.pickle_path, kwargs_array_folder_template.format(k))
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not local_kwargs
|
||||||
|
kwarg_filenames = [kwargs_filename_template.format(k) for k in remote_kwargs]
|
||||||
|
kwarg_array_folders = [
|
||||||
|
kwargs_array_folder_template.format(k) for k in remote_array_kwargs
|
||||||
|
]
|
||||||
|
kwarg_array_shapes = [v.shape for v in remote_array_kwargs.values()]
|
||||||
|
separate_kwargs = list(remote_kwargs.keys())
|
||||||
|
separate_array_kwargs = list(remote_array_kwargs.keys())
|
||||||
|
|
||||||
|
if literal_kwargs:
|
||||||
|
common.dump(literal_kwargs, remaining_kwargs_file, opener=folder.open)
|
||||||
|
|
||||||
|
# Add the '.common' subpackage as a package called 'common'.
|
||||||
|
# This can therefore be used directly from the script.
|
||||||
|
common_package_folder = folder.get_subfolder("common", create=True)
|
||||||
|
for filename, contents in common.package_module_contents():
|
||||||
|
with common_package_folder.open(filename, "w") as f:
|
||||||
|
f.write(contents)
|
||||||
|
|
||||||
|
# TODO: factor this out
|
||||||
|
script = textwrap.dedent(
|
||||||
|
f"""\
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import cloudpickle
|
||||||
|
|
||||||
|
import common
|
||||||
|
|
||||||
|
# Define paths for I/O
|
||||||
|
|
||||||
|
function_file = "{function_file}"
|
||||||
|
separate_kwargs = {separate_kwargs}
|
||||||
|
separate_kwarg_filenames = {kwarg_filenames}
|
||||||
|
separate_array_kwargs = {separate_array_kwargs}
|
||||||
|
separate_array_folders = {kwarg_array_folders}
|
||||||
|
separate_array_shapes = {kwarg_array_shapes}
|
||||||
|
remaining_kwargs_file = "{remaining_kwargs_file}"
|
||||||
|
exception_file = "{exception_file}"
|
||||||
|
return_value_files = {return_value_files}
|
||||||
|
assert return_value_files
|
||||||
|
|
||||||
|
# Load code
|
||||||
|
|
||||||
|
func = common.load(function_file)
|
||||||
|
|
||||||
|
# Load kwargs
|
||||||
|
|
||||||
|
kwargs = dict()
|
||||||
|
# TODO: hard-code this when we switch to a Jinja template
|
||||||
|
# TODO: parallel load using a threadpool
|
||||||
|
for pname, fname in zip(separate_kwargs, separate_kwarg_filenames):
|
||||||
|
kwargs[pname] = common.load(fname)
|
||||||
|
for pname, fname, shape in zip(
|
||||||
|
separate_array_kwargs, separate_array_folders, separate_array_shapes,
|
||||||
|
):
|
||||||
|
kwargs[pname] = common.FileBasedObjectArray(fname, shape=shape)
|
||||||
|
if os.path.exists(remaining_kwargs_file):
|
||||||
|
kwargs.update(common.load(remaining_kwargs_file))
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
|
||||||
|
try:
|
||||||
|
return_values = func(**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
common.dump(e, exception_file)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Output
|
||||||
|
|
||||||
|
if len(return_value_files) == 1:
|
||||||
|
common.dump(return_values, return_value_files[0])
|
||||||
|
else:
|
||||||
|
for r, f in zip(return_values, return_value_files):
|
||||||
|
common.dump(r, f)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with folder.open("__in__.py", "w", encoding="utf8") as handle:
|
||||||
|
handle.write(script)
|
||||||
|
codeinfo.stdin_name = "__in__.py"
|
||||||
|
|
||||||
|
return calcinfo
|
||||||
|
|
||||||
|
|
||||||
|
class PyMapJob(PyCalcJob):
|
||||||
|
"""CalcJob that maps a Python function over (a subset of) its parameters."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define(cls, spec: aiida.engine.CalcJobProcessSpec): # noqa: D102
|
||||||
|
super().define(spec)
|
||||||
|
|
||||||
|
spec.input(
|
||||||
|
"metadata.options.mapspec",
|
||||||
|
valid_type=str,
|
||||||
|
help=(
|
||||||
|
"A specification for which parameters to map over, "
|
||||||
|
"and how to map them"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setting 1 as the default means people won't accidentally
|
||||||
|
# overload the cluster with jobs.
|
||||||
|
spec.input(
|
||||||
|
"metadata.options.max_concurrent_machines",
|
||||||
|
valid_type=int,
|
||||||
|
default=1,
|
||||||
|
help="How many machines to use for this map, maximally.",
|
||||||
|
)
|
||||||
|
spec.input(
|
||||||
|
"metadata.options.cores_per_machine",
|
||||||
|
valid_type=int,
|
||||||
|
help="How many cores per machines to use for this map.",
|
||||||
|
)
|
||||||
|
spec.inputs["metadata"]["options"][
|
||||||
|
"parser_name"
|
||||||
|
].default = "dynamic_workflows.PyMapParser"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mapspec(self) -> common.MapSpec:
|
||||||
|
"""Parameter and shape specification for this map job."""
|
||||||
|
return common.MapSpec.from_string(self.metadata.options.mapspec)
|
||||||
|
|
||||||
|
# TODO: refactor / merge this with PyCalcJob
|
||||||
|
def prepare_for_submission( # noqa: C901
|
||||||
|
self, folder: aiida.common.folders.Folder
|
||||||
|
) -> aiida.common.CalcInfo: # noqa: D102
|
||||||
|
# TODO: update "resources" given the resources specified on "py_func"
|
||||||
|
codeinfo = aiida.common.CodeInfo()
|
||||||
|
codeinfo.code_uuid = self.inputs.code.uuid
|
||||||
|
|
||||||
|
calcinfo = aiida.common.CalcInfo()
|
||||||
|
calcinfo.codes_info = [codeinfo]
|
||||||
|
calcinfo.remote_copy_list = []
|
||||||
|
calcinfo.remote_symlink_list = []
|
||||||
|
|
||||||
|
py_function = self.inputs.func
|
||||||
|
kwargs = self.inputs.kwargs
|
||||||
|
computer = self.inputs.code.computer
|
||||||
|
|
||||||
|
spec = self.mapspec
|
||||||
|
mapped_kwargs = {
|
||||||
|
k: v for k, v in self.inputs.kwargs.items() if k in spec.parameters
|
||||||
|
}
|
||||||
|
mapped_kwarg_shapes = toolz.valmap(array_shape, mapped_kwargs)
|
||||||
|
# This will raise an exception if the shapes are not compatible.
|
||||||
|
spec.shape(mapped_kwarg_shapes)
|
||||||
|
|
||||||
|
function_file = "__func__.pickle"
|
||||||
|
exceptions_folder = "__exceptions__"
|
||||||
|
remaining_kwargs_file = "__kwargs__/__remaining__.pickle"
|
||||||
|
kwarg_file_template = "__kwargs__/{}.pickle"
|
||||||
|
mapped_kwarg_folder_template = "__kwargs__/{}"
|
||||||
|
return_value_folders = [f"__return_values__/{r}" for r in py_function.returns]
|
||||||
|
|
||||||
|
calcinfo.retrieve_list = [exceptions_folder]
|
||||||
|
|
||||||
|
# TODO: figure out how to do this with "folder.copy_file" or whatever
|
||||||
|
with folder.open(function_file, "wb") as f:
|
||||||
|
f.write(py_function.pickle)
|
||||||
|
|
||||||
|
folder.get_subfolder(exceptions_folder, create=True)
|
||||||
|
folder.get_subfolder("__kwargs__", create=True)
|
||||||
|
folder.get_subfolder("__return_values__", create=True)
|
||||||
|
|
||||||
|
folder.get_subfolder(exceptions_folder, create=True)
|
||||||
|
for rv in return_value_folders:
|
||||||
|
folder.get_subfolder(rv, create=True)
|
||||||
|
|
||||||
|
valid_sequence_types = (
|
||||||
|
aiida.orm.List,
|
||||||
|
PyArray,
|
||||||
|
PyRemoteArray,
|
||||||
|
)
|
||||||
|
for k in mapped_kwargs:
|
||||||
|
v = kwargs[k]
|
||||||
|
if not isinstance(v, valid_sequence_types):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected one of {valid_sequence_types} for {k}, "
|
||||||
|
f"but received {type(v)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
remaining_kwargs = dict()
|
||||||
|
mapped_literal_kwargs = dict()
|
||||||
|
remote_kwargs = dict()
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
# TODO: refactor this to allow more generic / customizable dispatch
|
||||||
|
if isinstance(v, (PyRemoteData, PyRemoteArray)):
|
||||||
|
remote_kwargs[k] = v
|
||||||
|
elif isinstance(v, aiida.orm.List) and k in mapped_kwargs:
|
||||||
|
mapped_literal_kwargs[k] = v.get_list()
|
||||||
|
elif isinstance(v, PyArray) and k in mapped_kwargs:
|
||||||
|
mapped_literal_kwargs[k] = v.get_array()
|
||||||
|
elif isinstance(v, aiida.orm.List):
|
||||||
|
remaining_kwargs[k] = v.get_list()
|
||||||
|
elif isinstance(v, PyArray):
|
||||||
|
remaining_kwargs[k] = v.get_array()
|
||||||
|
elif isinstance(v, Nil):
|
||||||
|
remaining_kwargs[k] = None
|
||||||
|
elif isinstance(v, PyData):
|
||||||
|
assert False
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
remaining_kwargs[k] = v.value
|
||||||
|
except AttributeError:
|
||||||
|
raise RuntimeError(f"Unsure how to treat values of type {type(v)}")
|
||||||
|
|
||||||
|
if remaining_kwargs:
|
||||||
|
common.dump(remaining_kwargs, remaining_kwargs_file, opener=folder.open)
|
||||||
|
|
||||||
|
for k, v in mapped_literal_kwargs.items():
|
||||||
|
common.dump(v, kwarg_file_template.format(k), opener=folder.open)
|
||||||
|
|
||||||
|
for k, v in remote_kwargs.items():
|
||||||
|
# TODO: move the data as needed.
|
||||||
|
if v.computer.uuid != self.inputs.code.computer.uuid:
|
||||||
|
raise ValueError(
|
||||||
|
f"Data passed as '{k}' to '{py_function.name}' is stored "
|
||||||
|
f"on '{v.computer.label}', which is not directly accessible "
|
||||||
|
f"from '{computer.label}'."
|
||||||
|
)
|
||||||
|
if k in mapped_kwargs:
|
||||||
|
template = mapped_kwarg_folder_template
|
||||||
|
else:
|
||||||
|
template = kwarg_file_template
|
||||||
|
calcinfo.remote_symlink_list.append(
|
||||||
|
(computer.uuid, v.pickle_path, template.format(k))
|
||||||
|
)
|
||||||
|
|
||||||
|
separate_kwargs = [k for k in remote_kwargs if k not in mapped_kwargs]
|
||||||
|
|
||||||
|
# Add the '.common' subpackage as a package called 'common'.
|
||||||
|
# This can therefore be used directly from the script.
|
||||||
|
common_package_folder = folder.get_subfolder("common", create=True)
|
||||||
|
for filename, contents in common.package_module_contents():
|
||||||
|
with common_package_folder.open(filename, "w") as f:
|
||||||
|
f.write(contents)
|
||||||
|
|
||||||
|
# TODO: factor this out
|
||||||
|
script = textwrap.dedent(
|
||||||
|
f"""\
|
||||||
|
import functools
|
||||||
|
import operator
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import cloudpickle
|
||||||
|
|
||||||
|
import common
|
||||||
|
|
||||||
|
# hard-coded to 1 job per map element for now
|
||||||
|
element_id = int(os.environ["SLURM_ARRAY_TASK_ID"])
|
||||||
|
|
||||||
|
def tails(seq):
|
||||||
|
while seq:
|
||||||
|
seq = seq[1:]
|
||||||
|
yield seq
|
||||||
|
|
||||||
|
def make_strides(shape):
|
||||||
|
return tuple(functools.reduce(operator.mul, s, 1) for s in tails(shape))
|
||||||
|
|
||||||
|
mapspec = common.MapSpec.from_string("{self.metadata.options.mapspec}")
|
||||||
|
kwarg_shapes = {mapped_kwarg_shapes}
|
||||||
|
map_shape = mapspec.shape(kwarg_shapes)
|
||||||
|
output_key = mapspec.output_key(map_shape, element_id)
|
||||||
|
input_keys = {{
|
||||||
|
k: v[0] if len(v) == 1 else v
|
||||||
|
for k, v in mapspec.input_keys(map_shape, element_id).items()
|
||||||
|
}}
|
||||||
|
|
||||||
|
# Define paths for I/O
|
||||||
|
|
||||||
|
function_file = "{function_file}"
|
||||||
|
mapped_kwargs = {spec.parameters}
|
||||||
|
mapped_literal_kwargs = {list(mapped_literal_kwargs.keys())}
|
||||||
|
separate_kwargs = {separate_kwargs}
|
||||||
|
|
||||||
|
kwarg_file_template = "{kwarg_file_template}"
|
||||||
|
mapped_kwarg_folder_template = "{mapped_kwarg_folder_template}"
|
||||||
|
|
||||||
|
remaining_kwargs_file = "{remaining_kwargs_file}"
|
||||||
|
exceptions_folder = "{exceptions_folder}"
|
||||||
|
return_value_folders = {return_value_folders}
|
||||||
|
assert return_value_folders
|
||||||
|
|
||||||
|
# Load code
|
||||||
|
|
||||||
|
func = common.load(function_file)
|
||||||
|
|
||||||
|
# Load kwargs
|
||||||
|
|
||||||
|
kwargs = dict()
|
||||||
|
# TODO: hard-code this when we switch to a Jinja template
|
||||||
|
# TODO: parallel load using a threadpool
|
||||||
|
for pname in separate_kwargs:
|
||||||
|
kwargs[pname] = common.load(kwarg_file_template.format(pname))
|
||||||
|
for pname in mapped_kwargs:
|
||||||
|
if pname in mapped_literal_kwargs:
|
||||||
|
values = common.load(kwarg_file_template.format(pname))
|
||||||
|
else:
|
||||||
|
values = common.FileBasedObjectArray(
|
||||||
|
mapped_kwarg_folder_template.format(pname),
|
||||||
|
shape=kwarg_shapes[pname],
|
||||||
|
)
|
||||||
|
kwargs[pname] = values[input_keys[pname]]
|
||||||
|
if os.path.exists(remaining_kwargs_file):
|
||||||
|
kwargs.update(common.load(remaining_kwargs_file))
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
|
||||||
|
try:
|
||||||
|
return_values = func(**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
exceptions = common.FileBasedObjectArray(
|
||||||
|
exceptions_folder, shape=map_shape
|
||||||
|
)
|
||||||
|
exceptions.dump(output_key, e)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Output
|
||||||
|
|
||||||
|
if len(return_value_folders) == 1:
|
||||||
|
return_values = (return_values,)
|
||||||
|
|
||||||
|
for r, f in zip(return_values, return_value_folders):
|
||||||
|
output_array = common.FileBasedObjectArray(f, shape=map_shape)
|
||||||
|
output_array.dump(output_key, r)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with folder.open("__in__.py", "w", encoding="utf8") as handle:
|
||||||
|
handle.write(script)
|
||||||
|
codeinfo.stdin_name = "__in__.py"
|
||||||
|
|
||||||
|
return calcinfo
|
||||||
|
|
||||||
|
|
||||||
|
@aiida.engine.calcfunction
|
||||||
|
def merge_remote_arrays(**kwargs: PyRemoteArray) -> PyRemoteArray:
|
||||||
|
"""Merge several remote arrays into a single array.
|
||||||
|
|
||||||
|
This is most commonly used for combining the results of
|
||||||
|
several PyMapJobs, where each job only produced a subset of
|
||||||
|
the results (e.g. some tasks failed).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
**kwargs
|
||||||
|
The arrays to merge. The arrays will be merged in the same
|
||||||
|
order as 'kwargs' (i.e. lexicographically by key).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the input arrays are not on the same computer.
|
||||||
|
If the input arrays are not the same shape
|
||||||
|
"""
|
||||||
|
arrays = [kwargs[k] for k in sorted(kwargs.keys())]
|
||||||
|
|
||||||
|
computer, *other_computers = [x.computer for x in arrays]
|
||||||
|
if any(computer.uuid != x.uuid for x in other_computers):
|
||||||
|
raise ValueError("Need to be on same computer")
|
||||||
|
|
||||||
|
shape, *other_shapes = [x.shape for x in arrays]
|
||||||
|
if any(shape != x for x in other_shapes):
|
||||||
|
raise ValueError("Arrays need to be same shape")
|
||||||
|
|
||||||
|
output_array = PyRemoteArray(
|
||||||
|
computer=computer,
|
||||||
|
shape=shape,
|
||||||
|
filename_template=common.array.filename_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
with computer.get_transport() as transport:
|
||||||
|
f = create_remote_folder(transport, computer.get_workdir(), output_array.uuid)
|
||||||
|
for arr in arrays:
|
||||||
|
array_files = os.path.join(arr.get_attribute("remote_path"), "*")
|
||||||
|
transport.copy(array_files, f, recursive=False)
|
||||||
|
|
||||||
|
output_array.attributes["remote_path"] = f
|
||||||
|
return output_array
|
||||||
|
|
||||||
|
|
||||||
|
def create_remote_folder(transport, workdir_template, uuid):
|
||||||
|
"""Create a folder in the Aiida working directory on a remote computer.
|
||||||
|
|
||||||
|
Params
|
||||||
|
------
|
||||||
|
transport
|
||||||
|
A transport to the remote computer.
|
||||||
|
workdir_template
|
||||||
|
Template string for the Aiida working directory on the computer.
|
||||||
|
Must expect a 'username' argument.
|
||||||
|
uuid
|
||||||
|
A UUID uniquely identifying the remote folder. This will be
|
||||||
|
combined with 'workdir_template' to provide a sharded folder
|
||||||
|
structure.
|
||||||
|
"""
|
||||||
|
path = workdir_template.format(username=transport.whoami())
|
||||||
|
# Create a sharded path, e.g. 'ab1234ef...' -> 'ab/12/34ef...'.
|
||||||
|
for segment in (uuid[:2], uuid[2:4], uuid[4:]):
|
||||||
|
path = os.path.join(path, segment)
|
||||||
|
transport.mkdir(path, ignore_existing=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def num_mapjob_tasks(p: aiida.orm.ProcessNode) -> int:
|
||||||
|
"""Return the number of tasks that will be executed by a mapjob."""
|
||||||
|
mapspec = common.MapSpec.from_string(p.get_option("mapspec"))
|
||||||
|
mapped_kwargs = {
|
||||||
|
k: v for k, v in p.inputs.kwargs.items() if k in mapspec.parameters
|
||||||
|
}
|
||||||
|
return np.sum(~expected_mask(mapspec, mapped_kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def expected_mask(mapspec: common.MapSpec, inputs: Dict[str, Any]) -> np.ndarray:
|
||||||
|
"""Return the result mask that one should expect, given a MapSpec and inputs.
|
||||||
|
|
||||||
|
When executing a PyMapJob over inputs that have a mask applied, we expect the
|
||||||
|
output to be masked also. This function returns the expected mask.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
mapspec
|
||||||
|
MapSpec that determines how inputs should be combined.
|
||||||
|
inputs
|
||||||
|
Inputs to map over
|
||||||
|
"""
|
||||||
|
kwarg_shapes = toolz.valmap(array_shape, inputs)
|
||||||
|
kwarg_masks = toolz.valmap(array_mask, inputs)
|
||||||
|
# This will raise an exception if the shapes are incompatible.
|
||||||
|
map_shape = mapspec.shape(kwarg_shapes)
|
||||||
|
map_size = np.prod(map_shape)
|
||||||
|
|
||||||
|
# We only want to run tasks for _unmasked_ map elements.
|
||||||
|
# Additionally, instead of a task array specified like "0,1,2,...",
|
||||||
|
# we want to group tasks into 'runs': "0-30,35-38,...".
|
||||||
|
def is_masked(i):
|
||||||
|
return any(
|
||||||
|
kwarg_masks[k][v] for k, v in mapspec.input_keys(map_shape, i).items()
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.array([is_masked(x) for x in range(map_size)]).reshape(map_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def array_job_spec(mapspec: common.MapSpec, inputs: Dict[str, Any]) -> str:
|
||||||
|
"""Return a job-array task specification, given a MapSpec and inputs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
mapspec
|
||||||
|
MapSpec that determines how inputs should be combined.
|
||||||
|
inputs
|
||||||
|
Inputs to map over
|
||||||
|
"""
|
||||||
|
# We only want tasks in the array job corresponding to the _unmasked_
|
||||||
|
# elements in the map.
|
||||||
|
unmasked_elements = ~expected_mask(mapspec, inputs).reshape(-1)
|
||||||
|
return array_job_spec_from_booleans(unmasked_elements)
|
||||||
|
|
||||||
|
|
||||||
|
def array_job_spec_from_booleans(should_run_task: Sequence[bool]) -> str:
|
||||||
|
"""Return a job-array task specification, given a sequence of booleans.
|
||||||
|
|
||||||
|
If element 'i' in the sequence is 'True', then task 'i' will be included
|
||||||
|
in the job array spec
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> array_job_spec_from_booleans([False, True, True, True, False, True])
|
||||||
|
"1-3,5"
|
||||||
|
"""
|
||||||
|
return ",".join(
|
||||||
|
str(start) if start == stop else f"{start}-{stop}"
|
||||||
|
for start, stop in _group_runs(should_run_task)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _group_runs(s: Sequence[bool]):
|
||||||
|
"""Yield (start, stop) pairs for runs of 'True' in 's'.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> list(_group_runs([True, True, True]))
|
||||||
|
[(0,2)]
|
||||||
|
>>> list(_group_runs(
|
||||||
|
... [False, True, True, True, False, False, True, False, True, True]
|
||||||
|
... )
|
||||||
|
...
|
||||||
|
[(1,3), (6, 6), (8,9)]
|
||||||
|
"""
|
||||||
|
prev_unmasked = False
|
||||||
|
start = None
|
||||||
|
for i, unmasked in enumerate(s):
|
||||||
|
if unmasked and not prev_unmasked:
|
||||||
|
start = i
|
||||||
|
if prev_unmasked and not unmasked:
|
||||||
|
assert start is not None
|
||||||
|
yield (start, i - 1)
|
||||||
|
start = None
|
||||||
|
prev_unmasked = unmasked
|
||||||
|
|
||||||
|
if prev_unmasked and start is not None:
|
||||||
|
yield (start, i)
|
||||||
|
|
||||||
|
|
||||||
|
def all_equal(seq):
|
||||||
|
"""Return True iff all elements of the input are equal."""
|
||||||
|
fst, *rest = seq
|
||||||
|
if not rest:
|
||||||
|
return True
|
||||||
|
return all(r == fst for r in rest)
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
# Common code used both by the plugin and by the runtime that wraps usercode.
|
||||||
|
|
||||||
|
import importlib.resources
|
||||||
|
|
||||||
|
from .array import FileBasedObjectArray
|
||||||
|
from .mapspec import MapSpec
|
||||||
|
from .serialize import dump, load
|
||||||
|
|
||||||
|
__all__ = ["dump", "load", "FileBasedObjectArray", "MapSpec", "package_module_contents"]
|
||||||
|
|
||||||
|
|
||||||
|
def package_module_contents():
|
||||||
|
"""Yield (filename, contents) pairs for each module in this subpackage."""
|
||||||
|
for filename in importlib.resources.contents(__package__):
|
||||||
|
if filename.endswith(".py"):
|
||||||
|
yield filename, importlib.resources.read_text(__package__, filename)
|
|
@ -0,0 +1,136 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
import concurrent.futures
|
||||||
|
import functools
|
||||||
|
import itertools
|
||||||
|
import operator
|
||||||
|
import pathlib
|
||||||
|
from typing import Any, List, Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from . import serialize
|
||||||
|
|
||||||
|
filename_template = "__{:d}__.pickle"
|
||||||
|
|
||||||
|
|
||||||
|
class FileBasedObjectArray:
|
||||||
|
"""Array interface to a folder of files on disk.
|
||||||
|
|
||||||
|
__getitem__ returns "np.ma.masked" for non-existant files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, folder, shape, strides=None, filename_template=filename_template,
|
||||||
|
):
|
||||||
|
self.folder = pathlib.Path(folder).absolute()
|
||||||
|
self.shape = tuple(shape)
|
||||||
|
self.strides = _make_strides(self.shape) if strides is None else tuple(strides)
|
||||||
|
self.filename_template = str(filename_template)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Return number of elements in the array."""
|
||||||
|
return functools.reduce(operator.mul, self.shape, 1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rank(self) -> int:
|
||||||
|
"""Return the rank of the array."""
|
||||||
|
return len(self.shape)
|
||||||
|
|
||||||
|
def _normalize_key(self, key: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||||
|
if not isinstance(key, tuple):
|
||||||
|
key = (key,)
|
||||||
|
if len(key) != self.rank:
|
||||||
|
raise IndexError(
|
||||||
|
f"too many indices for array: array is {self.rank}-dimensional, "
|
||||||
|
"but {len(key)} were indexed"
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(isinstance(k, slice) for k in key):
|
||||||
|
raise NotImplementedError("Cannot yet slice subarrays")
|
||||||
|
|
||||||
|
normalized_key = []
|
||||||
|
for axis, k in enumerate(key):
|
||||||
|
axis_size = self.shape[axis]
|
||||||
|
normalized_k = k if k >= 0 else (axis_size - k)
|
||||||
|
if not (0 <= normalized_k < axis_size):
|
||||||
|
raise IndexError(
|
||||||
|
"index {k} is out of bounds for axis {axis} with size {axis_size}"
|
||||||
|
)
|
||||||
|
normalized_key.append(k)
|
||||||
|
|
||||||
|
return tuple(normalized_key)
|
||||||
|
|
||||||
|
def _index_to_file(self, index: int) -> pathlib.Path:
|
||||||
|
"""Return the filename associated with the given index."""
|
||||||
|
return self.folder / self.filename_template.format(index)
|
||||||
|
|
||||||
|
def _key_to_file(self, key: Tuple[int, ...]) -> pathlib.Path:
|
||||||
|
"""Return the filename associated with the given key."""
|
||||||
|
index = sum(k * s for k, s in zip(key, self.strides))
|
||||||
|
return self._index_to_file(index)
|
||||||
|
|
||||||
|
def _files(self):
|
||||||
|
"""Yield all the filenames that constitute the data in this array."""
|
||||||
|
return map(self._key_to_file, itertools.product(*map(range, self.shape)))
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
key = self._normalize_key(key)
|
||||||
|
if any(isinstance(x, slice) for x in key):
|
||||||
|
# XXX: need to figure out strides in order to implement this.
|
||||||
|
raise NotImplementedError("Cannot yet slice subarrays")
|
||||||
|
|
||||||
|
f = self._key_to_file(key)
|
||||||
|
if not f.is_file():
|
||||||
|
return np.ma.core.masked
|
||||||
|
return serialize.load(f)
|
||||||
|
|
||||||
|
def to_array(self) -> np.ma.core.MaskedArray:
|
||||||
|
"""Return a masked numpy array containing all the data.
|
||||||
|
|
||||||
|
The returned numpy array has dtype "object" and a mask for
|
||||||
|
masking out missing data.
|
||||||
|
"""
|
||||||
|
items = _load_all(map(self._index_to_file, range(self.size)))
|
||||||
|
mask = [not self._index_to_file(i).is_file() for i in range(self.size)]
|
||||||
|
return np.ma.array(items, mask=mask, dtype=object).reshape(self.shape)
|
||||||
|
|
||||||
|
def dump(self, key, value):
|
||||||
|
"""Dump 'value' into the file associated with 'key'.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> arr = FileBasedObjectArray(...)
|
||||||
|
>>> arr.dump((2, 1, 5), dict(a=1, b=2))
|
||||||
|
"""
|
||||||
|
key = self._normalize_key(key)
|
||||||
|
if not any(isinstance(x, slice) for x in key):
|
||||||
|
return serialize.dump(value, self._key_to_file(key))
|
||||||
|
|
||||||
|
raise NotImplementedError("Cannot yet dump subarrays")
|
||||||
|
|
||||||
|
|
||||||
|
def _tails(seq):
|
||||||
|
while seq:
|
||||||
|
seq = seq[1:]
|
||||||
|
yield seq
|
||||||
|
|
||||||
|
|
||||||
|
def _make_strides(shape):
|
||||||
|
return tuple(functools.reduce(operator.mul, s, 1) for s in _tails(shape))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_all(filenames: Sequence[str]) -> List[Any]:
|
||||||
|
def maybe_read(f):
|
||||||
|
return serialize.read(f) if f.is_file() else None
|
||||||
|
|
||||||
|
def maybe_load(x):
|
||||||
|
return serialize.loads(x) if x is not None else None
|
||||||
|
|
||||||
|
# Delegate file reading to the threadpool but deserialize sequentially,
|
||||||
|
# as this is pure Python and CPU bound
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as tex:
|
||||||
|
return [maybe_load(x) for x in tex.map(maybe_read, filenames)]
|
|
@ -0,0 +1,226 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import functools
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from .array import _make_strides
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ArraySpec:
|
||||||
|
"""Specification for a named array, with some axes indexed by named indices."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
axes: Tuple[Optional[str]]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.name.isidentifier():
|
||||||
|
raise ValueError(
|
||||||
|
f"Array name '{self.name}' is not a valid Python identifier"
|
||||||
|
)
|
||||||
|
for i in self.axes:
|
||||||
|
if not (i is None or i.isidentifier()):
|
||||||
|
raise ValueError(f"Index name '{i}' is not a valid Python identifier")
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
indices = [":" if x is None else x for x in self.axes]
|
||||||
|
return f"{self.name}[{', '.join(indices)}]"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def indices(self) -> Tuple[str]:
|
||||||
|
"""Return the names of the indices for this array spec."""
|
||||||
|
return tuple(x for x in self.axes if x is not None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rank(self) -> int:
|
||||||
|
"""Return the rank of this array spec."""
|
||||||
|
return len(self.axes)
|
||||||
|
|
||||||
|
def validate(self, shape: Tuple[int, ...]):
|
||||||
|
"""Raise an exception if 'shape' is not compatible with this array spec."""
|
||||||
|
if len(shape) != self.rank:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expecting array of rank {self.rank}, but got array of shape {shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MapSpec:
|
||||||
|
"""Specification for how to map input axes to output axes.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> mapped = MapSpec.from_string("a[i, j], b[i, j], c[k] -> q[i, j, k]")
|
||||||
|
>>> partial_reduction = MapSpec.from_string("a[i, :], b[:, k] -> q[i, k]")
|
||||||
|
"""
|
||||||
|
|
||||||
|
inputs: Tuple[ArraySpec]
|
||||||
|
output: ArraySpec
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if any(x is None for x in self.output.axes):
|
||||||
|
raise ValueError("Output array must have all axes indexed (no ':').")
|
||||||
|
|
||||||
|
output_indices = set(self.output.indices)
|
||||||
|
input_indices = functools.reduce(
|
||||||
|
set.union, (x.indices for x in self.inputs), set()
|
||||||
|
)
|
||||||
|
|
||||||
|
if extra_indices := output_indices - input_indices:
|
||||||
|
raise ValueError(
|
||||||
|
"Output array has indices that do not appear "
|
||||||
|
f"in the input: {extra_indices}"
|
||||||
|
)
|
||||||
|
if unused_indices := input_indices - output_indices:
|
||||||
|
raise ValueError(
|
||||||
|
"Input array have indices that do not appear "
|
||||||
|
f"in the output: {unused_indices}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> Tuple[str, ...]:
|
||||||
|
"""Return the parameter names of this mapspec."""
|
||||||
|
return tuple(x.name for x in self.inputs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def indices(self) -> Tuple[str, ...]:
|
||||||
|
"""Return the index names for this MapSpec."""
|
||||||
|
return self.output.indices
|
||||||
|
|
||||||
|
def shape(self, shapes: Dict[str, Tuple[int, ...]]) -> Tuple[int, ...]:
|
||||||
|
"""Return the shape of the output of this MapSpec.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
shapes
|
||||||
|
Shapes of the inputs, keyed by name.
|
||||||
|
"""
|
||||||
|
input_names = set(x.name for x in self.inputs)
|
||||||
|
|
||||||
|
if extra_names := set(shapes.keys()) - input_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got extra array {extra_names} that are not accepted by this map."
|
||||||
|
)
|
||||||
|
if missing_names := input_names - set(shapes.keys()):
|
||||||
|
raise ValueError(
|
||||||
|
f"Inputs expected by this map were not provided: {missing_names}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Each individual array is of the appropriate rank
|
||||||
|
for x in self.inputs:
|
||||||
|
x.validate(shapes[x.name])
|
||||||
|
|
||||||
|
# Shapes match between array sharing a named index
|
||||||
|
|
||||||
|
def get_dim(array, index):
|
||||||
|
axis = array.axes.index(index)
|
||||||
|
return shapes[array.name][axis]
|
||||||
|
|
||||||
|
shape = []
|
||||||
|
for index in self.output.indices:
|
||||||
|
relevant_arrays = [x for x in self.inputs if index in x.indices]
|
||||||
|
dim, *rest = [get_dim(x, index) for x in relevant_arrays]
|
||||||
|
if any(dim != x for x in rest):
|
||||||
|
raise ValueError(
|
||||||
|
f"Dimension mismatch for arrays {relevant_arrays} "
|
||||||
|
f"along {index} axis."
|
||||||
|
)
|
||||||
|
shape.append(dim)
|
||||||
|
|
||||||
|
return tuple(shape)
|
||||||
|
|
||||||
|
def output_key(self, shape: Tuple[int, ...], linear_index: int) -> Tuple[int, ...]:
|
||||||
|
"""Return a key used for indexing the output of this map.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
shape
|
||||||
|
The shape of the map output.
|
||||||
|
linear_index
|
||||||
|
The index of the element for which to return the key.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> spec = MapSpec.from_string("x[i, j], y[j, :, k] -> z[i, j, k]")
|
||||||
|
>>> spec.output_key((5, 2, 3), 23)
|
||||||
|
(3, 1, 2)
|
||||||
|
"""
|
||||||
|
if len(shape) != len(self.indices):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected a shape of length {len(self.indices)}, got {shape}"
|
||||||
|
)
|
||||||
|
return tuple(
|
||||||
|
(linear_index // stride) % dim
|
||||||
|
for stride, dim in zip(_make_strides(shape), shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
def input_keys(
|
||||||
|
self, shape: Tuple[int, ...], linear_index: int,
|
||||||
|
) -> Dict[str, Tuple[Union[slice, int]]]:
|
||||||
|
"""Return keys for indexing inputs of this map.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
shape
|
||||||
|
The shape of the map output.
|
||||||
|
linear_index
|
||||||
|
The index of the element for which to return the keys.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> spec = MapSpec("x[i, j], y[j, :, k] -> z[i, j, k]")
|
||||||
|
>>> spec.input_keys((5, 2, 3), 23)
|
||||||
|
{'x': (3, 1), 'y': (1, slice(None, None, None), 2)}
|
||||||
|
"""
|
||||||
|
output_key = self.output_key(shape, linear_index)
|
||||||
|
if len(output_key) != len(self.indices):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected a key of shape {len(self.indices)}, got {output_key}"
|
||||||
|
)
|
||||||
|
ids = dict(zip(self.indices, output_key))
|
||||||
|
return {
|
||||||
|
x.name: tuple(slice(None) if ax is None else ids[ax] for ax in x.axes)
|
||||||
|
for x in self.inputs
|
||||||
|
}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{', '.join(map(str, self.inputs))} -> {self.output}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, expr):
|
||||||
|
"""Construct an MapSpec from a string."""
|
||||||
|
try:
|
||||||
|
in_, out_ = expr.split("->")
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Expected expression of form 'a -> b', but got '{expr}''")
|
||||||
|
|
||||||
|
inputs = _parse_indexed_arrays(in_)
|
||||||
|
outputs = _parse_indexed_arrays(out_)
|
||||||
|
if len(outputs) != 1:
|
||||||
|
raise ValueError(f"Expected a single output, but got {len(outputs)}")
|
||||||
|
(output,) = outputs
|
||||||
|
|
||||||
|
return cls(inputs, output)
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
"""Return a faithful representation of a MapSpec as a string."""
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_index_string(index_string) -> List[Optional[str]]:
|
||||||
|
indices = [idx.strip() for idx in index_string.split(",")]
|
||||||
|
return [i if i != ":" else None for i in indices]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_indexed_arrays(expr) -> List[ArraySpec]:
|
||||||
|
array_pattern = r"(\w+?)\[(.+?)\]"
|
||||||
|
return [
|
||||||
|
ArraySpec(name, _parse_index_string(indices))
|
||||||
|
for name, indices in re.findall(array_pattern, expr)
|
||||||
|
]
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
import cloudpickle
|
||||||
|
|
||||||
|
|
||||||
|
def read(name, opener=open):
|
||||||
|
"""Load file contents as a bytestring."""
|
||||||
|
with opener(name, "rb") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
loads = cloudpickle.loads
|
||||||
|
dumps = cloudpickle.dumps
|
||||||
|
|
||||||
|
|
||||||
|
def load(name, opener=open):
|
||||||
|
"""Load a cloudpickled object from the named file."""
|
||||||
|
with opener(name, "rb") as f:
|
||||||
|
return cloudpickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def dump(obj, name, opener=open):
|
||||||
|
"""Dump an object to the named file using cloudpickle."""
|
||||||
|
with opener(name, "wb") as f:
|
||||||
|
cloudpickle.dump(obj, f)
|
|
@ -0,0 +1,136 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from aiida import get_config_option
|
||||||
|
from aiida.cmdline.commands.cmd_process import process_kill, process_pause, process_play
|
||||||
|
from aiida.cmdline.utils import common, daemon, echo
|
||||||
|
from aiida.engine.daemon.client import get_daemon_client
|
||||||
|
from aiida.orm import ProcessNode, load_node
|
||||||
|
|
||||||
|
|
||||||
|
def kill(process: Union[ProcessNode, int, str], timeout: int = 5) -> bool:
|
||||||
|
"""Kill the specified process.
|
||||||
|
|
||||||
|
Params
|
||||||
|
------
|
||||||
|
process
|
||||||
|
The process to kill.
|
||||||
|
timeout
|
||||||
|
Timeout (in seconds) to wait for confirmation that the process was killed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
True only if the process is now terminated.
|
||||||
|
"""
|
||||||
|
process = _ensure_process_node(process)
|
||||||
|
process_kill.callback([process], timeout=timeout, wait=True)
|
||||||
|
return process.is_terminated
|
||||||
|
|
||||||
|
|
||||||
|
def pause(process: Union[ProcessNode, int, str], timeout: int = 5) -> bool:
|
||||||
|
"""Pause the specified process.
|
||||||
|
|
||||||
|
Paused processes will not continue execution, and can be unpaused later.
|
||||||
|
|
||||||
|
Params
|
||||||
|
------
|
||||||
|
process
|
||||||
|
The process to kill.
|
||||||
|
timeout
|
||||||
|
Timeout (in seconds) to wait for confirmation that the process was killed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
True only if the process is now paused.
|
||||||
|
"""
|
||||||
|
process = _ensure_process_node(process)
|
||||||
|
if process.is_terminated:
|
||||||
|
raise RuntimeError("Cannot pause terminated process {process.pk}.")
|
||||||
|
process_pause.callback([process], all_entries=False, timeout=timeout, wait=True)
|
||||||
|
return process.paused
|
||||||
|
|
||||||
|
|
||||||
|
def unpause(process: Union[ProcessNode, int, str], timeout: int = 5) -> bool:
|
||||||
|
"""Unpause the specified process.
|
||||||
|
|
||||||
|
Params
|
||||||
|
------
|
||||||
|
process
|
||||||
|
The process to kill.
|
||||||
|
timeout
|
||||||
|
Timeout (in seconds) to wait for confirmation that the process was killed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
True only if the process is now unpaused.
|
||||||
|
"""
|
||||||
|
process = _ensure_process_node(process)
|
||||||
|
if process.is_terminated:
|
||||||
|
raise RuntimeError("Cannot unpause terminated process {process.pk}.")
|
||||||
|
process_play.callback([process], all_entries=False, timeout=timeout, wait=True)
|
||||||
|
return not process.paused
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_daemon_restarted(n_workers: Optional[int] = None):
|
||||||
|
"""Restart the daemon (if it is running), or start it (if it is stopped).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_workers
|
||||||
|
The number of daemon workers to start. If not provided, the default
|
||||||
|
number of workers for this profile is used.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
If the daemon is running this is equivalent to running
|
||||||
|
'verdi daemon restart --reset', i.e. we fully restart the daemon, including
|
||||||
|
the circus controller. This ensures that any changes in the environment are
|
||||||
|
properly picked up by the daemon.
|
||||||
|
"""
|
||||||
|
client = get_daemon_client()
|
||||||
|
n_workers = n_workers or get_config_option("daemon.default_workers")
|
||||||
|
|
||||||
|
if client.is_daemon_running:
|
||||||
|
echo.echo("Stopping the daemon...", nl=False)
|
||||||
|
response = client.stop_daemon(wait=True)
|
||||||
|
retcode = daemon.print_client_response_status(response)
|
||||||
|
if retcode:
|
||||||
|
raise RuntimeError(f"Problem restarting Aiida daemon: {response['status']}")
|
||||||
|
|
||||||
|
echo.echo("Starting the daemon...", nl=False)
|
||||||
|
|
||||||
|
# We have to run this in a subprocess because it daemonizes, and we do not
|
||||||
|
# want to daemonize _this_ process.
|
||||||
|
command = [
|
||||||
|
"verdi",
|
||||||
|
"-p",
|
||||||
|
client.profile.name,
|
||||||
|
"daemon",
|
||||||
|
"start-circus",
|
||||||
|
str(n_workers),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
currenv = common.get_env_with_venv_bin()
|
||||||
|
subprocess.check_output(command, env=currenv, stderr=subprocess.STDOUT)
|
||||||
|
except subprocess.CalledProcessError as exception:
|
||||||
|
echo.echo("FAILED", fg="red", bold=True)
|
||||||
|
raise RuntimeError("Failed to start the daemon") from exception
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
response = client.get_status()
|
||||||
|
|
||||||
|
retcode = daemon.print_client_response_status(response)
|
||||||
|
if retcode:
|
||||||
|
raise RuntimeError(f"Problem starting Aiida daemon: {response['status']}")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_process_node(node_or_id: Union[ProcessNode, int, str]) -> ProcessNode:
|
||||||
|
if isinstance(node_or_id, ProcessNode):
|
||||||
|
return node_or_id
|
||||||
|
else:
|
||||||
|
return load_node(node_or_id)
|
|
@ -0,0 +1,458 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""Aiida data plugins for running arbitrary Python functions."""
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
import io
|
||||||
|
from itertools import repeat
|
||||||
|
import operator
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import aiida.orm
|
||||||
|
import cloudpickle
|
||||||
|
import numpy as np
|
||||||
|
import toolz
|
||||||
|
|
||||||
|
# To get Aiida's caching to be useful we need to have a stable way to hash Python
|
||||||
|
# functions. The "default" is to hash the cloudpickle blob, but this is not
|
||||||
|
# typically stable for functions defined in a Jupyter notebook.
|
||||||
|
# TODO: insert something useful here.
|
||||||
|
function_hasher = None
|
||||||
|
|
||||||
|
|
||||||
|
class PyFunction(aiida.orm.Data):
|
||||||
|
"""Aiida representation of a Python function."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# TODO: basic typechecks on these
|
||||||
|
func = kwargs.pop("func")
|
||||||
|
assert callable(func)
|
||||||
|
returns = kwargs.pop("returns")
|
||||||
|
if isinstance(returns, str):
|
||||||
|
returns = [returns]
|
||||||
|
resources = kwargs.pop("resources", None)
|
||||||
|
if resources is None:
|
||||||
|
resources = dict()
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.put_object_from_filelike(
|
||||||
|
path="function.pickle", handle=io.BytesIO(cloudpickle.dumps(func)),
|
||||||
|
)
|
||||||
|
self.set_attribute("resources", resources)
|
||||||
|
self.set_attribute("returns", returns)
|
||||||
|
self.set_attribute("parameters", _parameters(func))
|
||||||
|
|
||||||
|
# If 'function_hasher' is available then we store the
|
||||||
|
# function hash directly, and _get_objects_to_hash will
|
||||||
|
# _not_ use the pickle blob (which is not stable e.g.
|
||||||
|
# for functions defined in a notebook).
|
||||||
|
if callable(function_hasher):
|
||||||
|
self.set_attribute("_function_hash", function_hasher(func))
|
||||||
|
|
||||||
|
try:
|
||||||
|
source = inspect.getsource(func)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.set_attribute("source", source)
|
||||||
|
|
||||||
|
name = getattr(func, "__name__", None)
|
||||||
|
if name:
|
||||||
|
self.set_attribute("name", name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def resources(self) -> Dict[str, str]:
|
||||||
|
"""Resources required by this function."""
|
||||||
|
return self.get_attribute("resources")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def source(self) -> str:
|
||||||
|
"""Source code of this function."""
|
||||||
|
return self.get_attribute("source")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Name of this function."""
|
||||||
|
return self.get_attribute("name")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> List[str]:
|
||||||
|
"""Parameters of this function."""
|
||||||
|
return self.get_attribute("parameters")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def returns(self) -> Optional[List[str]]:
|
||||||
|
"""List of names returned by this function."""
|
||||||
|
return self.get_attribute("returns")
|
||||||
|
|
||||||
|
# TODO: use better caching for this (maybe on the class level?)
|
||||||
|
@functools.cached_property
|
||||||
|
def pickle(self) -> bytes:
|
||||||
|
"""Pickled function."""
|
||||||
|
return self.get_object_content("function.pickle", "rb")
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def callable(self) -> Callable:
|
||||||
|
"""Return the function stored in this object."""
|
||||||
|
return cloudpickle.loads(self.pickle)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __signature__(self):
|
||||||
|
return inspect.signature(self.callable)
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any):
|
||||||
|
"""Call the function stored in this object."""
|
||||||
|
return self.callable(*args, **kwargs)
|
||||||
|
|
||||||
|
def _get_objects_to_hash(self) -> List[Any]:
|
||||||
|
objects = super()._get_objects_to_hash()
|
||||||
|
|
||||||
|
# XXX: this depends on the specifics of the implementation
|
||||||
|
# of super()._get_objects_to_hash(). The second-to-last
|
||||||
|
# elements in 'objects' is the hash of the file repository.
|
||||||
|
# For 'PyFunction' nodes this contains the cloudpickle blob,
|
||||||
|
# which we _do not_ want hashed.
|
||||||
|
if "_function_hash" in self.attributes:
|
||||||
|
*a, _, x = objects
|
||||||
|
return [*a, x]
|
||||||
|
else:
|
||||||
|
return objects
|
||||||
|
|
||||||
|
|
||||||
|
def _parameters(f: Callable) -> List[str]:
|
||||||
|
valid_kinds = [
|
||||||
|
getattr(inspect.Parameter, k) for k in ("POSITIONAL_OR_KEYWORD", "KEYWORD_ONLY")
|
||||||
|
]
|
||||||
|
params = inspect.signature(f).parameters.values()
|
||||||
|
if any(p.kind not in valid_kinds for p in params):
|
||||||
|
raise TypeError("Invalid signature")
|
||||||
|
return [p.name for p in params]
|
||||||
|
|
||||||
|
|
||||||
|
class Nil(aiida.orm.Data):
|
||||||
|
"""Trivial representation of the None type in Aiida."""
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: make this JSON serializable so it can go directly in the DB
|
||||||
|
class PyOutline(aiida.orm.Data):
|
||||||
|
"""Naive Aiida representation of a workflow outline."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
outline = kwargs.pop("outline")
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.put_object_from_filelike(
|
||||||
|
path="outline.pickle", handle=io.BytesIO(cloudpickle.dumps(outline)),
|
||||||
|
)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def value(self):
|
||||||
|
"""Python object loaded from the stored pickle."""
|
||||||
|
return cloudpickle.loads(self.get_object_content("outline.pickle", "rb"))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Annotate these with the class name (useful for visualization)
|
||||||
|
class PyData(aiida.orm.Data):
|
||||||
|
"""Naive Aiida representation of an arbitrary Python object."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pickle_path = kwargs.pop("pickle_path")
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.put_object_from_file(filepath=pickle_path, path="object.pickle")
|
||||||
|
|
||||||
|
# TODO: do caching more intelligently: we could attach a cache to the
|
||||||
|
# _class_ instead so that if we create 2 PyData objects that
|
||||||
|
# point to the _same_ database entry (pk) then we only have to
|
||||||
|
# load the data once.
|
||||||
|
# (does Aiida provide some tooling for this?)
|
||||||
|
@functools.cached_property
|
||||||
|
def value(self):
|
||||||
|
"""Python object loaded from the stored pickle."""
|
||||||
|
return cloudpickle.loads(self.get_object_content("object.pickle", "rb"))
|
||||||
|
|
||||||
|
|
||||||
|
class PyRemoteData(aiida.orm.RemoteData):
|
||||||
|
"""Naive Aiida representation of an arbitrary Python object on a remote computer."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pickle_path = str(kwargs.pop("pickle_path"))
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.set_attribute("pickle_path", pickle_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pickle_path(self):
|
||||||
|
"""Return the remote path that contains the pickle."""
|
||||||
|
return os.path.join(self.get_remote_path(), self.get_attribute("pickle_path"))
|
||||||
|
|
||||||
|
def fetch_value(self):
|
||||||
|
"""Load Python object from the remote pickle."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode="rb") as f:
|
||||||
|
self.getfile(self.get_attribute("pickle_path"), f.name)
|
||||||
|
return cloudpickle.load(f)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_remote_data(cls, rd: aiida.orm.RemoteData, pickle_path: str):
|
||||||
|
"""Return a new PyRemoteData, given an existing RemoteData.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rd
|
||||||
|
RemoteData folder.
|
||||||
|
pickle_path
|
||||||
|
Relative path in the RemoteData that contains pickle data.
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
remote_path=rd.get_remote_path(),
|
||||||
|
pickle_path=pickle_path,
|
||||||
|
computer=rd.computer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PyRemoteArray(aiida.orm.RemoteData):
|
||||||
|
"""Naive Aiida representation of a remote array of arbitrary Python objects.
|
||||||
|
|
||||||
|
Each object is stored in a separate file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
shape = kwargs.pop("shape")
|
||||||
|
filename_template = kwargs.pop("filename_template")
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.set_attribute("shape", tuple(shape))
|
||||||
|
self.set_attribute("filename_template", str(filename_template))
|
||||||
|
|
||||||
|
def _file(self, i: int) -> str:
|
||||||
|
return self.get_attribute("filename_template").format(i)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pickle_path(self):
|
||||||
|
"""Return the remote path that contains the pickle files."""
|
||||||
|
return self.get_remote_path()
|
||||||
|
|
||||||
|
def _fetch_buffer(self, local_files=False):
|
||||||
|
"""Return iterator over Python objects in this array."""
|
||||||
|
|
||||||
|
def _load(dir: Path, pickle_file: str):
|
||||||
|
path = dir / pickle_file
|
||||||
|
if not path.is_file():
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
return cloudpickle.load(f)
|
||||||
|
|
||||||
|
def _iter_files(dir):
|
||||||
|
with ThreadPoolExecutor() as ex:
|
||||||
|
file_gen = map(self._file, range(self.size))
|
||||||
|
yield from ex.map(_load, repeat(dir), file_gen)
|
||||||
|
|
||||||
|
if local_files:
|
||||||
|
yield from _iter_files(Path(self.get_remote_path()))
|
||||||
|
else:
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
dir = Path(os.path.join(temp_dir, "values"))
|
||||||
|
# TODO: do this with chunks, rather than all files at once.
|
||||||
|
with self.get_authinfo().get_transport() as transport:
|
||||||
|
transport.gettree(self.get_remote_path(), dir)
|
||||||
|
yield from _iter_files(dir)
|
||||||
|
|
||||||
|
def fetch_value(self, local_files=False) -> np.ma.core.MaskedArray:
|
||||||
|
"""Return a numpy array with dtype 'object' for this array."""
|
||||||
|
# Objects that have a bogus '__array__' implementation fool
|
||||||
|
# 'buff[:] = xs', so we need to manually fill the array.
|
||||||
|
buff = np.empty((self.size,), dtype=object)
|
||||||
|
for i, x in enumerate(self._fetch_buffer(local_files)):
|
||||||
|
buff[i] = x
|
||||||
|
buff = buff.reshape(self.shape)
|
||||||
|
return np.ma.array(buff, mask=self.mask)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> Tuple[int, ...]:
|
||||||
|
"""Shape of this remote array."""
|
||||||
|
return tuple(self.get_attribute("shape"))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_masked(self) -> bool:
|
||||||
|
"""Return True if some elements of the array are 'masked' (missing)."""
|
||||||
|
return np.any(self.mask)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mask(self) -> np.ndarray:
|
||||||
|
"""Return the mask for the missing elements of the array."""
|
||||||
|
existing_files = set(
|
||||||
|
v["name"] for v in self.listdir_withattributes() if not v["isdir"]
|
||||||
|
)
|
||||||
|
return np.array(
|
||||||
|
[self._file(i) not in existing_files for i in range(self.size)], dtype=bool,
|
||||||
|
).reshape(self.shape)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Size of this remote array (product of the shape)."""
|
||||||
|
return toolz.reduce(operator.mul, self.shape, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class PyArray(PyData):
|
||||||
|
"""Wrapper around PyData for storing a single array."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
array = np.asarray(kwargs.pop("array"))
|
||||||
|
with tempfile.NamedTemporaryFile() as handle:
|
||||||
|
cloudpickle.dump(array, handle)
|
||||||
|
handle.flush()
|
||||||
|
handle.seek(0)
|
||||||
|
super().__init__(pickle_path=handle.name, **kwargs)
|
||||||
|
self.set_attribute("shape", array.shape)
|
||||||
|
self.set_attribute("dtype", str(array.dtype))
|
||||||
|
self._cached = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> Tuple[int, ...]:
|
||||||
|
"""Shape of this remote array."""
|
||||||
|
return tuple(self.get_attribute("shape"))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> Tuple[int, ...]:
|
||||||
|
"""Shape of this remote array."""
|
||||||
|
return np.dtype(self.get_attribute("dtype"))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Size of this remote array (product of the shape)."""
|
||||||
|
return toolz.reduce(operator.mul, self.shape, 1)
|
||||||
|
|
||||||
|
def get_array(self) -> np.ndarray:
|
||||||
|
"""Return the array."""
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class PyException(aiida.orm.Data):
|
||||||
|
"""Aiida representation of a Python exception."""
|
||||||
|
|
||||||
|
# - Exception type
|
||||||
|
# - message
|
||||||
|
# - traceback
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# Register automatic conversion from lists and numpy arrays
|
||||||
|
# to the appropriate Aiida datatypes
|
||||||
|
|
||||||
|
|
||||||
|
@aiida.orm.to_aiida_type.register(type(None))
|
||||||
|
def _(_: None):
|
||||||
|
return Nil()
|
||||||
|
|
||||||
|
|
||||||
|
# Aiida Lists can only handle built-in types, which is not general
|
||||||
|
# enough for our purposes. We therefore convert Python lists into
|
||||||
|
# 1D PyArray types with 'object' dtype.
|
||||||
|
@aiida.orm.to_aiida_type.register(list)
|
||||||
|
def _(xs: list):
|
||||||
|
arr = np.empty((len(xs),), dtype=object)
|
||||||
|
# Objects that have a bogus '__array__' implementation fool
|
||||||
|
# 'arr[:] = xs', so we need to manually fill the array.
|
||||||
|
for i, x in enumerate(xs):
|
||||||
|
arr[i] = x
|
||||||
|
return PyArray(array=arr)
|
||||||
|
|
||||||
|
|
||||||
|
@aiida.orm.to_aiida_type.register(np.ndarray)
|
||||||
|
def _(x):
|
||||||
|
return PyArray(array=x)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_aiida_type(x: Any) -> aiida.orm.Data:
|
||||||
|
"""Return a new Aiida value containing 'x', if not already of an Aiida datatype.
|
||||||
|
|
||||||
|
If 'x' is already an Aiida datatype, then return 'x'.
|
||||||
|
"""
|
||||||
|
if isinstance(x, aiida.orm.Data):
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
r = aiida.orm.to_aiida_type(x)
|
||||||
|
if not isinstance(r, aiida.orm.Data):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Expected 'to_aiida_type' to return an Aiida data node, but "
|
||||||
|
f"got an object of type '{type(r)}' instead (when passed "
|
||||||
|
f"an object of type '{type(x)}')."
|
||||||
|
)
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
# Register handlers for getting native Python objects from their
|
||||||
|
# Aiida equivalents
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def from_aiida_type(x):
|
||||||
|
"""Turn Aiida types into their corresponding native Python types."""
|
||||||
|
raise TypeError(f"Do not know how to convert {type(x)} to native Python type")
|
||||||
|
|
||||||
|
|
||||||
|
@from_aiida_type.register(Nil)
|
||||||
|
def _(_):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@from_aiida_type.register(aiida.orm.BaseType)
|
||||||
|
def _(x):
|
||||||
|
return x.value
|
||||||
|
|
||||||
|
|
||||||
|
@from_aiida_type.register(PyData)
|
||||||
|
def _(x):
|
||||||
|
return x.value
|
||||||
|
|
||||||
|
|
||||||
|
@from_aiida_type.register(PyArray)
|
||||||
|
def _(x):
|
||||||
|
return x.get_array()
|
||||||
|
|
||||||
|
|
||||||
|
# Register handlers for figuring out array shapes for different datatypes
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def array_shape(x) -> Tuple[int, ...]:
|
||||||
|
"""Return the shape of 'x'."""
|
||||||
|
try:
|
||||||
|
return tuple(map(int, x.shape))
|
||||||
|
except AttributeError:
|
||||||
|
raise TypeError(f"No array shape defined for type {type(x)}")
|
||||||
|
|
||||||
|
|
||||||
|
@array_shape.register(aiida.orm.List)
|
||||||
|
def _(x):
|
||||||
|
return (len(x),)
|
||||||
|
|
||||||
|
|
||||||
|
# Register handlers for figuring out array masks for different datatypes
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def array_mask(x) -> np.ndarray:
|
||||||
|
"""Return the mask applied to 'x'."""
|
||||||
|
try:
|
||||||
|
return x.mask
|
||||||
|
except AttributeError:
|
||||||
|
raise TypeError(f"No array mask defined for type {type(x)}")
|
||||||
|
|
||||||
|
|
||||||
|
@array_mask.register(aiida.orm.List)
|
||||||
|
def _(x):
|
||||||
|
return np.full((len(x),), False)
|
||||||
|
|
||||||
|
|
||||||
|
@array_mask.register(PyArray)
|
||||||
|
@array_mask.register(np.ndarray)
|
||||||
|
def _(x):
|
||||||
|
return np.full(x.shape, False)
|
|
@ -0,0 +1,429 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
import copy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import aiida.engine
|
||||||
|
import aiida.orm
|
||||||
|
import toolz
|
||||||
|
|
||||||
|
from .calculations import PyCalcJob, PyMapJob, array_job_spec
|
||||||
|
from .common import MapSpec
|
||||||
|
from .data import PyFunction, ensure_aiida_type
|
||||||
|
from .workchains import RestartedPyCalcJob, RestartedPyMapJob
|
||||||
|
|
||||||
|
__all__ = ["apply", "map_"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ExecutionEnvironment:
|
||||||
|
"""An execution environment in which to run a PyFunction as a PyCalcJob."""
|
||||||
|
|
||||||
|
code_label: str
|
||||||
|
computer_label: str
|
||||||
|
queue: Optional[Tuple[str, int]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def code(self):
|
||||||
|
return aiida.orm.load_code("@".join((self.code_label, self.computer_label)))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def computer(self):
|
||||||
|
return aiida.orm.load_computer(self.computer_label)
|
||||||
|
|
||||||
|
|
||||||
|
def code_from_conda_env(conda_env: str, computer_name: str) -> aiida.orm.Code:
|
||||||
|
c = aiida.orm.load_computer(computer_name)
|
||||||
|
with c.get_transport() as t:
|
||||||
|
username = t.whoami()
|
||||||
|
try:
|
||||||
|
conda_dir = c.get_property("conda_dir").format(username=username)
|
||||||
|
except AttributeError:
|
||||||
|
raise RuntimeError(f"'conda_dir' is not set for {computer_name}.")
|
||||||
|
|
||||||
|
conda_initscript = os.path.join(conda_dir, "etc", "profile.d", "conda.sh")
|
||||||
|
python_path = os.path.join(conda_dir, "envs", conda_env, "bin", "python")
|
||||||
|
|
||||||
|
prepend_text = "\n".join(
|
||||||
|
[f"source {conda_initscript}", f"conda activate {conda_env}"]
|
||||||
|
)
|
||||||
|
|
||||||
|
r, stdout, stderr = t.exec_command_wait(prepend_text)
|
||||||
|
|
||||||
|
if r != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to find Conda environment '{conda_env}' on '{computer_name}':"
|
||||||
|
f"\n{stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
code = aiida.orm.Code((c, python_path), label=conda_env)
|
||||||
|
code.set_prepend_text(prepend_text)
|
||||||
|
code.store()
|
||||||
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
def current_conda_environment() -> str:
|
||||||
|
"""Return current conda environment name."""
|
||||||
|
# from https://stackoverflow.com/a/57716519/3447047
|
||||||
|
return sys.exec_prefix.split(os.sep)[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def execution_environment(conda_env: Optional[str], computer: str, queue=None):
|
||||||
|
if conda_env is None:
|
||||||
|
conda_env = current_conda_environment()
|
||||||
|
code_id = "@".join([conda_env, computer])
|
||||||
|
try:
|
||||||
|
aiida.orm.load_code(code_id)
|
||||||
|
except aiida.common.NotExistent:
|
||||||
|
code = code_from_conda_env(conda_env, computer)
|
||||||
|
code.store()
|
||||||
|
|
||||||
|
if queue and (queue[0] not in get_queues(computer)):
|
||||||
|
raise ValueError(f"Queue '{queue[0]}' does not exist on '{computer}'")
|
||||||
|
|
||||||
|
return ExecutionEnvironment(conda_env, computer, queue)
|
||||||
|
|
||||||
|
|
||||||
|
def get_queues(computer_name) -> List[str]:
|
||||||
|
"""Return a list of valid queue names for the named computer."""
|
||||||
|
computer = aiida.orm.load_computer(computer_name)
|
||||||
|
with computer.get_transport() as t:
|
||||||
|
command = "sinfo --summarize"
|
||||||
|
retval, stdout, stderr = t.exec_command_wait(command)
|
||||||
|
if retval != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"'{command}' failed on on '{computer_name}' "
|
||||||
|
f"with exit code {retval}: {stderr}"
|
||||||
|
)
|
||||||
|
_, *lines = stdout.splitlines()
|
||||||
|
return [line.split(" ")[0] for line in lines]
|
||||||
|
|
||||||
|
|
||||||
|
def local_current_execution_environment() -> ExecutionEnvironment:
|
||||||
|
return execution_environment(None, "localhost")
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessBuilder(aiida.engine.ProcessBuilder):
|
||||||
|
"""ProcessBuilder that is serializable."""
|
||||||
|
|
||||||
|
def on(
|
||||||
|
self, env: ExecutionEnvironment, max_concurrent_machines: Optional[int] = None
|
||||||
|
) -> ProcessBuilder:
|
||||||
|
"""Return a new ProcessBuilder, setting it up for execution on 'env'."""
|
||||||
|
r = copy.deepcopy(self)
|
||||||
|
|
||||||
|
r.code = env.code
|
||||||
|
|
||||||
|
if env.queue is not None:
|
||||||
|
queue_name, cores_per_machine = env.queue
|
||||||
|
r.metadata.options.queue_name = queue_name
|
||||||
|
|
||||||
|
if issubclass(r.process_class, (PyMapJob, RestartedPyMapJob)):
|
||||||
|
# NOTE: We are using a feature of the scheduler (Slurm in our case) to
|
||||||
|
# use array jobs. We could probably figure a way to do this with
|
||||||
|
# the 'direct' scheduler (GNU parallel or sth), but that is out
|
||||||
|
# of scope for now.
|
||||||
|
if env.computer.scheduler_type != "dynamic_workflows.slurm":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Mapping is currently only supported in an environment that "
|
||||||
|
f"supports Slurm array jobs, but {env.computer.label} is "
|
||||||
|
f" configured to use '{env.computer.scheduler_type}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if env.queue is None:
|
||||||
|
raise ValueError(
|
||||||
|
"A queue specification (e.g. ('my-queue', 24) ) is required"
|
||||||
|
)
|
||||||
|
|
||||||
|
r.metadata.options.cores_per_machine = cores_per_machine
|
||||||
|
|
||||||
|
if max_concurrent_machines is not None:
|
||||||
|
r.metadata.options.max_concurrent_machines = max_concurrent_machines
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
def finalize(self, **kwargs) -> ProcessBuilder:
|
||||||
|
"""Return a new ProcessBuilder, setting its 'kwargs' to those provided."""
|
||||||
|
r = copy.deepcopy(self)
|
||||||
|
r.kwargs = toolz.valmap(ensure_aiida_type, kwargs)
|
||||||
|
|
||||||
|
opts = r.metadata.options
|
||||||
|
|
||||||
|
custom_scheduler_commands = ["#SBATCH --requeue"]
|
||||||
|
|
||||||
|
if issubclass(r.process_class, (PyMapJob, RestartedPyMapJob)):
|
||||||
|
mapspec = MapSpec.from_string(opts.mapspec)
|
||||||
|
mapped_kwargs = {
|
||||||
|
k: v for k, v in r.kwargs.items() if k in mapspec.parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
cores_per_job = opts.resources.get(
|
||||||
|
"num_cores_per_mpiproc", 1
|
||||||
|
) * opts.resources.get("num_mpiprocs_per_machine", 1)
|
||||||
|
jobs_per_machine = opts.cores_per_machine // cores_per_job
|
||||||
|
max_concurrent_jobs = jobs_per_machine * opts.max_concurrent_machines
|
||||||
|
|
||||||
|
task_spec = array_job_spec(mapspec, mapped_kwargs)
|
||||||
|
# NOTE: This assumes that we are running on Slurm.
|
||||||
|
custom_scheduler_commands.append(
|
||||||
|
f"#SBATCH --array={task_spec}%{max_concurrent_jobs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
opts.custom_scheduler_commands = "\n".join(custom_scheduler_commands)
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
def with_restarts(self, max_restarts: int) -> ProcessBuilder:
|
||||||
|
"""Return a new builder for a RestartedPyCalcJob or RestartedPyMapJob."""
|
||||||
|
if issubclass(self.process_class, (PyMapJob, RestartedPyMapJob)):
|
||||||
|
r = ProcessBuilder(RestartedPyMapJob)
|
||||||
|
elif issubclass(self.process_class, (PyCalcJob, RestartedPyCalcJob)):
|
||||||
|
r = ProcessBuilder(RestartedPyCalcJob)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Do not know how to add restarts to {self.process_class}")
|
||||||
|
_copy_builder_contents(to=r, frm=self)
|
||||||
|
r.metadata.options.max_restarts = max_restarts
|
||||||
|
return r
|
||||||
|
|
||||||
|
# XXX: This is a complete hack to be able to serialize "Outline".
|
||||||
|
# We should think this through more carefully when we come to refactor.
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
def serialized_aiida_nodes(x):
|
||||||
|
if isinstance(x, aiida.orm.Data):
|
||||||
|
if not x.is_stored:
|
||||||
|
x.store()
|
||||||
|
return _AiidaData(x.uuid)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
serialized_data = traverse_mapping(serialized_aiida_nodes, self._data)
|
||||||
|
return self._process_class, serialized_data
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
process_class, serialized_data = state
|
||||||
|
self.__init__(process_class)
|
||||||
|
|
||||||
|
def deserialize_aiida_nodes(x):
|
||||||
|
if isinstance(x, _AiidaData):
|
||||||
|
return aiida.orm.load_node(x.uuid)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
deserialized_data = traverse_mapping(deserialize_aiida_nodes, serialized_data)
|
||||||
|
|
||||||
|
for k, v in deserialized_data.items():
|
||||||
|
if isinstance(v, Mapping):
|
||||||
|
getattr(self, k)._update(v)
|
||||||
|
else:
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
# XXX: This is part of the __getstate__/__setstate__ hack for our custom ProcessBuilder
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _AiidaData:
|
||||||
|
uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_builder_contents(
|
||||||
|
to: aiida.engine.ProcessBuilderNamespace, frm: aiida.engine.ProcessBuilderNamespace,
|
||||||
|
):
|
||||||
|
"""Recursively copy the contents of 'frm' into 'to'.
|
||||||
|
|
||||||
|
This mutates 'to'.
|
||||||
|
"""
|
||||||
|
for k, v in frm.items():
|
||||||
|
if isinstance(v, aiida.engine.ProcessBuilderNamespace):
|
||||||
|
_copy_builder_contents(to[k], v)
|
||||||
|
else:
|
||||||
|
setattr(to, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
def traverse_mapping(f: Callable[[Any], Any], d: Mapping):
|
||||||
|
"""Traverse a nested Mapping, applying 'f' to all non-mapping values."""
|
||||||
|
return {
|
||||||
|
k: traverse_mapping(f, v) if isinstance(v, Mapping) else f(v)
|
||||||
|
for k, v in d.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def apply(f: PyFunction, *, max_restarts: int = 1, **kwargs) -> ProcessBuilder:
|
||||||
|
"""Apply f to **kwargs as a PyCalcJob or RestartedPyCalcJob.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
f
|
||||||
|
The function to apply
|
||||||
|
max_restarts
|
||||||
|
The number of times to run 'f'. If >1 then a builder
|
||||||
|
for a RestartedPyCalcJob is returned, otherwise
|
||||||
|
a builder for a PyCalcJob is returned.
|
||||||
|
**kwargs
|
||||||
|
Keyword arguments to pass to 'f'. Will be converted
|
||||||
|
to Aiida types using "aiida.orm.to_aiida_type" if
|
||||||
|
not already a subtype of "aiida.orm.Data".
|
||||||
|
"""
|
||||||
|
# TODO: check that 'f' applies cleanly to '**kwargs'
|
||||||
|
if max_restarts > 1:
|
||||||
|
builder = ProcessBuilder(RestartedPyCalcJob)
|
||||||
|
builder.metadata.options.max_restarts = int(max_restarts)
|
||||||
|
else:
|
||||||
|
builder = ProcessBuilder(PyCalcJob)
|
||||||
|
|
||||||
|
builder.func = f
|
||||||
|
builder.metadata.label = f.name
|
||||||
|
if kwargs:
|
||||||
|
builder.kwargs = toolz.valmap(ensure_aiida_type, kwargs)
|
||||||
|
if f.resources:
|
||||||
|
_apply_pyfunction_resources(f.resources, builder.metadata.options)
|
||||||
|
return builder
|
||||||
|
|
||||||
|
|
||||||
|
def apply_some(f: PyFunction, *, max_restarts: int = 1, **kwargs) -> ProcessBuilder:
|
||||||
|
"""Apply f to **kwargs as a PyCalcJob or RestartedPyCalcJob.
|
||||||
|
|
||||||
|
'kwargs' may contain _more_ inputs than what 'f' requires: extra
|
||||||
|
inputs are ignored.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
f
|
||||||
|
The function to apply
|
||||||
|
max_restarts
|
||||||
|
The number of times to run 'f'. If >1 then a builder
|
||||||
|
for a RestartedPyCalcJob is returned, otherwise
|
||||||
|
a builder for a PyCalcJob is returned.
|
||||||
|
**kwargs
|
||||||
|
Keyword arguments to pass to 'f'. Will be converted
|
||||||
|
to Aiida types using "aiida.orm.to_aiida_type" if
|
||||||
|
not already a subtype of "aiida.orm.Data".
|
||||||
|
"""
|
||||||
|
if max_restarts > 1:
|
||||||
|
builder = ProcessBuilder(RestartedPyCalcJob)
|
||||||
|
builder.metadata.options.max_restarts = int(max_restarts)
|
||||||
|
else:
|
||||||
|
builder = ProcessBuilder(PyCalcJob)
|
||||||
|
|
||||||
|
builder.func = f
|
||||||
|
builder.metadata.label = f.name
|
||||||
|
relevant_kwargs = toolz.keyfilter(lambda k: k in f.parameters, kwargs)
|
||||||
|
if relevant_kwargs:
|
||||||
|
builder.kwargs = toolz.valmap(ensure_aiida_type, relevant_kwargs)
|
||||||
|
if f.resources:
|
||||||
|
_apply_pyfunction_resources(f.resources, builder.metadata.options)
|
||||||
|
return builder
|
||||||
|
|
||||||
|
|
||||||
|
def map_(
|
||||||
|
f: PyFunction,
|
||||||
|
spec: Union[str, MapSpec],
|
||||||
|
*,
|
||||||
|
max_concurrent_machines: Optional[int] = None,
|
||||||
|
max_restarts: int = 1,
|
||||||
|
**kwargs,
|
||||||
|
) -> aiida.engine.ProcessBuilder:
|
||||||
|
"""Map 'f' over (a subset of) its inputs as a PyMapJob.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
f
|
||||||
|
Function to map over
|
||||||
|
spec
|
||||||
|
Specification for which parameters to map over, and how to map them.
|
||||||
|
max_concurrent_machines
|
||||||
|
The maximum number of machines to use concurrently.
|
||||||
|
max_restarts
|
||||||
|
The maximum number of times to restart the PyMapJob before returning
|
||||||
|
a partial (masked) result and a non-zero exit code.
|
||||||
|
**kwargs
|
||||||
|
Keyword arguments to 'f'. Any arguments that are to be mapped over
|
||||||
|
must by Aiida lists.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> from aiida.orm import List
|
||||||
|
>>> import aiida_dynamic_workflows as flow
|
||||||
|
>>>
|
||||||
|
>>> f = flow.step(lambda x, y: x + y, returns="sum")
|
||||||
|
>>>
|
||||||
|
>>> # We can map over _all_ inputs
|
||||||
|
>>> sums = flow.engine.map_(
|
||||||
|
... f, "x[i], y[i] -> sum[i]", x=List([1, 2, 3]), y=List([4, 5, 6])
|
||||||
|
... )
|
||||||
|
>>> # or we can map over a _subset_ of inputs
|
||||||
|
>>> only_one = flow.engine.map_(f, "x[i] -> sum[i]", x=List([1, 2, 3]), y=5)
|
||||||
|
>>> # or we can do an "outer product":
|
||||||
|
>>> outer= flow.engine.map_(
|
||||||
|
... f, "x[i], y[j] -> sum[i, j]", x=List([1, 2, 3]), y=List([4, 5, 6])
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
if max_restarts > 1:
|
||||||
|
builder = ProcessBuilder(RestartedPyMapJob)
|
||||||
|
builder.metadata.options.max_restarts = int(max_restarts)
|
||||||
|
else:
|
||||||
|
builder = ProcessBuilder(PyMapJob)
|
||||||
|
|
||||||
|
builder.func = f
|
||||||
|
builder.metadata.label = f.name
|
||||||
|
|
||||||
|
if isinstance(spec, str):
|
||||||
|
spec = MapSpec.from_string(spec)
|
||||||
|
elif not isinstance(spec, MapSpec):
|
||||||
|
raise TypeError(f"Expected single string or MapSpec, got {spec}")
|
||||||
|
if unknown_params := set(x.name for x in spec.inputs) - set(f.parameters):
|
||||||
|
raise ValueError(
|
||||||
|
f"{f} cannot be mapped over parameters that "
|
||||||
|
f"it does not take: {unknown_params}"
|
||||||
|
)
|
||||||
|
builder.metadata.options.mapspec = spec.to_string()
|
||||||
|
|
||||||
|
if max_concurrent_machines is not None:
|
||||||
|
builder.metadata.options.max_concurrent_machines = max_concurrent_machines
|
||||||
|
|
||||||
|
if f.resources:
|
||||||
|
_apply_pyfunction_resources(f.resources, builder.metadata.options)
|
||||||
|
|
||||||
|
if not kwargs:
|
||||||
|
return builder
|
||||||
|
|
||||||
|
return builder.finalize(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_pyfunction_resources(
|
||||||
|
resources: Dict, options: aiida.engine.ProcessBuilderNamespace
|
||||||
|
) -> None:
|
||||||
|
"""Apply the resource specification in 'resources' to the CalcJob options 'options'.
|
||||||
|
|
||||||
|
This mutates 'options'.
|
||||||
|
"""
|
||||||
|
memory = resources.get("memory")
|
||||||
|
if memory is not None:
|
||||||
|
# The Aiida Slurm plugin erroneously uses the multiplyer "1024" when converting
|
||||||
|
# to MegaBytes and passing to "--mem", so we must use it here also.
|
||||||
|
multiplier = {"kB": 1, "MB": 1024, "GB": 1000 * 1024}
|
||||||
|
amount, unit = memory[:-2], memory[-2:]
|
||||||
|
options.max_memory_kb = int(amount) * multiplier[unit]
|
||||||
|
|
||||||
|
cores = resources.get("cores")
|
||||||
|
if cores is not None:
|
||||||
|
options.resources["num_cores_per_mpiproc"] = int(cores)
|
||||||
|
|
||||||
|
|
||||||
|
def all_equal(seq):
|
||||||
|
"""Return True iff all elements of 'seq' are equal.
|
||||||
|
|
||||||
|
Returns 'True' if the sequence contains 0 or 1 elements.
|
||||||
|
"""
|
||||||
|
seq = list(seq)
|
||||||
|
if len(seq) in (0, 1):
|
||||||
|
return True
|
||||||
|
fst, *rest = seq
|
||||||
|
return all(r == fst for r in rest)
|
|
@ -0,0 +1,128 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""Aiida Parsers for interpreting the output of arbitrary Python functions."""
|
||||||
|
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
import aiida.engine
|
||||||
|
import aiida.parsers
|
||||||
|
|
||||||
|
from . import common
|
||||||
|
from .common import MapSpec
|
||||||
|
from .data import PyRemoteArray, PyRemoteData, array_shape
|
||||||
|
|
||||||
|
# TODO: unify 'PyCalcParser' and 'PyMapParser': they are identical except
|
||||||
|
# for the type of the outputs (PyRemoteData vs. PyRemoteArray).
|
||||||
|
|
||||||
|
|
||||||
|
class PyCalcParser(aiida.parsers.Parser):
|
||||||
|
"""Parser for a PyCalcJob."""
|
||||||
|
|
||||||
|
def parse(self, **kwargs): # noqa: D102
|
||||||
|
|
||||||
|
calc = self.node
|
||||||
|
|
||||||
|
def retrieve(value_file):
|
||||||
|
# No actual retrieval occurs; we just store a reference
|
||||||
|
# to the remote value.
|
||||||
|
return PyRemoteData.from_remote_data(
|
||||||
|
calc.outputs.remote_folder, value_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
exception_file = "__exception__.pickle"
|
||||||
|
remote_folder = calc.outputs["remote_folder"]
|
||||||
|
remote_files = remote_folder.listdir()
|
||||||
|
has_exception = exception_file in remote_files
|
||||||
|
|
||||||
|
exit_code = None
|
||||||
|
|
||||||
|
# If any data was produced we create the appropriate outputs.
|
||||||
|
# If something went wrong the exit code will still be non-zero.
|
||||||
|
output_folder = remote_folder.listdir("__return_values__")
|
||||||
|
for r in calc.inputs.func.returns:
|
||||||
|
filename = f"{r}.pickle"
|
||||||
|
path = os.path.join("__return_values__", filename)
|
||||||
|
if filename in output_folder:
|
||||||
|
self.out(f"return_values.{r}", retrieve(path))
|
||||||
|
else:
|
||||||
|
exit_code = self.exit_codes.MISSING_OUTPUT
|
||||||
|
|
||||||
|
try:
|
||||||
|
job_infos = calc.computer.get_scheduler().parse_detailed_job_info(
|
||||||
|
calc.get_detailed_job_info()
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
(job_info,) = job_infos
|
||||||
|
if job_info["State"] == "FAILED":
|
||||||
|
exit_code = self.exit_codes.NONZERO_EXIT_CODE
|
||||||
|
|
||||||
|
if has_exception:
|
||||||
|
self.out("exception", retrieve(exception_file))
|
||||||
|
exit_code = self.exit_codes.USER_CODE_RAISED
|
||||||
|
|
||||||
|
if exit_code is not None:
|
||||||
|
calc.set_exit_status(exit_code.status)
|
||||||
|
calc.set_exit_message(exit_code.message)
|
||||||
|
return exit_code
|
||||||
|
|
||||||
|
|
||||||
|
class PyMapParser(aiida.parsers.Parser):
|
||||||
|
"""Parser for a PyMapJob."""
|
||||||
|
|
||||||
|
def parse(self, **kwargs): # noqa: D102
|
||||||
|
|
||||||
|
calc = self.node
|
||||||
|
|
||||||
|
mapspec = MapSpec.from_string(calc.get_option("mapspec"))
|
||||||
|
mapped_parameter_shapes = {
|
||||||
|
k: array_shape(v)
|
||||||
|
for k, v in calc.inputs.kwargs.items()
|
||||||
|
if k in mapspec.parameters
|
||||||
|
}
|
||||||
|
expected_shape = mapspec.shape(mapped_parameter_shapes)
|
||||||
|
remote_folder = calc.outputs["remote_folder"]
|
||||||
|
has_exceptions = bool(remote_folder.listdir("__exceptions__"))
|
||||||
|
|
||||||
|
def retrieve(return_value_name):
|
||||||
|
return PyRemoteArray(
|
||||||
|
computer=calc.computer,
|
||||||
|
remote_path=os.path.join(
|
||||||
|
calc.outputs.remote_folder.get_remote_path(), return_value_name,
|
||||||
|
),
|
||||||
|
shape=expected_shape,
|
||||||
|
filename_template=common.array.filename_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
exit_code = None
|
||||||
|
|
||||||
|
# If any data was produced we create the appropriate outputs.
|
||||||
|
# Users can still tell something went wrong from the exit code.
|
||||||
|
for r in calc.inputs.func.returns:
|
||||||
|
path = os.path.join("__return_values__", r)
|
||||||
|
has_data = remote_folder.listdir(path)
|
||||||
|
if has_data:
|
||||||
|
self.out(f"return_values.{r}", retrieve(path))
|
||||||
|
else:
|
||||||
|
exit_code = self.exit_codes.MISSING_OUTPUT
|
||||||
|
|
||||||
|
try:
|
||||||
|
job_infos = calc.computer.get_scheduler().parse_detailed_job_info(
|
||||||
|
calc.get_detailed_job_info()
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if any(j["State"] == "FAILED" for j in job_infos):
|
||||||
|
exit_code = self.exit_codes.NONZERO_EXIT_CODE
|
||||||
|
|
||||||
|
if has_exceptions:
|
||||||
|
self.out("exception", retrieve("__exceptions__"))
|
||||||
|
exit_code = self.exit_codes.USER_CODE_RAISED
|
||||||
|
|
||||||
|
if exit_code is not None:
|
||||||
|
calc.set_exit_status(exit_code.status)
|
||||||
|
calc.set_exit_message(exit_code.message)
|
||||||
|
return exit_code
|
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import aiida.common
|
||||||
|
import aiida.engine
|
||||||
|
import aiida.orm
|
||||||
|
|
||||||
|
from .workflow import PyWorkChain
|
||||||
|
|
||||||
|
|
||||||
|
def workflows() -> aiida.orm.QueryBuilder:
|
||||||
|
"""Return an Aiida database query that will return all workflows."""
|
||||||
|
q = aiida.orm.QueryBuilder()
|
||||||
|
q.append(cls=PyWorkChain, tag="flow")
|
||||||
|
q.order_by({"flow": [{"ctime": {"order": "desc"}}]})
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
def running_workflows() -> aiida.orm.QueryBuilder:
|
||||||
|
"""Return an Aiida database query that will return all running workflows."""
|
||||||
|
r = workflows()
|
||||||
|
r.add_filter(
|
||||||
|
"flow",
|
||||||
|
{
|
||||||
|
"attributes.process_state": {
|
||||||
|
"in": [
|
||||||
|
aiida.engine.ProcessState.RUNNING.value,
|
||||||
|
aiida.engine.ProcessState.WAITING.value,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
def recent_workflows(
|
||||||
|
days: int = 0, hours: int = 0, minutes: int = 0
|
||||||
|
) -> aiida.orm.QueryBuilder:
|
||||||
|
"""Return an Aiida database query for all recently started workflows.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
days, hours, minutes
|
||||||
|
Any workflows started more recently than this many days/minutes/hours
|
||||||
|
will be included in the result of the query.
|
||||||
|
"""
|
||||||
|
delta = aiida.common.timezone.now() - datetime.timedelta(
|
||||||
|
days=days, hours=hours, minutes=minutes
|
||||||
|
)
|
||||||
|
r = workflows()
|
||||||
|
r.add_filter("flow", {"ctime": {">": delta}})
|
||||||
|
return r
|
|
@ -0,0 +1,271 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
|
import textwrap
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from IPython.display import Image
|
||||||
|
import aiida.cmdline.utils.common as cmd
|
||||||
|
from aiida.cmdline.utils.query.formatting import format_relative_time
|
||||||
|
import aiida.orm
|
||||||
|
from aiida.tools.visualization import Graph
|
||||||
|
import graphviz
|
||||||
|
|
||||||
|
from . import query
|
||||||
|
from .calculations import PyCalcJob, PyMapJob, num_mapjob_tasks
|
||||||
|
from .data import PyRemoteArray, PyRemoteData
|
||||||
|
from .utils import render_png
|
||||||
|
from .workchains import RestartedPyCalcJob, RestartedPyMapJob
|
||||||
|
from .workflow import PyWorkChain
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"log",
|
||||||
|
"graph",
|
||||||
|
"progress",
|
||||||
|
"running_workflows",
|
||||||
|
"recent_workflows",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
ProcessType = Union[aiida.orm.ProcessNode, int, str]
|
||||||
|
|
||||||
|
|
||||||
|
def log(proc: ProcessType) -> str:
|
||||||
|
"""Return the output of 'verdi process report' for the given process.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
proc
|
||||||
|
The Aiida node for the process, or a numeric ID, or a UUID.
|
||||||
|
"""
|
||||||
|
proc = _ensure_process_node(proc)
|
||||||
|
if isinstance(proc, aiida.orm.CalcJobNode):
|
||||||
|
return cmd.get_calcjob_report(proc)
|
||||||
|
elif isinstance(proc, aiida.orm.WorkChainNode):
|
||||||
|
return cmd.get_workchain_report(proc, levelname="REPORT")
|
||||||
|
elif isinstance(proc, (aiida.orm.CalcFunctionNode, aiida.orm.WorkFunctionNode)):
|
||||||
|
return cmd.get_process_function_report(proc)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Cannot get report for processes of type '{type(proc)}'")
|
||||||
|
|
||||||
|
|
||||||
|
def graph(
|
||||||
|
proc: ProcessType, size=(20, 20), as_png=False
|
||||||
|
) -> Union[graphviz.Digraph, Image]:
|
||||||
|
"""Return a graph visualization of a calculation or workflow.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
proc
|
||||||
|
The Aiida node for the process, or a numeric ID, or a UUID.
|
||||||
|
"""
|
||||||
|
proc = _ensure_process_node(proc)
|
||||||
|
graph = Graph(
|
||||||
|
graph_attr={"size": ",".join(map(str, size)), "rankdir": "LR"},
|
||||||
|
node_sublabel_fn=_node_sublabel,
|
||||||
|
)
|
||||||
|
graph.recurse_descendants(proc, include_process_inputs=True)
|
||||||
|
if as_png:
|
||||||
|
return render_png(graph.graphviz)
|
||||||
|
return graph.graphviz
|
||||||
|
|
||||||
|
|
||||||
|
def progress(proc: ProcessType) -> str:
|
||||||
|
"""Return a progress report of the given calculation or workflow.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
proc
|
||||||
|
The Aiida node for the process, or a numeric ID, or a UUID.
|
||||||
|
"""
|
||||||
|
proc = _ensure_process_node(proc)
|
||||||
|
if isinstance(proc, aiida.orm.CalcJobNode):
|
||||||
|
return _calcjob_progress(proc)
|
||||||
|
elif isinstance(proc, aiida.orm.WorkChainNode):
|
||||||
|
if issubclass(proc.process_class, PyWorkChain):
|
||||||
|
return _workflow_progress(proc)
|
||||||
|
elif issubclass(proc.process_class, (RestartedPyCalcJob, RestartedPyMapJob)):
|
||||||
|
return _restarted_calcjob_progress(proc)
|
||||||
|
elif isinstance(proc, (aiida.orm.CalcFunctionNode, aiida.orm.WorkFunctionNode)):
|
||||||
|
return _function_progress(proc)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Cannot get a progress report for processes of type '{type(proc)}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def running_workflows() -> str:
|
||||||
|
"""Return a progress report of the running workflows."""
|
||||||
|
r = _flatten(query.running_workflows().iterall())
|
||||||
|
return "\n\n".join(map(_workflow_progress, r))
|
||||||
|
|
||||||
|
|
||||||
|
def recent_workflows(days: int = 0, hours: int = 0, minutes: int = 0) -> str:
|
||||||
|
"""Return a progress report of all workflows that were started recently.
|
||||||
|
|
||||||
|
This also includes workflows that are already complete.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
days, hours, minutes
|
||||||
|
Any workflows started more recently than this many days/minutes/hours
|
||||||
|
will be included in the result of the query.
|
||||||
|
"""
|
||||||
|
r = _flatten(query.recent_workflows(**locals()).iterall())
|
||||||
|
return "\n\n".join(map(_workflow_progress, r))
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten(xs):
|
||||||
|
for ys in xs:
|
||||||
|
yield from ys
|
||||||
|
|
||||||
|
|
||||||
|
def _workflow_progress(p: aiida.orm.WorkChainNode) -> str:
|
||||||
|
assert issubclass(p.process_class, PyWorkChain)
|
||||||
|
lines = [
|
||||||
|
# This is a _single_ output line
|
||||||
|
f"{p.label or '<No label>'} (pk: {p.id}) "
|
||||||
|
f"[{_process_status(p)}, created {format_relative_time(p.ctime)}]"
|
||||||
|
]
|
||||||
|
for c in p.called:
|
||||||
|
lines.append(textwrap.indent(progress(c), " "))
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _restarted_calcjob_progress(p: aiida.orm.WorkChainNode) -> str:
|
||||||
|
assert issubclass(p.process_class, (RestartedPyCalcJob, RestartedPyMapJob))
|
||||||
|
lines = [
|
||||||
|
f"with_restarts({p.get_option('max_restarts')}) "
|
||||||
|
f"(pk: {p.id}) [{_process_status(p)}]"
|
||||||
|
]
|
||||||
|
for i, c in enumerate(p.called, 1):
|
||||||
|
if c.label == p.label:
|
||||||
|
# The launched process is the payload that we are running with restarts
|
||||||
|
s = f"attempt {i}: {progress(c)}"
|
||||||
|
else:
|
||||||
|
# Some post-processing (for RestartedPyMapJob)
|
||||||
|
s = progress(c)
|
||||||
|
lines.append(textwrap.indent(s, " "))
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _calcjob_progress(p: aiida.orm.CalcJobNode) -> str:
|
||||||
|
assert issubclass(p.process_class, PyCalcJob)
|
||||||
|
s = p.get_state() or p.process_state
|
||||||
|
|
||||||
|
# Show more detailed info while we're waiting for the Slurm job.
|
||||||
|
if s == aiida.common.CalcJobState.WITHSCHEDULER:
|
||||||
|
sections = [
|
||||||
|
f"created {format_relative_time(p.ctime)}",
|
||||||
|
]
|
||||||
|
if p.get_scheduler_state():
|
||||||
|
sections.append(f"{p.get_scheduler_state().value} job {p.get_job_id()}")
|
||||||
|
|
||||||
|
# Show total number of tasks and states of remaining tasks in mapjobs.
|
||||||
|
job_states = _slurm_job_states(p)
|
||||||
|
if job_states:
|
||||||
|
if issubclass(p.process_class, PyMapJob):
|
||||||
|
task_counts = Counter(job_states)
|
||||||
|
task_states = ", ".join(f"{k}: {v}" for k, v in task_counts.items())
|
||||||
|
task_summary = f"{sum(task_counts.values())} / {num_mapjob_tasks(p)}"
|
||||||
|
sections.extend(
|
||||||
|
[
|
||||||
|
f"remaining tasks ({task_summary})",
|
||||||
|
f"task states: {task_states}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sections.append(f"job state: {job_states[0]}")
|
||||||
|
msg = ", ".join(sections)
|
||||||
|
else:
|
||||||
|
msg = _process_status(p)
|
||||||
|
|
||||||
|
return f"{p.label} (pk: {p.id}) [{msg}]"
|
||||||
|
|
||||||
|
|
||||||
|
def _process_status(p: aiida.orm.ProcessNode) -> str:
|
||||||
|
|
||||||
|
generic_failure = (
|
||||||
|
f"failed, run 'aiida_dynamic_workflows.report.log({p.id})' "
|
||||||
|
"for more information"
|
||||||
|
)
|
||||||
|
|
||||||
|
if p.is_finished and not p.is_finished_ok:
|
||||||
|
# 's.value' is "finished", even if the process finished with a non-zero exit
|
||||||
|
# code. We prefer the more informative 'failed' + next steps.
|
||||||
|
msg = generic_failure
|
||||||
|
elif p.is_killed:
|
||||||
|
# Process was killed: 'process_status' includes the reason why.
|
||||||
|
msg = f"killed, {p.process_status}"
|
||||||
|
elif p.is_excepted:
|
||||||
|
# Process failed, and the error occured in the Aiida layers
|
||||||
|
msg = generic_failure
|
||||||
|
elif p.is_created_from_cache:
|
||||||
|
msg = (
|
||||||
|
f"{p.process_state.value} "
|
||||||
|
f"(created from cache, uuid: {p.get_cache_source()})"
|
||||||
|
)
|
||||||
|
elif p.is_finished_ok:
|
||||||
|
msg = "success"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Calcjobs have 'get_state', which gives more fine-grained information
|
||||||
|
msg = p.get_state().value
|
||||||
|
except AttributeError:
|
||||||
|
msg = p.process_state.value
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _function_progress(
|
||||||
|
p: Union[aiida.orm.CalcFunctionNode, aiida.orm.WorkFunctionNode]
|
||||||
|
) -> str:
|
||||||
|
return f"{p.label} (pk: {p.id}) [{p.process_state.value}]"
|
||||||
|
|
||||||
|
|
||||||
|
def _slurm_job_states(process):
|
||||||
|
info = process.get_last_job_info()
|
||||||
|
if not info:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return [x[1] for x in info.raw_data]
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_process_node(
|
||||||
|
node_or_id: Union[aiida.orm.ProcessNode, int, str]
|
||||||
|
) -> aiida.orm.ProcessNode:
|
||||||
|
if isinstance(node_or_id, aiida.orm.ProcessNode):
|
||||||
|
return node_or_id
|
||||||
|
else:
|
||||||
|
return aiida.orm.load_node(node_or_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _node_sublabel(node):
|
||||||
|
if isinstance(node, aiida.orm.CalcJobNode) and issubclass(
|
||||||
|
node.process_class, PyCalcJob
|
||||||
|
):
|
||||||
|
labels = [f"function: {node.inputs.func.name}"]
|
||||||
|
if state := node.get_state():
|
||||||
|
labels.append(f"State: {state.value}")
|
||||||
|
if (job_id := node.get_job_id()) and (state := node.get_scheduler_state()):
|
||||||
|
labels.append(f"Job: {job_id} ({state.value})")
|
||||||
|
if node.exit_status is not None:
|
||||||
|
labels.append(f"Exit Code: {node.exit_status}")
|
||||||
|
if node.exception:
|
||||||
|
labels.append("excepted")
|
||||||
|
return "\n".join(labels)
|
||||||
|
elif isinstance(node, (PyRemoteData, PyRemoteArray)):
|
||||||
|
try:
|
||||||
|
create_link = node.get_incoming().one()
|
||||||
|
except Exception:
|
||||||
|
return aiida.tools.visualization.graph.default_node_sublabels(node)
|
||||||
|
if create_link.link_label.startswith("return_values"):
|
||||||
|
return create_link.link_label.split("__")[1]
|
||||||
|
else:
|
||||||
|
return create_link.link_label
|
||||||
|
else:
|
||||||
|
return aiida.tools.visualization.graph.default_node_sublabels(node)
|
|
@ -0,0 +1,104 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from typing import Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
|
import aiida.orm
|
||||||
|
import toolz
|
||||||
|
|
||||||
|
from .calculations import PyCalcJob, PyMapJob
|
||||||
|
from .common import MapSpec
|
||||||
|
from .data import PyRemoteArray, from_aiida_type
|
||||||
|
|
||||||
|
|
||||||
|
def input_samples(result: PyRemoteArray) -> Iterable[Dict]:
|
||||||
|
"""Return an iterable of samples, given a result from a PyMapJob.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result
|
||||||
|
The array resulting from the execution of a PyMapJob.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
An iterable of dictionaries, ordered as 'result' (flattened, if
|
||||||
|
'result' is a >1D array). Each dictionary has the same keys (the
|
||||||
|
names of the parameters that produced 'result').
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import pandas as pd
|
||||||
|
>>> # In the following we assume 'charge' is a PyRemoteArray output from a PyMapJob.
|
||||||
|
>>> df = pd.DataFrame(input_samples(charge))
|
||||||
|
>>> # Add a 'charge' column showing the result associated with each sample.
|
||||||
|
>>> df.assign(charge=charge.reshape(-1))
|
||||||
|
"""
|
||||||
|
if result.creator is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot generate sample plan from data that was not produced from a CalcJob"
|
||||||
|
)
|
||||||
|
job = result.creator
|
||||||
|
if not issubclass(job.process_class, PyMapJob):
|
||||||
|
raise TypeError("Expected data that was produced from a MapJob")
|
||||||
|
output_axes = MapSpec.from_string(job.attributes["mapspec"]).output.axes
|
||||||
|
sp = _parameter_spec(result)
|
||||||
|
|
||||||
|
consts = {k: from_aiida_type(v) for k, (v, axes) in sp.items() if axes is None}
|
||||||
|
mapped = {
|
||||||
|
k: (from_aiida_type(v), axes) for k, (v, axes) in sp.items() if axes is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
# This could be done more efficiently if we return instead a dictionary of arrays.
|
||||||
|
|
||||||
|
for el in itertools.product(*map(range, result.shape)):
|
||||||
|
el = dict(zip(output_axes, el))
|
||||||
|
d = {k: v[tuple(el[ax] for ax in axes)] for k, (v, axes) in mapped.items()}
|
||||||
|
yield toolz.merge(consts, d)
|
||||||
|
|
||||||
|
|
||||||
|
def _parameter_spec(result: aiida.orm.Data, axes: Optional[Tuple[str]] = None) -> Dict:
|
||||||
|
"""Return a dictionary specifying the parameters that produced a given 'result'.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result
|
||||||
|
Data produced from a PyCalcJob or PyMapJob.
|
||||||
|
axes
|
||||||
|
Labels for each axis of 'result', used to rename input axis labels.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dictionary mapping parameter names (strings) to pairs: (Aiida node, axis names).
|
||||||
|
"""
|
||||||
|
job = result.creator
|
||||||
|
job_type = job.process_class
|
||||||
|
|
||||||
|
if not issubclass(job_type, PyCalcJob):
|
||||||
|
raise TypeError(f"Don't know what to do with {job_type}")
|
||||||
|
|
||||||
|
if issubclass(job_type, PyMapJob):
|
||||||
|
mapspec = MapSpec.from_string(job.attributes["mapspec"])
|
||||||
|
if axes:
|
||||||
|
assert len(axes) == len(mapspec.output.axes)
|
||||||
|
translation = dict(zip(mapspec.output.axes, axes))
|
||||||
|
else:
|
||||||
|
translation = dict()
|
||||||
|
input_axes = {
|
||||||
|
spec.name: [translation.get(ax, ax) for ax in spec.axes]
|
||||||
|
for spec in mapspec.inputs
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
input_axes = dict()
|
||||||
|
assert axes is None
|
||||||
|
|
||||||
|
kwargs = job.inputs.kwargs if hasattr(job.inputs, "kwargs") else {}
|
||||||
|
# Inputs that were _not_ created by another CalcJob are the parameters we seek.
|
||||||
|
parameters = {k: (v, input_axes.get(k)) for k, v in kwargs.items() if not v.creator}
|
||||||
|
# Inputs that _were_ created by another Calcjob need to have
|
||||||
|
# _their_ inputs inspected, in turn.
|
||||||
|
other_inputs = [(v, input_axes.get(k)) for k, v in kwargs.items() if v.creator]
|
||||||
|
upstream_params = [_parameter_spec(v, ax) for v, ax in other_inputs]
|
||||||
|
|
||||||
|
return toolz.merge(parameters, *upstream_params)
|
|
@ -0,0 +1,187 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
import datetime
|
||||||
|
from typing import List, Optional, T
|
||||||
|
|
||||||
|
from aiida.common.lang import type_check
|
||||||
|
from aiida.schedulers import JobInfo, JobState
|
||||||
|
from aiida.schedulers.plugins.slurm import SlurmScheduler
|
||||||
|
import toolz
|
||||||
|
|
||||||
|
__all__ = ["SlurmSchedulerWithJobArray"]
|
||||||
|
|
||||||
|
|
||||||
|
class SlurmSchedulerWithJobArray(SlurmScheduler):
|
||||||
|
"""A Slurm scheduler that reports only a single JobInfo for job arrays."""
|
||||||
|
|
||||||
|
def _parse_joblist_output(self, retval, stdout, stderr):
|
||||||
|
# Aiida assumes that there is a single job associated with each call
|
||||||
|
# to 'sbatch', but this is not true in the case of job arrays.
|
||||||
|
# In order to meet this requirement we merge the JobInfos for each job
|
||||||
|
# in the array.
|
||||||
|
return merge_job_arrays(super()._parse_joblist_output(retval, stdout, stderr))
|
||||||
|
|
||||||
|
# Return only the necessary fields for 'parse_output' to do its job.
|
||||||
|
# Our fat array jobs mean the response from 'sacct' can be pretty huge.
|
||||||
|
_detailed_job_info_fields = [
|
||||||
|
"JobID",
|
||||||
|
"ExitCode",
|
||||||
|
"State",
|
||||||
|
"Reason",
|
||||||
|
"CPUTime",
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_detailed_job_info_command(self, job_id):
|
||||||
|
fields = ",".join(self._detailed_job_info_fields)
|
||||||
|
# --parsable2 separates fields with pipes, with no trailing pipe
|
||||||
|
return f"sacct --format={fields} --parsable2 --jobs={job_id}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_detailed_job_info(cls, detailed_job_info):
|
||||||
|
"""Parse output from 'sacct', issued after the completion of the job."""
|
||||||
|
type_check(detailed_job_info, dict)
|
||||||
|
|
||||||
|
retval = detailed_job_info["retval"]
|
||||||
|
if retval != 0:
|
||||||
|
stderr = detailed_job_info["stderr"]
|
||||||
|
raise ValueError(f"Error code {retval} returned by 'sacct': {stderr}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
detailed_stdout = detailed_job_info["stdout"]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
"the `detailed_job_info` does not contain the required key `stdout`."
|
||||||
|
)
|
||||||
|
|
||||||
|
type_check(detailed_stdout, str)
|
||||||
|
|
||||||
|
lines = detailed_stdout.splitlines()
|
||||||
|
|
||||||
|
try:
|
||||||
|
fields, *job_infos = lines
|
||||||
|
except IndexError:
|
||||||
|
raise ValueError("`detailed_job_info.stdout` does not contain enough lines")
|
||||||
|
fields = fields.split("|")
|
||||||
|
|
||||||
|
if fields != cls._detailed_job_info_fields:
|
||||||
|
raise ValueError(
|
||||||
|
"Fields returned by 'sacct' do not match fields specified."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the individual job outputs
|
||||||
|
job_infos = [dict(zip(fields, info.split("|"))) for info in job_infos]
|
||||||
|
# Each job has a 'batch' entry also, which we ignore
|
||||||
|
job_infos = [j for j in job_infos if not j["JobID"].endswith(".batch")]
|
||||||
|
|
||||||
|
return job_infos
|
||||||
|
|
||||||
|
def parse_output(self, detailed_job_info, stdout, stderr):
|
||||||
|
"""Parse output from 'sacct', issued after the completion of the job."""
|
||||||
|
from aiida.engine import CalcJob
|
||||||
|
|
||||||
|
job_infos = self.parse_detailed_job_info(detailed_job_info)
|
||||||
|
|
||||||
|
# TODO: figure out how to return richer information to the calcjob, so
|
||||||
|
# that a workchain could in principle reschedule with only the
|
||||||
|
# failed jobs.
|
||||||
|
if any(j["State"] == "OUT_OF_MEMORY" for j in job_infos):
|
||||||
|
return CalcJob.exit_codes.ERROR_SCHEDULER_OUT_OF_MEMORY
|
||||||
|
if any(j["State"] == "TIMEOUT" for j in job_infos):
|
||||||
|
return CalcJob.exit_codes.ERROR_SCHEDULER_OUT_OF_WALLTIME
|
||||||
|
|
||||||
|
|
||||||
|
def merge_job_arrays(jobs: List[JobInfo]) -> List[JobInfo]:
|
||||||
|
"""Merge JobInfos from jobs in the same Slurm Array into a single JobInfo."""
|
||||||
|
mergers = {
|
||||||
|
"job_id": toolz.compose(job_array_id, toolz.first),
|
||||||
|
"dispatch_time": min,
|
||||||
|
"finish_time": toolz.compose(
|
||||||
|
max, toolz.curried.map(with_default(datetime.datetime.min)),
|
||||||
|
),
|
||||||
|
"job_state": total_job_state,
|
||||||
|
"raw_data": toolz.identity,
|
||||||
|
}
|
||||||
|
|
||||||
|
job_array_id_from_info = toolz.compose(
|
||||||
|
job_array_id, toolz.functoolz.attrgetter("job_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
merge_with_functions(*jobs, mergers=mergers, factory=JobInfo)
|
||||||
|
for jobs in toolz.groupby(job_array_id_from_info, jobs).values()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def total_job_state(states: List[JobState]) -> JobState:
|
||||||
|
# Order is important here
|
||||||
|
possible_states = [
|
||||||
|
JobState.UNDETERMINED,
|
||||||
|
JobState.RUNNING,
|
||||||
|
JobState.SUSPENDED,
|
||||||
|
JobState.QUEUED_HELD,
|
||||||
|
JobState.QUEUED,
|
||||||
|
]
|
||||||
|
for ps in possible_states:
|
||||||
|
if any(state == ps for state in states):
|
||||||
|
return ps
|
||||||
|
|
||||||
|
if all(state == JobState.DONE for state in states):
|
||||||
|
return JobState.DONE
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Invalid state encountered")
|
||||||
|
|
||||||
|
|
||||||
|
def job_array_id(job_id: str) -> str:
|
||||||
|
"""Return the ID of the associated array job.
|
||||||
|
|
||||||
|
If the provided job is not part of a job array then
|
||||||
|
the job ID is returned.
|
||||||
|
"""
|
||||||
|
return toolz.first(job_id.split("_"))
|
||||||
|
|
||||||
|
|
||||||
|
@toolz.curry
|
||||||
|
def with_default(default: T, v: Optional[T]) -> T:
|
||||||
|
"""Return 'v' if it is not 'None', otherwise return 'default'."""
|
||||||
|
return default if v is None else v
|
||||||
|
|
||||||
|
|
||||||
|
def merge_with_functions(*dicts, mergers, factory=dict):
|
||||||
|
"""Merge 'dicts', using 'mergers'.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
*dicts
|
||||||
|
The dictionaries / mappings to merge
|
||||||
|
mergers
|
||||||
|
Mapping from keys in 'dicts' to functions. Each function
|
||||||
|
accepts a list of values and returns a single value.
|
||||||
|
factory
|
||||||
|
Function that returns a new instance of the mapping
|
||||||
|
type that we would like returned
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> merge_with_functions(
|
||||||
|
... {"a": 1, "b": 10, "c": "hello"},
|
||||||
|
... {"a": 5, "b": 20, "c": "goodbye"},
|
||||||
|
... mergers={"a": min, "b": max},
|
||||||
|
... )
|
||||||
|
{"a": 1, "b": 20, "c": "goodbye"}
|
||||||
|
"""
|
||||||
|
if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
|
||||||
|
dicts = dicts[0]
|
||||||
|
|
||||||
|
result = factory()
|
||||||
|
for d in dicts:
|
||||||
|
for k, v in d.items():
|
||||||
|
if k not in result:
|
||||||
|
result[k] = [v]
|
||||||
|
else:
|
||||||
|
result[k].append(v)
|
||||||
|
return toolz.itemmap(
|
||||||
|
lambda kv: (kv[0], mergers.get(kv[0], toolz.last)(kv[1])), result, factory
|
||||||
|
)
|
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import toolz
|
||||||
|
|
||||||
|
from .data import PyFunction
|
||||||
|
|
||||||
|
__all__ = ["step"]
|
||||||
|
|
||||||
|
|
||||||
|
@toolz.curry
|
||||||
|
def step(
|
||||||
|
f: Callable,
|
||||||
|
*,
|
||||||
|
returns: Union[str, Tuple[str]] = "_return_value",
|
||||||
|
resources: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> PyFunction:
|
||||||
|
"""Construct a PyFunction from a Python function.
|
||||||
|
|
||||||
|
This function is commonly used as a decorator.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
f
|
||||||
|
The function to transform into a PyFunction
|
||||||
|
returns
|
||||||
|
The name of the output of this function.
|
||||||
|
If multiple names are provided, then 'f' is assumed to return
|
||||||
|
as many values (as a tuple) as there are names.
|
||||||
|
resources
|
||||||
|
Optional specification of computational resources that this
|
||||||
|
function needs. Possible resources are: "memory", "cores".
|
||||||
|
"memory" must be a string containing an integer value followed
|
||||||
|
by one of the following suffixes: "kB", "MB", "GB".
|
||||||
|
"cores" must be a positive integer.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> f = step(lambda x, y: x + y, returns="sum")
|
||||||
|
>>>
|
||||||
|
>>> @step(returns="other_sum", resources={"memory": "10GB", cores=2})
|
||||||
|
... def g(x: int, y: int) -> int:
|
||||||
|
... return x + y
|
||||||
|
...
|
||||||
|
>>> @step(returns=("a", "b"))
|
||||||
|
... def h(x):
|
||||||
|
... return (x + 1, x + 2)
|
||||||
|
...
|
||||||
|
>>>
|
||||||
|
"""
|
||||||
|
# TODO: First query the Aiida DB to see if this function already exists.
|
||||||
|
# This will require having a good hash for Python functions.
|
||||||
|
# This is a hard problem.
|
||||||
|
if resources:
|
||||||
|
_validate_resources(resources)
|
||||||
|
|
||||||
|
node = PyFunction(func=f, returns=returns, resources=resources)
|
||||||
|
node.store()
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_resources(resources) -> Dict:
|
||||||
|
resources = copy.deepcopy(resources)
|
||||||
|
if "memory" in resources:
|
||||||
|
_validate_memory(resources.pop("memory"))
|
||||||
|
if "cores" in resources:
|
||||||
|
_validate_cores(resources.pop("cores"))
|
||||||
|
if resources:
|
||||||
|
raise ValueError(f"Unexpected resource specifications: {list(resources)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_memory(memory: str):
|
||||||
|
mem, unit = memory[:-2], memory[-2:]
|
||||||
|
if not mem.isnumeric():
|
||||||
|
raise ValueError(f"Expected an integer amount of memory, got: '{mem}'")
|
||||||
|
elif int(mem) == 0:
|
||||||
|
raise ValueError("Cannot specify zero memory")
|
||||||
|
valid_units = ("kB", "MB", "GB")
|
||||||
|
if unit not in valid_units:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid memory unit: '{unit}' (expected one of {valid_units})."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_cores(cores: int):
|
||||||
|
if int(cores) != cores:
|
||||||
|
raise ValueError(f"Expected an integer number of cores, got: {cores}")
|
||||||
|
elif cores <= 0:
|
||||||
|
raise ValueError(f"Expected a positive number of cores, got: {cores}")
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from IPython.display import Image
|
||||||
|
import aiida
|
||||||
|
import graphviz
|
||||||
|
|
||||||
|
|
||||||
|
def block_until_done(chain: aiida.orm.WorkChainNode, interval=1) -> int:
|
||||||
|
"""Block a running chain until an exit code is set.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
chain : aiida.orm.WorkChainNode
|
||||||
|
interval : int, optional
|
||||||
|
Checking interval, by default 1
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
Exit code.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
async def wait_until_done(chain: aiida.orm.WorkChainNode) -> None:
|
||||||
|
while chain.exit_status is None:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
|
coro = wait_until_done(chain)
|
||||||
|
loop.run_until_complete(coro)
|
||||||
|
return chain.exit_status
|
||||||
|
|
||||||
|
|
||||||
|
def render_png(g: graphviz.Digraph) -> Image:
|
||||||
|
"""Render 'graphviz.Digraph' as png."""
|
||||||
|
return Image(g.render(format="png"))
|
|
@ -0,0 +1,348 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from aiida.engine import WorkChain, append_, if_, while_
|
||||||
|
import aiida.orm
|
||||||
|
import numpy as np
|
||||||
|
import toolz
|
||||||
|
|
||||||
|
from . import common
|
||||||
|
from .calculations import (
|
||||||
|
PyCalcJob,
|
||||||
|
PyMapJob,
|
||||||
|
array_job_spec_from_booleans,
|
||||||
|
expected_mask,
|
||||||
|
merge_remote_arrays,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Subclass needed for "option" getters/setters, so that a WorkChain
|
||||||
|
# can transparently wrap a CalcJob.
|
||||||
|
class WorkChainNode(aiida.orm.WorkChainNode):
|
||||||
|
"""ORM class for nodes representing the execution of a WorkChain."""
|
||||||
|
|
||||||
|
def get_option(self, name: str) -> Optional[Any]:
|
||||||
|
"""Return the value of an option that was set for this CalcJobNode."""
|
||||||
|
return self.get_attribute(name, None)
|
||||||
|
|
||||||
|
def set_option(self, name: str, value: Any) -> None:
|
||||||
|
"""Set an option to the given value."""
|
||||||
|
self.set_attribute(name, value)
|
||||||
|
|
||||||
|
def get_options(self) -> Dict[str, Any]:
|
||||||
|
"""Return the dictionary of options set for this CalcJobNode."""
|
||||||
|
options = {}
|
||||||
|
for name in self.process_class.spec_options.keys():
|
||||||
|
value = self.get_option(name)
|
||||||
|
if value is not None:
|
||||||
|
options[name] = value
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
def set_options(self, options: Dict[str, Any]) -> None:
|
||||||
|
"""Set the options for this CalcJobNode."""
|
||||||
|
for name, value in options.items():
|
||||||
|
self.set_option(name, value)
|
||||||
|
|
||||||
|
|
||||||
|
# Hack to make this new node type use the Aiida logger.
|
||||||
|
# This is important so that WorkChains that use this node type also
|
||||||
|
# use the Aiida logger.
|
||||||
|
WorkChainNode._logger = aiida.orm.WorkChainNode._logger
|
||||||
|
|
||||||
|
|
||||||
|
class RestartedPyMapJob(WorkChain):
|
||||||
|
"""Workchain that resubmits a PyMapJob until all the tasks are complete.
|
||||||
|
|
||||||
|
Tasks in the PyMapJob that succeeded on previous runs will not be resubmitted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_node_class = WorkChainNode
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define(cls, spec): # noqa: D102
|
||||||
|
super().define(spec)
|
||||||
|
spec.expose_inputs(PyMapJob)
|
||||||
|
spec.expose_outputs(PyMapJob, include=["return_values", "exception"])
|
||||||
|
spec.input(
|
||||||
|
"metadata.options.max_restarts",
|
||||||
|
valid_type=int,
|
||||||
|
default=5,
|
||||||
|
help=(
|
||||||
|
"Maximum number of iterations the work chain will "
|
||||||
|
"restart the process to finish successfully."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
spec.exit_code(
|
||||||
|
410,
|
||||||
|
"MAXIMUM_RESTARTS_EXCEEDED",
|
||||||
|
message="The maximum number of restarts was exceeded.",
|
||||||
|
)
|
||||||
|
|
||||||
|
spec.outline(
|
||||||
|
cls.setup,
|
||||||
|
while_(cls.should_run)(cls.run_mapjob, cls.inspect_result),
|
||||||
|
if_(cls.was_restarted)(cls.merge_arrays, cls.extract_merged_arrays).else_(
|
||||||
|
cls.pass_through_arrays
|
||||||
|
),
|
||||||
|
cls.output,
|
||||||
|
)
|
||||||
|
|
||||||
|
def setup(self): # noqa: D102
|
||||||
|
self.report("Setting up")
|
||||||
|
|
||||||
|
mapspec = common.MapSpec.from_string(self.inputs.metadata.options.mapspec)
|
||||||
|
mapped_inputs = {
|
||||||
|
k: v for k, v in self.inputs.kwargs.items() if k in mapspec.parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
self.ctx.required_mask = expected_mask(mapspec, mapped_inputs)
|
||||||
|
self.ctx.total_output_mask = np.full_like(self.ctx.required_mask, True)
|
||||||
|
|
||||||
|
self.ctx.job_shape = self.ctx.required_mask.shape
|
||||||
|
self.ctx.total_num_tasks = np.sum(~self.ctx.required_mask)
|
||||||
|
|
||||||
|
self.ctx.iteration = 0
|
||||||
|
self.ctx.launched_mapjobs = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_tasks_remaining(self) -> int:
|
||||||
|
"""Return the number of tasks that remain to be run."""
|
||||||
|
return self.ctx.total_num_tasks - np.sum(~self.ctx.total_output_mask)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remaining_task_array(self) -> np.ndarray:
|
||||||
|
"""Return a boolean array indicating which tasks still need to be run."""
|
||||||
|
return np.logical_xor(self.ctx.required_mask, self.ctx.total_output_mask)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_all_results(self) -> bool:
|
||||||
|
"""Return True iff all the necessary outputs are present."""
|
||||||
|
return np.all(self.ctx.total_output_mask == self.ctx.required_mask)
|
||||||
|
|
||||||
|
def should_run(self): # noqa: D102
|
||||||
|
return (
|
||||||
|
not self.has_all_results
|
||||||
|
and self.ctx.iteration < self.inputs.metadata.options.max_restarts
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_mapjob(self): # noqa: D102
|
||||||
|
# Run failed elements only, using custom
|
||||||
|
# Slurm parameters: -A 1,3-10,20%24
|
||||||
|
self.ctx.iteration += 1
|
||||||
|
|
||||||
|
self.report(f"Running MapJob for {self.n_tasks_remaining} tasks")
|
||||||
|
|
||||||
|
inputs = self.exposed_inputs(PyMapJob)
|
||||||
|
|
||||||
|
# Modify "metadata.options.custom_scheduler_commands" so that the
|
||||||
|
# correct tasks in the Slurm Job Array are run.
|
||||||
|
# NOTE: This assumes we are running on Slurm
|
||||||
|
options = inputs["metadata"]["options"]
|
||||||
|
csc = options.custom_scheduler_commands
|
||||||
|
# Remove the existing Array Job specification
|
||||||
|
commands = [x for x in csc.split("\n") if "--array" not in x]
|
||||||
|
# Add an updated Array Job specification
|
||||||
|
task_spec = array_job_spec_from_booleans(self.remaining_task_array.reshape(-1))
|
||||||
|
max_concurrent_jobs = (
|
||||||
|
options.cores_per_machine * options.max_concurrent_machines
|
||||||
|
)
|
||||||
|
commands.append(f"#SBATCH --array={task_spec}%{max_concurrent_jobs}")
|
||||||
|
inputs = toolz.assoc_in(
|
||||||
|
inputs,
|
||||||
|
("metadata", "options", "custom_scheduler_commands"),
|
||||||
|
"\n".join(commands),
|
||||||
|
)
|
||||||
|
|
||||||
|
# "max_restarts" does not apply to PyMapJobs
|
||||||
|
del inputs["metadata"]["options"]["max_restarts"]
|
||||||
|
|
||||||
|
fut = self.submit(PyMapJob, **inputs)
|
||||||
|
return self.to_context(launched_mapjobs=append_(fut))
|
||||||
|
|
||||||
|
def inspect_result(self): # noqa: D102
|
||||||
|
self.report("Inspecting result")
|
||||||
|
|
||||||
|
job = self.ctx.launched_mapjobs[-1]
|
||||||
|
|
||||||
|
m = result_mask(job, self.ctx.job_shape)
|
||||||
|
self.ctx.total_output_mask[~m] = False
|
||||||
|
|
||||||
|
self.report(
|
||||||
|
f"{np.sum(~m)} tasks succeeded, "
|
||||||
|
f"{self.n_tasks_remaining} / {self.ctx.total_num_tasks} remaining"
|
||||||
|
)
|
||||||
|
|
||||||
|
def was_restarted(self): # noqa: D102
|
||||||
|
return self.ctx.iteration > 1
|
||||||
|
|
||||||
|
def merge_arrays(self): # noqa: D102
|
||||||
|
self.report(f"Gathering arrays from {self.ctx.iteration} mapjobs.")
|
||||||
|
assert self.ctx.iteration > 1
|
||||||
|
|
||||||
|
exception_arrays = []
|
||||||
|
return_value_arrays = defaultdict(list)
|
||||||
|
for j in self.ctx.launched_mapjobs:
|
||||||
|
if "exception" in j.outputs:
|
||||||
|
exception_arrays.append(j.outputs.exception)
|
||||||
|
if "return_values" in j.outputs:
|
||||||
|
for k, v in j.outputs.return_values.items():
|
||||||
|
return_value_arrays[k].append(v)
|
||||||
|
|
||||||
|
# 'merge_remote_array' must take **kwargs (this is a limitation of Aiida), so
|
||||||
|
# we convert a list of inputs into a dictionary with keys 'x0', 'x1' etc.
|
||||||
|
def list_to_dict(lst):
|
||||||
|
return {f"x{i}": x for i, x in enumerate(lst)}
|
||||||
|
|
||||||
|
context_update = dict()
|
||||||
|
|
||||||
|
# TODO: switch 'runner.run_get_node' to 'submit' once WorkChain.submit
|
||||||
|
# allows CalcFunctions (it should already; this appears to be a
|
||||||
|
# bug in Aiida).
|
||||||
|
|
||||||
|
if exception_arrays:
|
||||||
|
r = self.runner.run_get_node(
|
||||||
|
merge_remote_arrays, **list_to_dict(exception_arrays),
|
||||||
|
)
|
||||||
|
context_update["exception"] = r.node
|
||||||
|
|
||||||
|
for k, arrays in return_value_arrays.items():
|
||||||
|
r = self.runner.run_get_node(merge_remote_arrays, **list_to_dict(arrays),)
|
||||||
|
context_update[f"return_values.{k}"] = r.node
|
||||||
|
|
||||||
|
return self.to_context(**context_update)
|
||||||
|
|
||||||
|
def extract_merged_arrays(self): # noqa: D102
|
||||||
|
if "exception" in self.ctx:
|
||||||
|
self.ctx.exception = self.ctx.exception.outputs.result
|
||||||
|
if "return_values" in self.ctx:
|
||||||
|
for k, v in self.ctx.return_values.items():
|
||||||
|
self.ctx.return_values[k] = v.outputs.result
|
||||||
|
|
||||||
|
def pass_through_arrays(self): # noqa: D102
|
||||||
|
self.report("Passing through results from single mapjob")
|
||||||
|
assert self.ctx.iteration == 1
|
||||||
|
(job,) = self.ctx.launched_mapjobs
|
||||||
|
if "exception" in job.outputs:
|
||||||
|
self.ctx.exception = job.outputs.exception
|
||||||
|
if "return_values" in job.outputs:
|
||||||
|
for k, v in job.outputs.return_values.items():
|
||||||
|
self.ctx[f"return_values.{k}"] = v
|
||||||
|
|
||||||
|
def output(self): # noqa: D102
|
||||||
|
self.report("Setting outputs")
|
||||||
|
if "exception" in self.ctx:
|
||||||
|
self.out("exception", self.ctx.exception)
|
||||||
|
for k, v in self.ctx.items():
|
||||||
|
if k.startswith("return_values"):
|
||||||
|
self.out(k, v)
|
||||||
|
|
||||||
|
max_restarts = self.inputs.metadata.options.max_restarts
|
||||||
|
if not self.has_all_results and self.ctx.iteration >= max_restarts:
|
||||||
|
self.report(f"Restarted the maximum number of times {max_restarts}")
|
||||||
|
return self.exit_codes.MAXIMUM_RESTARTS_EXCEEDED
|
||||||
|
|
||||||
|
|
||||||
|
def result_mask(job, expected_shape) -> np.ndarray:
|
||||||
|
"""Return the result mask for a PyMapJob that potentially has multiple outputs."""
|
||||||
|
if "return_values" not in job.outputs:
|
||||||
|
return np.full(expected_shape, True)
|
||||||
|
rvs = job.outputs.return_values
|
||||||
|
masks = [getattr(rvs, x).mask for x in rvs]
|
||||||
|
if len(masks) == 1:
|
||||||
|
return masks[0]
|
||||||
|
else:
|
||||||
|
# If for some reason one of the outputs is missing elements (i.e. the
|
||||||
|
# mask value is True) then we need to re-run the corresponding task.
|
||||||
|
return np.logical_or(*masks)
|
||||||
|
|
||||||
|
|
||||||
|
class RestartedPyCalcJob(WorkChain):
|
||||||
|
"""Workchain that resubmits a PyCalcJOb until it succeeds."""
|
||||||
|
|
||||||
|
_node_class = WorkChainNode
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define(cls, spec): # noqa: D102
|
||||||
|
super().define(spec)
|
||||||
|
spec.expose_inputs(PyCalcJob)
|
||||||
|
spec.expose_outputs(PyCalcJob, include=["return_values", "exception"])
|
||||||
|
spec.input(
|
||||||
|
"metadata.options.max_restarts",
|
||||||
|
valid_type=int,
|
||||||
|
default=5,
|
||||||
|
help=(
|
||||||
|
"Maximum number of iterations the work chain will "
|
||||||
|
"restart the process to finish successfully."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
spec.exit_code(
|
||||||
|
410,
|
||||||
|
"MAXIMUM_RESTARTS_EXCEEDED",
|
||||||
|
message="The maximum number of restarts was exceeded.",
|
||||||
|
)
|
||||||
|
spec.exit_code(
|
||||||
|
411, "CHILD_PROCESS_EXCEPTED", message="The child process excepted.",
|
||||||
|
)
|
||||||
|
spec.outline(
|
||||||
|
cls.setup,
|
||||||
|
while_(cls.should_run)(cls.run_calcjob, cls.inspect_result),
|
||||||
|
cls.output,
|
||||||
|
)
|
||||||
|
|
||||||
|
def setup(self): # noqa: D102
|
||||||
|
self.ctx.iteration = 0
|
||||||
|
self.ctx.function_name = self.inputs.func.name
|
||||||
|
self.ctx.children = []
|
||||||
|
self.ctx.is_finished = False
|
||||||
|
|
||||||
|
def should_run(self): # noqa: D102
|
||||||
|
return (
|
||||||
|
not self.ctx.is_finished
|
||||||
|
and self.ctx.iteration < self.inputs.metadata.options.max_restarts
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_calcjob(self): # noqa: D102
|
||||||
|
self.ctx.iteration += 1
|
||||||
|
inputs = self.exposed_inputs(PyCalcJob)
|
||||||
|
del inputs["metadata"]["options"]["max_restarts"]
|
||||||
|
node = self.submit(PyCalcJob, **inputs)
|
||||||
|
|
||||||
|
self.report(
|
||||||
|
f"Launching {self.ctx.function_name}<{node.pk}> "
|
||||||
|
f"iteration #{self.ctx.iteration}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.to_context(children=append_(node))
|
||||||
|
|
||||||
|
def inspect_result(self): # noqa: D102
|
||||||
|
node = self.ctx.children[-1]
|
||||||
|
|
||||||
|
if node.is_excepted:
|
||||||
|
self.report(f"{self.ctx.function_name}<{node.pk}> excepted; aborting")
|
||||||
|
return self.exit_codes.CHILD_PROCESS_EXCEPTED
|
||||||
|
|
||||||
|
self.ctx.is_finished = node.exit_status == 0
|
||||||
|
|
||||||
|
def output(self): # noqa: D102
|
||||||
|
node = self.ctx.children[-1]
|
||||||
|
label = f"{self.ctx.function_name}<{node.pk}>"
|
||||||
|
|
||||||
|
self.out_many(self.exposed_outputs(node, PyCalcJob))
|
||||||
|
|
||||||
|
max_restarts = self.inputs.metadata.options.max_restarts
|
||||||
|
if not self.ctx.is_finished and self.ctx.iteration >= max_restarts:
|
||||||
|
self.report(
|
||||||
|
f"Reached the maximum number of iterations {max_restarts}: "
|
||||||
|
f"last ran {label}"
|
||||||
|
)
|
||||||
|
return self.exit_codes.MAXIMUM_RESTARTS_EXCEEDED
|
||||||
|
else:
|
||||||
|
self.report(
|
||||||
|
f"Succeeded after {self.ctx.iteration} submissions: "
|
||||||
|
f"last ran {label}"
|
||||||
|
)
|
|
@ -0,0 +1,610 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import copy
|
||||||
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
import aiida.engine
|
||||||
|
import graphviz
|
||||||
|
import toolz
|
||||||
|
|
||||||
|
from . import common, engine
|
||||||
|
from .calculations import PyCalcJob, PyMapJob
|
||||||
|
from .data import PyFunction, PyOutline, ensure_aiida_type
|
||||||
|
from .utils import render_png
|
||||||
|
|
||||||
|
# TODO: this will all need to be refactored when we grok
|
||||||
|
# Aiida's 'Process' and 'Port' concepts.
|
||||||
|
|
||||||
|
|
||||||
|
class Step(metaclass=abc.ABCMeta):
|
||||||
|
"""Abstract base class for steps."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Single(Step):
|
||||||
|
"""A single workflow step."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Action(Single):
|
||||||
|
"""Step that will be run with the current workchain passed as argument."""
|
||||||
|
|
||||||
|
def do(self, workchain):
|
||||||
|
"""Do the action on the workchain."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Concurrent(Step):
|
||||||
|
"""Step consisting of several concurrent steps."""
|
||||||
|
|
||||||
|
steps: List[Step]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Sequential(Step):
|
||||||
|
"""Step consisting of several sequential steps."""
|
||||||
|
|
||||||
|
steps: List[Step]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Process(Single):
|
||||||
|
"""Step consisting of a single Aiida Process."""
|
||||||
|
|
||||||
|
builder: aiida.engine.ProcessBuilder
|
||||||
|
parameters: Tuple[str]
|
||||||
|
returns: Tuple[str]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
kind = self.builder.process_class
|
||||||
|
if issubclass(kind, PyCalcJob):
|
||||||
|
func = self.builder.func
|
||||||
|
return f"{kind.__name__}[{func.name}(pk: {func.pk})]"
|
||||||
|
else:
|
||||||
|
return kind.__name__
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class OutputAction(Action):
|
||||||
|
"""Action step that outputs values from the workflow context."""
|
||||||
|
|
||||||
|
outputs: Dict[str, str]
|
||||||
|
|
||||||
|
def do(self, workchain):
|
||||||
|
"""Return the named outputs from this workflow."""
|
||||||
|
for from_name, to_name in self.outputs.items():
|
||||||
|
if from_name in workchain.ctx:
|
||||||
|
workchain.out(f"return_values.{to_name}", workchain.ctx[from_name])
|
||||||
|
else:
|
||||||
|
workchain.report(
|
||||||
|
f"Failed to set output '{to_name}': '{from_name}' "
|
||||||
|
"does not exist on the workchain context (did "
|
||||||
|
"the step that produces this output fail?"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PyAction(Action):
|
||||||
|
"""Action step defined by a PyFunction."""
|
||||||
|
|
||||||
|
action: PyFunction
|
||||||
|
|
||||||
|
def do(self, workchain):
|
||||||
|
"""Do the action on the workchain."""
|
||||||
|
self.action(workchain)
|
||||||
|
|
||||||
|
|
||||||
|
def single_steps(step: Step) -> Iterator[Single]:
|
||||||
|
"""Yield all Single steps in a given step."""
|
||||||
|
if isinstance(step, Single):
|
||||||
|
yield step
|
||||||
|
elif isinstance(step, (Concurrent, Sequential)):
|
||||||
|
yield from toolz.mapcat(single_steps, step.steps)
|
||||||
|
else:
|
||||||
|
assert False, f"Unknown step type {type(step)}"
|
||||||
|
|
||||||
|
|
||||||
|
def single_processes(step: Step) -> Iterator[Process]:
|
||||||
|
"""Yield all Process steps in a given step."""
|
||||||
|
return filter(lambda s: isinstance(s, Process), single_steps(step))
|
||||||
|
|
||||||
|
|
||||||
|
def _check_valid_pyfunction(f: PyFunction):
|
||||||
|
"""Check that the provided PyFunction may be used as part of a workflow."""
|
||||||
|
if not isinstance(f, PyFunction):
|
||||||
|
raise TypeError()
|
||||||
|
if any(r.startswith("_") for r in f.returns):
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot use functions with return names containing underscores "
|
||||||
|
"in workflows."
|
||||||
|
)
|
||||||
|
if set(f.parameters).intersection(f.returns):
|
||||||
|
raise ValueError(
|
||||||
|
"Function has outputs that are named identically to its input(s)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_pyfunctions_compatible(a: PyFunction, b: PyFunction):
|
||||||
|
"""Check that Pyfunction 'b' has enough inputs/outputs to be compatible with 'a'."""
|
||||||
|
_check_valid_pyfunction(a)
|
||||||
|
_check_valid_pyfunction(b)
|
||||||
|
if missing_parameters := set(a.parameters) - set(b.parameters):
|
||||||
|
raise ValueError(f"'{b.name}' is missing parameters: {missing_parameters}")
|
||||||
|
if missing_returns := set(a.returns) - set(b.returns):
|
||||||
|
raise ValueError(f"'{b.name}' is missing return values: {missing_returns}")
|
||||||
|
|
||||||
|
|
||||||
|
def from_pyfunction(f: PyFunction) -> Step:
|
||||||
|
"""Construct a Step corresponding to applying a PyFunction."""
|
||||||
|
_check_valid_pyfunction(f)
|
||||||
|
return Process(builder=engine.apply(f), parameters=f.parameters, returns=f.returns,)
|
||||||
|
|
||||||
|
|
||||||
|
def map_(f: PyFunction, *args, **kwargs) -> Step:
|
||||||
|
"""Construct a Step corresponding to mapping a PyFunction.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
*args, **kwargs
|
||||||
|
Positional/keyword arguments to pass to 'aiida_dynamic_workflows.engine.map_'.
|
||||||
|
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
aiida_dynamic_workflows.engine.map_
|
||||||
|
"""
|
||||||
|
_check_valid_pyfunction(f)
|
||||||
|
return Process(
|
||||||
|
builder=engine.map_(f, *args, **kwargs),
|
||||||
|
parameters=f.parameters,
|
||||||
|
returns=f.returns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def concurrently(*fs: Union[PyFunction, Step]) -> Step:
|
||||||
|
"""Construct a Step for several tasks executing concurrently."""
|
||||||
|
if len(fs) < 2:
|
||||||
|
raise ValueError("Expected at least 2 steps")
|
||||||
|
|
||||||
|
for i, f in enumerate(fs):
|
||||||
|
for g in fs[i + 1 :]:
|
||||||
|
if set(f.returns).intersection(g.returns):
|
||||||
|
raise ValueError("Steps return values that are named the same")
|
||||||
|
|
||||||
|
returns = [set(f.returns) for f in fs]
|
||||||
|
|
||||||
|
parameters = [set(f.parameters) for f in fs]
|
||||||
|
if any(a.intersection(b) for a in parameters for b in returns):
|
||||||
|
raise ValueError("Steps cannot be run concurrently")
|
||||||
|
|
||||||
|
def ensure_single(f):
|
||||||
|
if isinstance(f, PyFunction):
|
||||||
|
return from_pyfunction(f)
|
||||||
|
elif isinstance(f, Single):
|
||||||
|
return f
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Expected PyFunction or Single, got {type(f)}")
|
||||||
|
|
||||||
|
return Concurrent([ensure_single(f) for f in fs])
|
||||||
|
|
||||||
|
|
||||||
|
def new_workflow(name: str) -> Outline:
|
||||||
|
"""Return an Outline with no steps , and the given name."""
|
||||||
|
return Outline(steps=(), label=name)
|
||||||
|
|
||||||
|
|
||||||
|
def first(s: Union[PyFunction, Step]) -> Outline:
|
||||||
|
"""Return an Outline consisting of a single Step."""
|
||||||
|
return Outline(steps=(ensure_step(s),))
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_step(s: Union[Step, PyFunction]) -> Step:
|
||||||
|
"""Return a Step, given a Step or a PyFunction."""
|
||||||
|
if isinstance(s, Step):
|
||||||
|
return s
|
||||||
|
elif isinstance(s, PyFunction):
|
||||||
|
return from_pyfunction(s)
|
||||||
|
elif isinstance(s, Outline):
|
||||||
|
return Sequential(s.steps)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Expected PyFunction, Step, or Outline, got {type(s)}")
|
||||||
|
|
||||||
|
|
||||||
|
def output(*names: str, **mappings: str) -> OutputAction:
|
||||||
|
"""Return an OutputAction that can be used in an outline."""
|
||||||
|
outputs = {name: name for name in names}
|
||||||
|
outputs.update({from_: to_ for from_, to_ in mappings.items()})
|
||||||
|
|
||||||
|
return OutputAction(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Outline:
|
||||||
|
"""Outline of the steps to be executed.
|
||||||
|
|
||||||
|
Each step kicks off either a _single_ process, or several processes
|
||||||
|
concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
steps: Tuple[Step]
|
||||||
|
#: Sequence of steps constituting the workflow
|
||||||
|
label: Optional[str] = None
|
||||||
|
#: Optional label identifying the workflow
|
||||||
|
|
||||||
|
def rename(self, name: str) -> Outline:
|
||||||
|
"""Return a new outline with a new name."""
|
||||||
|
return replace(self, label=name)
|
||||||
|
|
||||||
|
def then(self, step: Union[PyFunction, Step, Outline]) -> Outline:
|
||||||
|
"""Add the provided Step to the outline.
|
||||||
|
|
||||||
|
If a PyFunction is provided it is added as a single step.
|
||||||
|
"""
|
||||||
|
return replace(self, steps=self.steps + (ensure_step(step),))
|
||||||
|
|
||||||
|
def join(self, other: Outline) -> Outline:
|
||||||
|
"""Return a new outline consisting of this and 'other' joined together."""
|
||||||
|
return replace(self, steps=self.steps + other.steps)
|
||||||
|
|
||||||
|
def returning(self, *names, **mappings) -> Outline:
|
||||||
|
"""Return the named values from this workflow."""
|
||||||
|
possible_names = self.parameters.union(self.all_outputs)
|
||||||
|
existing_names = self.returns
|
||||||
|
requested_names = set(names).union(mappings.keys())
|
||||||
|
|
||||||
|
if invalid_names := requested_names - possible_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot return any of {invalid_names}; "
|
||||||
|
"they do not appear in this outline."
|
||||||
|
)
|
||||||
|
|
||||||
|
if already_returned := requested_names.intersection(existing_names):
|
||||||
|
raise ValueError(
|
||||||
|
"The following names are already returned "
|
||||||
|
f"by this outline: {already_returned}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return replace(self, steps=self.steps + (output(*names, **mappings),))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _single_processes(self) -> Iterator[Process]:
|
||||||
|
for step in self.steps:
|
||||||
|
yield from single_processes(step)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _single_steps(self) -> Iterator[Single]:
|
||||||
|
for step in self.steps:
|
||||||
|
yield from single_steps(step)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> Set[str]:
|
||||||
|
"""Parameters of the Outline."""
|
||||||
|
raw_parameters = toolz.reduce(
|
||||||
|
set.union, (s.parameters for s in self._single_processes), set(),
|
||||||
|
)
|
||||||
|
return raw_parameters - self.all_outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def returns(self) -> Set[str]:
|
||||||
|
"""Values returned by this Outline."""
|
||||||
|
ret = set()
|
||||||
|
for step in self._single_steps:
|
||||||
|
if isinstance(step, OutputAction):
|
||||||
|
ret.update(step.outputs.values())
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_outputs(self) -> Set[str]:
|
||||||
|
"""All outputs of this outline."""
|
||||||
|
return toolz.reduce(
|
||||||
|
set.union, (s.returns for s in self._single_processes), set(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def visualize(self, as_png=False) -> Union[graphviz.Digraph]:
|
||||||
|
"""Return a Graphviz visualization of this outline."""
|
||||||
|
g = graphviz.Digraph(graph_attr=dict(rankdir="LR"))
|
||||||
|
|
||||||
|
mapped_inputs = set()
|
||||||
|
|
||||||
|
for proc in self._single_processes:
|
||||||
|
proc_id = str(id(proc))
|
||||||
|
is_mapjob = issubclass(proc.builder.process_class, PyMapJob)
|
||||||
|
|
||||||
|
opts = dict(shape="rectangle")
|
||||||
|
output_opts = dict()
|
||||||
|
if is_mapjob:
|
||||||
|
for d in (opts, output_opts):
|
||||||
|
d["style"] = "filled"
|
||||||
|
d["fillcolor"] = "#ffaaaaaa"
|
||||||
|
|
||||||
|
g.node(proc_id, label=proc.builder.func.name, **opts)
|
||||||
|
|
||||||
|
if is_mapjob:
|
||||||
|
spec = common.MapSpec.from_string(proc.builder.metadata.options.mapspec)
|
||||||
|
for p in spec.parameters:
|
||||||
|
mapped_inputs.add(p)
|
||||||
|
g.node(p, **output_opts)
|
||||||
|
|
||||||
|
for r in proc.returns:
|
||||||
|
g.node(r, **output_opts)
|
||||||
|
g.edge(proc_id, r)
|
||||||
|
|
||||||
|
for p in self.parameters - mapped_inputs:
|
||||||
|
g.node(p, style="filled", fillcolor="#aaaaaa")
|
||||||
|
|
||||||
|
for proc in self._single_processes:
|
||||||
|
proc_id = str(id(proc))
|
||||||
|
for p in proc.parameters:
|
||||||
|
g.edge(p, proc_id)
|
||||||
|
if as_png:
|
||||||
|
return render_png(g)
|
||||||
|
return g
|
||||||
|
|
||||||
|
def traverse(self, f: Callable[[Single], Single]) -> Outline:
|
||||||
|
"""Return a copy of this Outline, with 'f' applied to all Single steps."""
|
||||||
|
|
||||||
|
def transform(x: Step) -> Step:
|
||||||
|
if isinstance(x, Single):
|
||||||
|
return f(x)
|
||||||
|
elif isinstance(x, (Concurrent, Sequential)):
|
||||||
|
return type(x)(steps=tuple(map(transform, x.steps)))
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unknown step type {type(x)}")
|
||||||
|
|
||||||
|
return replace(self, steps=tuple(map(transform, self.steps)))
|
||||||
|
|
||||||
|
def with_restarts(self, step_restarts: Dict[PyFunction, int]) -> Outline:
|
||||||
|
"""Return a copy of this Outline with restarts added to all specified steps.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> # Set up the original flow
|
||||||
|
>>> import aiida_dynamic_workflows as flows
|
||||||
|
>>> a = flows.step(lambda x, y: x + y, returning="z")
|
||||||
|
>>> b = flows.step(lambda z: 2 * z)
|
||||||
|
>>> flow = flows.workflow.first(a).then(b)
|
||||||
|
>>> # Apply restarts: a restarted up to 2 times, b up to 3.
|
||||||
|
>>> new_flow = flow.with_restarts({a: 2, b: 3})
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mapper(step):
|
||||||
|
try:
|
||||||
|
max_restarts = step_restarts[step.builder.func]
|
||||||
|
except (AttributeError, KeyError):
|
||||||
|
return step
|
||||||
|
else:
|
||||||
|
return replace(step, builder=step.builder.with_restarts(max_restarts))
|
||||||
|
|
||||||
|
return self.traverse(mapper)
|
||||||
|
|
||||||
|
def replace_steps(self, step_map: Dict[PyFunction, PyFunction]) -> Outline:
|
||||||
|
"""Return a copy of this Outline, replacing the step functions specified.
|
||||||
|
|
||||||
|
Any steps that are PyCalcJobs or PyMapJobs executing a PyFunction specified
|
||||||
|
in 'step_map' will have the function executed replaced by the corresponding
|
||||||
|
value in 'step_map'.
|
||||||
|
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
traverse
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> # Set up the original flow
|
||||||
|
>>> import aiida_dynamic_workflows as flows
|
||||||
|
>>> a = flows.step(lambda x, y: x + y, returning="z")
|
||||||
|
>>> b = flows.step(lambda z: 2 * z)
|
||||||
|
>>> flow = flows.workflow.first(a).then(b)
|
||||||
|
>>> # Create the new steps
|
||||||
|
>>> new_a = flows.step(lambda x, y: x * y, returning="z")
|
||||||
|
>>> new_b = flows.step(lambda z: 5 * z
|
||||||
|
>>> # Replace the old steps with new ones!
|
||||||
|
>>> new_flow = flow.replacing_steps({a: new_a, b: new_b})
|
||||||
|
"""
|
||||||
|
for a, b in step_map.items():
|
||||||
|
_check_pyfunctions_compatible(a, b)
|
||||||
|
|
||||||
|
def mapper(step):
|
||||||
|
try:
|
||||||
|
new_func = step_map[step.builder.func]
|
||||||
|
except (AttributeError, KeyError):
|
||||||
|
return step
|
||||||
|
else:
|
||||||
|
b = copy.deepcopy(step.builder)
|
||||||
|
b.func = new_func
|
||||||
|
return Process(
|
||||||
|
builder=b, parameters=new_func.parameters, returns=new_func.returns
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.traverse(mapper)
|
||||||
|
|
||||||
|
def on(
|
||||||
|
self,
|
||||||
|
env: engine.ExecutionEnvironment,
|
||||||
|
max_concurrent_machines: Optional[int] = None,
|
||||||
|
) -> Outline:
|
||||||
|
"""Return a new Outline with the execution environment set for all steps."""
|
||||||
|
|
||||||
|
def transform(s: Single):
|
||||||
|
if not isinstance(s, Process):
|
||||||
|
return s
|
||||||
|
return replace(s, builder=s.builder.on(env, max_concurrent_machines))
|
||||||
|
|
||||||
|
return self.traverse(transform)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: See if we can come up with a cleaner separation of "logical data flow"
|
||||||
|
# and "error handling flow".
|
||||||
|
|
||||||
|
# TODO: see if we can do this more "directly" with the Aiida/Plumpy
|
||||||
|
# "process" interface. As-is we are running our own "virtual machine"
|
||||||
|
# on top of Aiida's!.
|
||||||
|
class PyWorkChain(aiida.engine.WorkChain):
|
||||||
|
"""WorkChain for executing Outlines."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define(cls, spec): # noqa: D102
|
||||||
|
super().define(spec)
|
||||||
|
spec.input("outline", valid_type=PyOutline)
|
||||||
|
spec.input_namespace("kwargs", dynamic=True)
|
||||||
|
spec.output_namespace("return_values", dynamic=True)
|
||||||
|
spec.outline(
|
||||||
|
cls.setup,
|
||||||
|
aiida.engine.while_(cls.is_not_done)(cls.do_step, cls.check_output),
|
||||||
|
cls.finalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec.exit_code(401, "INVALID_STEP", message="Invalid step definition")
|
||||||
|
spec.exit_code(
|
||||||
|
450, "STEP_RETURNED_ERROR_CODE", message="A step returned an error code"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_builder(cls): # noqa: D102
|
||||||
|
return engine.ProcessBuilder(cls)
|
||||||
|
|
||||||
|
# TODO: have the outline persisted into "self.ctx"; this way
|
||||||
|
# we don't need to reload it from the DB on every step.
|
||||||
|
|
||||||
|
def setup(self): # noqa: D102
|
||||||
|
"""Set up the state for the workchain."""
|
||||||
|
outline = self.inputs.outline.value
|
||||||
|
self.ctx._this_step = 0
|
||||||
|
self.ctx._num_steps = len(outline.steps)
|
||||||
|
self.ctx._had_errors = False
|
||||||
|
|
||||||
|
if "kwargs" in self.inputs:
|
||||||
|
self.ctx.update(self.inputs.kwargs)
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""Finalize the workchain."""
|
||||||
|
if self.ctx._had_errors:
|
||||||
|
return self.exit_codes.STEP_RETURNED_ERROR_CODE
|
||||||
|
|
||||||
|
def is_not_done(self) -> bool:
|
||||||
|
"""Return True when there are no more steps in the workchain."""
|
||||||
|
return self.ctx._this_step < self.ctx._num_steps
|
||||||
|
|
||||||
|
def do_step(self):
|
||||||
|
"""Execute the current step in the workchain."""
|
||||||
|
this_step = self.ctx._this_step
|
||||||
|
self.report(f"doing step {this_step} of {self.ctx._num_steps}")
|
||||||
|
step = self.inputs.outline.value.steps[this_step]
|
||||||
|
|
||||||
|
if isinstance(step, (Single, Sequential)):
|
||||||
|
concurrent_steps = [step]
|
||||||
|
elif isinstance(step, Concurrent):
|
||||||
|
concurrent_steps = list(step.steps)
|
||||||
|
else:
|
||||||
|
self.report(f"Unknown step type {type(step)}")
|
||||||
|
return self.exit_codes.INVALID_STEP
|
||||||
|
|
||||||
|
for s in concurrent_steps:
|
||||||
|
self._base_step(s)
|
||||||
|
|
||||||
|
self.ctx._this_step += 1
|
||||||
|
|
||||||
|
def _base_step(self, s: Step):
|
||||||
|
if isinstance(s, Process):
|
||||||
|
try:
|
||||||
|
inputs = get_keys(self.ctx, s.parameters)
|
||||||
|
except KeyError as err:
|
||||||
|
self.report(f"Skipping step {s} due to missing inputs: {err.args}")
|
||||||
|
self.ctx._had_errors = True
|
||||||
|
return
|
||||||
|
|
||||||
|
finalized_builder = s.builder.finalize(**inputs)
|
||||||
|
|
||||||
|
fut = self.submit(finalized_builder)
|
||||||
|
self.report(f"Submitted {s} (pk: {fut.pk})")
|
||||||
|
self.to_context(_futures=aiida.engine.append_(fut))
|
||||||
|
elif isinstance(s, Sequential):
|
||||||
|
ol = Outline(steps=tuple(s.steps))
|
||||||
|
try:
|
||||||
|
inputs = get_keys(self.ctx, ol.parameters)
|
||||||
|
except KeyError as err:
|
||||||
|
self.report(f"Skipping step {s} due to missing inputs: {err.args}")
|
||||||
|
self.ctx._had_errors = True
|
||||||
|
return
|
||||||
|
|
||||||
|
builder = PyWorkChain.get_builder()
|
||||||
|
builder.outline = PyOutline(outline=ol)
|
||||||
|
builder.kwargs = inputs
|
||||||
|
fut = self.submit(builder)
|
||||||
|
self.report(f"Submitted sub-workchain: {fut.pk}")
|
||||||
|
self.to_context(_futures=aiida.engine.append_(fut))
|
||||||
|
elif isinstance(s, Action):
|
||||||
|
return s.do(self)
|
||||||
|
|
||||||
|
def check_output(self):
|
||||||
|
"""Check the output of the current step in the workchain."""
|
||||||
|
if "_futures" not in self.ctx:
|
||||||
|
return
|
||||||
|
|
||||||
|
for step in self.ctx._futures:
|
||||||
|
if step.exit_status != 0:
|
||||||
|
self.report(f"Step {step} reported a problem: {step.exit_message}")
|
||||||
|
self.ctx._had_errors = True
|
||||||
|
for name, value in return_values(step):
|
||||||
|
self.ctx[name] = value
|
||||||
|
|
||||||
|
del self.ctx["_futures"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_keys(dictionary, keys):
|
||||||
|
"""Select all keys in 'keys' from 'dictionary'."""
|
||||||
|
missing = []
|
||||||
|
r = dict()
|
||||||
|
for k in keys:
|
||||||
|
if k in dictionary:
|
||||||
|
r[k] = dictionary[k]
|
||||||
|
else:
|
||||||
|
missing.append(k)
|
||||||
|
if missing:
|
||||||
|
raise KeyError(*missing)
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
# XXX: This is all very tightly coupled to the definitions of "PyCalcJob"
|
||||||
|
# and "PyMapJob".
|
||||||
|
def return_values(calc: aiida.orm.ProcessNode):
|
||||||
|
"""Yield (name, node) tuples of return values of the given ProcessNode.
|
||||||
|
|
||||||
|
This assumes an output port namespace called "return_values".
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return calc.outputs.return_values.items()
|
||||||
|
except AttributeError:
|
||||||
|
return ()
|
||||||
|
|
||||||
|
|
||||||
|
def build(outline: Outline, **kwargs) -> PyWorkChain:
|
||||||
|
"""Return a ProcessBuilder for launching the given Outline."""
|
||||||
|
# TODO: validate that all ProcessBuilders in 'outline' are fully specified
|
||||||
|
_check_outline(outline)
|
||||||
|
builder = PyWorkChain.get_builder()
|
||||||
|
builder.outline = PyOutline(outline=outline)
|
||||||
|
if outline.label:
|
||||||
|
builder.metadata.label = outline.label
|
||||||
|
if missing := set(outline.parameters) - set(kwargs):
|
||||||
|
raise ValueError(f"Missing parameters: {missing}")
|
||||||
|
if superfluous := set(kwargs) - set(outline.parameters):
|
||||||
|
raise ValueError(f"Too many parameters: {superfluous}")
|
||||||
|
builder.kwargs = toolz.valmap(ensure_aiida_type, kwargs)
|
||||||
|
return builder
|
||||||
|
|
||||||
|
|
||||||
|
def _check_outline(outline: Outline):
|
||||||
|
for proc in outline._single_processes:
|
||||||
|
if proc.builder.code is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Execution environment not specified for {proc.builder.func.name}. "
|
||||||
|
"Did you remember to call 'on(env)' on the workflow?"
|
||||||
|
)
|
Загрузка…
Ссылка в новой задаче