線(xiàn)性單元
- 感知器有一個(gè)問(wèn)題,當(dāng)面對(duì)的數(shù)據(jù)集不是線(xiàn)性可分的時(shí)候用踩,『感知器規(guī)則』可能無(wú)法收斂窃款,這意味著我們永遠(yuǎn)也無(wú)法完成一個(gè)感知器的訓(xùn)練宛官。為了解決這個(gè)問(wèn)題渺贤,我們使用一個(gè)連續(xù)的線(xiàn)性函數(shù)來(lái)替代感知器的階躍函數(shù)九秀,這種感知器就叫做線(xiàn)性單元详瑞。線(xiàn)性單元在面對(duì)線(xiàn)性不可分的數(shù)據(jù)集時(shí)掂林,會(huì)收斂到一個(gè)最佳的近似上。
- 那么線(xiàn)性單元就是將感知機(jī)的輸出激活函數(shù)由分段函數(shù)改為了連續(xù)函數(shù)坝橡,進(jìn)而輸出的值域也由
舉例說(shuō)明
當(dāng)我們說(shuō)模型時(shí),我們實(shí)際上在談?wù)摳鶕?jù)輸入預(yù)測(cè)輸出y的算法计寇。比如锣杂,可以是一個(gè)人的工作年限,可以是他的月薪番宁,我們可以用某種算法來(lái)根據(jù)一個(gè)人的工作年限來(lái)預(yù)測(cè)他的收入元莫。
其中是可以擬合年限輸入和月薪輸出的待求權(quán)重參數(shù)。工作年限稱(chēng)為一個(gè)特征蝶押,輸入可以包含多個(gè)特征如:行業(yè)踱蠢,公司,職級(jí)等棋电。當(dāng)特征變多時(shí)茎截,對(duì)應(yīng)的每個(gè)特征都需要一個(gè)權(quán)重用于擬合輸入和輸出之間的關(guān)系。
,矩陣表示
其中
代碼
由于相較于Perceptron只改變了激活函數(shù)赶盔,所以我們可以繼承Perceptron快速實(shí)現(xiàn)LinerUnit
class LinerUnit(Perceptron):
def __init__(self, input_dim, activator) -> None:
super().__init__(input_dim, activator)
生成訓(xùn)練數(shù)據(jù)胜臊,定義可視化
# 新定義的連續(xù)線(xiàn)性激活函數(shù)
def liner_activater(x):
return x
def get_training_dataset():
"""
construct training_set, consist of n samples
Working years and corresponding salary.
"""
data = [[5], [3], [8], [1.4], [10.1], [8.1]]
labels = [5500, 2300, 7600, 1800, 11400, 20000]
return data, labels
def train_liner_unit(iterations, lr):
"""
Train a liner_unit with training_set.
"""
lu = LinerUnit(input_dim=1, activator=liner_activater)
lu.train(*get_training_dataset(), iterations=iterations, lr=lr)
return lu
def show_results(linear_unit, samples):
"""
Visualize the line after the linear unit fit
"""
predicts = [linear_unit.predict(s) for s in samples]
plt.scatter(samples, predicts, marker="o")
x_fit = np.linspace(start=0, stop=max(samples), num=100)
y_fit = linear_unit.weights * x_fit + linear_unit.bias
plt.plot(x_fit, y_fit, linestyle="-")
plt.xlabel("Working years")
plt.ylabel("Salary")
plt.show()
訓(xùn)練隶校,測(cè)試,并可視化
if __name__ == "__main__":
linear_unit = train_liner_unit(10, 0.1)
test_samples = [[3.4], [15], [1.5], [6.3], [8]]
# test
for year in test_samples:
print(f"Work {year} years, monthly salary = {linear_unit.predict(year)}")
show_results(linear_unit=linear_unit, samples=test_samples)