nn-recognise-images.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from model.neural_network import Network
  2. import tensorflow as tf
  3. import joblib
  4. def image2vec(picture):
  5. """ 将图像转化为样本的输入向量 """
  6. sample = []
  7. for i in range(3):
  8. for j in range(1):
  9. sample.append(picture[i][j])
  10. return sample
  11. def load_dataset():
  12. """ 导入mnist数据 """
  13. (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
  14. # 转为4D tensor,MNIST是灰度的,所以我们只有一个通道
  15. x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
  16. x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
  17. # # 归一化
  18. x_train = x_train.astype('float32')
  19. x_test = x_test.astype('float32')
  20. x_train /= 255
  21. x_test /= 255
  22. # 转onehot
  23. y_train = tf.keras.utils.to_categorical(y_train)
  24. y_test = tf.keras.utils.to_categorical(y_test)
  25. return (x_train.tolist(), y_train.tolist()), (x_test.tolist(), y_test.tolist())
  26. def get_result(vec):
  27. """ 手写数字识别 网络的输出是一个多维向量,这个向量第n个(从0开始编号)元素的值最大,那么n就是网络的识别结果 """
  28. max_value_index = 0
  29. max_value = 0
  30. for i in range(len(vec)):
  31. if vec[i] > max_value:
  32. max_value = vec[i]
  33. max_value_index = i
  34. return max_value_index
  35. def evaluate(network, test_data_set, test_labels):
  36. """ 使用正确率评估模型,比较直观 """
  37. error = 0
  38. total = len(test_data_set)
  39. for i in range(total):
  40. label = get_result(test_labels[i])
  41. predict = get_result(network.predict(test_data_set[i]))
  42. if label != predict:
  43. error += 1
  44. return 1. - float(error) / float(total)
  45. def main():
  46. x_train = [
  47. [[0.], [0.], [0.]],
  48. [[1.], [1.], [1.]],
  49. [[2.], [2.], [2.]],
  50. [[3.], [3.], [3.]]
  51. ]
  52. y_train = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
  53. # (x_train, y_train), (x_test, y_test) = load_dataset()
  54. # x_train = x_train[:10]
  55. # y_train = y_train[:10]
  56. x_train = [image2vec(i) for i in x_train]
  57. nn = Network([3, 2, 4])
  58. nn.fix(x_train, y_train, 0.5, 10)
  59. joblib.dump(nn, "export/nn-train.pkl")
  60. print(evaluate(nn, x_train, y_train))
  61. print(get_result(nn.predict([1., 1., 1.])))
  62. print(get_result(nn.predict([2., 2., 2.])))
  63. print(get_result(nn.predict([3., 3., 3.])))
  64. if __name__ == '__main__':
  65. main()