Linear Regression
Last updated
Last updated
Simply put, linear regression is building a model that to a line that fits data samples with the least loss values.
To do so, a model should figure out a proper relation, if there exists, between independent (x) and dependent values (y). This relation could be proportional or not or no relation at all.
As with other machine learning examples, it is impossible to predict something with no errors, so our goal is to build a model that produces the least possible loss values which is done by computing the difference between actual and predicted values.
Let's say we would like to know the relation between the height and weight of a person.
We can already tell that the taller the person is, the heavier the weight gets. Let's find out if this is true.
The data we are going to use is from Kaggle's weight-height uploaded by Mustafa Ali.
Gender
Height
Weight
0
Male
73.847017
241.893563
1
Male
68.781904
162.310473
Gender
Height
Weight
5000
Female
58.910732
102.088326
5001
Female
65.230013
141.305823
We have 10,000 data samples and gener, height and weight features.
Height
Weight
count
10000.000000
10000.000000
mean
66.367560
161.440357
std
3.847528
32.108439
min
54.263133
64.700127
25%
63.505620
135.818051
50%
66.318070
161.212928
75%
69.174262
187.169525
max
78.998742
269.989699
It seems that our assumption is right. The weight increases as the height does. Also by the looks of it, we could just ignore gender and treat the samples as one bigger group since one line could still fit pretty decently.
If we zoom out and view the height and weight samples (of male and female), it looks like this.
So surely, we cannot fit a line that goes through the origin to the samples.
As mentioned in Gradient Descent post, we first have to choose which loss function we are going to use and define partial derivatives.
Let's reuse the codes from the post and try running gradient descent.
We see that the loss goes to infinity and becomes nan. Usually this happens when x and y values are not small and the sum of losses gets huge.
One thing is normalization while the other is standardization.
Let's use both and compare.
We see that standardization converged faster than normalization. As shown, the speed of convergence depends on which scaling method we choose to use. However, it does not mean that we can use anything we want. There are some cases (or models) that prefer normalization over standardization and vice versa.
One example is when we work with SVM model. In this case, standardization will be better to maximize the margin between two classes. More details will be in another post.
Since we standardized samples, we have to do the same when we predict other samples.
Linear regression we used is Ordinary Least Squares but there are other linear regression as well, such as 1. Weighted Least Squares 2. Generalized Least Squares 3. Ridge Regression 4. Lasso Regression 5. Elastic Net Regression
There are also other forms not mentioned here. The last three regressions are regularized regression which will be covered in a separate post.
Also it is also possible to have linear regression whose line is actually not a line!
For example, let's say we have the following samples.
If we use the model used above, we will have a line just like this.
Since we have two different weights, the derivatives are different as well. The equations are
Although it requires us to know which model is used to generate samples, it is possible to fit a line to nonlinear data.
This post only deals with the basic linear regression without any regularization such as Lasso, Ridge or Elastic Net. There are many versions of it besides Ordinary Least Squares. These topics will be covered in later posts.
You can find the full code here.
Thank you all for reading and if you find any errors or typos or have any suggestions, please let me know.
One naive loss function can be where is the predicted values and is the number of samples. In this post, we will use Mean Squared Error function .
An equation used to generate plots is where is our new weights.