書中感知機算法學習的目標是極小化所有誤分類點到分離超平面的距離和, 于是損失函數(shù)定義為
其中M為誤分類樣本集合.
采用隨機梯度下降(SGD), 考慮單個誤分類樣本
對參數(shù)求導得
更新參數(shù)
import numpy as np
class Perceptron(object):
def __init__(self, feature_num, alpha, max_step=10000):
self._alpha = alpha
self._w = np.zeros(feature_num)
self._b = 0
self._max_step = max_step
def fit(self, X, y):
misclassify = True
step = 0
while misclassify and step <= self._max_step:
misclassify = False
step += 1
for tx, ty in zip(X, y):
if ty * (np.dot(tx, self._w) + self._b) <= 0:
self._w += self._alpha * tx * ty
self._b += self._alpha * ty
misclassify = True
def predict(self, X):
return np.where((X @ self._w.T + self._b).astype(int) > 0, 1, -1)