Grid Search#
This quick guide shows how grid search can be used to find the best hyperparameters for ForecastFlowML.
Import packages#
from forecastflowml import ForecastFlowML
from forecastflowml import FeatureExtractor
from forecastflowml.data.loader import load_walmart_m5
from lightgbm import LGBMRegressor
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
Initialize Spark#
spark = (
SparkSession.builder.master("local[4]")
.config("spark.driver.memory", "8g")
.config("spark.sql.shuffle.partitions", "4")
.config("spark.sql.execution.arrow.enabled", "true")
.getOrCreate()
)
Sample Dataset#
df = load_walmart_m5(spark)
df.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+
| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-03| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-04| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-05| 1.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-06| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-07| 3.0|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
only showing top 10 rows
Feature Engineering#
feature_extractor = FeatureExtractor(
id_col="id",
date_col="date",
target_col="sales",
lag_window_features={
"lag": [7 * (i + 1) for i in range(4)],
},
date_features=["day_of_week", "weekend", "week_of_year", "month", "year"],
)
df_features = feature_extractor.transform(df).localCheckpoint()
df_features.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+-----------+-------+------------+-----+----+
| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|day_of_week|weekend|week_of_year|month|year|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+-----------+-------+------------+-----+----+
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null| 2| 0| 5| 1|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null| 3| 0| 5| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null| 4| 0| 5| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null| 5| 0| 5| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null| 6| 0| 5| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null| 7| 1| 5| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null| 1| 1| 5| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| null| null| null| 2| 0| 6| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 0.0| null| null| null| 3| 0| 6| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.0| null| null| null| 4| 0| 6| 2|2011|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+-----------+-------+------------+-----+----+
only showing top 10 rows
Train/Test Dataset#
df_train = df_features.filter(F.col("date") < "2016-04-25")
df_test = df_features.filter(F.col("date") >= "2016-04-25")
Initialize Model#
forecast_flow = ForecastFlowML(
group_col="store_id",
id_col="id",
date_col="date",
target_col="sales",
date_frequency="days",
model_horizon=7,
max_forecast_horizon=28,
model=LGBMRegressor(random_state=42),
)
Search Hyperparameters with Grid Search#
trials = forecast_flow.grid_search(
df_train,
param_grid={"num_leaves": [10, 20, 30, 40, 50]},
n_cv_splits=1,
scoring_metric="neg_mean_squared_error",
)
trials.head(10)
| group | score | num_leaves | |
|---|---|---|---|
| 0 | WI_3 | -16.795271 | 10 |
| 1 | WI_3 | -17.216249 | 20 |
| 2 | WI_3 | -17.413006 | 30 |
| 3 | WI_3 | -17.590740 | 40 |
| 4 | WI_3 | -17.617151 | 50 |
| 5 | WI_2 | -30.413006 | 10 |
| 6 | WI_2 | -30.922466 | 20 |
| 7 | WI_2 | -31.298466 | 30 |
| 8 | WI_2 | -31.920683 | 40 |
| 9 | WI_2 | -31.998882 | 50 |
best_trial = trials.groupby("group", group_keys=False).apply(
lambda x: x.sort_values("score", ascending=False).head(1)
)
best_params = (
best_trial.set_index("group").drop("score", axis=1).to_dict(orient="index")
)
best_params
{'CA_1': {'num_leaves': 10},
'CA_2': {'num_leaves': 10},
'CA_3': {'num_leaves': 20},
'CA_4': {'num_leaves': 40},
'TX_1': {'num_leaves': 10},
'TX_2': {'num_leaves': 10},
'TX_3': {'num_leaves': 20},
'WI_1': {'num_leaves': 10},
'WI_2': {'num_leaves': 10},
'WI_3': {'num_leaves': 10}}
group_models = {k: LGBMRegressor(**v) for k, v in best_params.items()}
group_models
{'CA_1': LGBMRegressor(num_leaves=10),
'CA_2': LGBMRegressor(num_leaves=10),
'CA_3': LGBMRegressor(num_leaves=20),
'CA_4': LGBMRegressor(num_leaves=40),
'TX_1': LGBMRegressor(num_leaves=10),
'TX_2': LGBMRegressor(num_leaves=10),
'TX_3': LGBMRegressor(num_leaves=20),
'WI_1': LGBMRegressor(num_leaves=10),
'WI_2': LGBMRegressor(num_leaves=10),
'WI_3': LGBMRegressor(num_leaves=10)}
Training with Optimized Hyperparameters#
forecast_flow = ForecastFlowML(
group_col="store_id",
id_col="id",
date_col="date",
target_col="sales",
date_frequency="days",
model_horizon=7,
max_forecast_horizon=28,
model=group_models,
)
forecast_flow.train(df_train).show()
+-----+--------------------+--------------------+--------------------+--------------------+---------------+
|group| forecast_horizon| model| start_time| end_time|elapsed_seconds|
+-----+--------------------+--------------------+--------------------+--------------------+---------------+
| CA_2|[[1, 2, 3, 4, 5, ...|[clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 0.8|
| CA_3|[[1, 2, 3, 4, 5, ...|[clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.4|
| WI_2|[[1, 2, 3, 4, 5, ...|[clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 5.0|
| WI_3|[[1, 2, 3, 4, 5, ...|[clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 0.7|
| CA_1|[[1, 2, 3, 4, 5, ...|[clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.3|
| CA_4|[[1, 2, 3, 4, 5, ...|[clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.5|
| TX_1|[[1, 2, 3, 4, 5, ...|[clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.0|
| TX_3|[[1, 2, 3, 4, 5, ...|[clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.1|
| WI_1|[[1, 2, 3, 4, 5, ...|[clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.0|
| TX_2|[[1, 2, 3, 4, 5, ...|[clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...| 1.3|
+-----+--------------------+--------------------+--------------------+--------------------+---------------+