{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": {}, "inputWidgets": {}, "nuid": "a02bb3a2-e5d7-4d14-b966-827457675b75", "showTitle": false, "title": "" } }, "source": [ "# Quick Start\n", "\n", "ForecastFlowML is designed for scaleable forecasting and uses Spark for both \n", "feature engineering and training/prediction/hyperparameter optimisation.\n", "\n", "## Use Cases\n", "ForecastFlowML can generally be used for three use cases:\n", "- Data is stored in a ``PySpark DataFrame``, and we need to paralelly build \n", "many/big group models which does not fit into driver memory.\n", "- Data is stored in a ``PySpark DataFrame``, and we need to paralelly build\n", "a few/small group models which fits into driver memory.\n", "- Data is stored in a ``Pandas DataFrame``, and we need to paralelly build\n", "a few/small group models which fits into driver memory.\n", "\n", "This quick guide shows how you can develop a scaleable forecasting system on \n", "Kaggle Walmart M5 Competition sample dataset.\n", "\n", "## Goal\n", "- Build independent models for each of the stores in the dataset.\n", "- Parallelize training/inference steps.\n", "- Use LightGBM as machine learning algorithm.\n", "- Utilize direct multi-step forecasting approach.\n", "- Perform backtesting." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Import Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from forecastflowml import ForecastFlowML\n", "from forecastflowml import FeatureExtractor\n", "from forecastflowml.data.loader import load_walmart_m5\n", "from lightgbm import LGBMRegressor\n", "from pyspark.sql import SparkSession\n", "import pyspark.sql.functions as F\n", "import plotly.express as px\n", "import plotly.io as pio\n", "import pandas as pd\n", "\n", "pd.set_option(\"display.max_columns\", 40)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize Spark" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "spark = (\n", " SparkSession.builder.master(\"local[4]\")\n", " .config(\"spark.driver.memory\", \"8g\")\n", " .config(\"spark.sql.shuffle.partitions\", \"4\")\n", " .config(\"spark.sql.execution.arrow.enabled\", \"true\")\n", " .getOrCreate()\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Sample Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-03| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-04| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-05| 1.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-06| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-07| 3.0|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "df = load_walmart_m5(spark)\n", "df.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Feature Engineering" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " lag_window_features={\n", " \"lag\": [7 * (i + 1) for i in range(8)],\n", " \"mean\": [[window, lag] for lag in [7, 14, 21, 28] for window in [7, 14, 30]],\n", " },\n", " date_features=[\n", " \"day_of_month\",\n", " \"day_of_week\",\n", " \"week_of_year\",\n", " \"quarter\",\n", " \"month\",\n", " \"year\",\n", " ],\n", " count_consecutive_values={\n", " \"value\": 0,\n", " \"lags\": [7, 14, 21, 28],\n", " },\n", " history_length=True,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Pandas DataFrame" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
iditem_iddept_idcat_idstore_idstate_iddatesaleslag_7lag_14lag_21lag_28lag_35lag_42lag_49lag_56window_7_lag_7_meanwindow_14_lag_7_meanwindow_30_lag_7_meanwindow_7_lag_14_meanwindow_14_lag_14_meanwindow_30_lag_14_meanwindow_7_lag_21_meanwindow_14_lag_21_meanwindow_30_lag_21_meanwindow_7_lag_28_meanwindow_14_lag_28_meanwindow_30_lag_28_meancount_consecutive_value_lag_7count_consecutive_value_lag_14count_consecutive_value_lag_21count_consecutive_value_lag_28history_lengthday_of_monthday_of_weekweek_of_yearquartermonthyear
0FOODS_1_011_WI_2_evaluationFOODS_1_011FOODS_1FOODSWI_2WI2011-01-312.0NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN13125112011
1FOODS_1_011_WI_2_evaluationFOODS_1_011FOODS_1FOODSWI_2WI2011-02-010.0NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN2135122011
2FOODS_1_011_WI_2_evaluationFOODS_1_011FOODS_1FOODSWI_2WI2011-02-020.0NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN3245122011
3FOODS_1_011_WI_2_evaluationFOODS_1_011FOODS_1FOODSWI_2WI2011-02-030.0NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN4355122011
4FOODS_1_011_WI_2_evaluationFOODS_1_011FOODS_1FOODSWI_2WI2011-02-040.0NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN5465122011
........................................................................................................................
1470899HOUSEHOLD_2_514_WI_3_evaluationHOUSEHOLD_2_514HOUSEHOLD_2HOUSEHOLDWI_3WI2016-05-180.00.00.00.00.00.00.00.00.00.00.0714290.1666670.1428570.1428570.1666670.1428570.1428570.1666670.1428570.2142860.1333339.02.05.06.0193618420252016
1470900HOUSEHOLD_2_514_WI_3_evaluationHOUSEHOLD_2_514HOUSEHOLD_2HOUSEHOLDWI_3WI2016-05-190.00.00.00.00.01.00.00.00.00.00.0714290.1000000.1428570.1428570.1666670.1428570.0714290.1666670.0000000.2142860.13333310.03.06.07.0193719520252016
1470901HOUSEHOLD_2_514_WI_3_evaluationHOUSEHOLD_2_514HOUSEHOLD_2HOUSEHOLDWI_3WI2016-05-201.00.00.00.01.00.00.00.00.00.00.0714290.1000000.1428570.0714290.1666670.0000000.0714290.1666670.1428570.2857140.16666711.04.07.00.0193820620252016
1470902HOUSEHOLD_2_514_WI_3_evaluationHOUSEHOLD_2_514HOUSEHOLD_2HOUSEHOLDWI_3WI2016-05-210.00.00.00.00.00.00.00.00.00.00.0714290.0666670.1428570.0714290.1666670.0000000.0714290.1666670.1428570.2857140.16666712.05.08.01.0193921720252016
1470903HOUSEHOLD_2_514_WI_3_evaluationHOUSEHOLD_2_514HOUSEHOLD_2HOUSEHOLDWI_3WI2016-05-220.00.00.00.00.00.00.00.00.00.00.0714290.0666670.1428570.0714290.1666670.0000000.0714290.1666670.1428570.2857140.16666713.06.09.02.0194022120252016
\n", "

1470904 rows × 39 columns

\n", "
" ], "text/plain": [ " id item_id dept_id \\\n", "0 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n", "1 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n", "2 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n", "3 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n", "4 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n", "... ... ... ... \n", "1470899 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n", "1470900 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n", "1470901 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n", "1470902 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n", "1470903 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n", "\n", " cat_id store_id state_id date sales lag_7 lag_14 lag_21 \\\n", "0 FOODS WI_2 WI 2011-01-31 2.0 NaN NaN NaN \n", "1 FOODS WI_2 WI 2011-02-01 0.0 NaN NaN NaN \n", "2 FOODS WI_2 WI 2011-02-02 0.0 NaN NaN NaN \n", "3 FOODS WI_2 WI 2011-02-03 0.0 NaN NaN NaN \n", "4 FOODS WI_2 WI 2011-02-04 0.0 NaN NaN NaN \n", "... ... ... ... ... ... ... ... ... \n", "1470899 HOUSEHOLD WI_3 WI 2016-05-18 0.0 0.0 0.0 0.0 \n", "1470900 HOUSEHOLD WI_3 WI 2016-05-19 0.0 0.0 0.0 0.0 \n", "1470901 HOUSEHOLD WI_3 WI 2016-05-20 1.0 0.0 0.0 0.0 \n", "1470902 HOUSEHOLD WI_3 WI 2016-05-21 0.0 0.0 0.0 0.0 \n", "1470903 HOUSEHOLD WI_3 WI 2016-05-22 0.0 0.0 0.0 0.0 \n", "\n", " lag_28 lag_35 lag_42 lag_49 lag_56 window_7_lag_7_mean \\\n", "0 NaN NaN NaN NaN NaN NaN \n", "1 NaN NaN NaN NaN NaN NaN \n", "2 NaN NaN NaN NaN NaN NaN \n", "3 NaN NaN NaN NaN NaN NaN \n", "4 NaN NaN NaN NaN NaN NaN \n", "... ... ... ... ... ... ... \n", "1470899 0.0 0.0 0.0 0.0 0.0 0.0 \n", "1470900 0.0 1.0 0.0 0.0 0.0 0.0 \n", "1470901 1.0 0.0 0.0 0.0 0.0 0.0 \n", "1470902 0.0 0.0 0.0 0.0 0.0 0.0 \n", "1470903 0.0 0.0 0.0 0.0 0.0 0.0 \n", "\n", " window_14_lag_7_mean window_30_lag_7_mean window_7_lag_14_mean \\\n", "0 NaN NaN NaN \n", "1 NaN NaN NaN \n", "2 NaN NaN NaN \n", "3 NaN NaN NaN \n", "4 NaN NaN NaN \n", "... ... ... ... \n", "1470899 0.071429 0.166667 0.142857 \n", "1470900 0.071429 0.100000 0.142857 \n", "1470901 0.071429 0.100000 0.142857 \n", "1470902 0.071429 0.066667 0.142857 \n", "1470903 0.071429 0.066667 0.142857 \n", "\n", " window_14_lag_14_mean window_30_lag_14_mean window_7_lag_21_mean \\\n", "0 NaN NaN NaN \n", "1 NaN NaN NaN \n", "2 NaN NaN NaN \n", "3 NaN NaN NaN \n", "4 NaN NaN NaN \n", "... ... ... ... \n", "1470899 0.142857 0.166667 0.142857 \n", "1470900 0.142857 0.166667 0.142857 \n", "1470901 0.071429 0.166667 0.000000 \n", "1470902 0.071429 0.166667 0.000000 \n", "1470903 0.071429 0.166667 0.000000 \n", "\n", " window_14_lag_21_mean window_30_lag_21_mean window_7_lag_28_mean \\\n", "0 NaN NaN NaN \n", "1 NaN NaN NaN \n", "2 NaN NaN NaN \n", "3 NaN NaN NaN \n", "4 NaN NaN NaN \n", "... ... ... ... \n", "1470899 0.142857 0.166667 0.142857 \n", "1470900 0.071429 0.166667 0.000000 \n", "1470901 0.071429 0.166667 0.142857 \n", "1470902 0.071429 0.166667 0.142857 \n", "1470903 0.071429 0.166667 0.142857 \n", "\n", " window_14_lag_28_mean window_30_lag_28_mean \\\n", "0 NaN NaN \n", "1 NaN NaN \n", "2 NaN NaN \n", "3 NaN NaN \n", "4 NaN NaN \n", "... ... ... \n", "1470899 0.214286 0.133333 \n", "1470900 0.214286 0.133333 \n", "1470901 0.285714 0.166667 \n", "1470902 0.285714 0.166667 \n", "1470903 0.285714 0.166667 \n", "\n", " count_consecutive_value_lag_7 count_consecutive_value_lag_14 \\\n", "0 NaN NaN \n", "1 NaN NaN \n", "2 NaN NaN \n", "3 NaN NaN \n", "4 NaN NaN \n", "... ... ... \n", "1470899 9.0 2.0 \n", "1470900 10.0 3.0 \n", "1470901 11.0 4.0 \n", "1470902 12.0 5.0 \n", "1470903 13.0 6.0 \n", "\n", " count_consecutive_value_lag_21 count_consecutive_value_lag_28 \\\n", "0 NaN NaN \n", "1 NaN NaN \n", "2 NaN NaN \n", "3 NaN NaN \n", "4 NaN NaN \n", "... ... ... \n", "1470899 5.0 6.0 \n", "1470900 6.0 7.0 \n", "1470901 7.0 0.0 \n", "1470902 8.0 1.0 \n", "1470903 9.0 2.0 \n", "\n", " history_length day_of_month day_of_week week_of_year quarter \\\n", "0 1 31 2 5 1 \n", "1 2 1 3 5 1 \n", "2 3 2 4 5 1 \n", "3 4 3 5 5 1 \n", "4 5 4 6 5 1 \n", "... ... ... ... ... ... \n", "1470899 1936 18 4 20 2 \n", "1470900 1937 19 5 20 2 \n", "1470901 1938 20 6 20 2 \n", "1470902 1939 21 7 20 2 \n", "1470903 1940 22 1 20 2 \n", "\n", " month year \n", "0 1 2011 \n", "1 2 2011 \n", "2 2 2011 \n", "3 2 2011 \n", "4 2 2011 \n", "... ... ... \n", "1470899 5 2016 \n", "1470900 5 2016 \n", "1470901 5 2016 \n", "1470902 5 2016 \n", "1470903 5 2016 \n", "\n", "[1470904 rows x 39 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feature_extractor.transform(df.toPandas(), spark=spark)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### PySpark DataFrame" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|lag_35|lag_42|lag_49|lag_56|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|count_consecutive_value_lag_7|count_consecutive_value_lag_14|count_consecutive_value_lag_21|count_consecutive_value_lag_28|history_length|day_of_month|day_of_week|week_of_year|quarter|month|year|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 1| 31| 2| 5| 1| 1|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 2| 1| 3| 5| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 3| 2| 4| 5| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 4| 3| 5| 5| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 5| 4| 6| 5| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 6| 5| 7| 5| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 7| 6| 1| 5| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| null| null| null| null| null| null| null| 2.0| 2.0| 2.0| null| null| null| null| null| null| null| null| null| 0| null| null| null| 8| 7| 2| 6| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 0.0| null| null| null| null| null| null| null| 1.0| 1.0| 1.0| null| null| null| null| null| null| null| null| null| 1| null| null| null| 9| 8| 3| 6| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.0| null| null| null| null| null| null| null| 0.6666666666666666| 0.6666666666666666| 0.6666666666666666| null| null| null| null| null| null| null| null| null| 2| null| null| null| 10| 9| 4| 6| 1| 2|2011|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "df_features = feature_extractor.transform(df).localCheckpoint()\n", "df_features.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Training/Test Dataset" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "df_train = df_features.filter(F.col(\"date\") < \"2016-04-25\")\n", "df_test = df_features.filter(F.col(\"date\") >= \"2016-04-25\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "forecast_flow = ForecastFlowML(\n", " group_col=\"store_id\",\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " categorical_cols=[\"item_id\", \"dept_id\", \"cat_id\"],\n", " date_frequency=\"days\",\n", " model_horizon=7,\n", " max_forecast_horizon=28,\n", " model=LGBMRegressor(),\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### PySpark DataFrame with Distributed Results" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+--------------------+--------------------+--------------------+---------------+\n", "|group| forecast_horizon| model| start_time| end_time|elapsed_seconds|\n", "+-----+--------------------+--------------------+--------------------+--------------------+---------------+\n", "| CA_2|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.8|\n", "| CA_3|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.2|\n", "| WI_2|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.2|\n", "| WI_3|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 2.9|\n", "| CA_1|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 4.3|\n", "| CA_4|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.5|\n", "| TX_1|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.2|\n", "| TX_3|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.0|\n", "| WI_1|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 2.0|\n", "| TX_2|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|01-May-2023 (03:2...|01-May-2023 (03:2...| 3.8|\n", "+-----+--------------------+--------------------+--------------------+--------------------+---------------+\n", "\n" ] } ], "source": [ "trained_models = forecast_flow.train(df_train).localCheckpoint()\n", "trained_models.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### PySpark DataFrame with Local Results" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
groupforecast_horizonmodelstart_timeend_timeelapsed_seconds
0CA_2[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:22:31)01-May-2023 (03:22:38)6.6
1CA_3[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:22:38)01-May-2023 (03:22:42)3.6
2WI_2[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:22:42)01-May-2023 (03:22:47)5.1
3WI_3[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:22:47)01-May-2023 (03:22:51)3.2
4CA_1[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:22:30)01-May-2023 (03:22:37)7.5
5CA_4[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:22:38)01-May-2023 (03:22:41)3.8
6TX_1[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:22:42)01-May-2023 (03:22:47)5.3
7TX_3[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:22:48)01-May-2023 (03:22:51)3.4
8WI_1[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:22:51)01-May-2023 (03:22:54)2.4
9TX_2[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:22:28)01-May-2023 (03:22:33)4.7
\n", "
" ], "text/plain": [ " group forecast_horizon \\\n", "0 CA_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "1 CA_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "2 WI_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "3 WI_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "4 CA_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "5 CA_4 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "6 TX_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "7 TX_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "8 WI_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "9 TX_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "\n", " model start_time \\\n", "0 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:31) \n", "1 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:38) \n", "2 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:42) \n", "3 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:47) \n", "4 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:30) \n", "5 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:38) \n", "6 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:42) \n", "7 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:48) \n", "8 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:51) \n", "9 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:22:28) \n", "\n", " end_time elapsed_seconds \n", "0 01-May-2023 (03:22:38) 6.6 \n", "1 01-May-2023 (03:22:42) 3.6 \n", "2 01-May-2023 (03:22:47) 5.1 \n", "3 01-May-2023 (03:22:51) 3.2 \n", "4 01-May-2023 (03:22:37) 7.5 \n", "5 01-May-2023 (03:22:41) 3.8 \n", "6 01-May-2023 (03:22:47) 5.3 \n", "7 01-May-2023 (03:22:51) 3.4 \n", "8 01-May-2023 (03:22:54) 2.4 \n", "9 01-May-2023 (03:22:33) 4.7 " ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "forecast_flow.train(df_train, local_result=True)\n", "forecast_flow.model_" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Pandas DataFrame" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
groupforecast_horizonmodelstart_timeend_timeelapsed_seconds
0CA_2[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:23:16)01-May-2023 (03:23:21)4.4
1CA_3[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:23:21)01-May-2023 (03:23:25)3.4
2WI_2[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:23:25)01-May-2023 (03:23:28)3.0
3WI_3[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:23:28)01-May-2023 (03:23:32)3.3
4CA_1[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:23:14)01-May-2023 (03:23:20)5.8
5CA_4[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:23:21)01-May-2023 (03:23:24)3.3
6TX_1[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:23:24)01-May-2023 (03:23:28)3.4
7TX_3[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:23:28)01-May-2023 (03:23:32)3.4
8WI_1[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:23:32)01-May-2023 (03:23:34)2.2
9TX_2[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002...01-May-2023 (03:23:12)01-May-2023 (03:23:17)5.0
\n", "
" ], "text/plain": [ " group forecast_horizon \\\n", "0 CA_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "1 CA_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "2 WI_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "3 WI_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "4 CA_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "5 CA_4 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "6 TX_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "7 TX_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "8 WI_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "9 TX_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "\n", " model start_time \\\n", "0 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:16) \n", "1 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:21) \n", "2 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:25) \n", "3 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:28) \n", "4 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:14) \n", "5 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:21) \n", "6 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:24) \n", "7 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:28) \n", "8 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:32) \n", "9 [€\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (03:23:12) \n", "\n", " end_time elapsed_seconds \n", "0 01-May-2023 (03:23:21) 4.4 \n", "1 01-May-2023 (03:23:25) 3.4 \n", "2 01-May-2023 (03:23:28) 3.0 \n", "3 01-May-2023 (03:23:32) 3.3 \n", "4 01-May-2023 (03:23:20) 5.8 \n", "5 01-May-2023 (03:23:24) 3.3 \n", "6 01-May-2023 (03:23:28) 3.4 \n", "7 01-May-2023 (03:23:32) 3.4 \n", "8 01-May-2023 (03:23:34) 2.2 \n", "9 01-May-2023 (03:23:17) 5.0 " ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "forecast_flow.train(df_train.toPandas(), spark=spark)\n", "forecast_flow.model_" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Prediction" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### PySpark DataFrame" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+----------+----------+\n", "|group| id| date|prediction|\n", "+-----+--------------------+----------+----------+\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-25| 0.481568|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-26|0.46724537|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-27|0.41596597|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-28|0.40775877|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-29|0.43439913|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-30| 0.4951446|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-05-01| 0.4308696|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-25| 0.2172628|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-26| 0.1687214|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-27| 0.1687214|\n", "+-----+--------------------+----------+----------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "forecast = forecast_flow.predict(df_test, trained_models).localCheckpoint()\n", "forecast.show(10)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|lag_35|lag_42|lag_49|lag_56|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|count_consecutive_value_lag_7|count_consecutive_value_lag_14|count_consecutive_value_lag_21|count_consecutive_value_lag_28|history_length|day_of_month|day_of_week|week_of_year|quarter|month|year|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-25| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 1.0| 0.0| 0.0| 1.0| 0.5714285714285714| 0.8| 0.14285714285714285| 0.7857142857142857| 0.8| 1.4285714285714286| 0.7142857142857143| 0.9666666666666667| 0.0| 0.5714285714285714| 0.7333333333333333| 1| 5| 0| 8| 1912| 25| 2| 17| 2| 4|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-26| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 4.0| 0.0| 1.0| 0.5714285714285714| 0.6| 0.14285714285714285| 0.7857142857142857| 0.6666666666666666| 1.4285714285714286| 0.7142857142857143| 0.9666666666666667| 0.0| 0.5714285714285714| 0.6333333333333333| 2| 6| 1| 9| 1913| 26| 3| 17| 2| 4|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-27| 0.0| 2.0| 0.0| 1.0| 0.0| 0.0| 1.0| 2.0| 0.0| 1.2857142857142858| 0.6428571428571429| 0.6666666666666666| 0.0| 0.7857142857142857| 0.6333333333333333| 1.5714285714285714| 0.7857142857142857| 1.0| 0.0| 0.5| 0.6333333333333333| 0| 7| 0| 10| 1914| 27| 4| 17| 2| 4|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-28| 1.0| 0.0| 4.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 1.0714285714285714| 0.7666666666666667| 1.5714285714285714| 0.7857142857142857| 0.8666666666666667| 0.0| 0.42857142857142855| 0.6333333333333333| 1| 0| 1| 11| 1915| 28| 5| 17| 2| 4|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-29| 0.0| 0.0| 0.0| 0.0| 4.0| 0.0| 0.0| 0.0| 0.0| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.7857142857142857| 0.8| 0.5714285714285714| 0.7142857142857143| 0.7666666666666667| 2| 1| 2| 0| 1916| 29| 6| 17| 2| 4|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-04-30| 1.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 1.0| 0.7857142857142857| 0.7333333333333333| 0.5714285714285714| 0.7857142857142857| 0.7| 1.0| 0.7857142857142857| 0.8| 0.5714285714285714| 0.7142857142857143| 0.7666666666666667| 0| 2| 3| 1| 1917| 30| 7| 17| 2| 4|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-01| 4.0| 0.0| 3.0| 0.0| 5.0| 0.0| 6.0| 4.0| 0.0| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.6428571428571429| 0.8| 0.2857142857142857| 0.7857142857142857| 0.8| 1.2857142857142858| 0.6428571428571429| 0.9333333333333333| 1| 0| 4| 0| 1918| 1| 1| 17| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-02| 0.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 1.0| 0.0| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.5714285714285714| 0.8| 0.14285714285714285| 0.7857142857142857| 0.8| 1.4285714285714286| 0.7142857142857143| 0.9666666666666667| 2| 1| 5| 0| 1919| 2| 2| 18| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-03| 0.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 4.0| 0.8571428571428571| 0.9285714285714286| 0.8| 1.0| 0.5714285714285714| 0.6| 0.14285714285714285| 0.7857142857142857| 0.6666666666666666| 1.4285714285714286| 0.7142857142857143| 0.9666666666666667| 0| 2| 6| 1| 1920| 3| 3| 18| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-04| 0.0| 0.0| 2.0| 0.0| 1.0| 0.0| 0.0| 1.0| 2.0| 0.5714285714285714| 0.9285714285714286| 0.8| 1.2857142857142858| 0.6428571428571429| 0.6666666666666666| 0.0| 0.7857142857142857| 0.6333333333333333| 1.5714285714285714| 0.7857142857142857| 1.0| 1| 0| 7| 0| 1921| 4| 4| 18| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-05| 0.0| 1.0| 0.0| 4.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.7142857142857143| 0.7142857142857143| 0.8333333333333334| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 1.0714285714285714| 0.7666666666666667| 1.5714285714285714| 0.7857142857142857| 0.8666666666666667| 0| 1| 0| 1| 1922| 5| 5| 18| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-06| 4.0| 0.0| 0.0| 0.0| 0.0| 4.0| 0.0| 0.0| 0.0| 0.7142857142857143| 0.7142857142857143| 0.8333333333333334| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.7857142857142857| 0.8| 1| 2| 1| 2| 1923| 6| 6| 18| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-07| 0.0| 1.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.5714285714285714| 0.7857142857142857| 0.8666666666666667| 1.0| 0.7857142857142857| 0.7333333333333333| 0.5714285714285714| 0.7857142857142857| 0.7| 1.0| 0.7857142857142857| 0.8| 0| 0| 2| 3| 1924| 7| 7| 18| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-08| 0.0| 4.0| 0.0| 3.0| 0.0| 5.0| 0.0| 6.0| 4.0| 1.1428571428571428| 0.8571428571428571| 0.8666666666666667| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.6428571428571429| 0.8| 0.2857142857142857| 0.7857142857142857| 0.8| 0| 1| 0| 4| 1925| 8| 1| 18| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-09| 1.0| 0.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 1.0| 1.1428571428571428| 0.8571428571428571| 0.8666666666666667| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 1.0| 0.5714285714285714| 0.8| 0.14285714285714285| 0.7857142857142857| 0.8| 1| 2| 1| 5| 1926| 9| 2| 19| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-10| 0.0| 0.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.8571428571428571| 0.8571428571428571| 0.7| 0.8571428571428571| 0.9285714285714286| 0.8| 1.0| 0.5714285714285714| 0.6| 0.14285714285714285| 0.7857142857142857| 0.6666666666666666| 2| 0| 2| 6| 1927| 10| 3| 19| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-11| 0.0| 0.0| 0.0| 2.0| 0.0| 1.0| 0.0| 0.0| 1.0| 0.8571428571428571| 0.7142857142857143| 0.6666666666666666| 0.5714285714285714| 0.9285714285714286| 0.8| 1.2857142857142858| 0.6428571428571429| 0.6666666666666666| 0.0| 0.7857142857142857| 0.6333333333333333| 3| 1| 0| 7| 1928| 11| 4| 19| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-12| 1.0| 0.0| 1.0| 0.0| 4.0| 0.0| 0.0| 0.0| 1.0| 0.7142857142857143| 0.7142857142857143| 0.6666666666666666| 0.7142857142857143| 0.7142857142857143| 0.8333333333333334| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 1.0714285714285714| 0.7666666666666667| 4| 0| 1| 0| 1929| 12| 5| 19| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-13| 0.0| 4.0| 0.0| 0.0| 0.0| 0.0| 4.0| 0.0| 0.0| 1.2857142857142858| 1.0| 0.7666666666666667| 0.7142857142857143| 0.7142857142857143| 0.8333333333333334| 0.7142857142857143| 0.6428571428571429| 0.6666666666666666| 0.5714285714285714| 0.7857142857142857| 0.7333333333333333| 0| 1| 2| 1| 1930| 13| 6| 19| 2| 5|2016|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2016-05-14| 0.0| 0.0| 1.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 1.1428571428571428| 0.8571428571428571| 0.7666666666666667| 0.5714285714285714| 0.7857142857142857| 0.8666666666666667| 1.0| 0.7857142857142857| 0.7333333333333333| 0.5714285714285714| 0.7857142857142857| 0.7| 1| 0| 0| 2| 1931| 14| 7| 19| 2| 5|2016|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------+-----+----+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "df_test.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Pandas DataFrame" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
groupiddateprediction
0CA_2FOODS_1_179_CA_2_evaluation2016-04-250.481568
1CA_2FOODS_1_179_CA_2_evaluation2016-04-260.467245
2CA_2FOODS_1_179_CA_2_evaluation2016-04-270.415966
3CA_2FOODS_1_179_CA_2_evaluation2016-04-280.407759
4CA_2FOODS_1_179_CA_2_evaluation2016-04-290.434399
...............
26427TX_2HOUSEHOLD_2_481_TX_2_evaluation2016-05-180.215980
26428TX_2HOUSEHOLD_2_481_TX_2_evaluation2016-05-190.215980
26429TX_2HOUSEHOLD_2_481_TX_2_evaluation2016-05-200.222249
26430TX_2HOUSEHOLD_2_481_TX_2_evaluation2016-05-210.334569
26431TX_2HOUSEHOLD_2_481_TX_2_evaluation2016-05-220.313987
\n", "

26432 rows × 4 columns

\n", "
" ], "text/plain": [ " group id date prediction\n", "0 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-25 0.481568\n", "1 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-26 0.467245\n", "2 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-27 0.415966\n", "3 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-28 0.407759\n", "4 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-29 0.434399\n", "... ... ... ... ...\n", "26427 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-18 0.215980\n", "26428 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-19 0.215980\n", "26429 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-20 0.222249\n", "26430 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-21 0.334569\n", "26431 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-22 0.313987\n", "\n", "[26432 rows x 4 columns]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "forecast_flow.predict(df_test.toPandas(), spark=spark)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize Predictions" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "past_future = (\n", " df.select(\"id\", \"store_id\", \"date\", \"sales\")\n", " .join(forecast, on=[\"id\", \"date\"], how=\"left\")\n", " .groupBy(\"store_id\", \"date\")\n", " .agg(\n", " F.sum(\"sales\").alias(\"sales\"),\n", " F.sum(\"prediction\").alias(\"prediction\"),\n", " )\n", " .orderBy(\"store_id\", \"date\")\n", " .toPandas()\n", ")" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pio.renderers.default = \"notebook\"\n", "fig = px.line(\n", " past_future,\n", " x=\"date\",\n", " y=[\"sales\", \"prediction\"],\n", " facet_row_spacing=0.04,\n", " facet_col=\"store_id\",\n", " facet_col_wrap=2,\n", " height=1000,\n", " width=720,\n", ")\n", "fig.update_layout(\n", " legend=dict(orientation=\"h\", yanchor=\"top\", y=1.07, xanchor=\"center\", x=0.5),\n", " margin=dict(l=0, r=10, t=5, b=5),\n", " legend_title=\"\",\n", ")\n", "fig.update_traces(line=dict(width=1.7))\n", "fig.update_yaxes(matches=None, title=\"\")\n", "fig.update_xaxes(type=\"date\", range=[\"2015-11-01\", \"2016-05-22\"])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Backtesting" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+----------+---+------+-----------+\n", "|group| id| date| cv|target| prediction|\n", "+-----+--------------------+----------+---+------+-----------+\n", "| CA_2|FOODS_1_179_CA_2_...|2016-03-28| 0| 0.0| 0.44766802|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-03-29| 0| 0.0| 0.43386874|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-03-30| 0| 0.0| 0.40635538|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-03-31| 0| 1.0| 0.3618364|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-01| 0| 0.0| 0.40051356|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-02| 0| 1.0| 0.42851403|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-03| 0| 0.0| 0.40656742|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-03-28| 0| 0.0| 0.13468084|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-03-29| 0| 0.0|0.103752814|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-03-30| 0| 2.0|0.103752814|\n", "+-----+--------------------+----------+---+------+-----------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "cv_forecast = forecast_flow.cross_validate(df_train).localCheckpoint()\n", "cv_forecast.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize Cross Validation" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "cv_forecast = (\n", " df_train.select(\"id\", \"store_id\", \"date\", \"sales\")\n", " .join(\n", " cv_forecast.select(\"id\", \"date\", \"cv\", \"prediction\"),\n", " on=[\"id\", \"date\"],\n", " how=\"left\",\n", " )\n", " .groupBy(\"id\", \"store_id\", \"date\", \"sales\")\n", " .pivot(\"cv\")\n", " .sum(\"prediction\")\n", " .groupBy(\"store_id\", \"date\")\n", " .agg(\n", " F.sum(\"sales\").alias(\"sales\"),\n", " *[F.sum(f\"{i}\").alias(f\"cv_{i}\") for i in range(3)],\n", " )\n", " .orderBy(\"store_id\", \"date\")\n", ").toPandas()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pio.renderers.default = \"notebook\"\n", "fig = px.line(\n", " cv_forecast,\n", " x=\"date\",\n", " y=[\"sales\", *[f\"cv_{i}\" for i in range(3)]],\n", " facet_row_spacing=0.04,\n", " facet_col=\"store_id\",\n", " facet_col_wrap=2,\n", " height=1000,\n", " width=720,\n", ")\n", "fig.update_layout(\n", " legend=dict(orientation=\"h\", yanchor=\"top\", y=1.07, xanchor=\"center\", x=0.5),\n", " margin=dict(l=0, r=10, t=5, b=5),\n", " legend_title=\"\",\n", ")\n", "fig.update_traces(line=dict(width=1.7))\n", "fig.update_yaxes(matches=None, title=\"\")\n", "fig.update_xaxes(type=\"date\", range=[\"2015-11-01\", \"2016-04-24\"])" ] } ], "metadata": { "application/vnd.databricks.v1+notebook": { "dashboards": [ { "elements": [], "globalVars": {}, "guid": "ef82ffd4-2993-4b79-8327-f644b750f2dd", "layoutOption": { "grid": true, "stack": true }, "nuid": "a172d56a-d964-4505-ba66-5a7011220dbf", "origId": 1859120955398731, "title": "Untitled", "version": "DashboardViewV1", "width": 1024 } ], "language": "python", "notebookMetadata": { "pythonIndentUnit": 4 }, "notebookName": "ForecastFlowML Demo", "notebookOrigID": 2597536912577418, "widgets": {} }, "kernelspec": { "display_name": "spark", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.16" } }, "nbformat": 4, "nbformat_minor": 0 }