{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Feature Engineering\n",
"\n",
"ForecastFlowML includes a preprocessing module to create features bas ed on the time \n",
"series dataset. This user guide shows how the features can be created in a scaleable way\n",
"before the modelling phase."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from forecastflowml import FeatureExtractor\n",
"from forecastflowml import ForecastFlowML\n",
"from forecastflowml.data.loader import load_walmart_m5\n",
"from pyspark.sql import SparkSession\n",
"from lightgbm import LGBMRegressor\n",
"import pandas as pd\n",
"\n",
"pd.set_option(\"display.max_columns\", 100)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize Spark"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"spark = (\n",
" SparkSession.builder.master(\"local[4]\")\n",
" .config(\"spark.driver.memory\", \"8g\")\n",
" .config(\"spark.sql.shuffle.partitions\", \"4\")\n",
" .config(\"spark.sql.execution.arrow.enabled\", \"true\")\n",
" .getOrCreate()\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sample Dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----------+-------+------+--------+--------+----------+-----+\n",
"| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-03| 0.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-04| 0.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-05| 1.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-06| 0.0|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-07| 3.0|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"df = load_walmart_m5(spark).localCheckpoint()\n",
"df.show(10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feature Overview\n",
"\n",
"\n",
"With ``FeatureExtractor``, we can extract:\n",
"- Lag features\n",
"- Rolling statistics (mean, standard deviation etc.) with spesified lags\n",
"- Count of consecutive spesific values that may be used to count number of out-of-stock periods\n",
"- History length that refers to the number of periods from the beginning of the time series\n",
"- Date features\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Lags\n",
"\n",
"When extracting the features, we should be careful about the lags we are creating.\n",
"In this example, we are going to prepare features for 4 weekly models.\n",
"\n",
"- Model 1 will predict days 1–7, not using the the 6 most recent lag features.\n",
"- Model 2 will predict days 8–14, not using the the 13 most recent lag features.\n",
"- Model 3 will predict dayts 15–21, not using the the 20 most recent lag features.\n",
"- Model 4 will predict days 22–28, not using the the 27 most recent lag features.\n",
"\n",
"For lag features, we are going to extract the sales on the same week day over the past 4 weeks. \n",
"\n",
"\n",
"\n",
"Since each model has different horizon, they will be allowed to use different lags in the modelling phase. In summary, we need to extract ``lag_7``, ``lag_14``, ``lag_21``, ``lag_28``, ``lag_35``, ``lag_42`` and ``lag_49`` as features."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+\n",
"| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|lag_35|lag_42|lag_49|lag_56|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 0.0| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.0| null| null| null| null| null| null| null|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"feature_extractor = FeatureExtractor(\n",
" id_col=\"id\",\n",
" date_col=\"date\",\n",
" target_col=\"sales\",\n",
" lag_window_features={\n",
" \"lag\": [7 * (i + 1) for i in range(8)],\n",
" },\n",
")\n",
"df_features = feature_extractor.transform(df)\n",
"df_features.show(10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Rolling Statistics\n",
"\n",
"For rolling statistics, we are going to calculate the mean over the **window** of 7, 14 and 30 days, with the **most recent lags** that models can use which are 7 days for model 1, 14 days for model 2, 21 days for model 3 and 28 days for model 4.\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+\n",
"| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| 2.0| 2.0| null| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 1.0| 1.0| 1.0| null| null| null| null| null| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.6666666666666666| 0.6666666666666666| 0.6666666666666666| null| null| null| null| null| null| null| null| null|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"feature_extractor = FeatureExtractor(\n",
" id_col=\"id\",\n",
" date_col=\"date\",\n",
" target_col=\"sales\",\n",
" lag_window_features={\n",
" \"mean\": [[window, lag] for lag in [7, 14, 21, 28] for window in [7, 14, 30]],\n",
" },\n",
")\n",
"df_features = feature_extractor.transform(df)\n",
"df_features.show(10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Out-of-stock Periods\n",
"\n",
"Sometimes a product might be out-of-stock for a certain period. We are now going to\n",
"count the consecutive periods where sales did not occur with the **most recent lags** \n",
"that models can use."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----------------------------+------------------------------+------------------------------+------------------------------+\n",
"| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|count_consecutive_value_lag_7|count_consecutive_value_lag_14|count_consecutive_value_lag_21|count_consecutive_value_lag_28|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----------------------------+------------------------------+------------------------------+------------------------------+\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 0| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 1| null| null| null|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 2| null| null| null|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----------------------------+------------------------------+------------------------------+------------------------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"feature_extractor = FeatureExtractor(\n",
" id_col=\"id\",\n",
" date_col=\"date\",\n",
" target_col=\"sales\",\n",
" count_consecutive_values={\n",
" \"value\": 0,\n",
" \"lags\": [7, 14, 21, 28],\n",
" },\n",
")\n",
"df_features = feature_extractor.transform(df).localCheckpoint()\n",
"df_features.show(10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## History Length\n",
"\n",
"We can also count the total number periods past after the introduction of the time series."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----------+-------+------+--------+--------+----------+-----+--------------+\n",
"| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|history_length|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+--------------+\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| 1|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| 2|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| 3|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| 4|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| 5|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| 6|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| 7|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 8|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 9|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 10|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+--------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"feature_extractor = FeatureExtractor(\n",
" id_col=\"id\",\n",
" date_col=\"date\",\n",
" target_col=\"sales\",\n",
" history_length=True,\n",
")\n",
"df_features = feature_extractor.transform(df).localCheckpoint()\n",
"df_features.show(10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Date Features\n",
"\n",
"Finally, we can also include the date derived features."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----------+-------+------+--------+--------+----------+-----+------------+-----------+------------+-------------+-------+-------+-----+----+\n",
"| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+------------+-----------+------------+-------------+-------+-------+-----+----+\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0| 29| 7| 4| 5| 1| 1| 1|2011|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0| 30| 1| 4| 5| 1| 1| 1|2011|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0| 31| 2| 5| 5| 0| 1| 1|2011|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0| 1| 3| 5| 1| 0| 1| 2|2011|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0| 2| 4| 5| 1| 0| 1| 2|2011|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-03| 0.0| 3| 5| 5| 1| 0| 1| 2|2011|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-04| 0.0| 4| 6| 5| 1| 0| 1| 2|2011|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-05| 1.0| 5| 7| 5| 1| 1| 1| 2|2011|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-06| 0.0| 6| 1| 5| 1| 1| 1| 2|2011|\n",
"|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-07| 3.0| 7| 2| 6| 1| 0| 1| 2|2011|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+------------+-----------+------------+-------------+-------+-------+-----+----+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"feature_extractor = FeatureExtractor(\n",
" id_col=\"id\",\n",
" date_col=\"date\",\n",
" target_col=\"sales\",\n",
" date_features=[\n",
" \"day_of_month\",\n",
" \"day_of_week\",\n",
" \"week_of_year\",\n",
" \"week_of_month\",\n",
" \"weekend\",\n",
" \"quarter\",\n",
" \"month\",\n",
" \"year\",\n",
" ],\n",
")\n",
"df_features = feature_extractor.transform(df).localCheckpoint()\n",
"df_features.show(10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Combine Features\n",
"\n",
"Let's combine all of the features extraction steps together."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"feature_extractor = FeatureExtractor(\n",
" id_col=\"id\",\n",
" date_col=\"date\",\n",
" target_col=\"sales\",\n",
" lag_window_features={\n",
" \"lag\": [7 * (i + 1) for i in range(8)],\n",
" \"mean\": [[window, lag] for lag in [7, 14, 21, 28] for window in [7, 14, 30]],\n",
" },\n",
" date_features=[\n",
" \"day_of_month\",\n",
" \"day_of_week\",\n",
" \"week_of_year\",\n",
" \"week_of_month\",\n",
" \"weekend\",\n",
" \"quarter\",\n",
" \"month\",\n",
" \"year\",\n",
" ],\n",
" count_consecutive_values={\n",
" \"value\": 0,\n",
" \"lags\": [7, 14, 21, 28],\n",
" },\n",
" history_length=True,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### PySpark DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n",
"| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|lag_35|lag_42|lag_49|lag_56|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|count_consecutive_value_lag_7|count_consecutive_value_lag_14|count_consecutive_value_lag_21|count_consecutive_value_lag_28|history_length|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 1| 31| 2| 5| 5| 0| 1| 1|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 2| 1| 3| 5| 1| 0| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 3| 2| 4| 5| 1| 0| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 4| 3| 5| 5| 1| 0| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 5| 4| 6| 5| 1| 0| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 6| 5| 7| 5| 1| 1| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 7| 6| 1| 5| 1| 1| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| null| null| null| null| null| null| null| 2.0| 2.0| 2.0| null| null| null| null| null| null| null| null| null| 0| null| null| null| 8| 7| 2| 6| 1| 0| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 0.0| null| null| null| null| null| null| null| 1.0| 1.0| 1.0| null| null| null| null| null| null| null| null| null| 1| null| null| null| 9| 8| 3| 6| 2| 0| 1| 2|2011|\n",
"|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.0| null| null| null| null| null| null| null| 0.6666666666666666| 0.6666666666666666| 0.6666666666666666| null| null| null| null| null| null| null| null| null| 2| null| null| null| 10| 9| 4| 6| 2| 0| 1| 2|2011|\n",
"+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"df_train = feature_extractor.transform(df).localCheckpoint()\n",
"df_train.show(10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pandas DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" item_id | \n",
" dept_id | \n",
" cat_id | \n",
" store_id | \n",
" state_id | \n",
" date | \n",
" sales | \n",
" lag_7 | \n",
" lag_14 | \n",
" lag_21 | \n",
" lag_28 | \n",
" lag_35 | \n",
" lag_42 | \n",
" lag_49 | \n",
" lag_56 | \n",
" window_7_lag_7_mean | \n",
" window_14_lag_7_mean | \n",
" window_30_lag_7_mean | \n",
" window_7_lag_14_mean | \n",
" window_14_lag_14_mean | \n",
" window_30_lag_14_mean | \n",
" window_7_lag_21_mean | \n",
" window_14_lag_21_mean | \n",
" window_30_lag_21_mean | \n",
" window_7_lag_28_mean | \n",
" window_14_lag_28_mean | \n",
" window_30_lag_28_mean | \n",
" count_consecutive_value_lag_7 | \n",
" count_consecutive_value_lag_14 | \n",
" count_consecutive_value_lag_21 | \n",
" count_consecutive_value_lag_28 | \n",
" history_length | \n",
" day_of_month | \n",
" day_of_week | \n",
" week_of_year | \n",
" week_of_month | \n",
" weekend | \n",
" quarter | \n",
" month | \n",
" year | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" FOODS_1_011_WI_2_evaluation | \n",
" FOODS_1_011 | \n",
" FOODS_1 | \n",
" FOODS | \n",
" WI_2 | \n",
" WI | \n",
" 2011-01-31 | \n",
" 2.0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" 1 | \n",
" 31 | \n",
" 2 | \n",
" 5 | \n",
" 5 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 2011 | \n",
"
\n",
" \n",
" | 1 | \n",
" FOODS_1_011_WI_2_evaluation | \n",
" FOODS_1_011 | \n",
" FOODS_1 | \n",
" FOODS | \n",
" WI_2 | \n",
" WI | \n",
" 2011-02-01 | \n",
" 0.0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" 2 | \n",
" 1 | \n",
" 3 | \n",
" 5 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 2011 | \n",
"
\n",
" \n",
" | 2 | \n",
" FOODS_1_011_WI_2_evaluation | \n",
" FOODS_1_011 | \n",
" FOODS_1 | \n",
" FOODS | \n",
" WI_2 | \n",
" WI | \n",
" 2011-02-02 | \n",
" 0.0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" 3 | \n",
" 2 | \n",
" 4 | \n",
" 5 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 2011 | \n",
"
\n",
" \n",
" | 3 | \n",
" FOODS_1_011_WI_2_evaluation | \n",
" FOODS_1_011 | \n",
" FOODS_1 | \n",
" FOODS | \n",
" WI_2 | \n",
" WI | \n",
" 2011-02-03 | \n",
" 0.0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" 4 | \n",
" 3 | \n",
" 5 | \n",
" 5 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 2011 | \n",
"
\n",
" \n",
" | 4 | \n",
" FOODS_1_011_WI_2_evaluation | \n",
" FOODS_1_011 | \n",
" FOODS_1 | \n",
" FOODS | \n",
" WI_2 | \n",
" WI | \n",
" 2011-02-04 | \n",
" 0.0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
" 5 | \n",
" 4 | \n",
" 6 | \n",
" 5 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 2011 | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 1470899 | \n",
" HOUSEHOLD_2_514_WI_3_evaluation | \n",
" HOUSEHOLD_2_514 | \n",
" HOUSEHOLD_2 | \n",
" HOUSEHOLD | \n",
" WI_3 | \n",
" WI | \n",
" 2016-05-18 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.142857 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.142857 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.214286 | \n",
" 0.133333 | \n",
" 9.0 | \n",
" 2.0 | \n",
" 5.0 | \n",
" 6.0 | \n",
" 1936 | \n",
" 18 | \n",
" 4 | \n",
" 20 | \n",
" 3 | \n",
" 0 | \n",
" 2 | \n",
" 5 | \n",
" 2016 | \n",
"
\n",
" \n",
" | 1470900 | \n",
" HOUSEHOLD_2_514_WI_3_evaluation | \n",
" HOUSEHOLD_2_514 | \n",
" HOUSEHOLD_2 | \n",
" HOUSEHOLD | \n",
" WI_3 | \n",
" WI | \n",
" 2016-05-19 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.071429 | \n",
" 0.100000 | \n",
" 0.142857 | \n",
" 0.142857 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.000000 | \n",
" 0.214286 | \n",
" 0.133333 | \n",
" 10.0 | \n",
" 3.0 | \n",
" 6.0 | \n",
" 7.0 | \n",
" 1937 | \n",
" 19 | \n",
" 5 | \n",
" 20 | \n",
" 3 | \n",
" 0 | \n",
" 2 | \n",
" 5 | \n",
" 2016 | \n",
"
\n",
" \n",
" | 1470901 | \n",
" HOUSEHOLD_2_514_WI_3_evaluation | \n",
" HOUSEHOLD_2_514 | \n",
" HOUSEHOLD_2 | \n",
" HOUSEHOLD | \n",
" WI_3 | \n",
" WI | \n",
" 2016-05-20 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.071429 | \n",
" 0.100000 | \n",
" 0.142857 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.000000 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.285714 | \n",
" 0.166667 | \n",
" 11.0 | \n",
" 4.0 | \n",
" 7.0 | \n",
" 0.0 | \n",
" 1938 | \n",
" 20 | \n",
" 6 | \n",
" 20 | \n",
" 3 | \n",
" 0 | \n",
" 2 | \n",
" 5 | \n",
" 2016 | \n",
"
\n",
" \n",
" | 1470902 | \n",
" HOUSEHOLD_2_514_WI_3_evaluation | \n",
" HOUSEHOLD_2_514 | \n",
" HOUSEHOLD_2 | \n",
" HOUSEHOLD | \n",
" WI_3 | \n",
" WI | \n",
" 2016-05-21 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.071429 | \n",
" 0.066667 | \n",
" 0.142857 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.000000 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.285714 | \n",
" 0.166667 | \n",
" 12.0 | \n",
" 5.0 | \n",
" 8.0 | \n",
" 1.0 | \n",
" 1939 | \n",
" 21 | \n",
" 7 | \n",
" 20 | \n",
" 3 | \n",
" 1 | \n",
" 2 | \n",
" 5 | \n",
" 2016 | \n",
"
\n",
" \n",
" | 1470903 | \n",
" HOUSEHOLD_2_514_WI_3_evaluation | \n",
" HOUSEHOLD_2_514 | \n",
" HOUSEHOLD_2 | \n",
" HOUSEHOLD | \n",
" WI_3 | \n",
" WI | \n",
" 2016-05-22 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.071429 | \n",
" 0.066667 | \n",
" 0.142857 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.000000 | \n",
" 0.071429 | \n",
" 0.166667 | \n",
" 0.142857 | \n",
" 0.285714 | \n",
" 0.166667 | \n",
" 13.0 | \n",
" 6.0 | \n",
" 9.0 | \n",
" 2.0 | \n",
" 1940 | \n",
" 22 | \n",
" 1 | \n",
" 20 | \n",
" 4 | \n",
" 1 | \n",
" 2 | \n",
" 5 | \n",
" 2016 | \n",
"
\n",
" \n",
"
\n",
"
1470904 rows × 41 columns
\n",
"
"
],
"text/plain": [
" id item_id dept_id \\\n",
"0 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n",
"1 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n",
"2 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n",
"3 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n",
"4 FOODS_1_011_WI_2_evaluation FOODS_1_011 FOODS_1 \n",
"... ... ... ... \n",
"1470899 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n",
"1470900 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n",
"1470901 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n",
"1470902 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n",
"1470903 HOUSEHOLD_2_514_WI_3_evaluation HOUSEHOLD_2_514 HOUSEHOLD_2 \n",
"\n",
" cat_id store_id state_id date sales lag_7 lag_14 lag_21 \\\n",
"0 FOODS WI_2 WI 2011-01-31 2.0 NaN NaN NaN \n",
"1 FOODS WI_2 WI 2011-02-01 0.0 NaN NaN NaN \n",
"2 FOODS WI_2 WI 2011-02-02 0.0 NaN NaN NaN \n",
"3 FOODS WI_2 WI 2011-02-03 0.0 NaN NaN NaN \n",
"4 FOODS WI_2 WI 2011-02-04 0.0 NaN NaN NaN \n",
"... ... ... ... ... ... ... ... ... \n",
"1470899 HOUSEHOLD WI_3 WI 2016-05-18 0.0 0.0 0.0 0.0 \n",
"1470900 HOUSEHOLD WI_3 WI 2016-05-19 0.0 0.0 0.0 0.0 \n",
"1470901 HOUSEHOLD WI_3 WI 2016-05-20 1.0 0.0 0.0 0.0 \n",
"1470902 HOUSEHOLD WI_3 WI 2016-05-21 0.0 0.0 0.0 0.0 \n",
"1470903 HOUSEHOLD WI_3 WI 2016-05-22 0.0 0.0 0.0 0.0 \n",
"\n",
" lag_28 lag_35 lag_42 lag_49 lag_56 window_7_lag_7_mean \\\n",
"0 NaN NaN NaN NaN NaN NaN \n",
"1 NaN NaN NaN NaN NaN NaN \n",
"2 NaN NaN NaN NaN NaN NaN \n",
"3 NaN NaN NaN NaN NaN NaN \n",
"4 NaN NaN NaN NaN NaN NaN \n",
"... ... ... ... ... ... ... \n",
"1470899 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"1470900 0.0 1.0 0.0 0.0 0.0 0.0 \n",
"1470901 1.0 0.0 0.0 0.0 0.0 0.0 \n",
"1470902 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"1470903 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" window_14_lag_7_mean window_30_lag_7_mean window_7_lag_14_mean \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"1470899 0.071429 0.166667 0.142857 \n",
"1470900 0.071429 0.100000 0.142857 \n",
"1470901 0.071429 0.100000 0.142857 \n",
"1470902 0.071429 0.066667 0.142857 \n",
"1470903 0.071429 0.066667 0.142857 \n",
"\n",
" window_14_lag_14_mean window_30_lag_14_mean window_7_lag_21_mean \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"1470899 0.142857 0.166667 0.142857 \n",
"1470900 0.142857 0.166667 0.142857 \n",
"1470901 0.071429 0.166667 0.000000 \n",
"1470902 0.071429 0.166667 0.000000 \n",
"1470903 0.071429 0.166667 0.000000 \n",
"\n",
" window_14_lag_21_mean window_30_lag_21_mean window_7_lag_28_mean \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"1470899 0.142857 0.166667 0.142857 \n",
"1470900 0.071429 0.166667 0.000000 \n",
"1470901 0.071429 0.166667 0.142857 \n",
"1470902 0.071429 0.166667 0.142857 \n",
"1470903 0.071429 0.166667 0.142857 \n",
"\n",
" window_14_lag_28_mean window_30_lag_28_mean \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"... ... ... \n",
"1470899 0.214286 0.133333 \n",
"1470900 0.214286 0.133333 \n",
"1470901 0.285714 0.166667 \n",
"1470902 0.285714 0.166667 \n",
"1470903 0.285714 0.166667 \n",
"\n",
" count_consecutive_value_lag_7 count_consecutive_value_lag_14 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"... ... ... \n",
"1470899 9.0 2.0 \n",
"1470900 10.0 3.0 \n",
"1470901 11.0 4.0 \n",
"1470902 12.0 5.0 \n",
"1470903 13.0 6.0 \n",
"\n",
" count_consecutive_value_lag_21 count_consecutive_value_lag_28 \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"... ... ... \n",
"1470899 5.0 6.0 \n",
"1470900 6.0 7.0 \n",
"1470901 7.0 0.0 \n",
"1470902 8.0 1.0 \n",
"1470903 9.0 2.0 \n",
"\n",
" history_length day_of_month day_of_week week_of_year \\\n",
"0 1 31 2 5 \n",
"1 2 1 3 5 \n",
"2 3 2 4 5 \n",
"3 4 3 5 5 \n",
"4 5 4 6 5 \n",
"... ... ... ... ... \n",
"1470899 1936 18 4 20 \n",
"1470900 1937 19 5 20 \n",
"1470901 1938 20 6 20 \n",
"1470902 1939 21 7 20 \n",
"1470903 1940 22 1 20 \n",
"\n",
" week_of_month weekend quarter month year \n",
"0 5 0 1 1 2011 \n",
"1 1 0 1 2 2011 \n",
"2 1 0 1 2 2011 \n",
"3 1 0 1 2 2011 \n",
"4 1 0 1 2 2011 \n",
"... ... ... ... ... ... \n",
"1470899 3 0 2 5 2016 \n",
"1470900 3 0 2 5 2016 \n",
"1470901 3 0 2 5 2016 \n",
"1470902 3 1 2 5 2016 \n",
"1470903 4 1 2 5 2016 \n",
"\n",
"[1470904 rows x 41 columns]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"feature_extractor.transform(df.toPandas(), spark=spark)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training\n",
"\n",
"We can not pass the features created by ``FeatureExtractor`` to ``ForecastFlowML`` for training. As mentioned in the lag feature creation step, we are going to set ``use_lag_range=28`` to use lags which are 28 days after from the most recent lag features. "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" group | \n",
" forecast_horizon | \n",
" model | \n",
" start_time | \n",
" end_time | \n",
" elapsed_seconds | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" CA_2 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (02:43:36) | \n",
" 01-May-2023 (02:43:41) | \n",
" 5.5 | \n",
"
\n",
" \n",
" | 1 | \n",
" CA_3 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (02:43:42) | \n",
" 01-May-2023 (02:43:50) | \n",
" 7.9 | \n",
"
\n",
" \n",
" | 2 | \n",
" WI_2 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (02:43:50) | \n",
" 01-May-2023 (02:43:53) | \n",
" 2.7 | \n",
"
\n",
" \n",
" | 3 | \n",
" WI_3 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (02:43:53) | \n",
" 01-May-2023 (02:43:57) | \n",
" 3.8 | \n",
"
\n",
" \n",
" | 4 | \n",
" CA_1 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (02:43:34) | \n",
" 01-May-2023 (02:43:40) | \n",
" 5.9 | \n",
"
\n",
" \n",
" | 5 | \n",
" CA_4 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (02:43:41) | \n",
" 01-May-2023 (02:43:49) | \n",
" 7.7 | \n",
"
\n",
" \n",
" | 6 | \n",
" TX_1 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (02:43:49) | \n",
" 01-May-2023 (02:43:53) | \n",
" 3.8 | \n",
"
\n",
" \n",
" | 7 | \n",
" TX_3 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (02:43:53) | \n",
" 01-May-2023 (02:43:57) | \n",
" 4.2 | \n",
"
\n",
" \n",
" | 8 | \n",
" WI_1 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (02:43:58) | \n",
" 01-May-2023 (02:44:00) | \n",
" 2.1 | \n",
"
\n",
" \n",
" | 9 | \n",
" TX_2 | \n",
" [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... | \n",
" [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... | \n",
" 01-May-2023 (02:43:32) | \n",
" 01-May-2023 (02:43:39) | \n",
" 7.2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" group forecast_horizon \\\n",
"0 CA_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"1 CA_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"2 WI_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"3 WI_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"4 CA_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"5 CA_4 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"6 TX_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"7 TX_3 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"8 WI_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"9 TX_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n",
"\n",
" model start_time \\\n",
"0 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (02:43:36) \n",
"1 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (02:43:42) \n",
"2 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (02:43:50) \n",
"3 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (02:43:53) \n",
"4 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (02:43:34) \n",
"5 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (02:43:41) \n",
"6 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (02:43:49) \n",
"7 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (02:43:53) \n",
"8 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (02:43:58) \n",
"9 [\u0003clightgbm.sklearn\\nLGBMRegressor\\nq\u0000)q\u0001}q\u0002... 01-May-2023 (02:43:32) \n",
"\n",
" end_time elapsed_seconds \n",
"0 01-May-2023 (02:43:41) 5.5 \n",
"1 01-May-2023 (02:43:50) 7.9 \n",
"2 01-May-2023 (02:43:53) 2.7 \n",
"3 01-May-2023 (02:43:57) 3.8 \n",
"4 01-May-2023 (02:43:40) 5.9 \n",
"5 01-May-2023 (02:43:49) 7.7 \n",
"6 01-May-2023 (02:43:53) 3.8 \n",
"7 01-May-2023 (02:43:57) 4.2 \n",
"8 01-May-2023 (02:44:00) 2.1 \n",
"9 01-May-2023 (02:43:39) 7.2 "
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"forecast_flow = ForecastFlowML(\n",
" group_col=\"store_id\",\n",
" id_col=\"id\",\n",
" date_col=\"date\",\n",
" target_col=\"sales\",\n",
" date_frequency=\"days\",\n",
" model_horizon=7,\n",
" max_forecast_horizon=28,\n",
" model=LGBMRegressor(),\n",
" use_lag_range=28,\n",
")\n",
"trained_models = forecast_flow.train(df_train).toPandas()\n",
"trained_models"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Examine Features\n",
"\n",
"Let's examine which features are used for each model."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" model_0 | \n",
" model_1 | \n",
" model_2 | \n",
" model_3 | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" count_consecutive_value_lag_7 | \n",
" count_consecutive_value_lag_14 | \n",
" count_consecutive_value_lag_21 | \n",
" count_consecutive_value_lag_28 | \n",
"
\n",
" \n",
" | 1 | \n",
" day_of_month | \n",
" day_of_month | \n",
" day_of_month | \n",
" day_of_month | \n",
"
\n",
" \n",
" | 2 | \n",
" day_of_week | \n",
" day_of_week | \n",
" day_of_week | \n",
" day_of_week | \n",
"
\n",
" \n",
" | 3 | \n",
" history_length | \n",
" history_length | \n",
" history_length | \n",
" history_length | \n",
"
\n",
" \n",
" | 4 | \n",
" lag_14 | \n",
" lag_14 | \n",
" lag_21 | \n",
" lag_28 | \n",
"
\n",
" \n",
" | 5 | \n",
" lag_21 | \n",
" lag_21 | \n",
" lag_28 | \n",
" lag_35 | \n",
"
\n",
" \n",
" | 6 | \n",
" lag_28 | \n",
" lag_28 | \n",
" lag_35 | \n",
" lag_42 | \n",
"
\n",
" \n",
" | 7 | \n",
" lag_35 | \n",
" lag_35 | \n",
" lag_42 | \n",
" lag_49 | \n",
"
\n",
" \n",
" | 8 | \n",
" lag_7 | \n",
" lag_42 | \n",
" lag_49 | \n",
" lag_56 | \n",
"
\n",
" \n",
" | 9 | \n",
" month | \n",
" month | \n",
" month | \n",
" month | \n",
"
\n",
" \n",
" | 10 | \n",
" quarter | \n",
" quarter | \n",
" quarter | \n",
" quarter | \n",
"
\n",
" \n",
" | 11 | \n",
" week_of_month | \n",
" week_of_month | \n",
" week_of_month | \n",
" week_of_month | \n",
"
\n",
" \n",
" | 12 | \n",
" week_of_year | \n",
" week_of_year | \n",
" week_of_year | \n",
" week_of_year | \n",
"
\n",
" \n",
" | 13 | \n",
" weekend | \n",
" weekend | \n",
" weekend | \n",
" weekend | \n",
"
\n",
" \n",
" | 14 | \n",
" window_14_lag_7_mean | \n",
" window_14_lag_14_mean | \n",
" window_14_lag_21_mean | \n",
" window_14_lag_28_mean | \n",
"
\n",
" \n",
" | 15 | \n",
" window_30_lag_7_mean | \n",
" window_30_lag_14_mean | \n",
" window_30_lag_21_mean | \n",
" window_30_lag_28_mean | \n",
"
\n",
" \n",
" | 16 | \n",
" window_7_lag_7_mean | \n",
" window_7_lag_14_mean | \n",
" window_7_lag_21_mean | \n",
" window_7_lag_28_mean | \n",
"
\n",
" \n",
" | 17 | \n",
" year | \n",
" year | \n",
" year | \n",
" year | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" model_0 model_1 \\\n",
"0 count_consecutive_value_lag_7 count_consecutive_value_lag_14 \n",
"1 day_of_month day_of_month \n",
"2 day_of_week day_of_week \n",
"3 history_length history_length \n",
"4 lag_14 lag_14 \n",
"5 lag_21 lag_21 \n",
"6 lag_28 lag_28 \n",
"7 lag_35 lag_35 \n",
"8 lag_7 lag_42 \n",
"9 month month \n",
"10 quarter quarter \n",
"11 week_of_month week_of_month \n",
"12 week_of_year week_of_year \n",
"13 weekend weekend \n",
"14 window_14_lag_7_mean window_14_lag_14_mean \n",
"15 window_30_lag_7_mean window_30_lag_14_mean \n",
"16 window_7_lag_7_mean window_7_lag_14_mean \n",
"17 year year \n",
"\n",
" model_2 model_3 \n",
"0 count_consecutive_value_lag_21 count_consecutive_value_lag_28 \n",
"1 day_of_month day_of_month \n",
"2 day_of_week day_of_week \n",
"3 history_length history_length \n",
"4 lag_21 lag_28 \n",
"5 lag_28 lag_35 \n",
"6 lag_35 lag_42 \n",
"7 lag_42 lag_49 \n",
"8 lag_49 lag_56 \n",
"9 month month \n",
"10 quarter quarter \n",
"11 week_of_month week_of_month \n",
"12 week_of_year week_of_year \n",
"13 weekend weekend \n",
"14 window_14_lag_21_mean window_14_lag_28_mean \n",
"15 window_30_lag_21_mean window_30_lag_28_mean \n",
"16 window_7_lag_21_mean window_7_lag_28_mean \n",
"17 year year "
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pickle\n",
"\n",
"features = {}\n",
"for i in range(4):\n",
" model = pickle.loads(bytes(trained_models[\"model\"].iloc[0][i], \"latin1\"))\n",
" features[f\"model_{i}\"] = sorted(model.feature_name_)\n",
"pd.DataFrame(features)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sspark37",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}