{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": {}, "inputWidgets": {}, "nuid": "a02bb3a2-e5d7-4d14-b966-827457675b75", "showTitle": false, "title": "" } }, "source": [ "# Grid Search\n", "\n", "This quick guide shows how grid search can be used to find the best hyperparameters for ``ForecastFlowML``." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Import packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from forecastflowml import ForecastFlowML\n", "from forecastflowml import FeatureExtractor\n", "from forecastflowml.data.loader import load_walmart_m5\n", "from lightgbm import LGBMRegressor\n", "from pyspark.sql import SparkSession\n", "import pyspark.sql.functions as F" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize Spark" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "spark = (\n", " SparkSession.builder.master(\"local[4]\")\n", " .config(\"spark.driver.memory\", \"8g\")\n", " .config(\"spark.sql.shuffle.partitions\", \"4\")\n", " .config(\"spark.sql.execution.arrow.enabled\", \"true\")\n", " .getOrCreate()\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Sample Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-03| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-04| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-05| 1.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-06| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-07| 3.0|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "df = load_walmart_m5(spark)\n", "df.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Feature Engineering" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+-----------+-------+------------+-----+----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|day_of_week|weekend|week_of_year|month|year|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+-----------+-------+------------+-----+----+\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null| 2| 0| 5| 1|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null| 3| 0| 5| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null| 4| 0| 5| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null| 5| 0| 5| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null| 6| 0| 5| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null| 7| 1| 5| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null| 1| 1| 5| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| null| null| null| 2| 0| 6| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 0.0| null| null| null| 3| 0| 6| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.0| null| null| null| 4| 0| 6| 2|2011|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+-----------+-------+------------+-----+----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " lag_window_features={\n", " \"lag\": [7 * (i + 1) for i in range(4)],\n", " },\n", " date_features=[\"day_of_week\", \"weekend\", \"week_of_year\", \"month\", \"year\"],\n", ")\n", "df_features = feature_extractor.transform(df).localCheckpoint()\n", "df_features.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Train/Test Dataset" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "df_train = df_features.filter(F.col(\"date\") < \"2016-04-25\")\n", "df_test = df_features.filter(F.col(\"date\") >= \"2016-04-25\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "forecast_flow = ForecastFlowML(\n", " group_col=\"store_id\",\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " date_frequency=\"days\",\n", " model_horizon=7,\n", " max_forecast_horizon=28,\n", " model=LGBMRegressor(random_state=42),\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Search Hyperparameters with Grid Search" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
groupscorenum_leaves
0WI_3-16.79527110
1WI_3-17.21624920
2WI_3-17.41300630
3WI_3-17.59074040
4WI_3-17.61715150
5WI_2-30.41300610
6WI_2-30.92246620
7WI_2-31.29846630
8WI_2-31.92068340
9WI_2-31.99888250
\n", "
" ], "text/plain": [ " group score num_leaves\n", "0 WI_3 -16.795271 10\n", "1 WI_3 -17.216249 20\n", "2 WI_3 -17.413006 30\n", "3 WI_3 -17.590740 40\n", "4 WI_3 -17.617151 50\n", "5 WI_2 -30.413006 10\n", "6 WI_2 -30.922466 20\n", "7 WI_2 -31.298466 30\n", "8 WI_2 -31.920683 40\n", "9 WI_2 -31.998882 50" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trials = forecast_flow.grid_search(\n", " df_train,\n", " param_grid={\"num_leaves\": [10, 20, 30, 40, 50]},\n", " n_cv_splits=1,\n", " scoring_metric=\"neg_mean_squared_error\",\n", ")\n", "trials.head(10)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'CA_1': {'num_leaves': 10},\n", " 'CA_2': {'num_leaves': 10},\n", " 'CA_3': {'num_leaves': 20},\n", " 'CA_4': {'num_leaves': 40},\n", " 'TX_1': {'num_leaves': 10},\n", " 'TX_2': {'num_leaves': 10},\n", " 'TX_3': {'num_leaves': 20},\n", " 'WI_1': {'num_leaves': 10},\n", " 'WI_2': {'num_leaves': 10},\n", " 'WI_3': {'num_leaves': 10}}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_trial = trials.groupby(\"group\", group_keys=False).apply(\n", " lambda x: x.sort_values(\"score\", ascending=False).head(1)\n", ")\n", "best_params = (\n", " best_trial.set_index(\"group\").drop(\"score\", axis=1).to_dict(orient=\"index\")\n", ")\n", "best_params" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'CA_1': LGBMRegressor(num_leaves=10),\n", " 'CA_2': LGBMRegressor(num_leaves=10),\n", " 'CA_3': LGBMRegressor(num_leaves=20),\n", " 'CA_4': LGBMRegressor(num_leaves=40),\n", " 'TX_1': LGBMRegressor(num_leaves=10),\n", " 'TX_2': LGBMRegressor(num_leaves=10),\n", " 'TX_3': LGBMRegressor(num_leaves=20),\n", " 'WI_1': LGBMRegressor(num_leaves=10),\n", " 'WI_2': LGBMRegressor(num_leaves=10),\n", " 'WI_3': LGBMRegressor(num_leaves=10)}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "group_models = {k: LGBMRegressor(**v) for k, v in best_params.items()}\n", "group_models" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Training with Optimized Hyperparameters " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "forecast_flow = ForecastFlowML(\n", " group_col=\"store_id\",\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " date_frequency=\"days\",\n", " model_horizon=7,\n", " max_forecast_horizon=28,\n", " model=group_models,\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+--------------------+--------------------+--------------------+---------------+\n", "|group| forecast_horizon| model| start_time| end_time|elapsed_seconds|\n", "+-----+--------------------+--------------------+--------------------+--------------------+---------------+\n", "| CA_2|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 0.8|\n", "| CA_3|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.4|\n", "| WI_2|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 5.0|\n", "| WI_3|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 0.7|\n", "| CA_1|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.3|\n", "| CA_4|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.5|\n", "| TX_1|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.0|\n", "| TX_3|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.1|\n", "| WI_1|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.0|\n", "| TX_2|[[1, 2, 3, 4, 5, ...|[€\u0003clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.3|\n", "+-----+--------------------+--------------------+--------------------+--------------------+---------------+\n", "\n" ] } ], "source": [ "forecast_flow.train(df_train).show()" ] } ], "metadata": { "application/vnd.databricks.v1+notebook": { "dashboards": [ { "elements": [], "globalVars": {}, "guid": "ef82ffd4-2993-4b79-8327-f644b750f2dd", "layoutOption": { "grid": true, "stack": true }, "nuid": "a172d56a-d964-4505-ba66-5a7011220dbf", "origId": 1859120955398731, "title": "Untitled", "version": "DashboardViewV1", "width": 1024 } ], "language": "python", "notebookMetadata": { "pythonIndentUnit": 4 }, "notebookName": "ForecastFlowML Demo", "notebookOrigID": 2597536912577418, "widgets": {} }, "kernelspec": { "display_name": "spark", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.16" } }, "nbformat": 4, "nbformat_minor": 0 }