fashion-mnist 是mnist的升級版阴颖;
數(shù)據(jù)長這樣,7000張不同類別的單色圖片:
任務(wù)是給這些圖片分類坊秸,衣服鞋包包歸納整齊谜疤。
數(shù)據(jù)歸類:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D,MaxPool2D,AveragePooling2D,Flatten,Dense
import pandas as pd
path = Path.cwd()/"mnist_fashion"
pathes = []
def to_the_end(path):
if path.is_file():
pathes.append(path)
else:
for i in path.iterdir():
to_the_end(i)
to_the_end(path)
def show_data(row,col,x_train):
for index in range(1,row*col+1):
ax = plt.subplot(row,col,index)
ax.imshow(x_train[index],"gray")
plt.axis('off')
X = np.array([plt.imread(str(i)) for i in pathes])
Y = np.array([int(p.parent.name) for p in pathes])
X_train, X_test, y_train, y_test = train_test_split(X,Y)
X_train, X_test = np.expand_dims(X_train,-1),np.expand_dims(X_test,-1)
CNN分類數(shù)據(jù)集:
model = Sequential()
model.add(Conv2D(32,(3,3),input_shape=(28,28,1)))
model.add(MaxPool2D())
model.add(Conv2D(64,(3,3)))
model.add(MaxPool2D())
model.add(Conv2D(128,(3,3)))
model.add(MaxPool2D())
model.add(Flatten())
model.add(Dense(128,activation="relu"))
model.add(Dense(10,activation="softmax"))
model.summary()
model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=["accuracy"])
history = model.fit(X_train,y_train,batch_size=500,epochs=10,validation_data=[X_test,y_test])
history = pd.DataFrame(history.history)
history.plot()
跑分結(jié)果:
Model: "sequential_33"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_83 (Conv2D) (None, 26, 26, 32) 320
_________________________________________________________________
max_pooling2d_41 (MaxPooling (None, 13, 13, 32) 0
_________________________________________________________________
conv2d_84 (Conv2D) (None, 11, 11, 64) 18496
_________________________________________________________________
max_pooling2d_42 (MaxPooling (None, 5, 5, 64) 0
_________________________________________________________________
conv2d_85 (Conv2D) (None, 3, 3, 128) 73856
_________________________________________________________________
max_pooling2d_43 (MaxPooling (None, 1, 1, 128) 0
_________________________________________________________________
flatten_10 (Flatten) (None, 128) 0
_________________________________________________________________
dense_11 (Dense) (None, 128) 16512
_________________________________________________________________
dense_12 (Dense) (None, 10) 1290
=================================================================
Total params: 110,474
Trainable params: 110,474
Non-trainable params: 0
_________________________________________________________________
Train on 52500 samples, validate on 17500 samples
Epoch 1/10
52500/52500 [==============================] - 39s 745us/sample - loss: 0.8906 - accuracy: 0.6817 - val_loss: 0.6245 - val_accuracy: 0.7737
Epoch 2/10
52500/52500 [==============================] - 39s 741us/sample - loss: 0.5823 - accuracy: 0.7905 - val_loss: 0.5329 - val_accuracy: 0.8085
Epoch 3/10
52500/52500 [==============================] - 39s 746us/sample - loss: 0.5072 - accuracy: 0.8181 - val_loss: 0.4819 - val_accuracy: 0.8294
Epoch 4/10
52500/52500 [==============================] - 43s 815us/sample - loss: 0.4567 - accuracy: 0.8355 - val_loss: 0.4415 - val_accuracy: 0.8447
Epoch 5/10
52500/52500 [==============================] - 48s 919us/sample - loss: 0.4260 - accuracy: 0.8475 - val_loss: 0.4359 - val_accuracy: 0.8456
Epoch 6/10
52500/52500 [==============================] - 44s 836us/sample - loss: 0.3940 - accuracy: 0.8569 - val_loss: 0.4029 - val_accuracy: 0.8545
Epoch 7/10
52500/52500 [==============================] - 39s 744us/sample - loss: 0.3785 - accuracy: 0.8630 - val_loss: 0.4127 - val_accuracy: 0.8489
Epoch 8/10
52500/52500 [==============================] - 39s 741us/sample - loss: 0.3580 - accuracy: 0.8705 - val_loss: 0.3710 - val_accuracy: 0.8696
Epoch 9/10
52500/52500 [==============================] - 40s 753us/sample - loss: 0.3443 - accuracy: 0.8749 - val_loss: 0.3732 - val_accuracy: 0.8636
Epoch 10/10
52500/52500 [==============================] - 39s 745us/sample - loss: 0.3316 - accuracy: 0.8800 - val_loss: 0.3634 - val_accuracy: 0.8716
驗(yàn)證集正確率達(dá)到了87%,是不是很棒帽撑!