[Tesorflow Keras 學習筆記]新手一定要玩的MNIST手寫數字辨識

甜不辣馬拉松
8 min readDec 9, 2020

--

Josef Steppan
目錄
1.前言
2.資料庫介紹
2.1image的部分
2.2label的部分
3.資料前處理(data normalize label one hot encoding)
3.1影像資料
3.2標註資料
4.建立多元感知器模型(Multilayer perceptron)
5.開始訓練
6.畫出訓練歷程
7.測試集評估
8.測試集評估(混淆矩陣表示) confusion table
8.1如果有過度擬和的情況(overfiting),加入dropout
8.2用新的模型架構重新訓練
8.3多加一層隱藏層
9.完整Code
10.小心得

前言

剛踏入這個領域,很多人都以MNIST資料庫當作小試身手,可以說是machine learning 或 deep learning 的 hello world!那就來玩吧~

資料庫介紹

這是一個大型手寫數字資料庫,經常被使用在機器學習便是的領域,每一張圖片為 28*28大小,這個數據庫當中包含60000筆訓練影像和10000筆測試影像。

從keras 也有接口可以下載MNIST的資料,現在就來實際載資料,以及了解其資料型態,呈現方式。

image的部分

把影像畫出來

畫出影像和其label

label 的部分

資料前處理(data normalize label one hot encoding)

影像資料:

  • 原本資料型態為 1*28*28 轉換成1*684 (這個684=28*28)
  • 然後再進行影像的正規化,全部除以255(這是顏色的最大值)

標註資料:

  • 使用 one hot encodeing 轉換數值,因為這個是數字分類,其實要做的事情是分類,每個類別間的關係並沒有大小關係,所以通過one hot 會讓訓練的情況更平均、更好

到這邊基本的資料前處理都完成了,接下來就是建立模型

建立多元感知器模型(Multilayer perceptron)

模型架構:

Model: "sequential_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_9 (Dense) (None, 256) 200960
_________________________________________________________________
dense_10 (Dense) (None, 10) 2570
=================================================================
Total params: 203,530
Trainable params: 203,530
Non-trainable params: 0
_________________________________________________________________

從這個 summary 可以看出 這一個模型是兩層的模型
然後隱藏層有256個神經元
輸出層有10個神經元
另外是 param 參數
參數的計算方式第一個是 200960=256*784+256
另外一個是2570=256*10+10=2570
下面有一個全部訓練 total params=200960+2570=203530

開始訓練

訓練主要有兩個動作: model.copmpile + model.fit

訓練歷程:

畫出訓練歷程

acc
loss

測試集評估

測試資料的準確度為 0.978

測試集評估(混淆矩陣表示) confusion table

如果有過度擬和的情況(overfiting),加入dropout

過度擬和可以從訓練的曲線圖來看,當訓練的曲線和測試的曲線在越多次epoch 後兩者相去甚遠,就有可能是 overfiting,這時候可以試試看 dropout

dropout 是把一些神經元丟掉,不要那麼多的特徵值,造成它只會訓練的資料,驗證的資料就通通答不出來,因此修改後的模型為

Model: “sequential_6” _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_13 (Dense) (None, 1000) 785000 _________________________________________________________________ dropout_3 (Dropout) (None, 1000) 0 _________________________________________________________________ dense_14 (Dense) (None, 10) 10010 ================================================================= Total params: 795,010 Trainable params: 795,010 Non-trainable params: 0 _________________________________________________________________ None

用新的模型架構重新訓練

訓練歷程:

測試結果:

這裡效果沒有一開始的九成七還好,模型再調整一些

多加一層隱藏層

訓練歷程:

測試結果:

這次的準確度明顯提升了! 到達 0.980的準確度,增加隱藏層是有效的

完整Code

小心得

透過這個手寫數字資料庫的練習,基本從資料前處理、建構模型、訓練模型、執行預測、預測成效評估,都包辦了呢!

這應該是最基本的模型架構,後續會有各種模型的變形,如果資料辨識的難度增加,模型架構也需要做加寬、加深的調整。

那以MNIST 手寫數字辨識上,影像的難度較低,因此模型結構不需要太複雜,即可辨別出各個數字,在測試的準確度也達0.978、0.98,已是蠻高的準確度,較容易出錯的數字如數字8辨錯成0、數字8辨錯成3、識字4辨錯成9、數字7辨錯成9,這幾項比較常見的錯誤,其餘目前錯誤率較低,之後有機會試試看其他模型是否可以來降低判錯的部分。

❤️感恩看到這裡的你,希望這篇文章有幫上你👏歡迎拍手給我鼓勵,我是甜不辣馬拉松,我們下次見

--

--

甜不辣馬拉松

幻想自己是貝多芬,可是敲打的卻是機械鍵盤