finnts/R/parallel_util.R

108 строки
2.9 KiB
R

#' Get number of cores to use when registering parallel back end
#'
#' @param num_cores number of cores for parallel processing
#'
#' @noRd
get_cores <- function(num_cores) {
if (is.null(num_cores)) {
parallel::detectCores() - 1
} else {
min(num_cores, parallel::detectCores() - 1)
}
}
#' Function to submit tasks sequentially, in parallel on local machine, or in spark
#'
#' @param run_info run info
#' @param parallel_processing type of parallel processing to run
#' @param num_cores number of cores
#' @param task_length number of time series to submit to parallel cluster
#'
#' @noRd
par_start <- function(run_info,
parallel_processing,
num_cores,
task_length) {
cl <- NULL
base_packages <- c(
"tibble", "dplyr", "timetk", "hts", "tidyselect", "stringr", "foreach",
"doParallel", "parallel", "lubridate", "parsnip", "tune", "dials", "workflows",
"Cubist", "earth", "glmnet", "kernlab", "purrr",
"recipes", "rules", "modeltime", "fs", "digest", "tidyr",
"vroom", "utils", "cli"
)
parallel_packages <- c(
"gtools", "hts", "magrittr", "methods", "base",
"plyr", "rsample"
)
add_packages <- NULL
if (inherits(run_info$storage_object, "blob_container")) {
add_packages <- c(add_packages, "AzureStor")
} else if (inherits(run_info$storage_object, "ms_drive")) {
add_packages <- c(add_packages, "Microsoft365R")
}
if (run_info$data_output == "parquet") {
add_packages <- c(add_packages, "arrow")
}
if (run_info$object_output == "qs") {
add_packages <- c(add_packages, "qs")
}
if (is.null(parallel_processing)) {
`%op%` <- foreach::`%do%`
packages <- c(base_packages, add_packages)
} else if (parallel_processing == "spark") {
if (!exists("sc")) {
stop("Ensure that you are connected to a spark cluster using an object called 'sc'")
}
`%op%` <- foreach::`%dopar%`
sparklyr::registerDoSpark(sc, parallelism = task_length)
packages <- NULL
} else if (parallel_processing == "local_machine") {
cores <- get_cores(num_cores)
cl <- parallel::makeCluster(min(cores, task_length))
if (substr(run_info$path, start = 1, stop = 6) == "/synfs") {
# make sure libraries are exported properly in Azure Synapse
e <- new.env()
e$libs <- .libPaths()
parallel::clusterExport(cl, "libs", envir = e)
parallel::clusterEvalQ(cl, .libPaths(libs))
}
doParallel::registerDoParallel(cl)
`%op%` <- foreach::`%dopar%`
packages <- c(base_packages, add_packages, parallel_packages)
} else {
stop("error")
}
return(list(packages = packages, foreach_operator = `%op%`, cl = cl))
}
#' Function to clean up after submitting tasks sequentially, in parallel on local machine, or in spark
#'
#' @param cl cluster object
#'
#' @noRd
par_end <- function(cl) {
foreach::registerDoSEQ()
if (!is.null(cl)) {
parallel::stopCluster(cl)
}
}