{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "This document will walk you through a complete example of using the Pandas library to do data preparation for a neural network model.\n", "\n", "Pandas is one of the most popular libraries in the Python world for doing data science. Pandas can help you to pre-process raw data before feeding it to a model.\n", "\n", "Before working through this notebook you should familiarize yourself with the basic operations in Pandas. A good way to do this is to read through \n", "[this online tutorial](https://www.learndatasci.com/tutorials/python-pandas-tutorial-complete-introduction-for-beginners/)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The data set we will work through in this notebook is available on the course web site. Run the code in the cell below to fetch the CSV file with the data you will need." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "URL transformed to HTTPS due to an HSTS policy\n", "--2023-09-21 13:17:02-- https://www.lawrence.edu/fast/greggj/CMSC490/short_rides.csv\n", "Resolving www.lawrence.edu (www.lawrence.edu)... 143.44.124.14\n", "Connecting to www.lawrence.edu (www.lawrence.edu)|143.44.124.14|:443... connected.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://www7.lawrence.edu/fast/greggj/CMSC490/short_rides.csv [following]\n", "--2023-09-21 13:17:03-- https://www7.lawrence.edu/fast/greggj/CMSC490/short_rides.csv\n", "Resolving www7.lawrence.edu (www7.lawrence.edu)... 143.44.124.14\n", "Connecting to www7.lawrence.edu (www7.lawrence.edu)|143.44.124.14|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: https://www2.lawrence.edu/fast/greggj/CMSC490/short_rides.csv [following]\n", "--2023-09-21 13:17:03-- https://www2.lawrence.edu/fast/greggj/CMSC490/short_rides.csv\n", "Resolving www2.lawrence.edu (www2.lawrence.edu)... 143.44.124.14\n", "Connecting to www2.lawrence.edu (www2.lawrence.edu)|143.44.124.14|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: http://www2.lawrence.edu/fast/GREGGJ/CMSC490/short_rides.csv [following]\n", "--2023-09-21 13:17:03-- http://www2.lawrence.edu/fast/GREGGJ/CMSC490/short_rides.csv\n", "Connecting to www2.lawrence.edu (www2.lawrence.edu)|143.44.124.14|:80... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 1658 (1.6K) [text/csv]\n", "Saving to: ‘short_rides.csv.1’\n", "\n", "short_rides.csv.1 100%[===================>] 1.62K --.-KB/s in 0s \n", "\n", "2023-09-21 13:17:03 (185 MB/s) - ‘short_rides.csv.1’ saved [1658/1658]\n", "\n" ] } ], "source": [ "!wget http://www.lawrence.edu/fast/greggj/CMSC490/short_rides.csv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The data set is some data that I assembled from Strava data on a series of bike rides I did in 2016 and 2017. All of these rides followed the same route, and on each of these rides a rode the same bicycle. What varied from ride to ride were mostly the weather conditions: over the course of the year the temperature and the wind conditions would vary quite a bit.\n", "\n", "Two other bits of data in the data set include what year the ride took place in and some additional detail on the bike used. The year is relevant because there was a slight change to the route from 2016 to 2017, and also I slow down a little bit from year to year as I get older. The data set also contains some information about the\n", "bike I used. I used the same bike for all of these rides, but I have two different sets of wheels that I use on the bike. In the spring and fall I use a heavier and \n", "more puncture resistant wheel set, while in the summer I use a lighter and faster set of wheels.\n", "\n", "In this study we are going to look at the average speed for a bike ride as a function of all of the factors that may affect the outcome.\n", "\n", "We start by importing our data from the CSV file. This gets loaded into a Pandas dataframe. The Pandas head() command shows the first few rows in the data frame." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yearwheelspeedtempwinddir
02016115.99809864.05.0NaN
12016016.80049371.011.0170.0
22016016.91917065.07.095.0
32016016.62040777.07.0120.0
42016016.58306377.03.0220.0
\n", "
" ], "text/plain": [ " year wheel speed temp wind dir\n", "0 2016 1 15.998098 64.0 5.0 NaN\n", "1 2016 0 16.800493 71.0 11.0 170.0\n", "2 2016 0 16.919170 65.0 7.0 95.0\n", "3 2016 0 16.620407 77.0 7.0 120.0\n", "4 2016 0 16.583063 77.0 3.0 220.0" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('short_rides.csv')\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Taking a look at the full data set, the first thing we want to watch out for is missing data values. The 'dir' column, which represents wind direction, has a number of NaN entries. The reason for this is that sometimes when the wind speed drops too low the weather service will refuse to report a wind direction.\n", "\n", "One common strategy for fixing these kinds of missing values is to replace the NaN values in a column with the mean value for the variable in that column. First we compute the mean for the 'dir' column." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "184.20212765957447" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['dir'].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The pandas fillna method allows us to replace NaN values in a column with values of our choice. In this case, we simply replace all of the NaNs in the 'dir' column with the mean wind direction of 184." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yearwheelspeedtempwinddir
year1.000000-0.076339-0.277386-0.1312130.1871930.104080
wheel-0.0763391.000000-0.699100-0.708325-0.1029830.174093
speed-0.277386-0.6991001.0000000.7092050.002943-0.235181
temp-0.131213-0.7083250.7092051.0000000.018917-0.187772
wind0.187193-0.1029830.0029430.0189171.0000000.251351
dir0.1040800.174093-0.235181-0.1877720.2513511.000000
\n", "
" ], "text/plain": [ " year wheel speed temp wind dir\n", "year 1.000000 -0.076339 -0.277386 -0.131213 0.187193 0.104080\n", "wheel -0.076339 1.000000 -0.699100 -0.708325 -0.102983 0.174093\n", "speed -0.277386 -0.699100 1.000000 0.709205 0.002943 -0.235181\n", "temp -0.131213 -0.708325 0.709205 1.000000 0.018917 -0.187772\n", "wind 0.187193 -0.102983 0.002943 0.018917 1.000000 0.251351\n", "dir 0.104080 0.174093 -0.235181 -0.187772 0.251351 1.000000" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['dir'] = df['dir'].fillna(value=184)\n", "df.corr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After our first round of data cleaning, the next thing to examine is the correlation matrix for the data set. The correlation matrix can give us a sense for which features are most strongly correlated with the output value. In this example the output value we are trying to explain is the speed. Looking at the row for speed in the correlation matrix, we see that most of the other features are correlated with the speed.\n", "\n", "The one surprise here is the correlation between speed and wind speed: that correlation is almost 0, which does not make any sense.\n", "\n", "It turns out that the problem here is the way that the data set represents the wind. Wind is represented here as a combination of wind speed ('wind') and wind direction ('dir'). This is problematic for two reasons. The first is the lack of correlation between the wind speed and the speed of travel. The second is that although wind direction is likely to have an impact on speed, that relation may well be non-linear. This is because wind direction is represented as an angle measured clockwise from north.\n", "\n", "The fix for both of these problems is to switch to a more appropriate representation for wind. Wind is a vector, and the most natural way to represent vectors is via a list of vector components, not a combination of a vector length and a vector direction. The code below uses the 'wind' and 'dir' columns to construct two new components, 's' and 'w', which decompose the wind vector into southerly and westerly components. After converting the vector representation we can then drop the original 'wind' and 'dir' columns.\n", "\n", "One final bit of data cleaning to do here is to transform the year column from values of 2016 or 2017 into 0, 1 values." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yearwheelspeedtempsw
year1.000000-0.099840-0.264604-0.1312130.0959900.190441
wheel-0.0998401.000000-0.688161-0.708325-0.2693890.099353
speed-0.264604-0.6881611.0000000.7092050.224236-0.213081
temp-0.131213-0.7083250.7092051.0000000.389297-0.040870
s0.095990-0.2693890.2242360.3892971.0000000.021641
w0.1904410.099353-0.213081-0.0408700.0216411.000000
\n", "
" ], "text/plain": [ " year wheel speed temp s w\n", "year 1.000000 -0.099840 -0.264604 -0.131213 0.095990 0.190441\n", "wheel -0.099840 1.000000 -0.688161 -0.708325 -0.269389 0.099353\n", "speed -0.264604 -0.688161 1.000000 0.709205 0.224236 -0.213081\n", "temp -0.131213 -0.708325 0.709205 1.000000 0.389297 -0.040870\n", "s 0.095990 -0.269389 0.224236 0.389297 1.000000 0.021641\n", "w 0.190441 0.099353 -0.213081 -0.040870 0.021641 1.000000" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['s'] = df['wind']*np.cos(((df['dir']-180)/180)*np.pi)\n", "df['w'] = df['wind']*np.cos(((df['dir']-270)/180)*np.pi)\n", "df['year'] = df['year']-2016\n", "df = df.drop(['wind','dir'],axis=1).dropna()\n", "df.corr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The correlation matrix for the data set after the wind vector fix looks a lot more reasonable. There is now a noticable correlation between wind and speed, as you would expect." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now ready to start constructing and training a model. To give us a baseline model to compare an eventual neural network model against I am going to use the sci-kit-learn package to construct a simple machine learning model first.\n", "\n", "The first step is to split our data set into a training set, which will be used to train the model, and a test set, which will be used to evaluate the model.\n", "\n", "The sci-kit-learn train_test_split function does the job nicely, randomly splitting the original data set into a training set and a test set. A widely used convention in data science is to set aside 20% of the original data set as a test set, so that is what we do here." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "train_set, test_set = train_test_split(df,test_size=0.2,random_state=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model we will use for this example is one of the simplest models available, a multi-linear regression model.\n", "\n", "The code below loads the model, and fits it to the training set X and y values." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LinearRegression()" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.linear_model import LinearRegression\n", "lin_reg = LinearRegression()\n", "X = train_set[['year','wheel','temp','s','w']]\n", "y = train_set['speed']\n", "lin_reg.fit(X,y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once the model has been trained by fitting it to the training data, we can use the model to make predictions. We feed the model the test data set X values and then compare the predictions produced by the model with y values from the test set." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "30 0.664158\n", "43 0.160004\n", "29 -0.166340\n", "47 -0.428257\n", "27 -0.007010\n", "40 0.697200\n", "14 -0.821402\n", "22 -0.621427\n", "5 0.063666\n", "28 -0.370111\n", "Name: speed, dtype: float64" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions = lin_reg.predict(test_set[['year','wheel','temp','s','w']])\n", "predictions - test_set['speed']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, most of these predictions are reasonably good.\n", "\n", "A standard metric to apply to a set of predictions is the mean square error. As usual, sci-kit-learn has evaluation functions ready to help us compute this metric." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.23658250196854227" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.metrics import mean_squared_error\n", "mse = mean_squared_error(predictions,test_set['speed'])\n", "mse" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One last thing to look at in our model is the model coefficients. These coefficients can give us a sense for the impact of each one of the input features.\n", "\n", "The coefficients here tell me that in going from 2016 to 2017 my average speed dropped by 0.3 mph. This models the impact of the installation of the roundabout on my route, which forced me to stop at an intersection I would typically zip through without stopping. The coefficients also say that putting on the heavy winter wheel causes my average speed to drop by about 0.6 mph: this lines up pretty well with my own experience. Another interesting observation here concerns the wind speed. The westerly component of the wind speed has an impact about six times larger than the southerly component. This makes perfect sense, since most the route lines up in an east-west direction, with the segment where I am travelling to the west more exposed to the wind than the segment where I am travelling to the east." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(15.593571424281194,\n", " array([-2.85780789e-01, -7.58538983e-01, 1.47548761e-02, 4.81419590e-03,\n", " -3.48648341e-04]))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lin_reg.intercept_,lin_reg.coef_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One last experiment before we conclude this first pass through the data. The mse we computed above is based on a single train/test split of the data. Like all experiments that depend on random samples, this mse value is likely to fluctuate as we do different splits. One way to model this effect is to perform a cross validation.\n", "\n", "To do a cross validation we start by randomizing the data set. We then split the data set into N equal sized folds and then run N experiments. In each experiment we set aside one of the folds as a test set, train the model on the remaining N-1 folds, and then compute an rms error for the test set. Finally, we average together the rms errors for the N experiments to produce an average rms error. This average error gives us a better feel for the actual model error for our model.\n", "\n", "As always, sci-kit-learn has a handy function available to perform cross validations, cross_val_score. Before calling this function I have to first shuffle the full data set. The key to doing the shuffling is to to use the numpy permutation function to make a randomly permuted list of the indices and then pass that list of shuffled indices to the pandas iloc function to shuffle the rows in the original data set." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-0.36602975, -0.14647831, -0.10125056, -0.14462003, -0.150603 ])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.model_selection import cross_val_score\n", "shuffled_indices = np.random.permutation(len(df))\n", "shuffled_data = df.iloc[shuffled_indices]\n", "full_X = shuffled_data[['year','wheel','temp','s','w']]\n", "full_Y = shuffled_data['speed']\n", "scores = cross_val_score(lin_reg,full_X,full_Y,scoring='neg_mean_squared_error',cv=5)\n", "scores" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The scores for the five folds do vary somewhat, which is to be expected. Taking the average of the errors for the five folds will give us a better sense for the average error we can expect from this model." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.42637580855203794" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.sqrt(-scores.mean())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This tells us that the model we constructed here is able to predict my speed on the route accurate to about 0.4 mph. This is a pretty decent final result." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now want to construct a simple neural network model for this same data set.\n", "\n", "The first thing I am going to do is to follow a slightly different procedure to construct a training set and a test set. We will use the Pandas sample() method to shuffle the rows in the dataframe, and then split off the last 11 rows for use in a test set." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yearwheelspeedtempsw
471016.57797155.06.062178e+00-3.500000
321016.85215066.05.985384e+000.418539
501115.38829837.04.898587e-168.000000
451015.83334570.05.282728e+0011.328847
381016.39507574.05.196152e+003.000000
\n", "
" ], "text/plain": [ " year wheel speed temp s w\n", "47 1 0 16.577971 55.0 6.062178e+00 -3.500000\n", "32 1 0 16.852150 66.0 5.985384e+00 0.418539\n", "50 1 1 15.388298 37.0 4.898587e-16 8.000000\n", "45 1 0 15.833345 70.0 5.282728e+00 11.328847\n", "38 1 0 16.395075 74.0 5.196152e+00 3.000000" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = df.sample(frac=1)\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I will use the Pandas to_numpy() method to convert parts of the dataframe to numpy arrays, because keras works best with numpy arrays as its inputs." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(52, 5)\n" ] } ], "source": [ "input = df.drop(['speed'],axis=1).to_numpy()\n", "output = df['speed'].to_numpy()\n", "print(input.shape)\n", "mean_speed = output.mean()\n", "train_data = input[:41,:]\n", "train_targets = output[:41] - mean_speed\n", "test_data = input[41:,:]\n", "test_targets = output[41:] - mean_speed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are ready to build and train our neural network. We start with a simple network with two hidden layers." ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras import layers\n", "\n", "def build_model():\n", " model = keras.Sequential([\n", " layers.BatchNormalization(),\n", " layers.Dense(12, activation=\"relu\"),\n", " layers.Dense(12, activation=\"relu\"),\n", " layers.Dense(1)\n", " ])\n", " model.compile(optimizer=\"Adam\", loss=\"mse\", metrics=[\"mae\"])\n", " return model" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/80\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "1/1 [==============================] - 1s 692ms/step - loss: 0.7459 - mae: 0.7059 - val_loss: 49.1181 - val_mae: 6.6360\n", "Epoch 2/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.7248 - mae: 0.6959 - val_loss: 27.1337 - val_mae: 4.8620\n", "Epoch 3/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.7042 - mae: 0.6861 - val_loss: 18.2167 - val_mae: 3.9391\n", "Epoch 4/80\n", "1/1 [==============================] - 0s 13ms/step - loss: 0.6843 - mae: 0.6764 - val_loss: 13.3767 - val_mae: 3.3404\n", "Epoch 5/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.6649 - mae: 0.6669 - val_loss: 10.3496 - val_mae: 2.9083\n", "Epoch 6/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.6461 - mae: 0.6574 - val_loss: 8.3035 - val_mae: 2.5796\n", "Epoch 7/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.6279 - mae: 0.6480 - val_loss: 6.8263 - val_mae: 2.3152\n", "Epoch 8/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.6102 - mae: 0.6387 - val_loss: 5.7170 - val_mae: 2.0963\n", "Epoch 9/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.5932 - mae: 0.6296 - val_loss: 4.8597 - val_mae: 1.9124\n", "Epoch 10/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.5768 - mae: 0.6207 - val_loss: 4.1776 - val_mae: 1.7536\n", "Epoch 11/80\n", "1/1 [==============================] - 0s 15ms/step - loss: 0.5610 - mae: 0.6119 - val_loss: 3.6246 - val_mae: 1.6126\n", "Epoch 12/80\n", "1/1 [==============================] - 0s 21ms/step - loss: 0.5460 - mae: 0.6035 - val_loss: 3.1705 - val_mae: 1.4871\n", "Epoch 13/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.5315 - mae: 0.5958 - val_loss: 2.7933 - val_mae: 1.3828\n", "Epoch 14/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.5178 - mae: 0.5883 - val_loss: 2.4772 - val_mae: 1.2988\n", "Epoch 15/80\n", "1/1 [==============================] - 0s 13ms/step - loss: 0.5049 - mae: 0.5812 - val_loss: 2.2100 - val_mae: 1.2265\n", "Epoch 16/80\n", "1/1 [==============================] - 0s 17ms/step - loss: 0.4926 - mae: 0.5741 - val_loss: 1.9848 - val_mae: 1.1616\n", "Epoch 17/80\n", "1/1 [==============================] - 0s 17ms/step - loss: 0.4806 - mae: 0.5672 - val_loss: 1.7928 - val_mae: 1.1030\n", "Epoch 18/80\n", "1/1 [==============================] - 0s 13ms/step - loss: 0.4690 - mae: 0.5604 - val_loss: 1.6269 - val_mae: 1.0488\n", "Epoch 19/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.4580 - mae: 0.5537 - val_loss: 1.4828 - val_mae: 0.9987\n", "Epoch 20/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.4474 - mae: 0.5472 - val_loss: 1.3570 - val_mae: 0.9533\n", "Epoch 21/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.4373 - mae: 0.5408 - val_loss: 1.2468 - val_mae: 0.9213\n", "Epoch 22/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.4275 - mae: 0.5345 - val_loss: 1.1499 - val_mae: 0.8913\n", "Epoch 23/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.4181 - mae: 0.5283 - val_loss: 1.0644 - val_mae: 0.8632\n", "Epoch 24/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.4089 - mae: 0.5223 - val_loss: 0.9888 - val_mae: 0.8368\n", "Epoch 25/80\n", "1/1 [==============================] - 0s 14ms/step - loss: 0.4003 - mae: 0.5165 - val_loss: 0.9215 - val_mae: 0.8118\n", "Epoch 26/80\n", "1/1 [==============================] - 0s 15ms/step - loss: 0.3922 - mae: 0.5110 - val_loss: 0.8616 - val_mae: 0.7882\n", "Epoch 27/80\n", "1/1 [==============================] - 0s 13ms/step - loss: 0.3842 - mae: 0.5056 - val_loss: 0.8080 - val_mae: 0.7656\n", "Epoch 28/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.3764 - mae: 0.5002 - val_loss: 0.7599 - val_mae: 0.7440\n", "Epoch 29/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.3688 - mae: 0.4950 - val_loss: 0.7167 - val_mae: 0.7235\n", "Epoch 30/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.3617 - mae: 0.4902 - val_loss: 0.6778 - val_mae: 0.7039\n", "Epoch 31/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.3548 - mae: 0.4855 - val_loss: 0.6430 - val_mae: 0.6855\n", "Epoch 32/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.3481 - mae: 0.4809 - val_loss: 0.6119 - val_mae: 0.6681\n", "Epoch 33/80\n", "1/1 [==============================] - 0s 16ms/step - loss: 0.3417 - mae: 0.4764 - val_loss: 0.5838 - val_mae: 0.6514\n", "Epoch 34/80\n", "1/1 [==============================] - 0s 16ms/step - loss: 0.3356 - mae: 0.4720 - val_loss: 0.5582 - val_mae: 0.6352\n", "Epoch 35/80\n", "1/1 [==============================] - 0s 13ms/step - loss: 0.3296 - mae: 0.4676 - val_loss: 0.5358 - val_mae: 0.6206\n", "Epoch 36/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.3238 - mae: 0.4633 - val_loss: 0.5152 - val_mae: 0.6065\n", "Epoch 37/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.3182 - mae: 0.4591 - val_loss: 0.4963 - val_mae: 0.5930\n", "Epoch 38/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.3128 - mae: 0.4549 - val_loss: 0.4791 - val_mae: 0.5800\n", "Epoch 39/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.3075 - mae: 0.4507 - val_loss: 0.4632 - val_mae: 0.5674\n", "Epoch 40/80\n", "1/1 [==============================] - 0s 13ms/step - loss: 0.3024 - mae: 0.4466 - val_loss: 0.4486 - val_mae: 0.5564\n", "Epoch 41/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2974 - mae: 0.4426 - val_loss: 0.4351 - val_mae: 0.5476\n", "Epoch 42/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2926 - mae: 0.4386 - val_loss: 0.4227 - val_mae: 0.5391\n", "Epoch 43/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2879 - mae: 0.4347 - val_loss: 0.4112 - val_mae: 0.5340\n", "Epoch 44/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2836 - mae: 0.4309 - val_loss: 0.4006 - val_mae: 0.5291\n", "Epoch 45/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2795 - mae: 0.4273 - val_loss: 0.3908 - val_mae: 0.5242\n", "Epoch 46/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2755 - mae: 0.4237 - val_loss: 0.3819 - val_mae: 0.5197\n", "Epoch 47/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2715 - mae: 0.4203 - val_loss: 0.3736 - val_mae: 0.5151\n", "Epoch 48/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2677 - mae: 0.4171 - val_loss: 0.3658 - val_mae: 0.5107\n", "Epoch 49/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2640 - mae: 0.4139 - val_loss: 0.3585 - val_mae: 0.5062\n", "Epoch 50/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2604 - mae: 0.4108 - val_loss: 0.3517 - val_mae: 0.5019\n", "Epoch 51/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2567 - mae: 0.4076 - val_loss: 0.3453 - val_mae: 0.4976\n", "Epoch 52/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2531 - mae: 0.4044 - val_loss: 0.3393 - val_mae: 0.4934\n", "Epoch 53/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2495 - mae: 0.4012 - val_loss: 0.3334 - val_mae: 0.4891\n", "Epoch 54/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2461 - mae: 0.3981 - val_loss: 0.3274 - val_mae: 0.4844\n", "Epoch 55/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2428 - mae: 0.3950 - val_loss: 0.3218 - val_mae: 0.4799\n", "Epoch 56/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2395 - mae: 0.3919 - val_loss: 0.3166 - val_mae: 0.4757\n", "Epoch 57/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2363 - mae: 0.3888 - val_loss: 0.3116 - val_mae: 0.4717\n", "Epoch 58/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2331 - mae: 0.3858 - val_loss: 0.3070 - val_mae: 0.4677\n", "Epoch 59/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2301 - mae: 0.3828 - val_loss: 0.3025 - val_mae: 0.4637\n", "Epoch 60/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2272 - mae: 0.3799 - val_loss: 0.2981 - val_mae: 0.4594\n", "Epoch 61/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2245 - mae: 0.3770 - val_loss: 0.2943 - val_mae: 0.4554\n", "Epoch 62/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2219 - mae: 0.3743 - val_loss: 0.2917 - val_mae: 0.4521\n", "Epoch 63/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2194 - mae: 0.3716 - val_loss: 0.2900 - val_mae: 0.4491\n", "Epoch 64/80\n", "1/1 [==============================] - 0s 13ms/step - loss: 0.2170 - mae: 0.3690 - val_loss: 0.2884 - val_mae: 0.4462\n", "Epoch 65/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2147 - mae: 0.3664 - val_loss: 0.2867 - val_mae: 0.4432\n", "Epoch 66/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2125 - mae: 0.3639 - val_loss: 0.2854 - val_mae: 0.4412\n", "Epoch 67/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2103 - mae: 0.3614 - val_loss: 0.2842 - val_mae: 0.4397\n", "Epoch 68/80\n", "1/1 [==============================] - 0s 11ms/step - loss: 0.2081 - mae: 0.3589 - val_loss: 0.2831 - val_mae: 0.4382\n", "Epoch 69/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2060 - mae: 0.3566 - val_loss: 0.2820 - val_mae: 0.4367\n", "Epoch 70/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2041 - mae: 0.3543 - val_loss: 0.2811 - val_mae: 0.4352\n", "Epoch 71/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.2022 - mae: 0.3521 - val_loss: 0.2802 - val_mae: 0.4337\n", "Epoch 72/80\n", "1/1 [==============================] - 0s 21ms/step - loss: 0.2003 - mae: 0.3499 - val_loss: 0.2794 - val_mae: 0.4322\n", "Epoch 73/80\n", "1/1 [==============================] - 0s 27ms/step - loss: 0.1985 - mae: 0.3478 - val_loss: 0.2787 - val_mae: 0.4307\n", "Epoch 74/80\n", "1/1 [==============================] - 0s 17ms/step - loss: 0.1967 - mae: 0.3457 - val_loss: 0.2778 - val_mae: 0.4290\n", "Epoch 75/80\n", "1/1 [==============================] - 0s 13ms/step - loss: 0.1949 - mae: 0.3436 - val_loss: 0.2768 - val_mae: 0.4270\n", "Epoch 76/80\n", "1/1 [==============================] - 0s 14ms/step - loss: 0.1933 - mae: 0.3416 - val_loss: 0.2758 - val_mae: 0.4253\n", "Epoch 77/80\n", "1/1 [==============================] - 0s 13ms/step - loss: 0.1916 - mae: 0.3396 - val_loss: 0.2749 - val_mae: 0.4237\n", "Epoch 78/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.1900 - mae: 0.3377 - val_loss: 0.2740 - val_mae: 0.4220\n", "Epoch 79/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.1884 - mae: 0.3357 - val_loss: 0.2732 - val_mae: 0.4205\n", "Epoch 80/80\n", "1/1 [==============================] - 0s 12ms/step - loss: 0.1868 - mae: 0.3338 - val_loss: 0.2724 - val_mae: 0.4189\n" ] } ], "source": [ "from tensorflow import keras\n", "\n", "model = build_model()\n", "history = model.fit(train_data, train_targets,\n", " epochs=80, batch_size=41,validation_data=(test_data, test_targets))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is very promising. The validation mse at the end of training is very much in line with the mse we saw in the earlier model.\n", "\n", "We can also produce a plot of the validation loss over time." ] }, { "cell_type": "code", "execution_count": 116, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEWCAYAAABhffzLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAc70lEQVR4nO3de5gV9Z3n8fcHaFS6lYs0iI2IJniLM956HS/JbCKaUWIEN4+JTlQmcdY1yeYyMSrOPptMMpM8bjbr5jaTiSYqRuMlMV5WsypLvERjjOAVVEQNKoLQeOMqcvnuH7869qHtxgPddep01+f1PPVU1a/OqfOlgc+p/lXVrxQRmJlZeQwqugAzM6svB7+ZWck4+M3MSsbBb2ZWMg5+M7OScfCbmZWMg98GHEkh6f3Z8r9L+u+1vHY7PufTku7c3jq3st8PS1rc1/s1q3DwW8ORdIekb3XTPlXSK5KG1LqviDgnIv65D2qamH1JvPPZEXF1RHy0t/s2qzcHvzWiK4AzJKlL+xnA1RGxsf4lmQ0cDn5rRDcBo4APVRokjQROBK6UdLikByS9IWmppB9LGtrdjiRdIelfqtbPy96zRNJnu7z2Y5IekbRS0kuS/qlq873Z/A1JqyUdKenvJN1X9f6jJD0k6c1sflTVtrsl/bOk+yWtknSnpNG1/DAk7Z+9/w1J8yWdVLVtiqQns32+LOlrWftoSbdm73lN0u8l+f+7AQ5+a0ARsQ64HjizqvmTwNMR8RiwCfgHYDRwJDAZ+Px77VfS8cDXgOOAScCxXV6yJvvMEcDHgM9JmpZt++tsPiIiWiLigS77HgXcBvwQ2BW4GLhN0q5VL/tb4DPAGGBoVst71dwE/B/gzux9XwSulrRv9pKfA/8lInYGDgR+l7WfCywGWoGxwD8CHp/FAAe/Na6ZwCmSdsrWz8zaiIi5EfHHiNgYEYuAnwL/sYZ9fhK4PCLmRcQa4J+qN0bE3RHxRERsjojHgWtq3C+kL4qFEfGLrK5rgKeBj1e95vKIeKbqi+3gGvZ7BNACXBQRb0fE74BbgdOy7RuAAyTtEhGvR8TDVe3jgD0jYkNE/D48MJdlHPzWkCLiPqADmCppb+A/AL8EkLRP1o3xiqSVwHdIR//vZXfgpar1F6o3SvorSXdJ6pD0JnBOjfut7PuFLm0vAG1V669ULa8lBXpNNUfE5h72+wlgCvCCpHskHZm1/0/gWeBOSc9LmlHbH8PKwMFvjexK0pH+GcCdEbEsa/8J6Wh6UkTsQurG6HoiuDtLgT2q1id02f5L4BZgj4gYDvx71X7f62h5CbBnl7YJwMs11PVe+92jS//8O/uNiIciYiqpG+gm0m8SRMSqiDg3IvYm/dbxVUmTe1mLDRAOfmtkV5L64f8zWTdPZmdgJbBa0n7A52rc3/XA30k6QNIw4Btdtu8MvBYRb0k6nNQnX9EBbAb27mHfvwX2kfS3koZI+hRwAKlbpjceJJ17OF9Sk6QPk4L8WklDs3sJhkfEBtLPZBOApBMlvT+7MqrSvqmXtdgA4eC3hpX13/8BaCYdiVd8jRTKq4BLgetq3N//Bb5POgH6LJ0nQis+D3xL0irg62RHz9l71wLfBu7PrpQ5osu+XyVddXQu8CpwPnBiRKyopbat1Pw2cBJwArAC+DfgzIh4OnvJGcCirMvrHOD0rH0S8P+A1cADwL9FxN29qcUGDvl8j5lZufiI38ysZBz8ZmYl4+A3MysZB7+ZWcnUPMphkUaPHh0TJ04sugwzs35l7ty5KyKitWt7vwj+iRMnMmfOnKLLMDPrVyR1vZsccFePmVnpOPjNzEom164eSYtId1duAjZGRHs2fO11wERgEfDJiHg9zzrMzKxTPY74PxIRB0dEe7Y+A5gdEZOA2dm6mZnVSRFdPVPpHHBrJjCtgBrMzEor7+AP0njgcyWdnbWNjYilANl8THdvlHS2pDmS5nR0dORcpplZeeR9OefREbFE0hhglqSn3/MdmYi4BLgEoL293SPJmZn1kVyP+CNiSTZfDtwIHA4skzQOIJsvz+vzb70VLroor72bmfVPuQW/pGZJO1eWgY8C80jjqk/PXjYduDmvGu64A7773bz2bmbWP+XZ1TMWuDE9AIghwC8j4nZJDwHXSzoLeBE4Ja8Cmpth9eq89m5m1j/lFvwR8TxwUDftrwJ1efZnczNs2JCmpqZ6fKKZWeMb0HfuNjen+Zo1xdZhZtZIHPxmZiXj4DczKxkHv5lZyTj4zcxKZkAHf0tLmjv4zcw6Dejg9xG/mdm7OfjNzErGwW9mVjIOfjOzkilF8Hu8HjOzTgM6+AcPhh128BG/mVm1AR38kI76HfxmZp0c/GZmJePgNzMrGQe/mVnJOPjNzEpmwAd/S4uD38ys2oAPfh/xm5ltycFvZlYyDn4zs5Jx8JuZlUwpgn/9eti4sehKzMwaQymCH3zUb2ZW4eA3MysZB7+ZWck4+M3MSsbBb2ZWMgM++Fta0tzBb2aWDPjg9xG/mdmWHPxmZiWTe/BLGizpEUm3ZuujJM2StDCbj8zz8x38ZmZbqscR/5eBp6rWZwCzI2ISMDtbz42D38xsS7kGv6TxwMeAn1U1TwVmZsszgWl51uDgNzPbUt5H/N8Hzgc2V7WNjYilANl8THdvlHS2pDmS5nR0dGx3AUOGwNChsHr1du/CzGxAyS34JZ0ILI+Iudvz/oi4JCLaI6K9tbW1V7V4hE4zs05Dctz30cBJkqYAOwK7SLoKWCZpXEQslTQOWJ5jDYCD38ysWm5H/BFxYUSMj4iJwKnA7yLidOAWYHr2sunAzXnVUOHgNzPrVMR1/BcBx0laCByXrefKwW9m1inPrp53RMTdwN3Z8qvA5Hp8boWD38ys04C/cxfSeD0OfjOzpBTB7yN+M7NODn4zs5Jx8JuZlYyD38ysZEoT/G+9BZs2FV2JmVnxShP8AGvXFluHmVkjKFXwe6A2M7OSBb/7+c3MHPxmZqXj4DczK5lSBH9LS5o7+M3MShL8PuI3M+vk4DczKxkHv5lZyTj4zcxKxsFvZlYypQj+pqY0OfjNzEoS/OAROs3MKkoV/B6rx8ysZMHvI34zMwe/mVnplCb4W1oc/GZmUKLg9xG/mVni4DczKxkHv5lZyTj4zcxKxsFvZlYypQr+detg8+aiKzEzK1apgh9g7dpi6zAzK1rpgt/dPWZWdrkFv6QdJf1J0mOS5kv6ZtY+StIsSQuz+ci8aqhWCX6P12NmZZfnEf964JiIOAg4GDhe0hHADGB2REwCZmfrufMRv5lZklvwR1I5vm7KpgCmAjOz9pnAtLxqqObgNzNLcu3jlzRY0qPAcmBWRDwIjI2IpQDZfEyeNVS0tKS5g9/Myi7X4I+ITRFxMDAeOFzSgbW+V9LZkuZImtPR0dHrWnzEb2aW1OWqnoh4A7gbOB5YJmkcQDZf3sN7LomI9ohob21t7XUNDn4zsyTPq3paJY3IlncCjgWeBm4Bpmcvmw7cnFcN1Rz8ZmbJkBz3PQ6YKWkw6Qvm+oi4VdIDwPWSzgJeBE7JsYZ3OPjNzJLcgj8iHgcO6ab9VWByXp/bEwe/mVlSmjt3hw6FIUMc/GZmpQl+8AidZmbg4DczK53SBb/H6jGzsitd8PuI38zKrlTB39Li4DczK1XwDx8Or79edBVmZsUqVfC3tcHLLxddhZlZsWoKfknNkgZly/tIOklSU76l9b22NujogPXri67EzKw4tR7x3wvsKKmN9PCUzwBX5FVUXtra0nzp0mLrMDMrUq3Br4hYC/wn4EcRcTJwQH5l5WP8+DRfvLjYOszMilRz8Es6Evg0cFvWlucAb7moHPG7n9/MyqzW4P8KcCFwY0TMl7Q3cFduVeXEwW9mVuNRe0TcA9wDkJ3kXRERX8qzsDyMGAHDhrmrx8zKrdaren4paRdJzcCTwAJJ5+VbWt+TfEmnmVmtXT0HRMRKYBrwW2ACcEZeReVp/HgHv5mVW63B35Rdtz8NuDkiNgCRW1U5amtzV4+ZlVutwf9TYBHQDNwraU9gZV5F5amtDZYsgc2bi67EzKwYNQV/RPwwItoiYkokLwAfybm2XIwfDxs2wIoVRVdiZlaMWk/uDpd0saQ52fS/SEf//U7lkk5395hZWdXa1XMZsAr4ZDatBC7Pq6g8+Vp+Myu7Wu++fV9EfKJq/ZuSHs2hntxVhm1w8JtZWdV6xL9O0gcrK5KOBtblU1K+xo6FwYPd1WNm5VXrEf85wJWShmfrrwPT8ykpX4MHw267+YjfzMqr1iEbHgMOkrRLtr5S0leAx3OsLTe+icvMymybnsAVESuzO3gBvppDPXXhm7jMrMx68+hF9VkVdebxesyszHoT/P1yyAZIXT0rV8KqVUVXYmZWf1vt45e0iu4DXsBOuVRUB9XX8u+3X7G1mJnV21aDPyJ2rlch9eTgN7My601XT7/lm7jMrMxKGfwer8fMyiy34Je0h6S7JD0lab6kL2ftoyTNkrQwm4/Mq4ae7LQTjBzpI34zK6c8j/g3AudGxP7AEcAXJB0AzABmR8QkYHa2Xne+icvMyiq34I+IpRHxcLa8CngKaAOmAjOzl80kPdWr7nwTl5mVVV36+CVNBA4BHgTGRsRSSF8OwJge3nN2Zfz/jo6OPq/JN3GZWVnlHvySWoAbgK9UDffwniLikohoj4j21tbWPq9r/HhYtiw9jcvMrExyDf7sAe03AFdHxG+y5mWSxmXbxwHL86yhJ21tEAFLlxbx6WZmxcnzqh4BPweeioiLqzbdQueQztOBm/OqYWv8JC4zK6tax+PfHkcDZwBPVD2t6x+Bi4DrJZ0FvAickmMNPfJNXGZWVrkFf0TcR88jeE7O63Nr5Zu4zKysSnnnLsCoUbDDDj7iN7PyKW3wSzBxIjz3XNGVmJnVV2mDH+Dgg+Hhh4uuwsysvkod/IcdBi+8AK++WnQlZmb1U+rgP/TQNH/kkWLrMDOrp1IH/yGHpPncucXWYWZWT6UO/lGjYK+93M9vZuVS6uCH1N3j4DezMnHwHwrPPgtvvll0JWZm9VH64D/ssDT3CV4zK4vSB3/lBK+7e8ysLEof/GPGpAHbfGWPmZVF6YMfUnePj/jNrCwc/KQTvAsWwOrVRVdiZpY/Bz8p+CPg0UeLrsTMLH8Ofjqv7HF3j5mVgYMfGDcOdtvNwW9m5eDgzxx2mK/sMbNycPBnDj0UnnwS1q4tuhIzs3w5+DOHHgqbN8MTTxRdiZlZvhz8mcoJ3oceKrYOM7O8Ofgz48enIZpvv73oSszM8uXgz0gwbRrMmgWrVhVdjZlZfhz8VU4+Gd5+20f9ZjawOfirHHUUtLbCjTcWXYmZWX4c/FUGD4aTToLbbktH/mZmA5GDv4tp02DlSrjrrqIrMTPLh4O/i2OPheZmd/eY2cDl4O9ixx3hhBPg5pvTDV1mZgONg78bJ58Mr7wCDz5YdCVmZn3Pwd+NKVNgyBC46aaiKzEz63u5Bb+kyyQtlzSvqm2UpFmSFmbzkXl9fm+MGAHHHJP6+SOKrsbMrG/lecR/BXB8l7YZwOyImATMztYb0rRpsHBhGrHTzGwgyS34I+Je4LUuzVOBmdnyTGBaXp/fWyefDE1N8JOfFF2JmVnfqncf/9iIWAqQzcfU+fNrtttucPrpcNll0NFRdDVmZn2nYU/uSjpb0hxJczoKSt7zzoN16+BHPyrk483MclHv4F8maRxANl/e0wsj4pKIaI+I9tbW1roVWG3//WHqVPjxj2H16kJKMDPrc/UO/luA6dnydODmOn/+NrvgAnj9dfjZz4quxMysb+R5Oec1wAPAvpIWSzoLuAg4TtJC4LhsvaEdeSR86ENw8cWwYUPR1ZiZ9d6QvHYcEaf1sGlyXp+ZlwsugBNPhGuugTPPLLoaM7PeadiTu41kyhQ48ED47nc9fo+Z9X8O/hpIMGMGzJ8PM2e+9+vNzBqZg79Gp52W+vq/+tU0gJuZWX/l4K/RoEFw6aXpuv4vfrHoaszMtp+Dfxvsuy98/evw61975E4z678c/NvovPPgoIPg85+HN94ouhozs23n4N9GTU3pZq5ly+D884uuxsxs2zn4t0N7ezrJe+mlcOWVRVdjZrZtHPzb6dvfTg9r+fu/h3vuKboaM7PaOfi309ChcMMN8P73p7H7FywouiIzs9o4+HthxAi47bb0fN4pUzxuv5n1Dw7+XtprL7jlFliyBD7+cXit6zPHzMwajIO/DxxxBFx7LTzySLq796WXiq7IzKxnDv4+MnUq3HEHLF6chnKeP7/oiszMuufg70Mf/jDce28awfODH/TVPmbWmBz8feygg+APf4CxY2HyZPjWt2DjxqKrMjPr5ODPwcSJ8OCDcOqp8I1vpN8EXnih6KrMzBIHf06GD4erroJf/AIefzz9JnD55X6Qi5kVz8Gfs9NPh0cfTU/w+uxn0xVAf/xj0VWZWZk5+Otg773TSd8rr+y86ufMM939Y2bFcPDXyaBBcMYZaWiHCy+E665Lwz185jPw9NNFV2dmZeLgr7Odd4bvfAeefTaN6X/ddXDAAXDKKenyz4iiKzSzgc7BX5A99oAf/AAWLUq/Acyala7+2X9/+N73PO6PmeXHwV+wMWPSEM9LlsAVV8Cuu6anfO2+O5xwAlx2mcf/MbO+5eBvEMOGwfTpcP/9MG9eetDLggVw1lnpZrC/+Rv4/vfT+QB3B5lZbyj6QYq0t7fHnDlzii6j7iLg4YfhV79KD3evjPm/555w7LFpQLgPfSiNECoVWqqZNSBJcyOi/V3tDv7+Y9GiNBDc7bfD3Xd3Pux9993TJaLt7XDYYWkaNarAQs2sITj4B5jNm9MIoPfdB7//PfzpT/Dcc53bJ0xIN439xV+k+f77wz77pKuKzKwcHPwl8PrrqWtozpw0TMQTT6RzAhs2dL5m993TF8D73pe6iPbeO80nTIDddkv3G5jZwNBT8A8pohjLx8iRaUTQyZM72zZsgGeeSecHKtMzz6RHRr7yypbvb2qCtrZ0qWlbW/qS2H339IUwdmznfNdd/QVh1p85+Ae4pib4wAfS1NWaNem8wZ//nJ4a9uKLaXrppfRbw8svw7p1737foEHpHMLo0dDampZ33TXNR41KX0AjRnROu+ySpuHDobnZXxpmRXPwl1hzc89fCpCuKlq5EpYuhWXLOqeODlixIs07OuD559MXxWuvdf9FUU2ClpZ0rqEyNTentpaWtDxsWJoqyzvtlKZhw2DHHdPyjjt2Tjvs0Dmvnob4X7dZtwr5ryHpeOAHwGDgZxFxURF12NZJ6Sh9+HDYb7/a3rNuHbz5Zrri6I030nmHlSs7pzffhFWr0rR6dZqvWZO6nVavTstr16b5+vW9q3/QIBg6dMupqend867TkCFbLvc0DR7c87x6qrQNGvTubZW26vnWlqsnqbb2ynp1e3dtXdu7vqbrNl9C3H/VPfglDQb+FTgOWAw8JOmWiHiy3rVY36scne+2W+/3tWlT+iJZty59GaxdC2+9laZK+/r1aXrrrc7l6mnDBnj77bT89tud65Xl6mnNmvS0tI0b03rX+aZNW27ftKmzrayqvwS6fim8V3vXqbK/7vZfvW1rr+1ufVtf39227v7c27u+re/96U/T/Tp9qYgj/sOBZyPieQBJ1wJTAQe/bWHw4M4uoEYWkS6vrf4iqCxX1qu3V6bqts2bO6fq9cr2ymdU2qrXN2/esoaInterX9t1vrXtXd/fdVt367Vuq1xY2PUCw+62be213a1v6+u729bd3/f2rm/reyGfS7CLCP424KWq9cXAX3V9kaSzgbMBJkyYUJ/KzLaD1Nl1Y9YfFHF9RXe/RL3rey4iLomI9ohob21trUNZZmblUETwLwb2qFofDywpoA4zs1IqIvgfAiZJ2kvSUOBU4JYC6jAzK6W69/FHxEZJ/xW4g3Q552URMb/edZiZlVUh1/FHxG+B3xbx2WZmZeeb583MSsbBb2ZWMg5+M7OS6Rfj8UvqAF6o8eWjgRU5ltMbjVpbo9YFjVtbo9YFjVtbo9YFA7e2PSPiXTdC9Yvg3xaS5nT34IFG0Ki1NWpd0Li1NWpd0Li1NWpdUL7a3NVjZlYyDn4zs5IZiMF/SdEFbEWj1taodUHj1taodUHj1taodUHJahtwffxmZrZ1A/GI38zMtsLBb2ZWMgMq+CUdL2mBpGclzSi4lsskLZc0r6ptlKRZkhZm85EF1LWHpLskPSVpvqQvN0JtknaU9CdJj2V1fbMR6qqqb7CkRyTd2mB1LZL0hKRHJc1psNpGSPq1pKezf29HFl2bpH2zn1VlWinpK0XXVVXfP2T//udJuib7f9HntQ2Y4K96lu8JwAHAaZIOKLCkK4Dju7TNAGZHxCRgdrZebxuBcyNif+AI4AvZz6no2tYDx0TEQcDBwPGSjmiAuiq+DDxVtd4odQF8JCIOrrrWu1Fq+wFwe0TsBxxE+vkVWltELMh+VgcDhwFrgRuLrgtAUhvwJaA9Ig4kjV58ai61RcSAmIAjgTuq1i8ELiy4ponAvKr1BcC4bHkcsKABfm43kx583zC1AcOAh0mP5Cy8LtLDgmYDxwC3NtLfJbAIGN2lrfDagF2AP5NdQNJItVXV8lHg/kapi87H0o4ijZx8a1Zjn9c2YI746f5Zvm0F1dKTsRGxFCCbjymyGEkTgUOAB2mA2rLulEeB5cCsiGiIuoDvA+cDm6vaGqEuSI8tvVPS3Ow51Y1S295AB3B51kX2M0nNDVJbxanANdly4XVFxMvA94AXgaXAmxFxZx61DaTgr+lZvpZIagFuAL4SESuLrgcgIjZF+hV8PHC4pAMLLglJJwLLI2Ju0bX04OiIOJTUxfkFSX9ddEGZIcChwE8i4hBgDcV2h20he/rfScCviq6lIuu7nwrsBewONEs6PY/PGkjB3x+e5btM0jiAbL68iCIkNZFC/+qI+E0j1QYQEW8Ad5POkRRd19HASZIWAdcCx0i6qgHqAiAilmTz5aS+6sMbpLbFwOLstzaAX5O+CBqhNkhflA9HxLJsvRHqOhb4c0R0RMQG4DfAUXnUNpCCvz88y/cWYHq2PJ3Uv15XkgT8HHgqIi5ulNoktUoakS3vRPpP8HTRdUXEhRExPiImkv5N/S4iTi+6LgBJzZJ2riyT+oPnNUJtEfEK8JKkfbOmycCTjVBb5jQ6u3mgMep6EThC0rDs/+lk0gnxvq+tqBMrOZ0cmQI8AzwH/LeCa7mG1E+3gXT0cxawK+kk4cJsPqqAuj5I6gJ7HHg0m6YUXRvwl8AjWV3zgK9n7YX/zKpq/DCdJ3cLr4vUj/5YNs2v/JtvhNqyOg4G5mR/pzcBIxuhNtLFA68Cw6vaCq8rq+ObpAOeecAvgB3yqM1DNpiZlcxA6uoxM7MaOPjNzErGwW9mVjIOfjOzknHwm5mVjIPfSk3Spi6jNfbZ3aWSJqpqdFazRjGk6ALMCrYu0jARZqXhI36zbmTj3P+P7BkBf5L0/qx9T0mzJT2ezSdk7WMl3aj0PIHHJB2V7WqwpEuzMdbvzO5KRtKXJD2Z7efagv6YVlIOfiu7nbp09XyqatvKiDgc+DFphE6y5Ssj4i+Bq4EfZu0/BO6J9DyBQ0l30gJMAv41Ij4AvAF8ImufARyS7eecfP5oZt3znbtWapJWR0RLN+2LSA+GeT4b1O6ViNhV0grS2OgbsvalETFaUgcwPiLWV+1jIml46UnZ+gVAU0T8i6TbgdWkoQxuiojVOf9Rzd7hI36znkUPyz29pjvrq5Y30Xle7WOkJ8YdBsyV5PNtVjcOfrOefapq/kC2/AfSKJ0Anwbuy5ZnA5+Ddx4os0tPO5U0CNgjIu4iPeBlBPCu3zrM8uKjDCu7nbKnflXcHhGVSzp3kPQg6QDptKztS8Blks4jPWHqM1n7l4FLJJ1FOrL/HGl01u4MBq6SNJz0AKH/HekZBGZ14T5+s25kffztEbGi6FrM+pq7eszMSsZH/GZmJeMjfjOzknHwm5mVjIPfzKxkHPxmZiXj4DczK5n/D35uagZsV2iuAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "val_loss = history.history[\"val_loss\"]\n", "epochs = range(1, len(val_loss) + 1)\n", "plt.title(\"Validation loss\")\n", "plt.plot(epochs, val_loss, \"b\", label=\"Validation loss\")\n", "plt.xlabel(\"Epochs\")\n", "plt.ylabel(\"Loss\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also look at the differences between the test targets and the predictions from the model." ] }, { "cell_type": "code", "execution_count": 117, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 117, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions = model(test_data)\n", "predictions - test_targets.reshape(11,1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These look pretty reasonable." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One final comment - since we normalized the data for this example and we did not normalize the inputs to the regression model above, we have to scale the mse for the neural network model by the standard deviation of the original outputs:" ] }, { "cell_type": "code", "execution_count": 118, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.41441898152670054" ] }, "execution_count": 118, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.sqrt(history.history[\"val_loss\"][-1]*df['speed'].std())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is comparable to the rms error of the earlier model, which was about 0.4 mph." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }