{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{"application/vnd.databricks.v1+cell":{"cellMetadata":{},"inputWidgets":{},"nuid":"a02bb3a2-e5d7-4d14-b966-827457675b75","showTitle":false,"title":""}},"source":["# Feature Importance\n","\n","This quick guide shows how the feature importances can be reached from ``ForecastFlowML``."]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Import packages"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[],"source":["from forecastflowml import ForecastFlowML\n","from forecastflowml import FeatureExtractor\n","from forecastflowml.data.loader import load_walmart_m5\n","from lightgbm import LGBMRegressor\n","from pyspark.sql import SparkSession"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Initialize Spark"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["spark = (\n"," SparkSession.builder.master(\"local[4]\")\n"," .config(\"spark.driver.memory\", \"8g\")\n"," .config(\"spark.sql.shuffle.partitions\", \"4\")\n"," .config(\"spark.sql.execution.arrow.enabled\", \"true\")\n"," .getOrCreate()\n",")"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Sample Dataset"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["+--------------------+-----------+-------+------+--------+--------+----------+-----+\n","| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|\n","+--------------------+-----------+-------+------+--------+--------+----------+-----+\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-03| 0.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-04| 0.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-05| 1.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-06| 0.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-07| 3.0|\n","+--------------------+-----------+-------+------+--------+--------+----------+-----+\n","only showing top 10 rows\n","\n"]}],"source":["df = load_walmart_m5(spark)\n","df.show(10)"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Feature Engineering"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+\n","| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|\n","+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null| 31| 2| 5| 5| 0| 1| 1|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null| 1| 3| 5| 1| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null| 2| 4| 5| 1| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null| 3| 5| 5| 1| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null| 4| 6| 5| 1| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null| 5| 7| 5| 1| 1| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null| 6| 1| 5| 1| 1| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| null| null| null| 7| 2| 6| 1| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 0.0| null| null| null| 8| 3| 6| 2| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.0| null| null| null| 9| 4| 6| 2| 0| 1| 2|2011|\n","+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+\n","only showing top 10 rows\n","\n"]}],"source":["feature_extractor = FeatureExtractor(\n"," id_col=\"id\",\n"," date_col=\"date\",\n"," target_col=\"sales\",\n"," lag_window_features={\n"," \"lag\": [7 * (i + 1) for i in range(4)],\n"," },\n"," date_features=[\n"," \"day_of_month\",\n"," \"day_of_week\",\n"," \"week_of_year\",\n"," \"week_of_month\",\n"," \"weekend\",\n"," \"quarter\",\n"," \"month\",\n"," \"year\",\n"," ],\n",")\n","df_train = feature_extractor.transform(df).localCheckpoint()\n","df_train.show(10)"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Feature Importance"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["forecast_flow = ForecastFlowML(\n"," group_col=\"store_id\",\n"," id_col=\"id\",\n"," date_col=\"date\",\n"," target_col=\"sales\",\n"," date_frequency=\"days\",\n"," model_horizon=7,\n"," max_forecast_horizon=28,\n"," model=LGBMRegressor(),\n",")"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["### PySpark DataFrame with Distributed Results"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[{"data":{"text/html":["
| \n"," | group | \n","forecast_horizon | \n","feature | \n","importance | \n","
|---|---|---|---|---|
| 0 | \n","CA_2 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","year | \n","407.0 | \n","
| 1 | \n","CA_2 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","day_of_week | \n","287.0 | \n","
| 2 | \n","CA_2 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","quarter | \n","46.0 | \n","
| 3 | \n","CA_2 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","week_of_year | \n","534.0 | \n","
| 4 | \n","CA_2 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","week_of_month | \n","86.0 | \n","
| ... | \n","... | \n","... | \n","... | \n","... | \n","
| 355 | \n","TX_2 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","week_of_month | \n","70.0 | \n","
| 356 | \n","TX_2 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","month | \n","79.0 | \n","
| 357 | \n","TX_2 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","weekend | \n","57.0 | \n","
| 358 | \n","TX_2 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","day_of_month | \n","403.0 | \n","
| 359 | \n","TX_2 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","lag_28 | \n","1129.0 | \n","
360 rows × 4 columns
\n","| \n"," | group | \n","forecast_horizon | \n","feature | \n","importance | \n","
|---|---|---|---|---|
| 0 | \n","CA_1 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","year | \n","396 | \n","
| 1 | \n","CA_1 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","day_of_week | \n","247 | \n","
| 2 | \n","CA_1 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","quarter | \n","48 | \n","
| 3 | \n","CA_1 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","week_of_year | \n","574 | \n","
| 4 | \n","CA_1 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","week_of_month | \n","80 | \n","
| ... | \n","... | \n","... | \n","... | \n","... | \n","
| 355 | \n","WI_3 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","week_of_month | \n","78 | \n","
| 356 | \n","WI_3 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","month | \n","62 | \n","
| 357 | \n","WI_3 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","weekend | \n","61 | \n","
| 358 | \n","WI_3 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","day_of_month | \n","478 | \n","
| 359 | \n","WI_3 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","lag_28 | \n","1159 | \n","
360 rows × 4 columns
\n","| \n"," | group | \n","forecast_horizon | \n","feature | \n","importance | \n","
|---|---|---|---|---|
| 0 | \n","CA_1 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","year | \n","396 | \n","
| 1 | \n","CA_1 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","day_of_week | \n","247 | \n","
| 2 | \n","CA_1 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","quarter | \n","48 | \n","
| 3 | \n","CA_1 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","week_of_year | \n","574 | \n","
| 4 | \n","CA_1 | \n","[1, 2, 3, 4, 5, 6, 7] | \n","week_of_month | \n","80 | \n","
| ... | \n","... | \n","... | \n","... | \n","... | \n","
| 355 | \n","WI_3 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","week_of_month | \n","78 | \n","
| 356 | \n","WI_3 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","month | \n","62 | \n","
| 357 | \n","WI_3 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","weekend | \n","61 | \n","
| 358 | \n","WI_3 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","day_of_month | \n","478 | \n","
| 359 | \n","WI_3 | \n","[22, 23, 24, 25, 26, 27, 28] | \n","lag_28 | \n","1159 | \n","
360 rows × 4 columns
\n","