- 在幾百個點組成的小規(guī)模數(shù)據(jù)集上雅任, 簡化版
SMO
算法的運行是沒有什么問題的, 但是在更大的數(shù)據(jù)集上的運行速度就會變慢咨跌。剛才巳經(jīng)討論了簡化版 SMO
算 法 沪么,下面我們就討論完整版的Platt SMO
算法。在這兩個版本中锌半,實現(xiàn) alpha
的更改和代數(shù)運算的優(yōu)化環(huán)節(jié)一模一樣禽车。在優(yōu)化過程中 ,唯一的不同就是選擇 alpha
的方式刊殉。完整版的 Platt SMO
算法應(yīng)用了一些能夠提速的啟發(fā)方法殉摔。或許讀者已經(jīng)意識到记焊,上一節(jié)的例子在執(zhí)行時存在一定的時間提升空間逸月。
- Platt
SMO
算法是通過一個外循環(huán)來選擇第一個 alpha
值的,并且其選擇過程會在兩種方式之間進(jìn)行交替: 一種方式是在所有數(shù)據(jù)集上進(jìn)行單遍掃描遍膜, 另一種方式則是在非邊界 alpha
中實現(xiàn)單遍掃描碗硬。而所謂非邊界 alpha
指的就是那些不等于邊界0或C
的 alpha
值 。對整個數(shù)據(jù)集的掃描相當(dāng)容易 瓢颅,而實現(xiàn)非邊界 alpha
值的掃描時恩尾,首先需要建立這些 alpha
的列表,然后再對這個表進(jìn)行遍歷挽懦。同時翰意,該步驟會跳過那些已知的不會改變的 alpha
值。
- 在選擇第一個
alpha
值后信柿,算法會通過一個內(nèi)循環(huán)來選擇第二個 alpha
值 冀偶。在優(yōu)化過程中,會通過最大化步長的方式來獲得第二個 alpha
值渔嚷。在簡化版 SMO
算法中进鸠,我們會在選擇 j
之后計算錯誤率 Ej 。但在這里圃伶,我們會建立一個全局的緩存用于保存誤差值堤如,并從中選擇使得步長或者說 Ei - Ej 最 大 的 alpha
值 蒲列。
- 在講述改進(jìn)后的代碼之前,我們必須要對上節(jié)的代碼進(jìn)行清理搀罢。下面的程序清單中包含 1 個用于清理代碼的數(shù)據(jù)結(jié)構(gòu)和 3 個用于對
E
進(jìn)行緩存的輔助函數(shù)蝗岖。
完整版 SMO 算法輔助函數(shù)
# 建立一個數(shù)據(jù)結(jié)構(gòu)來保存所有的重要值,這樣較為便利
class optStruct:
def __init__(self, dataMatIn, classLabels, C, toler):
self.X = dataMatIn
self.labelMat = classLabels
self.C = C
self.tol = toler
self.m = np.shape(dataMatIn)[0]
self.alphas = np.mat(np.zeros((self.m, 1)))
self.b = 0
# 誤差緩存,第一列為是否有效標(biāo)志位榔至,第二列為實際的E值
self.eCache = np.mat(np.zeros((self.m, 2)))
# 計算并返回 E 值
def calcEk(oS, k):
# 預(yù)測值
fXk = float(np.multiply(oS.alphas, oS.labelMat).T * (oS.X * oS.X[k,:].T)) + oS.b
# 誤差值
Ek = fXk - float(oS.labelMat[k])
return Ek
# 內(nèi)循環(huán)中的啟發(fā)式方法
# 用于選擇第二個 alpha 或者說內(nèi)循環(huán)的 alpha 值
def selectJ(i, oS, Ei):
maxK = -1
maxDeltaE = 0
Ej = 0
oS.eCache[i] = [1, Ei]
# 返回 eCache 第0列非0值下標(biāo)
validEcacheList = np.nonzero(oS.eCache[:, 0])[0]
if len(validEcacheList) > 1:
for k in validEcacheList:
if k == i:
continue
Ek = calcEk(oS, k)
deltaE = abs(Ei - Ek)
if (deltaE > maxDeltaE):
maxK = k
maxDeltaE = deltaE
Ej = Ek
return maxK, Ej
else:
j = selectJrand(i, oS.m)
Ej = calcEk(oS, j)
return j, Ej
# 計算誤差值并存入緩存中抵赢,在對alpha值進(jìn)行優(yōu)化之后會用到這個值
def updateEk(oS, k):
Ek = calcEk(oS, k)
oS.eCache[k] = [1, Ek]
完整版 SMO 算法中的優(yōu)化例程
- 此實現(xiàn)代碼幾乎和
smoSimple()
函數(shù)一模一樣, 但是這里的代碼已經(jīng)使用了自己的數(shù)據(jù)結(jié)構(gòu)唧取。該結(jié)構(gòu)在參數(shù) oS
中傳遞铅鲤。第二個重要的修改就是使用 selectJ ()
而不是 selectJrand()
來選擇第二個 alpha
的值。最后枫弟,在 alpha 值改變時更新 Ecache
邢享。
def innerL(i, oS):
Ei = calcEk(oS, i)
if ((oS.labelMat[i] * Ei < -oS.tol) and (oS.alphas[i] < oS.C)) or \
((oS.labelMat[i] * Ei > oS.tol) and (oS.alphas[i] > 0) ):
# 用于選擇第二個 alpha 或者說內(nèi)循環(huán)的 alpha 值
j, Ej = selectJ(i, oS, Ei)
alphaIoId = oS.alphas[i].copy()
alphaJoId = oS.alphas[j].copy()
if (oS.labelMat[i] != oS.labelMat[j]):
L = max(0, oS.alphas[j] - oS.alphas[i])
H = min(oS.C, oS.C + oS.alphas[j] - oS.alphas[i])
else:
L = max(0, oS.alphas[j] + oS.alphas[i] - oS.C)
H = min(oS.C, oS.alphas[j] + oS.alphas[i])
if L == H:
# print('L==H')
return 0
eta = 2.0 * oS.X[i,:] * oS.X[j,:].T - oS.X[i,:] * oS.X[i,:].T - oS.X[j,:] * oS.X[j,:].T
if eta >= 0:
# print('eta >= 0')
return 0
oS.alphas[j] -= oS.labelMat[j] * (Ei - Ej) / eta
oS.alphas[j] = clipAlpha(oS.alphas[j], H, L)
updateEk(oS, j) # 更新誤差緩存
if (abs(oS.alphas[j] - alphaJoId) < 0.00001):
# print('j not moving enough')
return 0
oS.alphas[i] += oS.labelMat[j] * oS.labelMat[i] * (alphaJoId - oS.alphas[j])
updateEk(oS, i) # 更新誤差緩存
b1 = oS.b - Ei - oS.labelMat[i] * (oS.alphas[i] - alphaIoId) * \
oS.X[i,:] * oS.X[i,:].T - oS.labelMat[j] * \
(oS.alphas[j] - alphaJoId) * oS.X[i,:] * oS.X[j,:].T
b2 = oS.b - Ej - oS.labelMat[i] * (oS.alphas[i] - alphaIoId) * \
oS.X[i,:] * oS.X[j,:].T - oS.labelMat[j] * \
(oS.alphas[j] - alphaJoId) * oS.X[j,:] * oS.X[j,:].T
if (0 < oS.alphas[i]) and (oS.alphas[i] < oS.C):
oS.b = b1
elif (0 < oS.alphas[j]) and (oS.alphas[j] < oS.C):
oS.b = b2
else:
oS.b = (b1 + b2) / 2.0
return 1
else:
return 0
完整版 SMO 算法中的外循環(huán)代碼
def smoP(dataMatIn, classLabels, C, toler, maxIter, kTup = ('lin', 0)):
oS = optStruct(np.mat(dataMatIn), np.mat(classLabels).transpose(), C, toler)
iter = 0
entireSet = True
alphaPairsChanged = 0
while(iter < maxIter) and ((alphaPairsChanged > 0) or (entireSet)):
alphaPairsChanged = 0
if entireSet:
for i in range(oS.m): # 遍歷所有的值
alphaPairsChanged += innerL(i, oS)
# print('fullSet, iter: %d i: %d, pairs changed %d' % (iter, i, alphaPairsChanged))
iter += 1
else:
# 遍歷非邊界值
nonBoundIs = np.nonzero((0 < oS.alphas.A) * (oS.alphas.A < C))[0]
for i in nonBoundIs:
alphaPairsChanged += innerL(i, oS)
# print('non-bound, iter: %d i: %d, pairs changed %d' %(iter, i, alphaPairsChanged))
iter += 1
if entireSet:
entireSet = False
elif (alphaPairsChanged == 0):
entireSet = True
# print('iteration number: %d' % iter)
return oS.b, oS.alphas
w 的計算
def calcWs(alphas, dataArr, labelArr):
X = np.mat(dataArr) # (100, 2)
labelMat = np.mat(labelArr).transpose() #(100, 1)
m, n = np.shape(X) # m = 100, n = 2
w = np.zeros((n, 1)) # (100, 1)
for i in range(m):
w += np.multiply(alphas[i] * labelMat[i], X[i,:].T)
return w
畫出分類示意圖
# 畫出完整分類圖
def plotFigure(weights, b):
x, y = loadDataSet('testSet.txt')
xarr = np.array(x)
n = np.shape(x)[0]
x1 = []; y1 = []
x2 = []; y2 = []
for i in np.arange(n):
if int(y[i]) == 1:
x1.append(xarr[i,0]); y1.append(xarr[i,1])
else:
x2.append(xarr[i,0]); y2.append(xarr[i,1])
plt.scatter(x1, y1, s = 30, c = 'r', marker = 's')
plt.scatter(x2, y2, s = 30, c = 'g')
# 畫出 SVM 分類直線
xx = np.arange(0, 10, 0.1)
# 由分類直線 weights[0] * xx + weights[1] * yy1 + b = 0 易得下式
yy1 = (-weights[0] * xx - b) / weights[1]
# 由分類直線 weights[0] * xx + weights[1] * yy2 + b + 1 = 0 易得下式
yy2 = (-weights[0] * xx - b - 1) / weights[1]
# 由分類直線 weights[0] * xx + weights[1] * yy3 + b - 1 = 0 易得下式
yy3 = (-weights[0] * xx - b + 1) / weights[1]
plt.plot(xx, yy1.T)
plt.plot(xx, yy2.T)
plt.plot(xx, yy3.T)
# 畫出支持向量點
for i in range(n):
if alphas[i] > 0.0:
plt.scatter(xarr[i,0], xarr[i,1], s = 150, c = 'none', alpha = 0.7, linewidth = 1.5, edgecolor = 'red')
plt.xlim((-2, 12))
plt.ylim((-8, 6))
plt.show()
主函數(shù)
if __name__ == '__main__':
dataArr, labelArr = loadDataSet('/home/gcb/data/testSet.txt')
b, alphas = smoP(dataArr, labelArr, 0.6, 0.001, 40)
w = calcWs(alphas, dataArr, labelArr)
plotFigure(w, b)
print(b)
print(alphas[alphas > 0]) # 支持向量對應(yīng)的 alpha > 0
print(w)
參考