nn-vec-recognise-images.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from model.neural_network_vec import Network
  2. import tensorflow as tf
  3. import joblib
  4. def load_dataset():
  5. """ 导入mnist数据 """
  6. (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
  7. # 扁平为1维
  8. x_train = x_train.reshape(x_train.shape[0], -1, 1)
  9. x_test = x_test.reshape(x_test.shape[0], -1, 1)
  10. # # 归一化
  11. x_train = x_train.astype('float32')
  12. x_test = x_test.astype('float32')
  13. x_train /= 255
  14. x_test /= 255
  15. # 转onehot
  16. y_train = tf.keras.utils.to_categorical(y_train)
  17. y_test = tf.keras.utils.to_categorical(y_test)
  18. # 扁平为1维
  19. y_train = y_train.reshape(y_train.shape[0], -1, 1)
  20. y_test = y_test.reshape(y_test.shape[0], -1, 1)
  21. return (x_train, y_train), (x_test, y_test)
  22. def get_result(vec):
  23. """ 手写数字识别 网络的输出是一个多维向量,这个向量第n个(从0开始编号)元素的值最大,那么n就是网络的识别结果 """
  24. max_value_index = 0
  25. max_value = 0
  26. for i in range(len(vec)):
  27. if vec[i] > max_value:
  28. max_value = vec[i]
  29. max_value_index = i
  30. return max_value_index
  31. def evaluate(network, test_data_set, test_labels):
  32. """ 使用正确率评估模型,比较直观 """
  33. error = 0
  34. total = len(test_data_set)
  35. for i in range(total):
  36. label = get_result(test_labels[i])
  37. predict = get_result(network.predict(test_data_set[i]))
  38. # print("预测值:%d, 实际值:%d " % (predict, label))
  39. if label != predict:
  40. error += 1
  41. return 1. - float(error) / float(total)
  42. def main():
  43. (x_train, y_train), (x_test, y_test) = load_dataset()
  44. y_train = y_train.reshape(y_train.shape[0], -1, 1)
  45. nn = Network([784, 300, 10])
  46. nn.fix(x_train, y_train, 0.5, 2)
  47. # 检查梯度
  48. nn.gradient_check(x_train[0], y_train[0])
  49. # 保存模型
  50. joblib.dump(nn, 'export/nn-vec-reconginise-images-train.pkl')
  51. # 查看正确率
  52. print(evaluate(nn, x_test, y_test)) # 训练2轮的准确率:0.9704
  53. if __name__ == '__main__':
  54. main()