mnist_data_processing.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import gzip
  2. import pickle
  3. import csv
  4. # unpickling the .pkl file
  5. with open('mnist.pkl', 'rb') as f:
  6. data = pickle.load(f)
  7. # dissecting the original dataset into three different datasets
  8. training_set = data[0]
  9. validation_set = data[1]
  10. testing_set = data[2]
  11. dataset_size = 50 # setting the truncated dataset size (max = 50000)
  12. '''
  13. Generating CSV files for truncated training, cross validation and testing sets
  14. First column: label
  15. Remaining columns: features
  16. Each individual row represents an image's pixel values (features) along with its digit (label)
  17. '''
  18. # generating a CSV file for training set
  19. with open('training_set.csv','wb') as out:
  20. csv_out = csv.writer(out)
  21. for i in range(dataset_size):
  22. current_row = [training_set[1][i]] + list(training_set[0][i])
  23. csv_out.writerow(current_row)
  24. # generating a CSV file for cross validation set
  25. with open('validation_set.csv','wb') as out:
  26. csv_out = csv.writer(out)
  27. for i in range(dataset_size):
  28. current_row = [validation_set[1][i]] + list(validation_set[0][i])
  29. csv_out.writerow(current_row)
  30. # generating a CSV file for testing set
  31. with open('testing_set.csv','wb') as out:
  32. csv_out = csv.writer(out)
  33. for i in range(dataset_size):
  34. current_row = [testing_set[1][i]] + list(testing_set[0][i])
  35. csv_out.writerow(current_row)