Feature Importance#

This quick guide shows how the feature importances can be reached from ForecastFlowML.

Import packages#

from forecastflowml import ForecastFlowML
from forecastflowml import FeatureExtractor
from forecastflowml.data.loader import load_walmart_m5
from lightgbm import LGBMRegressor
from pyspark.sql import SparkSession

Initialize Spark#

spark = (
    SparkSession.builder.master("local[4]")
    .config("spark.driver.memory", "8g")
    .config("spark.sql.shuffle.partitions", "4")
    .config("spark.sql.execution.arrow.enabled", "true")
    .getOrCreate()
)

Sample Dataset#

df = load_walmart_m5(spark)
df.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-29|  2.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-30|  5.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-31|  3.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-01|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-02|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-03|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-04|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-05|  1.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-06|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-07|  3.0|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
only showing top 10 rows

Feature Engineering#

feature_extractor = FeatureExtractor(
    id_col="id",
    date_col="date",
    target_col="sales",
    lag_window_features={
        "lag": [7 * (i + 1) for i in range(4)],
    },
    date_features=[
        "day_of_month",
        "day_of_week",
        "week_of_year",
        "week_of_month",
        "weekend",
        "quarter",
        "month",
        "year",
    ],
)
df_train = feature_extractor.transform(df).localCheckpoint()
df_train.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|lag_7|lag_14|lag_21|lag_28|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-01-31|  2.0| null|  null|  null|  null|          31|          2|           5|            5|      0|      1|    1|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-01|  0.0| null|  null|  null|  null|           1|          3|           5|            1|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-02|  0.0| null|  null|  null|  null|           2|          4|           5|            1|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-03|  0.0| null|  null|  null|  null|           3|          5|           5|            1|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-04|  0.0| null|  null|  null|  null|           4|          6|           5|            1|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-05|  0.0| null|  null|  null|  null|           5|          7|           5|            1|      1|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-06|  1.0| null|  null|  null|  null|           6|          1|           5|            1|      1|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-07|  0.0|  2.0|  null|  null|  null|           7|          2|           6|            1|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-08|  0.0|  0.0|  null|  null|  null|           8|          3|           6|            2|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-09|  0.0|  0.0|  null|  null|  null|           9|          4|           6|            2|      0|      1|    2|2011|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+
only showing top 10 rows

Feature Importance#

forecast_flow = ForecastFlowML(
    group_col="store_id",
    id_col="id",
    date_col="date",
    target_col="sales",
    date_frequency="days",
    model_horizon=7,
    max_forecast_horizon=28,
    model=LGBMRegressor(),
)

PySpark DataFrame with Distributed Results#

trained_models = forecast_flow.train(df_train).localCheckpoint()
forecast_flow.get_feature_importance(trained_models)
group forecast_horizon feature importance
0 CA_2 [1, 2, 3, 4, 5, 6, 7] year 407.0
1 CA_2 [1, 2, 3, 4, 5, 6, 7] day_of_week 287.0
2 CA_2 [1, 2, 3, 4, 5, 6, 7] quarter 46.0
3 CA_2 [1, 2, 3, 4, 5, 6, 7] week_of_year 534.0
4 CA_2 [1, 2, 3, 4, 5, 6, 7] week_of_month 86.0
... ... ... ... ...
355 TX_2 [22, 23, 24, 25, 26, 27, 28] week_of_month 70.0
356 TX_2 [22, 23, 24, 25, 26, 27, 28] month 79.0
357 TX_2 [22, 23, 24, 25, 26, 27, 28] weekend 57.0
358 TX_2 [22, 23, 24, 25, 26, 27, 28] day_of_month 403.0
359 TX_2 [22, 23, 24, 25, 26, 27, 28] lag_28 1129.0

360 rows × 4 columns

PySpark DataFrame with Local Results#

forecast_flow.train(df_train, local_result=True)
forecast_flow.get_feature_importance()
group forecast_horizon feature importance
0 CA_1 [1, 2, 3, 4, 5, 6, 7] year 396
1 CA_1 [1, 2, 3, 4, 5, 6, 7] day_of_week 247
2 CA_1 [1, 2, 3, 4, 5, 6, 7] quarter 48
3 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_year 574
4 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_month 80
... ... ... ... ...
355 WI_3 [22, 23, 24, 25, 26, 27, 28] week_of_month 78
356 WI_3 [22, 23, 24, 25, 26, 27, 28] month 62
357 WI_3 [22, 23, 24, 25, 26, 27, 28] weekend 61
358 WI_3 [22, 23, 24, 25, 26, 27, 28] day_of_month 478
359 WI_3 [22, 23, 24, 25, 26, 27, 28] lag_28 1159

360 rows × 4 columns

Pandas DataFrame#

forecast_flow.train(df_train.toPandas(), spark=spark)
forecast_flow.get_feature_importance()
group forecast_horizon feature importance
0 CA_1 [1, 2, 3, 4, 5, 6, 7] year 396
1 CA_1 [1, 2, 3, 4, 5, 6, 7] day_of_week 247
2 CA_1 [1, 2, 3, 4, 5, 6, 7] quarter 48
3 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_year 574
4 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_month 80
... ... ... ... ...
355 WI_3 [22, 23, 24, 25, 26, 27, 28] week_of_month 78
356 WI_3 [22, 23, 24, 25, 26, 27, 28] month 62
357 WI_3 [22, 23, 24, 25, 26, 27, 28] weekend 61
358 WI_3 [22, 23, 24, 25, 26, 27, 28] day_of_month 478
359 WI_3 [22, 23, 24, 25, 26, 27, 28] lag_28 1159

360 rows × 4 columns