Custom training loops with Pytorch
machine learning mathematics neural networks pytorch statisticsContents
Introduction
In a previous post, we saw a couple of examples on how to construct a linear regression model, define a custom loss function, have Tensorflow automatically compute the gradients of the loss function with respect to the trainable parameters, and then update the model’s parameters. We will do the same in this post, but we will use PyTorch this time. It’s been a while since I wanted to switch from Tensorflow to Pytorch, and what better way than start from the basics?
Fit quadratic regression model to data by minimizing MSE
Generate training data
First, we will generate some data coming from a quadratic model, i.e., \(y = a x^2 + b x + c\), and we will add some noise to make the setup look a bit more realistic, as in the real world.
Define a model with trainable parameters
In this step, we are defining a model, specifically the \(y = f(x) = a x^2 + b x + c\). Given the model’s parameters, \(a, b, c\), and an input \(x\), \(x\) being a tensor, we will calculate the output tensor \(y_\text{pred}\):
Define a custom loss function
Here we define a custom loss function that calculates the mean squared error between the model’s predictions and the actual target values in the dataset.
We then assign some initial random values to the parameters \(a, b, c\), and also tell PyTorch that we want it to compute the gradients for this tensor (the parameters tensor).
Here is a helper function that draws the predictions and actual targets in the same plot. Before training the model, we expect a considerable discordance between these two.
Define a custom training loop
This is the heart of our setup. Given the old values for the model’s parameters, we construct a function that calculates its predictions, how much they deviate from the actual targets, and modifies the parameters via gradient descent.
Run the custom training loop
We repeatedly apply the previous step until the training process converges to a particular combination of \(a, b, c\).
Final results
Finally, we superimpose the dataset with the best quadratic regression model PyTorch converged to: