Skip to main content
Ctrl+K

ForecastFlowML documentation

Site Navigation

  • Get Started
  • User Guide
  • API Reference

Site Navigation

  • Get Started
  • User Guide
  • API Reference

Section Navigation

  • Feature Engineering
  • Time Series Cross Validation
  • Feature Importance
  • Grid Search
  • Save/Load ForecastFlowML
  • User Guide
  • Grid Search

Grid Search#

This quick guide shows how grid search can be used to find the best hyperparameters for ForecastFlowML.

Import packages#

from forecastflowml import ForecastFlowML
from forecastflowml import FeatureExtractor
from forecastflowml.data.loader import load_walmart_m5
from lightgbm import LGBMRegressor
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

Initialize Spark#

spark = (
    SparkSession.builder.master("local[4]")
    .config("spark.driver.memory", "8g")
    .config("spark.sql.shuffle.partitions", "4")
    .config("spark.sql.execution.arrow.enabled", "true")
    .getOrCreate()
)

Sample Dataset#

df = load_walmart_m5(spark)
df.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-29|  2.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-30|  5.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-31|  3.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-01|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-02|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-03|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-04|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-05|  1.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-06|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-07|  3.0|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
only showing top 10 rows

Feature Engineering#

feature_extractor = FeatureExtractor(
    id_col="id",
    date_col="date",
    target_col="sales",
    lag_window_features={
        "lag": [7 * (i + 1) for i in range(4)],
    },
    date_features=["day_of_week", "weekend", "week_of_year", "month", "year"],
)
df_features = feature_extractor.transform(df).localCheckpoint()
df_features.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+-----------+-------+------------+-----+----+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|lag_7|lag_14|lag_21|lag_28|day_of_week|weekend|week_of_year|month|year|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+-----------+-------+------------+-----+----+
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-01-31|  2.0| null|  null|  null|  null|          2|      0|           5|    1|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-01|  0.0| null|  null|  null|  null|          3|      0|           5|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-02|  0.0| null|  null|  null|  null|          4|      0|           5|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-03|  0.0| null|  null|  null|  null|          5|      0|           5|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-04|  0.0| null|  null|  null|  null|          6|      0|           5|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-05|  0.0| null|  null|  null|  null|          7|      1|           5|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-06|  1.0| null|  null|  null|  null|          1|      1|           5|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-07|  0.0|  2.0|  null|  null|  null|          2|      0|           6|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-08|  0.0|  0.0|  null|  null|  null|          3|      0|           6|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-09|  0.0|  0.0|  null|  null|  null|          4|      0|           6|    2|2011|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+-----------+-------+------------+-----+----+
only showing top 10 rows

Train/Test Dataset#

df_train = df_features.filter(F.col("date") < "2016-04-25")
df_test = df_features.filter(F.col("date") >= "2016-04-25")

Initialize Model#

forecast_flow = ForecastFlowML(
    group_col="store_id",
    id_col="id",
    date_col="date",
    target_col="sales",
    date_frequency="days",
    model_horizon=7,
    max_forecast_horizon=28,
    model=LGBMRegressor(random_state=42),
)

Search Hyperparameters with Grid Search#

trials = forecast_flow.grid_search(
    df_train,
    param_grid={"num_leaves": [10, 20, 30, 40, 50]},
    n_cv_splits=1,
    scoring_metric="neg_mean_squared_error",
)
trials.head(10)
group score num_leaves
0 WI_3 -16.795271 10
1 WI_3 -17.216249 20
2 WI_3 -17.413006 30
3 WI_3 -17.590740 40
4 WI_3 -17.617151 50
5 WI_2 -30.413006 10
6 WI_2 -30.922466 20
7 WI_2 -31.298466 30
8 WI_2 -31.920683 40
9 WI_2 -31.998882 50
best_trial = trials.groupby("group", group_keys=False).apply(
    lambda x: x.sort_values("score", ascending=False).head(1)
)
best_params = (
    best_trial.set_index("group").drop("score", axis=1).to_dict(orient="index")
)
best_params
{'CA_1': {'num_leaves': 10},
 'CA_2': {'num_leaves': 10},
 'CA_3': {'num_leaves': 20},
 'CA_4': {'num_leaves': 40},
 'TX_1': {'num_leaves': 10},
 'TX_2': {'num_leaves': 10},
 'TX_3': {'num_leaves': 20},
 'WI_1': {'num_leaves': 10},
 'WI_2': {'num_leaves': 10},
 'WI_3': {'num_leaves': 10}}
group_models = {k: LGBMRegressor(**v) for k, v in best_params.items()}
group_models
{'CA_1': LGBMRegressor(num_leaves=10),
 'CA_2': LGBMRegressor(num_leaves=10),
 'CA_3': LGBMRegressor(num_leaves=20),
 'CA_4': LGBMRegressor(num_leaves=40),
 'TX_1': LGBMRegressor(num_leaves=10),
 'TX_2': LGBMRegressor(num_leaves=10),
 'TX_3': LGBMRegressor(num_leaves=20),
 'WI_1': LGBMRegressor(num_leaves=10),
 'WI_2': LGBMRegressor(num_leaves=10),
 'WI_3': LGBMRegressor(num_leaves=10)}

Training with Optimized Hyperparameters#

forecast_flow = ForecastFlowML(
    group_col="store_id",
    id_col="id",
    date_col="date",
    target_col="sales",
    date_frequency="days",
    model_horizon=7,
    max_forecast_horizon=28,
    model=group_models,
)
forecast_flow.train(df_train).show()
+-----+--------------------+--------------------+--------------------+--------------------+---------------+
|group|    forecast_horizon|               model|          start_time|            end_time|elapsed_seconds|
+-----+--------------------+--------------------+--------------------+--------------------+---------------+
| CA_2|[[1, 2, 3, 4, 5, ...|[€clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...|            0.8|
| CA_3|[[1, 2, 3, 4, 5, ...|[€clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...|            1.4|
| WI_2|[[1, 2, 3, 4, 5, ...|[€clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...|            5.0|
| WI_3|[[1, 2, 3, 4, 5, ...|[€clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...|            0.7|
| CA_1|[[1, 2, 3, 4, 5, ...|[€clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...|            1.3|
| CA_4|[[1, 2, 3, 4, 5, ...|[€clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...|            1.5|
| TX_1|[[1, 2, 3, 4, 5, ...|[€clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...|            1.0|
| TX_3|[[1, 2, 3, 4, 5, ...|[€clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...|            1.1|
| WI_1|[[1, 2, 3, 4, 5, ...|[€clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...|            1.0|
| TX_2|[[1, 2, 3, 4, 5, ...|[€clightgbm.skle...|02-May-2023 (18:4...|02-May-2023 (18:4...|            1.3|
+-----+--------------------+--------------------+--------------------+--------------------+---------------+
On this page
  • Import packages
  • Initialize Spark
  • Sample Dataset
  • Feature Engineering
  • Train/Test Dataset
  • Initialize Model
  • Search Hyperparameters with Grid Search
  • Training with Optimized Hyperparameters
Show Source

© Copyright 2023, Caner Turkseven.

Created using Sphinx 5.3.0.

Built with the PyData Sphinx Theme 0.13.3.