本文結構:
- 什么是線性單元
- 有什么用
- 代碼實現(xiàn)
1. 什么是線性單元
線性單元和感知器的區(qū)別就是在激活函數(shù):
感知器的 f 是階越函數(shù):
線性單元的激活函數(shù)是線性的:
所以線性模型的公式如下:
2. 有什么用
感知器存在一個問題卖漫,就是遇到線性不可分的數(shù)據(jù)時,就可能無法收斂,所以要使用一個可導的線性函數(shù)來替代階躍函數(shù),即線性單元,這樣就會收斂到一個最佳的近似上笙僚。
3. 代碼實現(xiàn)
1. 繼承Perceptron,初始化線性單元
from perceptron import Perceptron
#定義激活函數(shù)f
f = lambda x: x
class LinearUnit(Perceptron):
def __init__(self, input_num):
'''初始化線性單元,設置輸入?yún)?shù)的個數(shù)'''
Perceptron.__init__(self, input_num, f)
2. 定義一個線性單元, 調(diào)用 train_linear_unit
進行訓練
- 打印訓練獲得的權重
- 輸入?yún)?shù)值 [3.4] 測試一下預測值
if __name__ == '__main__':
'''訓練線性單元'''
linear_unit = train_linear_unit()
# 打印訓練獲得的權重
print linear_unit
# 測試
print 'Work 3.4 years, monthly salary = %.2f' % linear_unit.predict([3.4])
print 'Work 15 years, monthly salary = %.2f' % linear_unit.predict([15])
print 'Work 1.5 years, monthly salary = %.2f' % linear_unit.predict([1.5])
print 'Work 6.3 years, monthly salary = %.2f' % linear_unit.predict([6.3])
- 其中訓練的過程就是:
- 獲得訓練數(shù)據(jù)桥滨,
- 設定迭代次數(shù),學習速率等參數(shù)
- 再返回訓練好的線性單元
def train_linear_unit():
'''
使用數(shù)據(jù)訓練線性單元
'''
# 創(chuàng)建感知器弛车,輸入?yún)?shù)的特征數(shù)為1(工作年限)
lu = LinearUnit(1)
# 訓練齐媒,迭代10輪, 學習速率為0.01
input_vecs, labels = get_training_dataset()
lu.train(input_vecs, labels, 10, 0.01)
#返回訓練好的線性單元
return lu
完整代碼
from perceptron import Perceptron
#定義激活函數(shù)f
f = lambda x: x
class LinearUnit(Perceptron):
def __init__(self, input_num):
'''初始化線性單元,設置輸入?yún)?shù)的個數(shù)'''
Perceptron.__init__(self, input_num, f)
def get_training_dataset():
'''
捏造5個人的收入數(shù)據(jù)
'''
# 構建訓練數(shù)據(jù)
# 輸入向量列表纷跛,每一項是工作年限
input_vecs = [[5], [3], [8], [1.4], [10.1]]
# 期望的輸出列表喻括,月薪,注意要與輸入一一對應
labels = [5500, 2300, 7600, 1800, 11400]
return input_vecs, labels
def train_linear_unit():
'''
使用數(shù)據(jù)訓練線性單元
'''
# 創(chuàng)建感知器贫奠,輸入?yún)?shù)的特征數(shù)為1(工作年限)
lu = LinearUnit(1)
# 訓練唬血,迭代10輪, 學習速率為0.01
input_vecs, labels = get_training_dataset()
lu.train(input_vecs, labels, 10, 0.01)
#返回訓練好的線性單元
return lu
if __name__ == '__main__':
'''訓練線性單元'''
linear_unit = train_linear_unit()
# 打印訓練獲得的權重
print linear_unit
# 測試
print 'Work 3.4 years, monthly salary = %.2f' % linear_unit.predict([3.4])
print 'Work 15 years, monthly salary = %.2f' % linear_unit.predict([15])
print 'Work 1.5 years, monthly salary = %.2f' % linear_unit.predict([1.5])
print 'Work 6.3 years, monthly salary = %.2f' % linear_unit.predict([6.3])
學習資料:
https://www.zybuluo.com/hanbingtao/note/448086
推薦閱讀 歷史技術博文鏈接匯總
也許可以找到你想要的
我是 不會停的蝸牛 Alice
85后全職主婦
喜歡人工智能,行動派
創(chuàng)造力唤崭,思考力拷恨,學習力提升修煉進行中
歡迎您的喜歡,關注和評論谢肾!