How to Diagnose Overfitting and Underfitting of LSTM Models
It can be difficult to determine whether your Long Short-Term Memory model is performing well on your sequence prediction problem.
You may be getting a good model skill score, but it is important to know whether your model is a good fit for your data or if it is underfit or overfit and could do better with a different configuration.
In this tutorial, you will discover how you can diagnose the fit of your LSTM model on your sequence prediction problem.
After completing this tutorial, you will know:
- How to gather and plot training history of LSTM models.
- How to diagnose an underfit, good fit, and overfit model.
- How to develop more robust diagnostics by averaging multiple model runs.
Let’s get started.
Tutorial Overview
This tutorial is divided into 6 parts; they are:
- Training History in Keras
- Diagnostic Plots
- Underfit Example
- Good Fit Example
- Overfit Example
- Multiple Runs Example
1. Training History in Keras
You can learn a lot about the behavior of your model by reviewing its performance over time.
LSTM models are trained by calling the fit() function. This function returns a variable called history that contains a trace of the loss and any other metrics specified during the compilation of the model. These scores are recorded at the end of each epoch.
12 | ...history=model.fit(...) |
For example, if your model was compiled to optimize the log loss (binary_crossentropy) and measure accuracy each epoch, then the log loss and accuracy will be calculated and recorded in the history trace for each training epoch.
Each score is accessed by a key in the history object returned from calling fit(). By default, the loss optimized when fitting the model is called “loss” and accuracy is called “acc“.
12345 | ...model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])history=model.fit(X,Y,epochs=100)print(history.history['loss'])print(history.history['acc']) |
Keras also allows you to specify a separate validation dataset while fitting your model that can also be evaluated using the same loss and metrics.
This can be done by setting the validation_split argument on fit() to use a portion of the training data as a validation dataset.
12 | ...history=model.fit(X,Y,epochs=100,validation_split=0.33) |
This can also be done by setting the validation_data argument and passing a tuple of X and y datasets.
12 | ...history=model.fit(X,Y,epochs=100,validation_data=(valX,valY)) |
The metrics evaluated on the validation dataset are keyed using the same names, with a “val_” prefix.
1234567 | ...model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])history=model.fit(X,Y,epochs=100,validation_split=0.33)print(history.history['loss'])print(history.history['acc'])print(history.history['val_loss'])print(history.history['val_acc']) |
2. Diagnostic Plots
The training history of your LSTM models can be used to diagnose the behavior of your model.
You can plot the performance of your model using the Matplotlib library. For example, you can plot training loss vs test loss as follows:
12345678910 | from matplotlib import pyplot...history=model.fit(X,Y,epochs=100,validation_data=(valX,valY))pyplot.plot(history.history['loss'])pyplot.plot(history.history['val_loss'])pyplot.title('model train vs validation loss')pyplot.ylabel('loss')pyplot.xlabel('epoch')pyplot.legend(['train','validation'],loc='upper right')pyplot.show() |
Creating and reviewing these plots can help to inform you about possible new configurations to try in order to get better performance from your model.
Next, we will look at some examples. We will consider model skill on the train and validation sets in terms of loss that is minimized. You can use any metric that is meaningful on your problem.
3. Underfit Example
An underfit model is one that is demonstrated to perform well on the training dataset and poor on the test dataset.
This can be diagnosed from a plot where the training loss is lower than the validation loss, and the validation loss has a trend that suggests further improvements are possible.
A small contrived example of an underfit LSTM model is provided below.
12345678910111213141516171819202122232425262728293031323334353637383940 | from keras.models import Sequentialfrom keras.layers import Densefrom keras.layers import LSTMfrom matplotlib import pyplotfrom numpy import array# return training datadef get_train():seq=[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]]seq=array(seq)X,y=seq[:,0],seq[:,1]X=X.reshape((len(X),1,1))returnX,y# return validation datadef get_val():seq=[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]]seq=array(seq)X,y=seq[:,0],seq[:,1]X=X.reshape((len(X),1,1))returnX,y# define modelmodel=Sequential()model.add(LSTM(10,input_shape=(1,1)))model.add(Dense(1,activation='linear'))# compile modelmodel.compile(loss='mse',optimizer='adam')# fit modelX,y=get_train()valX,valY=get_val()history=model.fit(X,y,epochs=100,validation_data=(valX,valY),shuffle=False)# plot train and validation losspyplot.plot(history.history['loss'])pyplot.plot(history.history['val_loss'])pyplot.title('model train vs validation loss')pyplot.ylabel('loss')pyplot.xlabel('epoch')pyplot.legend(['train','validation'],loc='upper right')pyplot.show() |
Running this example produces a plot of train and validation loss showing the characteristic of an underfit model. In this case, performance may be improved by increasing the number of training epochs.
In this case, performance may be improved by increasing the number of training epochs.
Alternately, a model may be underfit if performance on the training set is better than the validation set and performance has leveled off. Below is an example of an
Below is an example of an an underfit model with insufficient memory cells.
12345678910111213141516171819202122232425262728293031323334353637383940 | from keras.models import Sequentialfrom keras.layers import Densefrom keras.layers import LSTMfrom matplotlib import pyplotfrom numpy import array# return training datadef get_train():seq=[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]]seq=array(seq)X,y=seq[:,0],seq[:,1]X=X.reshape((5,1,1))returnX,y# return validation datadef get_val():seq=[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]]seq=array(seq)X,y=seq[:,0],seq[:,1]X=X.reshape((len(X),1,1))returnX,y# define modelmodel=Sequential()model.add(LSTM(1,input_shape=(1,1)))model.add(Dense(1,activation='linear'))# compile modelmodel.compile(loss='mae',optimizer='sgd')# fit modelX,y=get_train()valX,valY=get_val()history=model.fit(X,y,epochs=300,validation_data=(valX,valY),
相關推薦How to Diagnose Overfitting and Underfitting of LSTM ModelsTweet Share Share Google Plus It can be difficult to determine whether your Long Short-Term Memo How to Engineer Your Way Out of Slow ModelsSo you just finished designing that great neural network architecture of yours. It has a blazing number of 300 fully connected layers interleaved How to start/stop DB instance of Oracle under Linuxsid dbca tracking onf status account note notes all All below actions should be executed with "oracle" user account 1. Check the stat How to solve multi-version conflict of OpenCV or PCL on ROS kinetic?Solve multi-version conflict prepare: make sure you know which version is in your machine: dpk-config --modversion opencv Note: If it don't work, try SuiteScript Tutorial - How to use it and why use it?What you will learn: What SuiteScript is? How to create a Script record in NetSuite? How to write and upload a JavaScript file? How to Best (and Free!!) Resources to Understand Nuts and Bolts of Deep LearningThe internet is filled with tutorials to get started with Deep Learning. You can choose to get started with the superb Stanford courses CS221&nbs How to read version (and other) information from Android and iOS apps using JavaHow to read version (and other) information from Android and iOS apps using Java https://medium.com/@mart.schneider/how-to-read-version-and-oth go : How to get the reflect.Type of an interface?4 In order to determine whether a given type implements an interface using the reflect package, you need to pass a reflect.Type How to get the IP address of a Linux system之前在 Windows/Mac OS 取得 ip address 透過 import socket print socket.gethostbyname(socket.gethostname()) 都沒問題。但在 Linux 裡出問題了。 print socket.gethostbyname_ex(s How to manually BEGIN and END transactions?程式愈寫愈複雜,怕資料不一致,所以 connection 的 isolation_level 設到 None = auto commit mode. 雖然,沒有下 commit() 不會寫到 database 裡,但由於為了效能,我偷偷的把 connection 放在記憶體裡重覆使用,connection How to use *args and **kwargs in Python這篇文章寫的滿好的耶,結論: 1星= array, 2星=dictionary. 1星範例: def test_var_args(farg, *args): print "formal arg:", farg for arg in args: print "an How to Use Homebrew Zsh Instead of Mac OS X DefaultOut of the box Mac OS X version 10.8.x (Lion) comes with zsh version 4.3.11 (i386-apple-darwin12.0). However zsh is currently at versi How to safely charge and store lithium drone batteriesThis post was done in partnership with Wirecutter. When readers choose to buy Wirecutter's independently chosen editorial picks, Wirecutter and Engadget ma How To Leverage AI As Part of Your Mobile Testing EffortsThe mobile application development has rapidly grown in recent years. The practices like Mobile DevOps and CI/CD set up the infrastructure to speed up app Ask HN: How to break out the loop of being an employee to your own business?I'm doing exactly that right now. Hired for a decent salary - but quiting.I don't know what ill be doing in 6 months. I don't have a financial buffer. This Privacy By Design: How To Sell Privacy And Make ChangeJoe Toscano is an award-winning designer and former consultant for Google who left in 2017 due to ethical concerns. Upgrade your inbox and get our editors' How to speak up and impact conversations as a junior designerHow to speak up and impact conversations as a junior designerA large part of my week is spent in meeting rooms, design critiques, and spontaneous discussio How to Be Lazy and Stay CalmWhat frustrates me most in my profession of software development is the regular necessity to understand large problem scopes before fixing small bugs, espe Command Magicks: How to Manipulate Files and Strings with the ConsoleCommand Magicks: How to Manipulate Files and Strings with the ConsoleProgramming will make you be amazed by the Cosmos. Source: Pixabay.As developers, ther How to deploy Kubernetes and Containerum on Digital OceanHow to deploy Kubernetes and Containerum on Digital Oceanby Nikita MazurForewordSeveral days ago we decided that Containerum Online — a hosted PaaS for lau |