sns畫法
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
income = pd.read_csv('Salary_Data.csv')
income.dtypes
sns.lmplot(x= 'YearsExperience',y='Salary', data = income, ci=None)
image.png
plt畫法(+statsmodels)
#繪制散點圖
plt.scatter(x=income.YearsExperience,y=income.Salary,color="blue")
#導(dǎo)入統(tǒng)計建模模塊
import statsmodels.api as sm
#構(gòu)建一元線性回歸模型
fit=sm.formula.ols("Salary~YearsExperience",data=income).fit()
#預(yù)測
pred=fit.predict(exog=income.YearsExperience)
#繪制回歸線
plt.plot(income.YearsExperience,pred,color="coral",linewidth=1)
plt.show()
image.png
plt畫法(+sklearn)
from sklearn.linear_model import LinearRegression
import numpy as np
X = np.array(income.YearsExperience).reshape(-1,1)
y = income.Salary
# print(X.shape)
# print(y.shape)
lr = LinearRegression()
lr.fit(X,y)
predict = lr.predict(X)
plt.scatter(X, y, c='b', s=60)
plt.plot(X,predict,color="coral",linewidth=1)
image.png