{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {},
"inputWidgets": {},
"nuid": "a02bb3a2-e5d7-4d14-b966-827457675b75",
"showTitle": false,
"title": ""
}
},
"source": [
"# Quick Start\n",
"\n",
"ForecastFlowML is designed for scaleable forecasting and uses Spark for both \n",
"feature engineering and training/prediction/hyperparameter optimisation.\n",
"\n",
"## Use Cases\n",
"ForecastFlowML can generally be used for three use cases:\n",
"- Data is stored in a ``PySpark DataFrame``, and we need to paralelly build \n",
"many/big group models which does not fit into driver memory.\n",
"- Data is stored in a ``PySpark DataFrame``, and we need to paralelly build\n",
"a few/small group models which fits into driver memory.\n",
"- Data is stored in a ``Pandas DataFrame``, and we need to paralelly build\n",
"a few/small group models which fits into driver memory.\n",
"\n",
"This quick guide shows how you can develop a scaleable forecasting system on \n",
"Kaggle Walmart M5 Competition sample dataset.\n",
"\n",
"## Goal\n",
"- Build independent models for each of the stores in the dataset.\n",
"- Parallelize training/inference steps.\n",
"- Use LightGBM as machine learning algorithm.\n",
"- Utilize direct multi-step forecasting approach.\n",
"- Perform backtesting."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import Packages"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from forecastflowml import ForecastFlowML\n",
"from forecastflowml import FeatureExtractor\n",
"from forecastflowml.data.loader import load_walmart_m5\n",
"from lightgbm import LGBMRegressor\n",
"from pyspark.sql import SparkSession\n",
"import pyspark.sql.functions as F\n",
"import plotly.express as px\n",
"import plotly.io as pio\n",
"import pandas as pd\n",
"\n",
"pd.set_option(\"display.max_columns\", 40)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize Spark"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"spark = (\n",
" SparkSession.builder.master(\"local[4]\")\n",
" .config(\"spark.driver.memory\", \"8g\")\n",
" .config(\"spark.sql.shuffle.partitions\", \"4\")\n",
" .config(\"spark.sql.execution.arrow.enabled\", \"true\")\n",
" .getOrCreate()\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sample Dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----------+-------+------+--------+--------+----------+-----+\n",
"| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-03| 0.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-04| 0.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-05| 1.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-06| 0.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-07| 3.0|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"df = load_walmart_m5(spark)\n",
"df.show(10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feature Engineering"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"feature_extractor = FeatureExtractor(\n",
" id_col=\"id\",\n",
" date_col=\"date\",\n",
" target_col=\"sales\",\n",
" lag_window_features={\n",
" \"lag\": [7 * (i + 1) for i in range(8)],\n",
" \"mean\": [[window, lag] for lag in [7, 14, 21, 28] for window in [7, 14, 30]],\n",
" },\n",
" date_features=[\n",
" \"day_of_month\",\n",
" \"day_of_week\",\n",
" \"week_of_year\",\n",
" \"quarter\",\n",
" \"month\",\n",
" \"year\",\n",
" ],\n",
" count_consecutive_values={\n",
" \"value\": 0,\n",
" \"lags\": [7, 14, 21, 28],\n",
" },\n",
" history_length=True,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pandas DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" item_id | \n",
" dept_id | \n",
" cat_id | \n",
" store_id | \n",
" state_id | \n",
" date | \n",
" sales | \n",
" lag_7 | \n",
" lag_14 | \n",
" lag_21 | \n",
" lag_28 | \n",
" lag_35 | \n",
" lag_42 | \n",
" lag_49 | \n",
" lag_56 | \n",
" window_7_lag_7_mean | \n",
" window_14_lag_7_mean | \n",
" window_30_lag_7_mean | \n",
" window_7_lag_14_mean | \n",
" window_14_lag_14_mean | \n",
" window_30_lag_14_mean | \n",
" window_7_lag_21_mean | \n",
" window_14_lag_21_mean | \n",
" window_30_lag_21_mean | \n",
" window_7_lag_28_mean | \n",
" window_14_lag_28_mean | \n",
" window_30_lag_28_mean | \n",
" count_consecutive_value_lag_7 | \n",
" count_consecutive_value_lag_14 | \n",
" count_consecutive_value_lag_21 | \n",
" count_consecutive_value_lag_28 | \n",
" history_length | \n",
" day_of_month | \n",
" day_of_week | \n",
" week_of_year | \n",
" quarter | \n",
" month | \n",
" year | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" FOODS_1_011_WI_2_evaluation | \n",
" FOODS_1_011 | \n",
" FOODS_1 | \n",
" FOODS | \n",
" WI_2 | \n",
" WI | \n",
" 2011-01-31 | \n",
" 2.0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" 1 | \n",
" 31 | \n",
" 2 | \n",
" 5 | \n",
" 1 | \n",
" 1 | \n",
" 2011 | \n",
"
\n",
" \n",
" | 1 | \n",
" FOODS_1_011_WI_2_evaluation | \n",
" FOODS_1_011 | \n",
" FOODS_1 | \n",
" FOODS | \n",
" WI_2 | \n",
" WI | \n",
" 2011-02-01 | \n",
" 0.0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" 2 | \n",
" 1 | \n",
" 3 | \n",
" 5 | \n",
" 1 | \n",
" 2 | \n",
" 2011 | \n",
"
\n",
" \n",
" | 2 | \n",
" FOODS_1_011_WI_2_evaluation | \n",
" FOODS_1_011 | \n",
" FOODS_1 | \n",
" FOODS | \n",
" WI_2 | \n",
" WI | \n",
" 2011-02-02 | \n",
" 0.0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" 3 | \n",
" 2 | \n",
" 4 | \n",
" 5 | \n",
" 1 | \n",
" 2 | \n",
" 2011 | \n",
"
\n",
" \n",
" | 3 | \n",
" FOODS_1_011_WI_2_evaluation | \n",
" FOODS_1_011 | \n",
" FOODS_1 | \n",
" FOODS | \n",
" WI_2 | \n",
" WI | \n",
" 2011-02-03 | \n",
" 0.0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" 4 | \n",
" 3 | \n",
" 5 | \n",
" 5 | \n",
" 1 | \n",
" 2 | \n",
" 2011 | \n",
"
\n",
" \n",
" | 4 | \n",
" FOODS_1_011_WI_2_evaluation | \n",
" FOODS_1_011 | \n",
" FOODS_1 | \n",
" FOODS | \n",
" WI_2 | \n",
" WI | \n",
" 2011-02-04 | \n",
" 0.0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" 5 | \n",
" 4 | \n",
" 6 | \n",
" 5 | \n",
" 1 | \n",
" 2 | \n",
" 2011 | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 1470899 | \n",
" HOUSEHOLD_2_514_WI_3_evaluation | \n",
" HOUSEHOLD_2_514 | \n",
" HOUSEHOLD_2 | \n",
" HOUSEHOLD | \n",
" WI_3 | \n",
" WI | \n",
" 2016-05-18 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.142857 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.142857 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.214286 | \n",
" 0.133333 | \n",
" 9.0 | \n",
" 2.0 | \n",
" 5.0 | \n",
" 6.0 | \n",
" 1936 | \n",
" 18 | \n",
" 4 | \n",
" 20 | \n",
" 2 | \n",
" 5 | \n",
" 2016 | \n",
"
\n",
" \n",
" | 1470900 | \n",
" HOUSEHOLD_2_514_WI_3_evaluation | \n",
" HOUSEHOLD_2_514 | \n",
" HOUSEHOLD_2 | \n",
" HOUSEHOLD | \n",
" WI_3 | \n",
" WI | \n",
" 2016-05-19 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.071429 | \n",
" 0.100000 | \n",
" 0.142857 | \n",
" 0.142857 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.000000 | \n",
" 0.214286 | \n",
" 0.133333 | \n",
" 10.0 | \n",
" 3.0 | \n",
" 6.0 | \n",
" 7.0 | \n",
" 1937 | \n",
" 19 | \n",
" 5 | \n",
" 20 | \n",
" 2 | \n",
" 5 | \n",
" 2016 | \n",
"
\n",
" \n",
" | 1470901 | \n",
" HOUSEHOLD_2_514_WI_3_evaluation | \n",
" HOUSEHOLD_2_514 | \n",
" HOUSEHOLD_2 | \n",
" HOUSEHOLD | \n",
" WI_3 | \n",
" WI | \n",
" 2016-05-20 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.071429 | \n",
" 0.100000 | \n",
" 0.142857 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.000000 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.285714 | \n",
" 0.166667 | \n",
" 11.0 | \n",
" 4.0 | \n",
" 7.0 | \n",
" 0.0 | \n",
" 1938 | \n",
" 20 | \n",
" 6 | \n",
" 20 | \n",
" 2 | \n",
" 5 | \n",
" 2016 | \n",
"
\n",
" \n",
" | 1470902 | \n",
" HOUSEHOLD_2_514_WI_3_evaluation | \n",
" HOUSEHOLD_2_514 | \n",
" HOUSEHOLD_2 | \n",
" HOUSEHOLD | \n",
" WI_3 | \n",
" WI | \n",
" 2016-05-21 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.071429 | \n",
" 0.066667 | \n",
" 0.142857 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.000000 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.285714 | \n",
" 0.166667 | \n",
" 12.0 | \n",
" 5.0 | \n",
" 8.0 | \n",
" 1.0 | \n",
" 1939 | \n",
" 21 | \n",
" 7 | \n",
" 20 | \n",
" 2 | \n",
" 5 | \n",
" 2016 | \n",
"
\n",
" \n",
" | 1470903 | \n",
" HOUSEHOLD_2_514_WI_3_evaluation | \n",
" HOUSEHOLD_2_514 | \n",
" HOUSEHOLD_2 | \n",
" HOUSEHOLD | \n",
" WI_3 | \n",
" WI | \n",
" 2016-05-22 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.071429 | \n",
" 0.066667 | \n",
" 0.142857 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.000000 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.285714 | \n",
" 0.166667 | \n",
" 13.0 | \n",
" 6.0 | \n",
" 9.0 | \n",
" 2.0 | \n",
" 1940 | \n",
" 22 | \n",
" 1 | \n",
" 20 | \n",
" 2 | \n",
" 5 | \n",
" 2016 | \n",
"
\n",
" \n",
"
\n",
"
1470904 rows × 39 columns
\n",
"
"
],
"text/plain": [
" id item_id dept_id \\\n",
"0 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n",
"1 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n",
"2 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n",
"3 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n",
"4 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n",
"... ... ... ... \n",
"1470899 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n",
"1470900 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n",
"1470901 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n",
"1470902 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n",
"1470903 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n",
"\n",
" cat_id store_id state_id date sales lag_7 lag_14 lag_21 \\\n",
"0 FOODS WI_2 WI 2011-01-31 2.0 NaN NaN NaN \n",
"1 FOODS WI_2 WI 2011-02-01 0.0 NaN NaN NaN \n",
"2 FOODS WI_2 WI 2011-02-02 0.0 NaN NaN NaN \n",
"3 FOODS WI_2 WI 2011-02-03 0.0 NaN NaN NaN \n",
"4 FOODS WI_2 WI 2011-02-04 0.0 NaN NaN NaN \n",
"... ... ... ... ... ... ... ... ... \n",
"1470899 HOUSEHOLD WI_3 WI 2016-05-18 0.0 0.0 0.0 0.0 \n",
"1470900 HOUSEHOLD WI_3 WI 2016-05-19 0.0 0.0 0.0 0.0 \n",
"1470901 HOUSEHOLD WI_3 WI 2016-05-20 1.0 0.0 0.0 0.0 \n",
"1470902 HOUSEHOLD WI_3 WI 2016-05-21 0.0 0.0 0.0 0.0 \n",
"1470903 HOUSEHOLD WI_3 WI 2016-05-22 0.0 0.0 0.0 0.0 \n",
"\n",
" lag_28 lag_35 lag_42 lag_49 lag_56 window_7_lag_7_mean \\\n",
"0 NaN NaN NaN NaN NaN NaN \n",
"1 NaN NaN NaN NaN NaN NaN \n",
"2 NaN NaN NaN NaN NaN NaN \n",
"3 NaN NaN NaN NaN NaN NaN \n",
"4 NaN NaN NaN NaN NaN NaN \n",
"... ... ... ... ... ... ... \n",
"1470899 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"1470900 0.0 1.0 0.0 0.0 0.0 0.0 \n",
"1470901 1.0 0.0 0.0 0.0 0.0 0.0 \n",
"1470902 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"1470903 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" window_14_lag_7_mean window_30_lag_7_mean window_7_lag_14_mean \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"1470899 0.071429 0.166667 0.142857 \n",
"1470900 0.071429 0.100000 0.142857 \n",
"1470901 0.071429 0.100000 0.142857 \n",
"1470902 0.071429 0.066667 0.142857 \n",
"1470903 0.071429 0.066667 0.142857 \n",
"\n",
" window_14_lag_14_mean window_30_lag_14_mean window_7_lag_21_mean \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"1470899 0.142857 0.166667 0.142857 \n",
"1470900 0.142857 0.166667 0.142857 \n",
"1470901 0.071429 0.166667 0.000000 \n",
"1470902 0.071429 0.166667 0.000000 \n",
"1470903 0.071429 0.166667 0.000000 \n",
"\n",
" window_14_lag_21_mean window_30_lag_21_mean window_7_lag_28_mean \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"1470899 0.142857 0.166667 0.142857 \n",
"1470900 0.071429 0.166667 0.000000 \n",
"1470901 0.071429 0.166667 0.142857 \n",
"1470902 0.071429 0.166667 0.142857 \n",
"1470903 0.071429 0.166667 0.142857 \n",
"\n",
" window_14_lag_28_mean window_30_lag_28_mean \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"... ... ... \n",
"1470899 0.214286 0.133333 \n",
"1470900 0.214286 0.133333 \n",
"1470901 0.285714 0.166667 \n",
"1470902 0.285714 0.166667 \n",
"1470903 0.285714 0.166667 \n",
"\n",
" count_consecutive_value_lag_7 count_consecutive_value_lag_14 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"... ... ... \n",
"1470899 9.0 2.0 \n",
"1470900 10.0 3.0 \n",
"1470901 11.0 4.0 \n",
"1470902 12.0 5.0 \n",
"1470903 13.0 6.0 \n",
"\n",
" count_consecutive_value_lag_21 count_consecutive_value_lag_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"... ... ... \n",
"1470899 5.0 6.0 \n",
"1470900 6.0 7.0 \n",
"1470901 7.0 0.0 \n",
"1470902 8.0 1.0 \n",
"1470903 9.0 2.0 \n",
"\n",
" history_length day_of_month day_of_week week_of_year quarter \\\n",
"0 1 31 2 5 1 \n",
"1 2 1 3 5 1 \n",
"2 3 2 4 5 1 \n",
"3 4 3 5 5 1 \n",
"4 5 4 6 5 1 \n",
"... ... ... ... ... ... \n",
"1470899 1936 18 4 20 2 \n",
"1470900 1937 19 5 20 2 \n",
"1470901 1938 20 6 20 2 \n",
"1470902 1939 21 7 20 2 \n",
"1470903 1940 22 1 20 2 \n",
"\n",
" month year \n",
"0 1 2011 \n",
"1 2 2011 \n",
"2 2 2011 \n",
"3 2 2011 \n",
"4 2 2011 \n",
"... ... ... \n",
"1470899 5 2016 \n",
"1470900 5 2016 \n",
"1470901 5 2016 \n",
"1470902 5 2016 \n",
"1470903 5 2016 \n",
"\n",
"[1470904 rows x 39 columns]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"feature_extractor.transform(df.toPandas(), spark=spark)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### PySpark DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n",
"| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|lag_35|lag_42|lag_49|lag_56|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|count_consecutive_value_lag_7|count_consecutive_value_lag_14|count_consecutive_value_lag_21|count_consecutive_value_lag_28|history_length|day_of_month|day_of_week|week_of_year|quarter|month|year|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 1| 31| 2| 5| 1| 1|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 2| 1| 3| 5| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 3| 2| 4| 5| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 4| 3| 5| 5| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 5| 4| 6| 5| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 6| 5| 7| 5| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 7| 6| 1| 5| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| null| null| null| null| null| null| null| 2.0| 2.0| 2.0| null| null| null| null| null| null| null| null| null| 0| null| null| null| 8| 7| 2| 6| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 0.0| null| null| null| null| null| null| null| 1.0| 1.0| 1.0| null| null| null| null| null| null| null| null| null| 1| null| null| null| 9| 8| 3| 6| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.0| null| null| null| null| null| null| null| 0.6666666666666666| 0.6666666666666666| 0.6666666666666666| null| null| null| null| null| null| null| null| null| 2| null| null| null| 10| 9| 4| 6| 1| 2|2011|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"df_features = feature_extractor.transform(df).localCheckpoint()\n",
"df_features.show(10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training/Test Dataset"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"df_train = df_features.filter(F.col(\"date\") < \"2016-04-25\")\n",
"df_test = df_features.filter(F.col(\"date\") >= \"2016-04-25\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"forecast_flow = ForecastFlowML(\n",
" group_col=\"store_id\",\n",
" id_col=\"id\",\n",
" date_col=\"date\",\n",
" target_col=\"sales\",\n",
" categorical_cols=[\"item_id\", \"dept_id\", \"cat_id\"],\n",
" date_frequency=\"days\",\n",
" model_horizon=7,\n",
" max_forecast_horizon=28,\n",
" model=LGBMRegressor(),\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### PySpark DataFrame with Distributed Results"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----+--------------------+--------------------+--------------------+--------------------+---------------+\n",
"|group| forecast_horizon| model| start_time| end_time|elapsed_seconds|\n",
"+-----+--------------------+--------------------+--------------------+--------------------+---------------+\n",
"| CA_2|[[1, 2, 3, 4, 5, ...|[\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.8|\n",
"| CA_3|[[1, 2, 3, 4, 5, ...|[\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.2|\n",
"| WI_2|[[1, 2, 3, 4, 5, ...|[\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.2|\n",
"| WI_3|[[1, 2, 3, 4, 5, ...|[\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 2.9|\n",
"| CA_1|[[1, 2, 3, 4, 5, ...|[\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 4.3|\n",
"| CA_4|[[1, 2, 3, 4, 5, ...|[\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.5|\n",
"| TX_1|[[1, 2, 3, 4, 5, ...|[\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.2|\n",
"| TX_3|[[1, 2, 3, 4, 5, ...|[\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.0|\n",
"| WI_1|[[1, 2, 3, 4, 5, ...|[\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 2.0|\n",
"| TX_2|[[1, 2, 3, 4, 5, ...|[\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.8|\n",
"+-----+--------------------+--------------------+--------------------+--------------------+---------------+\n",
"\n"
]
}
],
"source": [
"trained_models = forecast_flow.train(df_train).localCheckpoint()\n",
"trained_models.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### PySpark DataFrame with Local Results"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" group | \n",
" forecast_horizon | \n",
" model | \n",
" start_time | \n",
" end_time | \n",
" elapsed_seconds | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" CA_2 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:22:31) | \n",
" 01-May-2023 (03:22:38) | \n",
" 6.6 | \n",
"
\n",
" \n",
" | 1 | \n",
" CA_3 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:22:38) | \n",
" 01-May-2023 (03:22:42) | \n",
" 3.6 | \n",
"
\n",
" \n",
" | 2 | \n",
" WI_2 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:22:42) | \n",
" 01-May-2023 (03:22:47) | \n",
" 5.1 | \n",
"
\n",
" \n",
" | 3 | \n",
" WI_3 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:22:47) | \n",
" 01-May-2023 (03:22:51) | \n",
" 3.2 | \n",
"
\n",
" \n",
" | 4 | \n",
" CA_1 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:22:30) | \n",
" 01-May-2023 (03:22:37) | \n",
" 7.5 | \n",
"
\n",
" \n",
" | 5 | \n",
" CA_4 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:22:38) | \n",
" 01-May-2023 (03:22:41) | \n",
" 3.8 | \n",
"
\n",
" \n",
" | 6 | \n",
" TX_1 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:22:42) | \n",
" 01-May-2023 (03:22:47) | \n",
" 5.3 | \n",
"
\n",
" \n",
" | 7 | \n",
" TX_3 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:22:48) | \n",
" 01-May-2023 (03:22:51) | \n",
" 3.4 | \n",
"
\n",
" \n",
" | 8 | \n",
" WI_1 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:22:51) | \n",
" 01-May-2023 (03:22:54) | \n",
" 2.4 | \n",
"
\n",
" \n",
" | 9 | \n",
" TX_2 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:22:28) | \n",
" 01-May-2023 (03:22:33) | \n",
" 4.7 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" group forecast_horizon \\\n",
"0 CA_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"1 CA_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"2 WI_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"3 WI_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"4 CA_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"5 CA_4 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"6 TX_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"7 TX_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"8 WI_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"9 TX_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"\n",
" model start_time \\\n",
"0 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:31) \n",
"1 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:38) \n",
"2 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:42) \n",
"3 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:47) \n",
"4 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:30) \n",
"5 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:38) \n",
"6 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:42) \n",
"7 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:48) \n",
"8 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:51) \n",
"9 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:28) \n",
"\n",
" end_time elapsed_seconds \n",
"0 01-May-2023 (03:22:38) 6.6 \n",
"1 01-May-2023 (03:22:42) 3.6 \n",
"2 01-May-2023 (03:22:47) 5.1 \n",
"3 01-May-2023 (03:22:51) 3.2 \n",
"4 01-May-2023 (03:22:37) 7.5 \n",
"5 01-May-2023 (03:22:41) 3.8 \n",
"6 01-May-2023 (03:22:47) 5.3 \n",
"7 01-May-2023 (03:22:51) 3.4 \n",
"8 01-May-2023 (03:22:54) 2.4 \n",
"9 01-May-2023 (03:22:33) 4.7 "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"forecast_flow.train(df_train, local_result=True)\n",
"forecast_flow.model_"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pandas DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" group | \n",
" forecast_horizon | \n",
" model | \n",
" start_time | \n",
" end_time | \n",
" elapsed_seconds | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" CA_2 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:23:16) | \n",
" 01-May-2023 (03:23:21) | \n",
" 4.4 | \n",
"
\n",
" \n",
" | 1 | \n",
" CA_3 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:23:21) | \n",
" 01-May-2023 (03:23:25) | \n",
" 3.4 | \n",
"
\n",
" \n",
" | 2 | \n",
" WI_2 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:23:25) | \n",
" 01-May-2023 (03:23:28) | \n",
" 3.0 | \n",
"
\n",
" \n",
" | 3 | \n",
" WI_3 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:23:28) | \n",
" 01-May-2023 (03:23:32) | \n",
" 3.3 | \n",
"
\n",
" \n",
" | 4 | \n",
" CA_1 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:23:14) | \n",
" 01-May-2023 (03:23:20) | \n",
" 5.8 | \n",
"
\n",
" \n",
" | 5 | \n",
" CA_4 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:23:21) | \n",
" 01-May-2023 (03:23:24) | \n",
" 3.3 | \n",
"
\n",
" \n",
" | 6 | \n",
" TX_1 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:23:24) | \n",
" 01-May-2023 (03:23:28) | \n",
" 3.4 | \n",
"
\n",
" \n",
" | 7 | \n",
" TX_3 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:23:28) | \n",
" 01-May-2023 (03:23:32) | \n",
" 3.4 | \n",
"
\n",
" \n",
" | 8 | \n",
" WI_1 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:23:32) | \n",
" 01-May-2023 (03:23:34) | \n",
" 2.2 | \n",
"
\n",
" \n",
" | 9 | \n",
" TX_2 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (03:23:12) | \n",
" 01-May-2023 (03:23:17) | \n",
" 5.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" group forecast_horizon \\\n",
"0 CA_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"1 CA_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"2 WI_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"3 WI_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"4 CA_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"5 CA_4 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"6 TX_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"7 TX_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"8 WI_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"9 TX_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"\n",
" model start_time \\\n",
"0 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:16) \n",
"1 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:21) \n",
"2 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:25) \n",
"3 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:28) \n",
"4 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:14) \n",
"5 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:21) \n",
"6 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:24) \n",
"7 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:28) \n",
"8 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:32) \n",
"9 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:12) \n",
"\n",
" end_time elapsed_seconds \n",
"0 01-May-2023 (03:23:21) 4.4 \n",
"1 01-May-2023 (03:23:25) 3.4 \n",
"2 01-May-2023 (03:23:28) 3.0 \n",
"3 01-May-2023 (03:23:32) 3.3 \n",
"4 01-May-2023 (03:23:20) 5.8 \n",
"5 01-May-2023 (03:23:24) 3.3 \n",
"6 01-May-2023 (03:23:28) 3.4 \n",
"7 01-May-2023 (03:23:32) 3.4 \n",
"8 01-May-2023 (03:23:34) 2.2 \n",
"9 01-May-2023 (03:23:17) 5.0 "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"forecast_flow.train(df_train.toPandas(), spark=spark)\n",
"forecast_flow.model_"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prediction"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### PySpark DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----+--------------------+----------+----------+\n",
"|group| id| date|prediction|\n",
"+-----+--------------------+----------+----------+\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-04-25| 0.481568|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-04-26|0.46724537|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-04-27|0.41596597|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-04-28|0.40775877|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-04-29|0.43439913|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-04-30| 0.4951446|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-05-01| 0.4308696|\n",
"| CA_2|FOODS_1_192_CA_2_...|2016-04-25| 0.2172628|\n",
"| CA_2|FOODS_1_192_CA_2_...|2016-04-26| 0.1687214|\n",
"| CA_2|FOODS_1_192_CA_2_...|2016-04-27| 0.1687214|\n",
"+-----+--------------------+----------+----------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"forecast = forecast_flow.predict(df_test, trained_models).localCheckpoint()\n",
"forecast.show(10)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n",
"| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|lag_35|lag_42|lag_49|lag_56|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|count_consecutive_value_lag_7|count_consecutive_value_lag_14|count_consecutive_value_lag_21|count_consecutive_value_lag_28|history_length|day_of_month|day_of_week|week_of_year|quarter|month|year|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-25| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 1.0| 0.0| 0.0| 1.0| 0.5714285714285714| 0.8| 0.14285714285714285| 0.7857142857142857| 0.8| 1.4285714285714286| 0.7142857142857143| 0.9666666666666667| 0.0| 0.5714285714285714| 0.7333333333333333| 1| 5| 0| 8| 1912| 25| 2| 17| 2| 4|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-26| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 4.0| 0.0| 1.0| 0.5714285714285714| 0.6| 0.14285714285714285| 0.7857142857142857| 0.6666666666666666| 1.4285714285714286| 0.7142857142857143| 0.9666666666666667| 0.0| 0.5714285714285714| 0.6333333333333333| 2| 6| 1| 9| 1913| 26| 3| 17| 2| 4|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-27| 0.0| 2.0| 0.0| 1.0| 0.0| 0.0| 1.0| 2.0| 0.0| 1.2857142857142858| 0.6428571428571429| 0.6666666666666666| 0.0| 0.7857142857142857| 0.6333333333333333| 1.5714285714285714| 0.7857142857142857| 1.0| 0.0| 0.5| 0.6333333333333333| 0| 7| 0| 10| 1914| 27| 4| 17| 2| 4|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-28| 1.0| 0.0| 4.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 1.0714285714285714| 0.7666666666666667| 1.5714285714285714| 0.7857142857142857| 0.8666666666666667| 0.0| 0.42857142857142855| 0.6333333333333333| 1| 0| 1| 11| 1915| 28| 5| 17| 2| 4|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-29| 0.0| 0.0| 0.0| 0.0| 4.0| 0.0| 0.0| 0.0| 0.0| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.7857142857142857| 0.8| 0.5714285714285714| 0.7142857142857143| 0.7666666666666667| 2| 1| 2| 0| 1916| 29| 6| 17| 2| 4|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-30| 1.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 1.0| 0.7857142857142857| 0.7333333333333333| 0.5714285714285714| 0.7857142857142857| 0.7| 1.0| 0.7857142857142857| 0.8| 0.5714285714285714| 0.7142857142857143| 0.7666666666666667| 0| 2| 3| 1| 1917| 30| 7| 17| 2| 4|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-01| 4.0| 0.0| 3.0| 0.0| 5.0| 0.0| 6.0| 4.0| 0.0| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.6428571428571429| 0.8| 0.2857142857142857| 0.7857142857142857| 0.8| 1.2857142857142858| 0.6428571428571429| 0.9333333333333333| 1| 0| 4| 0| 1918| 1| 1| 17| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-02| 0.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 1.0| 0.0| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.5714285714285714| 0.8| 0.14285714285714285| 0.7857142857142857| 0.8| 1.4285714285714286| 0.7142857142857143| 0.9666666666666667| 2| 1| 5| 0| 1919| 2| 2| 18| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-03| 0.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 4.0| 0.8571428571428571| 0.9285714285714286| 0.8| 1.0| 0.5714285714285714| 0.6| 0.14285714285714285| 0.7857142857142857| 0.6666666666666666| 1.4285714285714286| 0.7142857142857143| 0.9666666666666667| 0| 2| 6| 1| 1920| 3| 3| 18| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-04| 0.0| 0.0| 2.0| 0.0| 1.0| 0.0| 0.0| 1.0| 2.0| 0.5714285714285714| 0.9285714285714286| 0.8| 1.2857142857142858| 0.6428571428571429| 0.6666666666666666| 0.0| 0.7857142857142857| 0.6333333333333333| 1.5714285714285714| 0.7857142857142857| 1.0| 1| 0| 7| 0| 1921| 4| 4| 18| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-05| 0.0| 1.0| 0.0| 4.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.7142857142857143| 0.7142857142857143| 0.8333333333333334| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 1.0714285714285714| 0.7666666666666667| 1.5714285714285714| 0.7857142857142857| 0.8666666666666667| 0| 1| 0| 1| 1922| 5| 5| 18| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-06| 4.0| 0.0| 0.0| 0.0| 0.0| 4.0| 0.0| 0.0| 0.0| 0.7142857142857143| 0.7142857142857143| 0.8333333333333334| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.7857142857142857| 0.8| 1| 2| 1| 2| 1923| 6| 6| 18| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-07| 0.0| 1.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.5714285714285714| 0.7857142857142857| 0.8666666666666667| 1.0| 0.7857142857142857| 0.7333333333333333| 0.5714285714285714| 0.7857142857142857| 0.7| 1.0| 0.7857142857142857| 0.8| 0| 0| 2| 3| 1924| 7| 7| 18| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-08| 0.0| 4.0| 0.0| 3.0| 0.0| 5.0| 0.0| 6.0| 4.0| 1.1428571428571428| 0.8571428571428571| 0.8666666666666667| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.6428571428571429| 0.8| 0.2857142857142857| 0.7857142857142857| 0.8| 0| 1| 0| 4| 1925| 8| 1| 18| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-09| 1.0| 0.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 1.0| 1.1428571428571428| 0.8571428571428571| 0.8666666666666667| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.5714285714285714| 0.8| 0.14285714285714285| 0.7857142857142857| 0.8| 1| 2| 1| 5| 1926| 9| 2| 19| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-10| 0.0| 0.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.8571428571428571| 0.8571428571428571| 0.7| 0.8571428571428571| 0.9285714285714286| 0.8| 1.0| 0.5714285714285714| 0.6| 0.14285714285714285| 0.7857142857142857| 0.6666666666666666| 2| 0| 2| 6| 1927| 10| 3| 19| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-11| 0.0| 0.0| 0.0| 2.0| 0.0| 1.0| 0.0| 0.0| 1.0| 0.8571428571428571| 0.7142857142857143| 0.6666666666666666| 0.5714285714285714| 0.9285714285714286| 0.8| 1.2857142857142858| 0.6428571428571429| 0.6666666666666666| 0.0| 0.7857142857142857| 0.6333333333333333| 3| 1| 0| 7| 1928| 11| 4| 19| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-12| 1.0| 0.0| 1.0| 0.0| 4.0| 0.0| 0.0| 0.0| 1.0| 0.7142857142857143| 0.7142857142857143| 0.6666666666666666| 0.7142857142857143| 0.7142857142857143| 0.8333333333333334| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 1.0714285714285714| 0.7666666666666667| 4| 0| 1| 0| 1929| 12| 5| 19| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-13| 0.0| 4.0| 0.0| 0.0| 0.0| 0.0| 4.0| 0.0| 0.0| 1.2857142857142858| 1.0| 0.7666666666666667| 0.7142857142857143| 0.7142857142857143| 0.8333333333333334| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 0| 1| 2| 1| 1930| 13| 6| 19| 2| 5|2016|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-14| 0.0| 0.0| 1.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 1.1428571428571428| 0.8571428571428571| 0.7666666666666667| 0.5714285714285714| 0.7857142857142857| 0.8666666666666667| 1.0| 0.7857142857142857| 0.7333333333333333| 0.5714285714285714| 0.7857142857142857| 0.7| 1| 0| 0| 2| 1931| 14| 7| 19| 2| 5|2016|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"source": [
"df_test.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pandas DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" group | \n",
" id | \n",
" date | \n",
" prediction | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" CA_2 | \n",
" FOODS_1_179_CA_2_evaluation | \n",
" 2016-04-25 | \n",
" 0.481568 | \n",
"
\n",
" \n",
" | 1 | \n",
" CA_2 | \n",
" FOODS_1_179_CA_2_evaluation | \n",
" 2016-04-26 | \n",
" 0.467245 | \n",
"
\n",
" \n",
" | 2 | \n",
" CA_2 | \n",
" FOODS_1_179_CA_2_evaluation | \n",
" 2016-04-27 | \n",
" 0.415966 | \n",
"
\n",
" \n",
" | 3 | \n",
" CA_2 | \n",
" FOODS_1_179_CA_2_evaluation | \n",
" 2016-04-28 | \n",
" 0.407759 | \n",
"
\n",
" \n",
" | 4 | \n",
" CA_2 | \n",
" FOODS_1_179_CA_2_evaluation | \n",
" 2016-04-29 | \n",
" 0.434399 | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 26427 | \n",
" TX_2 | \n",
" HOUSEHOLD_2_481_TX_2_evaluation | \n",
" 2016-05-18 | \n",
" 0.215980 | \n",
"
\n",
" \n",
" | 26428 | \n",
" TX_2 | \n",
" HOUSEHOLD_2_481_TX_2_evaluation | \n",
" 2016-05-19 | \n",
" 0.215980 | \n",
"
\n",
" \n",
" | 26429 | \n",
" TX_2 | \n",
" HOUSEHOLD_2_481_TX_2_evaluation | \n",
" 2016-05-20 | \n",
" 0.222249 | \n",
"
\n",
" \n",
" | 26430 | \n",
" TX_2 | \n",
" HOUSEHOLD_2_481_TX_2_evaluation | \n",
" 2016-05-21 | \n",
" 0.334569 | \n",
"
\n",
" \n",
" | 26431 | \n",
" TX_2 | \n",
" HOUSEHOLD_2_481_TX_2_evaluation | \n",
" 2016-05-22 | \n",
" 0.313987 | \n",
"
\n",
" \n",
"
\n",
"
26432 rows × 4 columns
\n",
"
"
],
"text/plain": [
" group id date prediction\n",
"0 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-25 0.481568\n",
"1 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-26 0.467245\n",
"2 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-27 0.415966\n",
"3 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-28 0.407759\n",
"4 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-29 0.434399\n",
"... ... ... ... ...\n",
"26427 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-18 0.215980\n",
"26428 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-19 0.215980\n",
"26429 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-20 0.222249\n",
"26430 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-21 0.334569\n",
"26431 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-22 0.313987\n",
"\n",
"[26432 rows x 4 columns]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"forecast_flow.predict(df_test.toPandas(), spark=spark)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Visualize Predictions"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"past_future = (\n",
" df.select(\"id\", \"store_id\", \"date\", \"sales\")\n",
" .join(forecast, on=[\"id\", \"date\"], how=\"left\")\n",
" .groupBy(\"store_id\", \"date\")\n",
" .agg(\n",
" F.sum(\"sales\").alias(\"sales\"),\n",
" F.sum(\"prediction\").alias(\"prediction\"),\n",
" )\n",
" .orderBy(\"store_id\", \"date\")\n",
" .toPandas()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" \n",
" "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pio.renderers.default = \"notebook\"\n",
"fig = px.line(\n",
" past_future,\n",
" x=\"date\",\n",
" y=[\"sales\", \"prediction\"],\n",
" facet_row_spacing=0.04,\n",
" facet_col=\"store_id\",\n",
" facet_col_wrap=2,\n",
" height=1000,\n",
" width=720,\n",
")\n",
"fig.update_layout(\n",
" legend=dict(orientation=\"h\", yanchor=\"top\", y=1.07, xanchor=\"center\", x=0.5),\n",
" margin=dict(l=0, r=10, t=5, b=5),\n",
" legend_title=\"\",\n",
")\n",
"fig.update_traces(line=dict(width=1.7))\n",
"fig.update_yaxes(matches=None, title=\"\")\n",
"fig.update_xaxes(type=\"date\", range=[\"2015-11-01\", \"2016-05-22\"])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Backtesting"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----+--------------------+----------+---+------+-----------+\n",
"|group| id| date| cv|target| prediction|\n",
"+-----+--------------------+----------+---+------+-----------+\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-03-28| 0| 0.0| 0.44766802|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-03-29| 0| 0.0| 0.43386874|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-03-30| 0| 0.0| 0.40635538|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-03-31| 0| 1.0| 0.3618364|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-04-01| 0| 0.0| 0.40051356|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-04-02| 0| 1.0| 0.42851403|\n",
"| CA_2|FOODS_1_179_CA_2_...|2016-04-03| 0| 0.0| 0.40656742|\n",
"| CA_2|FOODS_1_192_CA_2_...|2016-03-28| 0| 0.0| 0.13468084|\n",
"| CA_2|FOODS_1_192_CA_2_...|2016-03-29| 0| 0.0|0.103752814|\n",
"| CA_2|FOODS_1_192_CA_2_...|2016-03-30| 0| 2.0|0.103752814|\n",
"+-----+--------------------+----------+---+------+-----------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"cv_forecast = forecast_flow.cross_validate(df_train).localCheckpoint()\n",
"cv_forecast.show(10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Visualize Cross Validation"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"cv_forecast = (\n",
" df_train.select(\"id\", \"store_id\", \"date\", \"sales\")\n",
" .join(\n",
" cv_forecast.select(\"id\", \"date\", \"cv\", \"prediction\"),\n",
" on=[\"id\", \"date\"],\n",
" how=\"left\",\n",
" )\n",
" .groupBy(\"id\", \"store_id\", \"date\", \"sales\")\n",
" .pivot(\"cv\")\n",
" .sum(\"prediction\")\n",
" .groupBy(\"store_id\", \"date\")\n",
" .agg(\n",
" F.sum(\"sales\").alias(\"sales\"),\n",
" *[F.sum(f\"{i}\").alias(f\"cv_{i}\") for i in range(3)],\n",
" )\n",
" .orderBy(\"store_id\", \"date\")\n",
").toPandas()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" \n",
" "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pio.renderers.default = \"notebook\"\n",
"fig = px.line(\n",
" cv_forecast,\n",
" x=\"date\",\n",
" y=[\"sales\", *[f\"cv_{i}\" for i in range(3)]],\n",
" facet_row_spacing=0.04,\n",
" facet_col=\"store_id\",\n",
" facet_col_wrap=2,\n",
" height=1000,\n",
" width=720,\n",
")\n",
"fig.update_layout(\n",
" legend=dict(orientation=\"h\", yanchor=\"top\", y=1.07, xanchor=\"center\", x=0.5),\n",
" margin=dict(l=0, r=10, t=5, b=5),\n",
" legend_title=\"\",\n",
")\n",
"fig.update_traces(line=dict(width=1.7))\n",
"fig.update_yaxes(matches=None, title=\"\")\n",
"fig.update_xaxes(type=\"date\", range=[\"2015-11-01\", \"2016-04-24\"])"
]
}
],
"metadata": {
"application/vnd.databricks.v1+notebook": {
"dashboards": [
{
"elements": [],
"globalVars": {},
"guid": "ef82ffd4-2993-4b79-8327-f644b750f2dd",
"layoutOption": {
"grid": true,
"stack": true
},
"nuid": "a172d56a-d964-4505-ba66-5a7011220dbf",
"origId": 1859120955398731,
"title": "Untitled",
"version": "DashboardViewV1",
"width": 1024
}
],
"language": "python",
"notebookMetadata": {
"pythonIndentUnit": 4
},
"notebookName": "ForecastFlowML Demo",
"notebookOrigID": 2597536912577418,
"widgets": {}
},
"kernelspec": {
"display_name": "spark",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.16"
}
},
"nbformat": 4,
"nbformat_minor": 0
}