#-*- coding:utf-8 -*-
importnumpyasnp
importtensorflowastf
importmatplotlib.pyplotasplt
importtrain_Car_Data
importtime
data = train_Car_Data.load_Data(download=False)
new_Data = train_Car_Data.covert2onehot(data)
#準備好所有數(shù)據(jù) 訓練集和測試集
new_Data = new_Data.values.astype(np.float32)#將oneshot轉(zhuǎn)化為32位
np.random.shuffle(new_Data)#隨機化數(shù)據(jù)
sep =int(0.7*len(new_Data))#提取出前70%的數(shù)據(jù)的下標
train_data = new_Data[:sep]
test_Data = new_Data[sep:]
#建立網(wǎng)絡
tf_input = tf.placeholder(tf.float32,[None,25],"input")#第一個是行 第二個是列(有25列)
tfx = tf_input[:,:21]#打豎的行不要求,對打橫的選21列
tfy = tf_input[:,21:]
l1 = tf.layers.dense(tfx,128,tf.nn.relu,name="l1")#第一個參數(shù)是數(shù)據(jù) 第二個是unit第三個是激勵函數(shù) 第四個是名字 這是在設置隱藏層
l2 = tf.layers.dense(l1,128,tf.nn.relu,name="l2")#再設置一個隱藏層
out = tf.layers.dense(l2,4,name="l3")#輸出層
prediction = tf.nn.softmax(out,name="pred")#先用激勵函數(shù)softmax得到預期值 留到后面對比
loss = tf.losses.softmax_cross_entropy(onehot_labels=tfy,logits=out)#給標簽 也就是y值 和x值來得到損失值
accuracy = tf.metrics.accuracy(# return (acc, update_op), and create 2 local var得到準確度
labels=tf.argmax(tfy,axis=1),predictions=tf.argmax(out,axis=1),
)[1]
opt = tf.train.GradientDescentOptimizer(learning_rate=0.1)#梯度下降
train_op = opt.minimize(loss)#最小化損失 讓損失最小
sess = tf.Session()
sess.run(tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()))
# plt.ion()
# fig , (ax1,ax2) = plt.subplots(1,2,figsize=(8,4))
# accuracies,steps = [], []
fortinrange(4000):
batch_index = np.random.randint(len(train_data),size=32)
sess.run(train_op,{tf_input:train_data[batch_index]})
ift %50==0:
acc_,pre_,loss_ = sess.run([accuracy,prediction,loss],{tf_input:test_Data})
# accuracies.append(acc_)
# steps.append(t)
print("Set: %i "% t,"| Accurate: %.2f"% acc_,"| Loss: %.2f"% loss_)
writer = tf.summary.FileWriter('./my_graph',sess.graph)
#? ? ? ? ax1.cla()
#
#? ? ? ? for c in range(4):
#? ? ? ? ? ? bp = ax1.bar(x=c+0.1,height=sum((np.argmax(pre_,axis=1) == c)),width=0.2,color='red')
#? ? ? ? ? ? bt = ax1.bar(x=c-0.1,height=sum((np.argmax(test_Data[:,21:],axis=1) == c)),width= 0.2,color='blue')
#? ? ? ? ax1.set_xticks(range(4),["accepted", "good", "unaccepted", "very good"])
#? ? ? ? ax1.legend(handles=[bp, bt], labels=["prediction", "target"])
#? ? ? ? ax1.set_ylim((0,400))
#? ? ? ? ax2.cla()
#? ? ? ? ax2.plot(steps,accuracies,label="accuracy")
#? ? ? ? ax2.set_ylim(ymax=1)
#? ? ? ? ax2.set_ylabel("accuracy")
#
#
# plt.ioff()
# plt.show()
輸出
Set: 0 | Accurate: 0.69 | Loss: 1.22
Set: 50? | Accurate: 0.74 | Loss: 0.53
Set: 100? | Accurate: 0.78 | Loss: 0.39
Set: 150? | Accurate: 0.80 | Loss: 0.29
Set: 200? | Accurate: 0.82 | Loss: 0.25
Set: 250? | Accurate: 0.84 | Loss: 0.21
Set: 300? | Accurate: 0.85 | Loss: 0.17
Set: 350? | Accurate: 0.86 | Loss: 0.17
Set: 400? | Accurate: 0.87 | Loss: 0.14
Set: 450? | Accurate: 0.88 | Loss: 0.13
Set: 500? | Accurate: 0.89 | Loss: 0.11
Set: 550? | Accurate: 0.90 | Loss: 0.10
Set: 600? | Accurate: 0.90 | Loss: 0.10
Set: 650? | Accurate: 0.91 | Loss: 0.09
Set: 700? | Accurate: 0.91 | Loss: 0.08
Set: 750? | Accurate: 0.91 | Loss: 0.07
Set: 800? | Accurate: 0.92 | Loss: 0.06
Set: 850? | Accurate: 0.92 | Loss: 0.06
Set: 900? | Accurate: 0.93 | Loss: 0.06
Set: 950? | Accurate: 0.93 | Loss: 0.05
Set: 1000? | Accurate: 0.93 | Loss: 0.05
Set: 1050? | Accurate: 0.93 | Loss: 0.05
Set: 1100? | Accurate: 0.94 | Loss: 0.06
Set: 1150? | Accurate: 0.94 | Loss: 0.04
Set: 1200? | Accurate: 0.94 | Loss: 0.04
Set: 1250? | Accurate: 0.94 | Loss: 0.04
Set: 1300? | Accurate: 0.94 | Loss: 0.03
Set: 1350? | Accurate: 0.95 | Loss: 0.03
Set: 1400? | Accurate: 0.95 | Loss: 0.03
Set: 1450? | Accurate: 0.95 | Loss: 0.03
Set: 1500? | Accurate: 0.95 | Loss: 0.03
Set: 1550? | Accurate: 0.95 | Loss: 0.03
Set: 1600? | Accurate: 0.95 | Loss: 0.03
Set: 1650? | Accurate: 0.95 | Loss: 0.03
Set: 1700? | Accurate: 0.96 | Loss: 0.02
Set: 1750? | Accurate: 0.96 | Loss: 0.03
Set: 1800? | Accurate: 0.96 | Loss: 0.02
Set: 1850? | Accurate: 0.96 | Loss: 0.02
Set: 1900? | Accurate: 0.96 | Loss: 0.02
Set: 1950? | Accurate: 0.96 | Loss: 0.02
Set: 2000? | Accurate: 0.96 | Loss: 0.02
Set: 2050? | Accurate: 0.96 | Loss: 0.02
Set: 2100? | Accurate: 0.96 | Loss: 0.02
Set: 2150? | Accurate: 0.96 | Loss: 0.02
Set: 2200? | Accurate: 0.97 | Loss: 0.02
Set: 2250? | Accurate: 0.97 | Loss: 0.02
Set: 2300? | Accurate: 0.97 | Loss: 0.02
Set: 2350? | Accurate: 0.97 | Loss: 0.02
Set: 2400? | Accurate: 0.97 | Loss: 0.02
Set: 2450? | Accurate: 0.97 | Loss: 0.02
Set: 2500? | Accurate: 0.97 | Loss: 0.02
Set: 2550? | Accurate: 0.97 | Loss: 0.02
Set: 2600? | Accurate: 0.97 | Loss: 0.01
Set: 2650? | Accurate: 0.97 | Loss: 0.01
Set: 2700? | Accurate: 0.97 | Loss: 0.01
Set: 2750? | Accurate: 0.97 | Loss: 0.01
Set: 2800? | Accurate: 0.97 | Loss: 0.01
Set: 2850? | Accurate: 0.97 | Loss: 0.01
Set: 2900? | Accurate: 0.97 | Loss: 0.01
Set: 2950? | Accurate: 0.97 | Loss: 0.01
Set: 3000? | Accurate: 0.97 | Loss: 0.01
Set: 3050? | Accurate: 0.97 | Loss: 0.01
Set: 3100? | Accurate: 0.97 | Loss: 0.01
Set: 3150? | Accurate: 0.97 | Loss: 0.01
Set: 3200? | Accurate: 0.98 | Loss: 0.01
Set: 3250? | Accurate: 0.98 | Loss: 0.01
Set: 3300? | Accurate: 0.98 | Loss: 0.01
Set: 3350? | Accurate: 0.98 | Loss: 0.01
Set: 3400? | Accurate: 0.98 | Loss: 0.01
Set: 3450? | Accurate: 0.98 | Loss: 0.01
Set: 3500? | Accurate: 0.98 | Loss: 0.01
Set: 3550? | Accurate: 0.98 | Loss: 0.01
Set: 3600? | Accurate: 0.98 | Loss: 0.01
Set: 3650? | Accurate: 0.98 | Loss: 0.01
Set: 3700? | Accurate: 0.98 | Loss: 0.01
Set: 3750? | Accurate: 0.98 | Loss: 0.01
Set: 3800? | Accurate: 0.98 | Loss: 0.01
Set: 3850? | Accurate: 0.98 | Loss: 0.01
Set: 3900? | Accurate: 0.98 | Loss: 0.01
Set: 3950? | Accurate: 0.98 | Loss: 0.01
Process finished with exit code 0
可以發(fā)現(xiàn)準確度逐步增加,損失逐步減少