8-4-LogisticRegression-recognise-images.py 600 B

12345678910111213141516171819202122
  1. from datasets import Datasets
  2. from sklearn.linear_model import LogisticRegression
  3. from sklearn.model_selection import cross_val_score
  4. def main():
  5. # 加载MNIST数据
  6. train_data, valid_data, test_data = Datasets.load_mnist()
  7. x_train, y_train = train_data
  8. x_test, y_test = test_data
  9. # 逻辑回归训练并预测
  10. lr = LogisticRegression(solver='lbfgs', max_iter=2000)
  11. lr.fit(x_train, y_train)
  12. print(lr.score(x_test, y_test))
  13. scores = cross_val_score(lr, x_test, y_test, cv=10, scoring="accuracy")
  14. print(scores.mean())
  15. if __name__ == "__main__":
  16. main()