Feature Importance#
This quick guide shows how the feature importances can be reached from ForecastFlowML.
Import packages#
from forecastflowml import ForecastFlowML
from forecastflowml import FeatureExtractor
from forecastflowml.data.loader import load_walmart_m5
from lightgbm import LGBMRegressor
from pyspark.sql import SparkSession
Initialize Spark#
spark = (
SparkSession.builder.master("local[4]")
.config("spark.driver.memory", "8g")
.config("spark.sql.shuffle.partitions", "4")
.config("spark.sql.execution.arrow.enabled", "true")
.getOrCreate()
)
Sample Dataset#
df = load_walmart_m5(spark)
df.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+
| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-03| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-04| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-05| 1.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-06| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-07| 3.0|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
only showing top 10 rows
Feature Engineering#
feature_extractor = FeatureExtractor(
id_col="id",
date_col="date",
target_col="sales",
lag_window_features={
"lag": [7 * (i + 1) for i in range(4)],
},
date_features=[
"day_of_month",
"day_of_week",
"week_of_year",
"week_of_month",
"weekend",
"quarter",
"month",
"year",
],
)
df_train = feature_extractor.transform(df).localCheckpoint()
df_train.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+
| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null| 31| 2| 5| 5| 0| 1| 1|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null| 1| 3| 5| 1| 0| 1| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null| 2| 4| 5| 1| 0| 1| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null| 3| 5| 5| 1| 0| 1| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null| 4| 6| 5| 1| 0| 1| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null| 5| 7| 5| 1| 1| 1| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null| 6| 1| 5| 1| 1| 1| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| null| null| null| 7| 2| 6| 1| 0| 1| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 0.0| null| null| null| 8| 3| 6| 2| 0| 1| 2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.0| null| null| null| 9| 4| 6| 2| 0| 1| 2|2011|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+
only showing top 10 rows
Feature Importance#
forecast_flow = ForecastFlowML(
group_col="store_id",
id_col="id",
date_col="date",
target_col="sales",
date_frequency="days",
model_horizon=7,
max_forecast_horizon=28,
model=LGBMRegressor(),
)
PySpark DataFrame with Distributed Results#
trained_models = forecast_flow.train(df_train).localCheckpoint()
forecast_flow.get_feature_importance(trained_models)
| group | forecast_horizon | feature | importance | |
|---|---|---|---|---|
| 0 | CA_2 | [1, 2, 3, 4, 5, 6, 7] | year | 407.0 |
| 1 | CA_2 | [1, 2, 3, 4, 5, 6, 7] | day_of_week | 287.0 |
| 2 | CA_2 | [1, 2, 3, 4, 5, 6, 7] | quarter | 46.0 |
| 3 | CA_2 | [1, 2, 3, 4, 5, 6, 7] | week_of_year | 534.0 |
| 4 | CA_2 | [1, 2, 3, 4, 5, 6, 7] | week_of_month | 86.0 |
| ... | ... | ... | ... | ... |
| 355 | TX_2 | [22, 23, 24, 25, 26, 27, 28] | week_of_month | 70.0 |
| 356 | TX_2 | [22, 23, 24, 25, 26, 27, 28] | month | 79.0 |
| 357 | TX_2 | [22, 23, 24, 25, 26, 27, 28] | weekend | 57.0 |
| 358 | TX_2 | [22, 23, 24, 25, 26, 27, 28] | day_of_month | 403.0 |
| 359 | TX_2 | [22, 23, 24, 25, 26, 27, 28] | lag_28 | 1129.0 |
360 rows × 4 columns
PySpark DataFrame with Local Results#
forecast_flow.train(df_train, local_result=True)
forecast_flow.get_feature_importance()
| group | forecast_horizon | feature | importance | |
|---|---|---|---|---|
| 0 | CA_1 | [1, 2, 3, 4, 5, 6, 7] | year | 396 |
| 1 | CA_1 | [1, 2, 3, 4, 5, 6, 7] | day_of_week | 247 |
| 2 | CA_1 | [1, 2, 3, 4, 5, 6, 7] | quarter | 48 |
| 3 | CA_1 | [1, 2, 3, 4, 5, 6, 7] | week_of_year | 574 |
| 4 | CA_1 | [1, 2, 3, 4, 5, 6, 7] | week_of_month | 80 |
| ... | ... | ... | ... | ... |
| 355 | WI_3 | [22, 23, 24, 25, 26, 27, 28] | week_of_month | 78 |
| 356 | WI_3 | [22, 23, 24, 25, 26, 27, 28] | month | 62 |
| 357 | WI_3 | [22, 23, 24, 25, 26, 27, 28] | weekend | 61 |
| 358 | WI_3 | [22, 23, 24, 25, 26, 27, 28] | day_of_month | 478 |
| 359 | WI_3 | [22, 23, 24, 25, 26, 27, 28] | lag_28 | 1159 |
360 rows × 4 columns
Pandas DataFrame#
forecast_flow.train(df_train.toPandas(), spark=spark)
forecast_flow.get_feature_importance()
| group | forecast_horizon | feature | importance | |
|---|---|---|---|---|
| 0 | CA_1 | [1, 2, 3, 4, 5, 6, 7] | year | 396 |
| 1 | CA_1 | [1, 2, 3, 4, 5, 6, 7] | day_of_week | 247 |
| 2 | CA_1 | [1, 2, 3, 4, 5, 6, 7] | quarter | 48 |
| 3 | CA_1 | [1, 2, 3, 4, 5, 6, 7] | week_of_year | 574 |
| 4 | CA_1 | [1, 2, 3, 4, 5, 6, 7] | week_of_month | 80 |
| ... | ... | ... | ... | ... |
| 355 | WI_3 | [22, 23, 24, 25, 26, 27, 28] | week_of_month | 78 |
| 356 | WI_3 | [22, 23, 24, 25, 26, 27, 28] | month | 62 |
| 357 | WI_3 | [22, 23, 24, 25, 26, 27, 28] | weekend | 61 |
| 358 | WI_3 | [22, 23, 24, 25, 26, 27, 28] | day_of_month | 478 |
| 359 | WI_3 | [22, 23, 24, 25, 26, 27, 28] | lag_28 | 1159 |
360 rows × 4 columns