12345678910111213141516171819202122232425262728293031323334353637383940414243 |
- import gzip
- import pickle
- import csv
- with open('mnist.pkl', 'rb') as f:
- data = pickle.load(f)
- training_set = data[0]
- validation_set = data[1]
- testing_set = data[2]
- dataset_size = 50
- '''
- Generating CSV files for truncated training, cross validation and testing sets
- First column: label
- Remaining columns: features
- Each individual row represents an image's pixel values (features) along with its digit (label)
- '''
- with open('training_set.csv','wb') as out:
- csv_out = csv.writer(out)
- for i in range(dataset_size):
- current_row = [training_set[1][i]] + list(training_set[0][i])
- csv_out.writerow(current_row)
- with open('validation_set.csv','wb') as out:
- csv_out = csv.writer(out)
- for i in range(dataset_size):
- current_row = [validation_set[1][i]] + list(validation_set[0][i])
- csv_out.writerow(current_row)
- with open('testing_set.csv','wb') as out:
- csv_out = csv.writer(out)
- for i in range(dataset_size):
- current_row = [testing_set[1][i]] + list(testing_set[0][i])
- csv_out.writerow(current_row)
|