一維插值
應(yīng)用場景
樣本點是必須滿足的(關(guān)鍵點不允許偏移)
經(jīng)濟學(xué)陌僵、氣象,已經(jīng)有一些數(shù)據(jù)嗅骄,不想用微分方程机杜,在論文里放圖
線性插值與樣條插值
import numpy as np
import pylab as pl
from scipy import interpolate
import matplotlib.pyplot as plt
x = np.linspace(0, 2*np.pi + np.pi/4, 10)
y = np.sin(x)
x_new = np.linspace(0, 2*np.pi + np.pi/4, 100)
f_linear = interpolate.interp1d(x, y) #一維插值
tck = interpolate.splrep(x, y) #形成樣條關(guān)系式
y_bspline = interpolate.splev(x_new, tck) #B樣條
#可視化
plt.xlabel(u'ampere/A')
plt.ylabel(u'volt/V')
plt.plot(x, y, "o", label = u"The original data")
plt.plot(x_new, f_linear(x_new), label = u"Linear interpolation")
plt.plot(x_new, y_bspline, label = u"B-spline interpolation")
pl.legend()
pl.show()
高階樣條插值
#創(chuàng)建數(shù)據(jù)點集
import numpy as np
x = np.linspace(0, 10, 11)
y = np.sin(x)
#繪制數(shù)據(jù)點集
import pylab as pl
pl.figure(figsize = (12,9))
pl.plot(x, y, 'ro')
#根據(jù)kind創(chuàng)建interpld對象f、計算插值結(jié)果
xnew = np.linspace(0, 10, 101)
from scipy import interpolate
for kind in ['nearest', 'zero', 'linear', 'quadratic', 5]:
f = interpolate.interp1d(x, y, kind = kind)
ynew = f(xnew)
pl.plot(xnew, ynew, label = str(kind))
pl.xticks(fontsize = 20)
pl.yticks(fontsize = 20)
pl.legend(loc = 'lower right')
pl.show()
5階樣條更加接近正弦曲線其弊,但x范圍更大以后5階的龍格現(xiàn)象也越明顯
二維插值
方法與一維數(shù)據(jù)插值類似癞己,為二維樣條插值
import numpy as np
from scipy import interpolate
import pylab as pl
import matplotlib as mpl
def func(x, y):
return (x+y)*np.exp(-5.0*(x**2+y**2))
#X-Y軸分為15*15的網(wǎng)格
y, x = np.mgrid[-1:1:15j, -1:1:15j]
#計算每個網(wǎng)格點上函數(shù)值
fvals = func(x, y)
#三次樣條二維插值
newfunc = interpolate.interp2d(x, y, fvals, kind = 'cubic')
#計算100*100網(wǎng)格上插值
xnew = np.linspace(-1, 1, 100)
ynew = np.linspace(-1, 1, 100)
np.meshgrid(xnew, ynew)
fnew = newfunc(xnew, ynew)
#可視化
#讓imshow的參數(shù)interpolation設(shè)置為'nearest'方便比較插值處理
pl.subplot(121)
im1 = pl.imshow(fvals, extent = [-1, 1, -1, 1],
cmap = mpl.cm.hot, interpolation='nearest', origin='lower')
pl.colorbar(im1)
pl.subplot(122)
im2 = pl.imshow(fnew, extent = [-1, 1, -1, 1],
cmap = mpl.cm.hot, interpolation='nearest', origin='lower')
pl.colorbar(im2)
pl.tight_layout()
pl.show()
二維插值的三維圖
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from scipy import interpolate
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
def func(x, y):
return (x+y)*np.exp(-5.0*(x**2+y**2))
#X-Y軸分為20*20的網(wǎng)格
x = np.linspace(-1, 1, 20)
y = np.linspace(-1, 1, 20)
x, y = np.meshgrid(x, y)
fvals = func(x,y)
#畫分圖1
fig = plt.figure(figsize = (9,6))
ax = plt.subplot(1, 2, 1, projection = '3d')
surf = ax.plot_surface(x, y, fvals, rstride=2, cstride=2,
cmap=cm.coolwarm, linewidth=0.5, antialiased=True)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('f(x,y)')
plt.colorbar(surf, shrink=0.5, aspect=5) #添加顏色條標注
#二維插值
newfunc = interpolate.interp2d(x, y, fvals, kind='cubic')
#計算100*100網(wǎng)格上插值
xnew = np.linspace(-1, 1, 100)
ynew = np.linspace(-1, 1, 100)
fnew = newfunc(xnew, ynew)
xnew, ynew = np.meshgrid(xnew, ynew)
ax2 = plt.subplot(1, 2, 2, projection='3d')
surf2 = ax2.plot_surface(xnew, ynew, fnew, rstride=2, cstride=2,
cmap=cm.coolwarm, linewidth=0.5, antialiased=True)
ax2.set_xlabel('xnew')
ax2.set_ylabel('ynew')
ax2.set_zlabel('fnew(x,y)')
plt.colorbar(surf2, shrink=0.5, aspect=5) #添加顏色條標注
plt.tight_layout()
plt.show()
左圖的二維數(shù)據(jù)集的函數(shù)值由于樣本較少,會顯得粗糙梭伐。而右圖對二維樣本數(shù)據(jù)進行三次樣條插值痹雅,擬合得到更多數(shù)據(jù)點的樣本值,繪圖后圖像明顯光滑多了
最小二乘擬合
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import leastsq
plt.figure(figsize=(9,9))
X = np.array([8.19, 2.72, 6.39, 8.71, 4.7, 2.66, 3.78])
Y = np.array([7.01, 2.78, 6.47, 6.71, 4.1, 4.23, 4.05])
#計算以p為參數(shù)的直線與原始數(shù)據(jù)之間的誤差
def f(p):
k, b = p
return (Y-(k*X+b))
#leastsq使得f的輸出數(shù)組的平方和最小糊识,參數(shù)初始值為[1,0]
r = leastsq(f, [1,0])
k, b = r[0]
plt.scatter(X, Y, s=100, alpha=1.0, marker='o', label = 'Data Points')
x = np.linspace(0, 10, 1000)
y = k*x + b
plt.plot(x, y, color='r', linewidth=5,
linestyle=":", markersize=20, label = 'Fitting Curve')
plt.legend(loc=0, numpoints=1)
leg = plt.gca().get_legend()
ltext = leg.get_texts()
plt.setp(ltext, fontsize='xx-large')
plt.xlabel('Amphere/A', fontsize=20)
plt.ylabel('Volt/V', fontsize=20)
plt.xlim(0, x.max()*1.1)
plt.ylim(0, y.max()*1.1)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(loc='upper left')