improvement: using Ray to parallelize arima fitting (#159)
* using Ray to parallelize arima fitting * added ray as dependency * text about ray, disable warnings, and minor stuff * scipy 1.4.1 or above * reverting scipy, azuremlsdk issue * minor mod Co-authored-by: Vanja Paunic <15053814+vapaunic@users.noreply.github.com>
This commit is contained in:
Родитель
bed41898be
Коммит
55dabffee2
|
@ -43,7 +43,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"System version: 3.6.10 |Anaconda, Inc.| (default, Jan 7 2020, 21:14:29) \n",
|
||||
"System version: 3.6.10 |Anaconda, Inc.| (default, Mar 23 2020, 23:13:11) \n",
|
||||
"[GCC 7.3.0]\n"
|
||||
]
|
||||
}
|
||||
|
@ -53,6 +53,7 @@
|
|||
"import sys\n",
|
||||
"import math\n",
|
||||
"import warnings\n",
|
||||
"import ray\n",
|
||||
"import itertools\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
|
@ -172,22 +173,12 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Process training data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Process training data\n",
|
||||
"\n",
|
||||
"Our time series data is not complete, since we have missing sales for some stores/products and weeks. We will fill in those missing values by propagating the last valid observation forward to next available value. We will define functions for data frame processing, then use these functions within a loop that loops over each forecasting rounds.\n",
|
||||
"\n",
|
||||
"Note that our time series are grouped by `store` and `brand`, while `week` represents a time step, and `logmove` represents the value to predict."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that our time series are grouped by `store` and `brand`, while `week` represents a time step, and `logmove` represents the value to predict.\n",
|
||||
"\n",
|
||||
"Let's first process the training data. Note that the training data runs from `FIRST_WEEK` to `LAST_WEEK - HORIZON - GAP + 1` as defined in Parameters section above."
|
||||
]
|
||||
},
|
||||
|
@ -475,13 +466,8 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Process test data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Process test data\n",
|
||||
"\n",
|
||||
"Let's now process the test data. Note that the test data runs from `LAST_WEEK - HORIZON + 1` to `LAST_WEEK`. Note that, in addition to filling out missing values, we also convert unit sales from logarithmic scale to the counts. We will do model training on the log scale, due to improved performance, however, we will transfrom the test data back into the unit scale (counts) by applying `math.exp()`, so that we can evaluate the performance on the unit scale."
|
||||
]
|
||||
},
|
||||
|
@ -828,10 +814,10 @@
|
|||
" <th>Method:</th> <td>css-mle</td> <th> S.D. of innovations</th> <td>0.292</td> \n",
|
||||
"</tr>\n",
|
||||
"<tr>\n",
|
||||
" <th>Date:</th> <td>Mon, 23 Mar 2020</td> <th> AIC </th> <td>42.669</td> \n",
|
||||
" <th>Date:</th> <td>Wed, 25 Mar 2020</td> <th> AIC </th> <td>42.669</td> \n",
|
||||
"</tr>\n",
|
||||
"<tr>\n",
|
||||
" <th>Time:</th> <td>10:20:48</td> <th> BIC </th> <td>50.362</td> \n",
|
||||
" <th>Time:</th> <td>20:39:11</td> <th> BIC </th> <td>50.362</td> \n",
|
||||
"</tr>\n",
|
||||
"<tr>\n",
|
||||
" <th>Sample:</th> <td>0</td> <th> HQIC </th> <td>45.779</td> \n",
|
||||
|
@ -869,8 +855,8 @@
|
|||
"Dep. Variable: y No. Observations: 96\n",
|
||||
"Model: ARMA(1, 0) Log Likelihood -18.335\n",
|
||||
"Method: css-mle S.D. of innovations 0.292\n",
|
||||
"Date: Mon, 23 Mar 2020 AIC 42.669\n",
|
||||
"Time: 10:20:48 BIC 50.362\n",
|
||||
"Date: Wed, 25 Mar 2020 AIC 42.669\n",
|
||||
"Time: 20:39:11 BIC 50.362\n",
|
||||
"Sample: 0 HQIC 45.779\n",
|
||||
" \n",
|
||||
"==============================================================================\n",
|
||||
|
@ -949,13 +935,8 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model evaluation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model evaluation\n",
|
||||
"\n",
|
||||
"Let's now take a look at the predictions. Since auto_arima model makes consecutive forecasts from the last time point, we want to forecast the next `n_periods = GAP + HORIZON - 1` points, so that we can account for the GAP, as described in the data setup. As mentioned above, we are also transforming our predictions from logarithmic scale to counts, for calculating evaluation metric."
|
||||
]
|
||||
},
|
||||
|
@ -1133,16 +1114,16 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model training for all stores and brands"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model training for all stores and brands\n",
|
||||
"\n",
|
||||
"Now let's run model training across all the stores and brands. We will re-run the same code to automatically search for the best parameters, simply wrapped in a for loop iterating over stores and brands.\n",
|
||||
"\n",
|
||||
"> NOTE: Since we are building a model for each time series sequentially, it will take ~11 minutes to iterate over 900+ time series for each store and brand. To execute the next cell faster, you can run the below code on a subset of stores, by setting the `STORE_SUBSET` parameter to `True` in the **Parameters** section on top."
|
||||
"Note that we will be using [Ray](https://ray.readthedocs.io/en/latest/#) to distribute the computation to the cores available on your machine. To do this, we use the following:\n",
|
||||
"- `ray.init()` will start all the relevant Ray processes\n",
|
||||
"- we define a function to run an ARIMA model on a single brand and single store. To turn this function into a function that can be executed remotely, we declare the function with the ` @ray.remote` decorator.\n",
|
||||
"- `ray.get()` collects the results, and `ray.shutdown()` will stop Ray.\n",
|
||||
"\n",
|
||||
"It will take around 2.5 minutes to run the below cell on a machine with 4 cores. If you would like to further reduce the run time, you can run the below code on a subset of stores, by setting the `STORE_SUBSET` parameter to `True` in the *Parameters* section on top. This will limit the modeling to the first 20 stores."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1152,114 +1133,37 @@
|
|||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2020-03-25 20:39:12,463\tWARNING services.py:586 -- setpgrp failed, processes may not be cleaned up properly: [Errno 1] Operation not permitted.\n",
|
||||
"2020-03-25 20:39:12,466\tINFO resource_spec.py:212 -- Starting Ray with 7.28 GiB memory available for workers and up to 3.65 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).\n",
|
||||
"2020-03-25 20:39:12,884\tINFO services.py:1078 -- View the Ray dashboard at \u001b[1m\u001b[32mlocalhost:8265\u001b[39m\u001b[22m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Training ARIMA model...\n",
|
||||
"10:20:49.334417 - Forecasting for store: 2\n",
|
||||
"10:20:53.401574 - Forecasting for store: 5\n",
|
||||
"10:20:58.115340 - Forecasting for store: 8\n",
|
||||
"10:21:02.542300 - Forecasting for store: 9\n",
|
||||
"10:21:07.662120 - Forecasting for store: 12\n",
|
||||
"10:21:13.272879 - Forecasting for store: 14\n",
|
||||
"10:21:20.360329 - Forecasting for store: 18\n",
|
||||
"10:21:24.616698 - Forecasting for store: 21\n",
|
||||
"10:21:31.726077 - Forecasting for store: 28\n",
|
||||
"10:21:36.421752 - Forecasting for store: 32\n",
|
||||
"10:21:43.706641 - Forecasting for store: 33\n",
|
||||
"10:21:48.973600 - Forecasting for store: 40\n",
|
||||
"10:21:53.454550 - Forecasting for store: 44\n",
|
||||
"10:22:05.987141 - Forecasting for store: 45\n",
|
||||
"10:22:11.369793 - Forecasting for store: 47\n",
|
||||
"10:22:23.237898 - Forecasting for store: 48\n",
|
||||
"10:22:27.139243 - Forecasting for store: 49\n",
|
||||
"10:22:32.829189 - Forecasting for store: 50\n",
|
||||
"10:22:36.883287 - Forecasting for store: 51\n",
|
||||
"10:22:41.747701 - Forecasting for store: 52\n",
|
||||
"10:22:46.448277 - Forecasting for store: 53\n",
|
||||
"10:22:51.055589 - Forecasting for store: 54\n",
|
||||
"10:22:58.821450 - Forecasting for store: 56\n",
|
||||
"10:23:07.506682 - Forecasting for store: 59\n",
|
||||
"10:23:12.778405 - Forecasting for store: 62\n",
|
||||
"10:23:17.544591 - Forecasting for store: 64\n",
|
||||
"10:23:20.571268 - Forecasting for store: 67\n",
|
||||
"10:23:26.563757 - Forecasting for store: 68\n",
|
||||
"10:23:31.146425 - Forecasting for store: 70\n",
|
||||
"10:23:35.123765 - Forecasting for store: 71\n",
|
||||
"10:23:38.712292 - Forecasting for store: 72\n",
|
||||
"10:23:45.099950 - Forecasting for store: 73\n",
|
||||
"10:23:52.571869 - Forecasting for store: 74\n",
|
||||
"10:23:58.752658 - Forecasting for store: 75\n",
|
||||
"10:24:03.179671 - Forecasting for store: 76\n",
|
||||
"10:24:16.749856 - Forecasting for store: 77\n",
|
||||
"10:24:21.546209 - Forecasting for store: 78\n",
|
||||
"10:24:30.336712 - Forecasting for store: 80\n",
|
||||
"10:24:36.486470 - Forecasting for store: 81\n",
|
||||
"10:24:42.409050 - Forecasting for store: 83\n",
|
||||
"10:24:48.298496 - Forecasting for store: 84\n",
|
||||
"10:24:54.153607 - Forecasting for store: 86\n",
|
||||
"10:25:02.059484 - Forecasting for store: 88\n",
|
||||
"10:25:07.806313 - Forecasting for store: 89\n",
|
||||
"10:25:19.482612 - Forecasting for store: 90\n",
|
||||
"10:25:24.955881 - Forecasting for store: 91\n",
|
||||
"10:25:32.753247 - Forecasting for store: 92\n",
|
||||
"10:25:38.835709 - Forecasting for store: 93\n",
|
||||
"10:25:52.138178 - Forecasting for store: 94\n",
|
||||
"10:26:09.610388 - Forecasting for store: 95\n",
|
||||
"10:26:17.710164 - Forecasting for store: 97\n",
|
||||
"10:26:32.846734 - Forecasting for store: 98\n",
|
||||
"10:26:42.350089 - Forecasting for store: 100\n",
|
||||
"10:26:53.629710 - Forecasting for store: 101\n",
|
||||
"10:26:59.951894 - Forecasting for store: 102\n",
|
||||
"10:27:06.786594 - Forecasting for store: 103\n",
|
||||
"10:27:16.239215 - Forecasting for store: 104\n",
|
||||
"10:27:40.024204 - Forecasting for store: 105\n",
|
||||
"10:27:46.505146 - Forecasting for store: 106\n",
|
||||
"10:27:56.103137 - Forecasting for store: 107\n",
|
||||
"10:28:02.083994 - Forecasting for store: 109\n",
|
||||
"10:28:11.977216 - Forecasting for store: 110\n",
|
||||
"10:28:20.027559 - Forecasting for store: 111\n",
|
||||
"10:28:29.607314 - Forecasting for store: 112\n",
|
||||
"10:28:38.379475 - Forecasting for store: 113\n",
|
||||
"10:28:47.224587 - Forecasting for store: 114\n",
|
||||
"10:28:56.517500 - Forecasting for store: 115\n",
|
||||
"10:29:04.452412 - Forecasting for store: 116\n",
|
||||
"10:29:12.507529 - Forecasting for store: 117\n",
|
||||
"10:29:22.483425 - Forecasting for store: 118\n",
|
||||
"10:29:29.267235 - Forecasting for store: 119\n",
|
||||
"10:29:38.062543 - Forecasting for store: 121\n",
|
||||
"10:29:44.519399 - Forecasting for store: 122\n",
|
||||
"10:29:53.129138 - Forecasting for store: 123\n",
|
||||
"10:30:04.086351 - Forecasting for store: 124\n",
|
||||
"10:30:11.941476 - Forecasting for store: 126\n",
|
||||
"10:30:17.497423 - Forecasting for store: 128\n",
|
||||
"10:30:23.574140 - Forecasting for store: 129\n",
|
||||
"10:30:30.009697 - Forecasting for store: 130\n",
|
||||
"10:30:35.644882 - Forecasting for store: 131\n",
|
||||
"10:30:39.275225 - Forecasting for store: 132\n",
|
||||
"10:30:52.546664 - Forecasting for store: 134\n",
|
||||
"10:30:59.428141 - Forecasting for store: 137\n",
|
||||
"CPU times: user 26min 35s, sys: 1min 29s, total: 28min 5s\n",
|
||||
"Wall time: 10min 14s\n"
|
||||
"CPU times: user 3.43 s, sys: 677 ms, total: 4.11 s\n",
|
||||
"Wall time: 2min 32s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"\n",
|
||||
"# initialize Ray\n",
|
||||
"ray.init(log_to_driver=False)\n",
|
||||
"\n",
|
||||
"if STORE_SUBSET:\n",
|
||||
" store_list = store_list[0:10]\n",
|
||||
" store_list = store_list[0:20]\n",
|
||||
"\n",
|
||||
"result_df = pd.DataFrame(None, columns=[\"predictions\", \"store\", \"brand\", \"week\", \"actuals\"])\n",
|
||||
"\n",
|
||||
"print(\"Training ARIMA model...\")\n",
|
||||
"for store, brand in itertools.product(store_list, brand_list):\n",
|
||||
"\n",
|
||||
" if brand == 1:\n",
|
||||
" print(f\"{datetime.now().time()} - Forecasting for store: {store}\")\n",
|
||||
"\n",
|
||||
" train_ts = train_filled.loc[(train_filled.store == store) & (train_filled.brand == brand)]\n",
|
||||
"@ray.remote\n",
|
||||
"def train_store_brand(data, store, brand):\n",
|
||||
" train_ts = data.loc[(data.store == store) & (data.brand == brand)]\n",
|
||||
" train_ts = np.array(train_ts[\"logmove\"])\n",
|
||||
"\n",
|
||||
" model = auto_arima(\n",
|
||||
|
@ -1270,6 +1174,7 @@
|
|||
" max_p=params[\"max_p\"],\n",
|
||||
" max_q=params[\"max_q\"],\n",
|
||||
" stepwise=True,\n",
|
||||
" error_action=\"ignore\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" model.fit(train_ts)\n",
|
||||
|
@ -1279,9 +1184,22 @@
|
|||
" pred_df = pd.DataFrame({\"predictions\": predictions, \"store\": store, \"brand\": brand, \"week\": test_week_list})\n",
|
||||
" test_ts = test_filled.loc[(test_filled.store == store) & (test_filled.brand == brand)]\n",
|
||||
"\n",
|
||||
" combined_df = pd.merge(pred_df, test_ts, on=[\"store\", \"brand\", \"week\"], how=\"left\")\n",
|
||||
" return pd.merge(pred_df, test_ts, on=[\"store\", \"brand\", \"week\"], how=\"left\")\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"print(\"Training ARIMA model...\")\n",
|
||||
"\n",
|
||||
" result_df = result_df.append(combined_df, ignore_index=True)"
|
||||
"# persist input data into Ray shared memory\n",
|
||||
"train_filled_id = ray.put(train_filled)\n",
|
||||
"\n",
|
||||
"# train for each store/brand\n",
|
||||
"results = [train_store_brand.remote(train_filled_id, store, brand) \n",
|
||||
" for store, brand in itertools.product(store_list, brand_list)]\n",
|
||||
"\n",
|
||||
"result_df = pd.concat(ray.get(results), ignore_index=True)\n",
|
||||
"\n",
|
||||
"# stop Ray\n",
|
||||
"ray.shutdown()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1299,7 +1217,7 @@
|
|||
{
|
||||
"data": {
|
||||
"application/scrapbook.scrap.json+json": {
|
||||
"data": 69.74536523104287,
|
||||
"data": 69.73950700142306,
|
||||
"encoder": "json",
|
||||
"name": "MAPE",
|
||||
"version": 1
|
||||
|
@ -1318,7 +1236,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"MAPE of the forecasts is 69.74536523104287 %\n"
|
||||
"MAPE of the forecasts is 69.73950700142306 %\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
@ -45,3 +45,4 @@ dependencies:
|
|||
- statsmodels==0.11.1
|
||||
- pmdarima==1.1.1
|
||||
- gitpython==3.0.8
|
||||
- ray==0.8.2
|
||||
|
|
Загрузка…
Ссылка в новой задаче