Skip to content

Linear Regression

Linear models

  • A Simple Linear Regression Example (source)

This example uses the only the first feature of the diabetes dataset, in order to illustrate a two-dimensional plot of this regression technique. The straight line can be seen in the plot, showing how linear regression attempts to draw a straight line that will best minimize the residual sum of squares between the observed responses in the dataset, and the responses predicted by the linear approximation.

The coefficients, the residual sum of squares and the variance score are also calculated.

import numpy as np
from sklearn import datasets, linear_model
  • Import Diabetes Dataset
diabetes = datasets.load_diabetes()

'''Use only one feature'''
diabetes_X = diabetes.data[:, np.newaxis, 2]
  • Split Dataset to Train and Test
'''Split the data into training/testing sets'''
diabetes_X_train = diabetes_X[:-20]
diabetes_X_test = diabetes_X[-20:]

'''Split the targets into training/testing sets'''
diabetes_y_train = diabetes.target[:-20]
diabetes_y_test = diabetes.target[-20:]
  • Fit the Model
'''Create linear regression object'''
regr = linear_model.LinearRegression()

'''Train the model using the training sets'''
regr.fit(diabetes_X_train, diabetes_y_train)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)
  • Calculate Regression cofficient, Mean Squared Error and Variance score.
'''The coefficients'''
print('Coefficients: \n', regr.coef_)


'''The mean squared error'''
print("Mean squared error: %.2f"
      % np.mean((regr.predict(diabetes_X_test) - diabetes_y_test) ** 2))


'''Explained variance score: 1 is perfect prediction'''
print('Variance score: %.2f' % regr.score(diabetes_X_test, diabetes_y_test))
Coefficients: 
 [938.23786125]
Mean squared error: 2548.07
Variance score: 0.47
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
'''Plot outputs'''
plt.figure(figsize = [12,8])

plt.scatter(diabetes_X_test,\
            diabetes_y_test,\
            color='black')

plt.plot(diabetes_X_test,\
         regr.predict(diabetes_X_test),\
         color='blue',
         linewidth=3)

plt.grid(True)
plt.show()

png