remove any references to torch/tabnet as we are not currently running those models. Replace meta packages like tidyverse and tidymodels with their underlying packages. Change how weighted MAPE is calculated to help improve negative forecasts.
This commit is contained in:
Родитель
fe87c4254f
Коммит
9f26dc1da5
|
@ -53,9 +53,8 @@ jobs:
|
|||
run: |
|
||||
install.packages(c('rcmdcheck', 'knitr', 'rmarkdown', 'lubridate', 'foreach', 'modeltime',
|
||||
'Cubist', 'doParallel', 'devtools', 'dplyr', 'earth', 'glmnet',
|
||||
'gtools', 'hts', 'kernlab',
|
||||
'modeltime.gluonts', 'modeltime.resample', 'plyr', 'purrr', 'rules',
|
||||
'tabnet', 'tibble', 'tidyr', 'timetk', 'torch'))
|
||||
'gtools', 'hts', 'kernlab', 'modeltime.gluonts', 'modeltime.resample', 'plyr', 'purrr',
|
||||
'rules', 'tibble', 'tidyr', 'timetk'))
|
||||
|
||||
devtools::install_github(c("Azure/rAzureBatch", "Azure/doAzureParallel"))
|
||||
|
||||
|
|
|
@ -51,7 +51,6 @@ Imports:
|
|||
rsample,
|
||||
rules,
|
||||
stringr,
|
||||
tabnet,
|
||||
tibble,
|
||||
tidyr,
|
||||
tidyselect,
|
||||
|
@ -66,6 +65,5 @@ Suggests:
|
|||
Config/testthat/edition: 3
|
||||
Depends:
|
||||
R (>= 3.6.0),
|
||||
modeltime,
|
||||
torch
|
||||
modeltime
|
||||
VignetteBuilder: knitr
|
||||
|
|
|
@ -21,12 +21,10 @@ export(stlm_arima)
|
|||
export(stlm_ets)
|
||||
export(svm_poly)
|
||||
export(svm_rbf)
|
||||
export(tabnet)
|
||||
export(tbats)
|
||||
export(theta)
|
||||
export(xgboost)
|
||||
import(modeltime)
|
||||
import(torch)
|
||||
importFrom(Cubist,cubist)
|
||||
importFrom(Cubist,cubistControl)
|
||||
importFrom(earth,earth)
|
||||
|
|
|
@ -245,8 +245,8 @@ get_back_test_scenario_hist_periods<- function(full_data_tbl,
|
|||
get_export_packages <- function(){
|
||||
c('modeltime', 'modeltime.gluonts', 'modeltime.resample',
|
||||
'timetk', 'rules', 'Cubist', 'glmnet', 'earth', 'kernlab', 'xgboost',
|
||||
'tidyverse', 'lubridate', 'prophet', 'torch', 'tabnet',
|
||||
"doParallel", "parallel")
|
||||
'dplyr', 'tibble', 'tidyr', 'purrr', 'stringr', 'lubridate', 'prophet',
|
||||
'doParallel', 'parallel')
|
||||
}
|
||||
|
||||
#' Fetches a list of parallel transfer functions
|
||||
|
@ -256,7 +256,7 @@ get_export_packages <- function(){
|
|||
get_transfer_functions <- function(){
|
||||
c("arima", "arima_boost", "croston", "cubist", "deepar", "ets", "glmnet", "mars",
|
||||
"meanf", "nbeats", "nnetar", "nnetar_xregs", "prophet", "prophet_boost", "prophet_xregs",
|
||||
"snaive", "stlm_arima", "stlm_ets", "svm_poly", "svm_rbf", "tbats", "tabnet", "theta", "xgboost",
|
||||
"snaive", "stlm_arima", "stlm_ets", "svm_poly", "svm_rbf", "tbats", "theta", "xgboost",
|
||||
"multivariate_prep_recipe_1", "multivariate_prep_recipe_2","combo_specific_filter",
|
||||
"construct_forecast_models", "get_model_functions", "get_not_all_data_models",
|
||||
"get_r1_data_models", "get_r2_data_models", "get_deep_learning_models", "get_frequency_adjustment_models",
|
||||
|
|
|
@ -377,8 +377,9 @@ forecast_time_series <- function(input_data,
|
|||
}
|
||||
|
||||
combinations_tbl <- foreach::foreach(i = model_combinations[[1]], .combine = 'rbind',
|
||||
.packages = c('tidyverse', 'lubridate',
|
||||
"doParallel", "parallel", "gtools"),
|
||||
.packages = c('dplyr', 'tibble', 'tidyr', 'purrr',
|
||||
'stringr', 'lubridate',
|
||||
'doParallel', 'parallel', "gtools"),
|
||||
.export = c("fcst_prep")) %dopar% {
|
||||
|
||||
fcst_combination_temp <- fcst_prep %>%
|
||||
|
@ -468,8 +469,8 @@ forecast_time_series <- function(input_data,
|
|||
dplyr::mutate(Target = ifelse(Target == 0, 0.1, Target)) %>%
|
||||
dplyr::mutate(MAPE = round(abs((FCST - Target) / Target), digits = 4)) %>%
|
||||
dplyr::group_by(Model, Combo) %>%
|
||||
dplyr::mutate(Combo_Total = sum(Target, na.rm = TRUE),
|
||||
weighted_MAPE = (Target/Combo_Total)*MAPE) %>%
|
||||
dplyr::mutate(Combo_Total = sum(abs(Target), na.rm = TRUE),
|
||||
weighted_MAPE = (abs(Target)/Combo_Total)*MAPE) %>%
|
||||
dplyr::summarise(Rolling_MAPE = sum(weighted_MAPE, na.rm=TRUE)) %>%
|
||||
dplyr::arrange(Rolling_MAPE) %>%
|
||||
dplyr::ungroup()
|
||||
|
|
52
R/models.R
52
R/models.R
|
@ -1399,58 +1399,6 @@ svm_rbf <- function(train_data,
|
|||
|
||||
}
|
||||
|
||||
#' TabNet
|
||||
#'
|
||||
#' @param train_data Training Data
|
||||
#' @param parallel Parallel
|
||||
#' @param date_rm_regex Date RM Regex
|
||||
#' @param fiscal_year_start Fiscal Year Start
|
||||
#'
|
||||
#' @return Get Tab Net
|
||||
#' @keywords internal
|
||||
#' @export
|
||||
tabnet <- function(train_data,
|
||||
parallel,
|
||||
fiscal_year_start,
|
||||
date_rm_regex,
|
||||
pca) {
|
||||
|
||||
date_rm_regex_final <- "(.xts$)|(.iso$)|(hour)|(minute)|(second)|(am.pm)|(day)|(week)"
|
||||
#create model recipe
|
||||
|
||||
recipe_spec_tabnet <- train_data %>%
|
||||
get_recipie_configurable(fiscal_year_start,
|
||||
date_rm_regex_final,
|
||||
mutate_adj_half = FALSE,
|
||||
step_nzv = "none",
|
||||
one_hot = TRUE,
|
||||
pca = pca)
|
||||
|
||||
model_spec_tabnet <- tabnet::tabnet(
|
||||
mode = "regression",
|
||||
batch_size = tune::tune(),
|
||||
virtual_batch_size = tune::tune(),
|
||||
epochs = tune::tune()
|
||||
) %>%
|
||||
parsnip::set_engine("torch")
|
||||
|
||||
wflw_spec_tune_tabnet <- get_workflow_simple(model_spec_tabnet,
|
||||
recipe_spec_tabnet)
|
||||
|
||||
tune_results_tabnet <- train_data %>%
|
||||
get_kfold_tune_grid(wkflw = wflw_spec_tune_tabnet,
|
||||
parallel = parallel)
|
||||
|
||||
wflw_fit_tabnet <- train_data %>%
|
||||
get_fit_wkflw_best(tune_results_tabnet,
|
||||
wflw_spec_tune_tabnet)
|
||||
|
||||
cli::cli_alert_success("tabnet")
|
||||
|
||||
return(wflw_fit_tabnet)
|
||||
|
||||
}
|
||||
|
||||
#' Tbats
|
||||
#'
|
||||
#' @param train_data Training Data
|
||||
|
|
|
@ -27,8 +27,6 @@ utils::globalVariables(c(".id", ".key", ".model_desc", ".pred", ".resample_id",
|
|||
|
||||
#' @importFrom glmnet glmnet
|
||||
|
||||
#' @import torch
|
||||
|
||||
# * cbind.fill custom function ----
|
||||
#create function to cbind dataframes that contain different amounts of rows
|
||||
#https://github.com/cvarrichio/rowr/blob/master/R/rowr.R
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/models.R
|
||||
\name{tabnet}
|
||||
\alias{tabnet}
|
||||
\title{TabNet}
|
||||
\usage{
|
||||
tabnet(train_data, parallel, fiscal_year_start, date_rm_regex, pca)
|
||||
}
|
||||
\arguments{
|
||||
\item{train_data}{Training Data}
|
||||
|
||||
\item{parallel}{Parallel}
|
||||
|
||||
\item{fiscal_year_start}{Fiscal Year Start}
|
||||
|
||||
\item{date_rm_regex}{Date RM Regex}
|
||||
}
|
||||
\value{
|
||||
Get Tab Net
|
||||
}
|
||||
\description{
|
||||
TabNet
|
||||
}
|
||||
\keyword{internal}
|
|
@ -80,7 +80,7 @@ azure_batch_cluster_config <- list(
|
|||
),
|
||||
"containerImage" = "mftokic/finn-azure-batch-dev", # docker image you can use that automatically downloads software needed for Finn to run in cloud
|
||||
"rPackages" = list(
|
||||
"cran" = c('Rcpp', 'modeltime', 'modeltime.resample', 'tidymodels', 'lubridate', 'rules', 'Cubist', 'earth', 'kernlab', 'doParallel', 'tidyverse', 'torch', 'tabnet', 'prophet', 'glmnet', 'gtools'), # finnts package dependencies
|
||||
"cran" = c('Rcpp', 'modeltime', 'modeltime.resample', 'parsnip', 'tune', 'recipes', 'rsample', 'workflows', 'dials', 'lubridate', 'rules', 'Cubist', 'earth', 'kernlab', 'doParallel', 'dplyr', 'tibble', 'tidyr', 'purrr', 'stringr', 'prophet', 'glmnet', 'gtools'), # finnts package dependencies
|
||||
"github" = list(),
|
||||
"bioconductor" = list()
|
||||
),
|
||||
|
|
Загрузка…
Ссылка в новой задаче