{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{"application/vnd.databricks.v1+cell":{"cellMetadata":{},"inputWidgets":{},"nuid":"a02bb3a2-e5d7-4d14-b966-827457675b75","showTitle":false,"title":""}},"source":["# Feature Importance\n","\n","This quick guide shows how the feature importances can be reached from ``ForecastFlowML``."]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Import packages"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[],"source":["from forecastflowml import ForecastFlowML\n","from forecastflowml import FeatureExtractor\n","from forecastflowml.data.loader import load_walmart_m5\n","from lightgbm import LGBMRegressor\n","from pyspark.sql import SparkSession"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Initialize Spark"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["spark = (\n"," SparkSession.builder.master(\"local[4]\")\n"," .config(\"spark.driver.memory\", \"8g\")\n"," .config(\"spark.sql.shuffle.partitions\", \"4\")\n"," .config(\"spark.sql.execution.arrow.enabled\", \"true\")\n"," .getOrCreate()\n",")"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Sample Dataset"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["+--------------------+-----------+-------+------+--------+--------+----------+-----+\n","| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|\n","+--------------------+-----------+-------+------+--------+--------+----------+-----+\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-03| 0.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-04| 0.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-05| 1.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-06| 0.0|\n","|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-07| 3.0|\n","+--------------------+-----------+-------+------+--------+--------+----------+-----+\n","only showing top 10 rows\n","\n"]}],"source":["df = load_walmart_m5(spark)\n","df.show(10)"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Feature Engineering"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+\n","| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|\n","+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null| 31| 2| 5| 5| 0| 1| 1|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null| 1| 3| 5| 1| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null| 2| 4| 5| 1| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null| 3| 5| 5| 1| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null| 4| 6| 5| 1| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null| 5| 7| 5| 1| 1| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null| 6| 1| 5| 1| 1| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| null| null| null| 7| 2| 6| 1| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 0.0| null| null| null| 8| 3| 6| 2| 0| 1| 2|2011|\n","|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.0| null| null| null| 9| 4| 6| 2| 0| 1| 2|2011|\n","+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+\n","only showing top 10 rows\n","\n"]}],"source":["feature_extractor = FeatureExtractor(\n"," id_col=\"id\",\n"," date_col=\"date\",\n"," target_col=\"sales\",\n"," lag_window_features={\n"," \"lag\": [7 * (i + 1) for i in range(4)],\n"," },\n"," date_features=[\n"," \"day_of_month\",\n"," \"day_of_week\",\n"," \"week_of_year\",\n"," \"week_of_month\",\n"," \"weekend\",\n"," \"quarter\",\n"," \"month\",\n"," \"year\",\n"," ],\n",")\n","df_train = feature_extractor.transform(df).localCheckpoint()\n","df_train.show(10)"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Feature Importance"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["forecast_flow = ForecastFlowML(\n"," group_col=\"store_id\",\n"," id_col=\"id\",\n"," date_col=\"date\",\n"," target_col=\"sales\",\n"," date_frequency=\"days\",\n"," model_horizon=7,\n"," max_forecast_horizon=28,\n"," model=LGBMRegressor(),\n",")"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["### PySpark DataFrame with Distributed Results"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
groupforecast_horizonfeatureimportance
0CA_2[1, 2, 3, 4, 5, 6, 7]year407.0
1CA_2[1, 2, 3, 4, 5, 6, 7]day_of_week287.0
2CA_2[1, 2, 3, 4, 5, 6, 7]quarter46.0
3CA_2[1, 2, 3, 4, 5, 6, 7]week_of_year534.0
4CA_2[1, 2, 3, 4, 5, 6, 7]week_of_month86.0
...............
355TX_2[22, 23, 24, 25, 26, 27, 28]week_of_month70.0
356TX_2[22, 23, 24, 25, 26, 27, 28]month79.0
357TX_2[22, 23, 24, 25, 26, 27, 28]weekend57.0
358TX_2[22, 23, 24, 25, 26, 27, 28]day_of_month403.0
359TX_2[22, 23, 24, 25, 26, 27, 28]lag_281129.0
\n","

360 rows × 4 columns

\n","
"],"text/plain":[" group forecast_horizon feature importance\n","0 CA_2 [1, 2, 3, 4, 5, 6, 7] year 407.0\n","1 CA_2 [1, 2, 3, 4, 5, 6, 7] day_of_week 287.0\n","2 CA_2 [1, 2, 3, 4, 5, 6, 7] quarter 46.0\n","3 CA_2 [1, 2, 3, 4, 5, 6, 7] week_of_year 534.0\n","4 CA_2 [1, 2, 3, 4, 5, 6, 7] week_of_month 86.0\n",".. ... ... ... ...\n","355 TX_2 [22, 23, 24, 25, 26, 27, 28] week_of_month 70.0\n","356 TX_2 [22, 23, 24, 25, 26, 27, 28] month 79.0\n","357 TX_2 [22, 23, 24, 25, 26, 27, 28] weekend 57.0\n","358 TX_2 [22, 23, 24, 25, 26, 27, 28] day_of_month 403.0\n","359 TX_2 [22, 23, 24, 25, 26, 27, 28] lag_28 1129.0\n","\n","[360 rows x 4 columns]"]},"execution_count":6,"metadata":{},"output_type":"execute_result"}],"source":["trained_models = forecast_flow.train(df_train).localCheckpoint()\n","forecast_flow.get_feature_importance(trained_models)"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["### PySpark DataFrame with Local Results"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
groupforecast_horizonfeatureimportance
0CA_1[1, 2, 3, 4, 5, 6, 7]year396
1CA_1[1, 2, 3, 4, 5, 6, 7]day_of_week247
2CA_1[1, 2, 3, 4, 5, 6, 7]quarter48
3CA_1[1, 2, 3, 4, 5, 6, 7]week_of_year574
4CA_1[1, 2, 3, 4, 5, 6, 7]week_of_month80
...............
355WI_3[22, 23, 24, 25, 26, 27, 28]week_of_month78
356WI_3[22, 23, 24, 25, 26, 27, 28]month62
357WI_3[22, 23, 24, 25, 26, 27, 28]weekend61
358WI_3[22, 23, 24, 25, 26, 27, 28]day_of_month478
359WI_3[22, 23, 24, 25, 26, 27, 28]lag_281159
\n","

360 rows × 4 columns

\n","
"],"text/plain":[" group forecast_horizon feature importance\n","0 CA_1 [1, 2, 3, 4, 5, 6, 7] year 396\n","1 CA_1 [1, 2, 3, 4, 5, 6, 7] day_of_week 247\n","2 CA_1 [1, 2, 3, 4, 5, 6, 7] quarter 48\n","3 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_year 574\n","4 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_month 80\n",".. ... ... ... ...\n","355 WI_3 [22, 23, 24, 25, 26, 27, 28] week_of_month 78\n","356 WI_3 [22, 23, 24, 25, 26, 27, 28] month 62\n","357 WI_3 [22, 23, 24, 25, 26, 27, 28] weekend 61\n","358 WI_3 [22, 23, 24, 25, 26, 27, 28] day_of_month 478\n","359 WI_3 [22, 23, 24, 25, 26, 27, 28] lag_28 1159\n","\n","[360 rows x 4 columns]"]},"execution_count":7,"metadata":{},"output_type":"execute_result"}],"source":["forecast_flow.train(df_train, local_result=True)\n","forecast_flow.get_feature_importance()"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["### Pandas DataFrame"]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
groupforecast_horizonfeatureimportance
0CA_1[1, 2, 3, 4, 5, 6, 7]year396
1CA_1[1, 2, 3, 4, 5, 6, 7]day_of_week247
2CA_1[1, 2, 3, 4, 5, 6, 7]quarter48
3CA_1[1, 2, 3, 4, 5, 6, 7]week_of_year574
4CA_1[1, 2, 3, 4, 5, 6, 7]week_of_month80
...............
355WI_3[22, 23, 24, 25, 26, 27, 28]week_of_month78
356WI_3[22, 23, 24, 25, 26, 27, 28]month62
357WI_3[22, 23, 24, 25, 26, 27, 28]weekend61
358WI_3[22, 23, 24, 25, 26, 27, 28]day_of_month478
359WI_3[22, 23, 24, 25, 26, 27, 28]lag_281159
\n","

360 rows × 4 columns

\n","
"],"text/plain":[" group forecast_horizon feature importance\n","0 CA_1 [1, 2, 3, 4, 5, 6, 7] year 396\n","1 CA_1 [1, 2, 3, 4, 5, 6, 7] day_of_week 247\n","2 CA_1 [1, 2, 3, 4, 5, 6, 7] quarter 48\n","3 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_year 574\n","4 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_month 80\n",".. ... ... ... ...\n","355 WI_3 [22, 23, 24, 25, 26, 27, 28] week_of_month 78\n","356 WI_3 [22, 23, 24, 25, 26, 27, 28] month 62\n","357 WI_3 [22, 23, 24, 25, 26, 27, 28] weekend 61\n","358 WI_3 [22, 23, 24, 25, 26, 27, 28] day_of_month 478\n","359 WI_3 [22, 23, 24, 25, 26, 27, 28] lag_28 1159\n","\n","[360 rows x 4 columns]"]},"execution_count":8,"metadata":{},"output_type":"execute_result"}],"source":["forecast_flow.train(df_train.toPandas(), spark=spark)\n","forecast_flow.get_feature_importance()"]}],"metadata":{"application/vnd.databricks.v1+notebook":{"dashboards":[{"elements":[],"globalVars":{},"guid":"ef82ffd4-2993-4b79-8327-f644b750f2dd","layoutOption":{"grid":true,"stack":true},"nuid":"a172d56a-d964-4505-ba66-5a7011220dbf","origId":1859120955398731,"title":"Untitled","version":"DashboardViewV1","width":1024}],"language":"python","notebookMetadata":{"pythonIndentUnit":4},"notebookName":"ForecastFlowML Demo","notebookOrigID":2597536912577418,"widgets":{}},"kernelspec":{"display_name":"spark","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.16"}},"nbformat":4,"nbformat_minor":0}