add prophet model notebook
This commit is contained in:
Родитель
65a5354c44
Коммит
3fd9ad8f73
|
@ -0,0 +1,221 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Prophet\n",
|
||||
"\n",
|
||||
"### Uses prophet model as prediction of future."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"import os, sys\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from subseasonal_toolkit.utils.notebook_util import isnotebook\n",
|
||||
"if isnotebook():\n",
|
||||
" # Autoreload packages that are modified\n",
|
||||
" %load_ext autoreload\n",
|
||||
" %autoreload 2\n",
|
||||
"else:\n",
|
||||
" from argparse import ArgumentParser\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"from scipy.spatial.distance import cdist, euclidean\n",
|
||||
"from datetime import datetime, timedelta\n",
|
||||
"from ttictoc import tic, toc\n",
|
||||
"from subseasonal_data.utils import get_measurement_variable\n",
|
||||
"from subseasonal_toolkit.utils.general_util import printf\n",
|
||||
"from subseasonal_toolkit.utils.experiments_util import get_id_name, get_th_name, get_first_year, get_start_delta\n",
|
||||
"from subseasonal_toolkit.utils.models_util import (get_submodel_name, start_logger, log_params, get_forecast_filename,\n",
|
||||
" save_forecasts)\n",
|
||||
"from subseasonal_toolkit.utils.eval_util import get_target_dates, mean_rmse_to_score, save_metric\n",
|
||||
"from sklearn.linear_model import *\n",
|
||||
"\n",
|
||||
"from subseasonal_data import data_loaders"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"#\n",
|
||||
"# Specify model parameters\n",
|
||||
"#\n",
|
||||
"if not isnotebook():\n",
|
||||
" # If notebook run as a script, parse command-line arguments\n",
|
||||
" parser = ArgumentParser()\n",
|
||||
" parser.add_argument(\"pos_vars\",nargs=\"*\") # gt_id and horizon \n",
|
||||
" parser.add_argument('--target_dates', '-t', default=\"std_test\")\n",
|
||||
" args, opt = parser.parse_known_args()\n",
|
||||
" \n",
|
||||
" # Assign variables \n",
|
||||
" gt_id = get_id_name(args.pos_vars[0]) # \"contest_precip\" or \"contest_tmp2m\" \n",
|
||||
" horizon = get_th_name(args.pos_vars[1]) # \"12w\", \"34w\", or \"56w\" \n",
|
||||
" target_dates = args.target_dates\n",
|
||||
"else:\n",
|
||||
" # Otherwise, specify arguments interactively \n",
|
||||
" gt_id = \"contest_tmp2m\"\n",
|
||||
" horizon = \"34w\"\n",
|
||||
" target_dates = \"std_contest\"\n",
|
||||
"\n",
|
||||
"#\n",
|
||||
"# Process model parameters\n",
|
||||
"#\n",
|
||||
"# One can subtract this number from a target date to find the last viable training date.\n",
|
||||
"start_delta = timedelta(days=get_start_delta(horizon, gt_id))\n",
|
||||
"\n",
|
||||
"# Record model and submodel name\n",
|
||||
"model_name = \"prophet\"\n",
|
||||
"submodel_name = get_submodel_name(model_name)\n",
|
||||
"\n",
|
||||
"FIRST_SAVE_YEAR = 2007 # Don't save forecasts from years prior to FIRST_SAVE_YEAR\n",
|
||||
"\n",
|
||||
"if not isnotebook():\n",
|
||||
" # Save output to log file\n",
|
||||
" logger = start_logger(model=model_name,submodel=submodel_name,gt_id=gt_id,\n",
|
||||
" horizon=horizon,target_dates=target_dates)\n",
|
||||
" # Store parameter values in log \n",
|
||||
" params_names = ['gt_id', 'horizon', 'target_dates']\n",
|
||||
" params_values = [eval(param) for param in params_names]\n",
|
||||
" log_params(params_names, params_values)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"printf('Loading target variable and dropping extraneous columns')\n",
|
||||
"tic()\n",
|
||||
"var = get_measurement_variable(gt_id)\n",
|
||||
"gt = data_loaders.get_ground_truth(gt_id).loc[:,[\"start_date\",\"lat\",\"lon\",var]]\n",
|
||||
"toc()"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"printf('Pivoting dataframe to have one column per lat-lon pair and one row per start_date')\n",
|
||||
"tic()\n",
|
||||
"gt = gt.set_index(['lat','lon','start_date']).squeeze().unstack(['lat','lon'])\n",
|
||||
"toc()"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"#\n",
|
||||
"# Make predictions for each target date\n",
|
||||
"#\n",
|
||||
"from fbprophet import Prophet\n",
|
||||
"from pandas.tseries.offsets import DateOffset\n",
|
||||
"def forecast(df, pred_date, num_periods, start_delta, m):\n",
|
||||
" df = df[df.ds < pred_date - start_delta]\n",
|
||||
" m.fit(df)\n",
|
||||
" future = m.make_future_dataframe(periods=num_periods + start_delta.days)\n",
|
||||
" forecast = m.predict(future)\n",
|
||||
" return forecast.tail(num_periods)[[\"ds\", \"yhat\"]], m\n",
|
||||
"\n",
|
||||
"def get_first_fourth_month(date):\n",
|
||||
" targets = {(4, 30), (8, 31), (12, 31)}\n",
|
||||
" while (date.month, date.day) not in targets:\n",
|
||||
" date = date - DateOffset(days=1)\n",
|
||||
" return date\n",
|
||||
"\n",
|
||||
"def get_predictions(df: pd.DataFrame, date, start_delta):\n",
|
||||
" # take the first (12/31, 8/31, 4/30) right before the date. \n",
|
||||
" true_date = get_first_fourth_month(date)\n",
|
||||
" # build that dataframe (also, maybe cache it locally?)\n",
|
||||
" df = df.loc[:true_date]\n",
|
||||
" # figure out the predictions on date\n",
|
||||
" all_coords_arr = []\n",
|
||||
" all_preds = []\n",
|
||||
" for column in df:\n",
|
||||
" new_df = df[column]\n",
|
||||
" print(new_df.values)\n",
|
||||
" cur_date = true_date\n",
|
||||
" df_single = pd.DataFrame(\n",
|
||||
" {\"ds\": new_df.reset_index()[\"start_date\"], \"y\": new_df.values})\n",
|
||||
" print(df_single)\n",
|
||||
" grid_df = pd.DataFrame()\n",
|
||||
" all_preds = []\n",
|
||||
" m = Prophet(yearly_seasonality=True, weekly_seasonality=False)\n",
|
||||
" num_periods = ((cur_date + DateOffset(months=4)) - cur_date).days\n",
|
||||
" preds, _ = forecast(df_single, date, num_periods, start_delta, m)\n",
|
||||
" all_coords_arr.append(column)\n",
|
||||
" all_preds.append(preds)\n",
|
||||
" res = [z[z.ds == date][\"yhat\"].item() for z in all_preds]\n",
|
||||
" return res\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tic()\n",
|
||||
"target_date_objs = pd.Series(get_target_dates(date_str=target_dates,horizon=horizon))\n",
|
||||
"rmses = pd.Series(index=target_date_objs, dtype=np.float64)\n",
|
||||
"preds = pd.DataFrame(index = target_date_objs, columns = gt.columns, \n",
|
||||
" dtype=np.float64)\n",
|
||||
"preds.index.name = \"start_date\"\n",
|
||||
"# Sort target_date_objs by day of week\n",
|
||||
"target_date_objs = target_date_objs[target_date_objs.dt.weekday.argsort(kind='stable')]\n",
|
||||
"toc()\n",
|
||||
"for target_date_obj in target_date_objs:\n",
|
||||
" tic()\n",
|
||||
" target_date_str = datetime.strftime(target_date_obj, '%Y%m%d')\n",
|
||||
" # Find the last observable training date for this target\n",
|
||||
" last_train_date = target_date_obj - start_delta\n",
|
||||
" if not last_train_date in gt.index:\n",
|
||||
" printf(f'-Warning: no persistence prediction for {target_date_str}; skipping')\n",
|
||||
" continue\n",
|
||||
" printf(f'Forming persistence prediction for {target_date_obj}')\n",
|
||||
"\n",
|
||||
" # key logic here:\n",
|
||||
" \n",
|
||||
" preds.loc[target_date_obj,:] = get_predictions(gt.loc[:last_train_date,:], target_date_obj, start_delta)\n",
|
||||
" \n",
|
||||
" # Save prediction to file in standard format\n",
|
||||
" if target_date_obj.year >= FIRST_SAVE_YEAR:\n",
|
||||
" save_forecasts(\n",
|
||||
" preds.loc[[target_date_obj],:].unstack().rename(\"pred\").reset_index(),\n",
|
||||
" model=model_name, submodel=submodel_name, \n",
|
||||
" gt_id=gt_id, horizon=horizon, \n",
|
||||
" target_date_str=target_date_str)\n",
|
||||
" # Evaluate and store error if we have ground truth data\n",
|
||||
" if target_date_obj in gt.index:\n",
|
||||
" rmse = np.sqrt(np.square(preds.loc[target_date_obj,:] - gt.loc[target_date_obj,:]).mean())\n",
|
||||
" rmses.loc[target_date_obj] = rmse\n",
|
||||
" print(\"-rmse: {}, score: {}\".format(rmse, mean_rmse_to_score(rmse)))\n",
|
||||
" mean_rmse = rmses.mean()\n",
|
||||
" print(\"-mean rmse: {}, running score: {}\".format(mean_rmse, mean_rmse_to_score(mean_rmse)))\n",
|
||||
" toc()\n",
|
||||
"\n",
|
||||
"printf(\"Save rmses in standard format\")\n",
|
||||
"rmses = rmses.sort_index().reset_index()\n",
|
||||
"rmses.columns = ['start_date','rmse']\n",
|
||||
"save_metric(rmses, model=model_name, submodel=submodel_name, gt_id=gt_id, horizon=horizon, target_dates=target_dates, metric=\"rmse\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
Загрузка…
Ссылка в новой задаче