transform_cuhk03.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. """Refactor file directories, save/rename images and partition the
  2. train/val/test set, in order to support the unified dataset interface.
  3. """
  4. from __future__ import print_function
  5. import sys
  6. sys.path.insert(0, '.')
  7. from zipfile import ZipFile
  8. import os.path as osp
  9. import sys
  10. import h5py
  11. from scipy.misc import imsave
  12. from itertools import chain
  13. from bpm.utils.utils import may_make_dir
  14. from bpm.utils.utils import load_pickle
  15. from bpm.utils.utils import save_pickle
  16. from bpm.utils.dataset_utils import partition_train_val_set
  17. from bpm.utils.dataset_utils import new_im_name_tmpl
  18. from bpm.utils.dataset_utils import parse_im_name
  19. def save_images(mat_file, save_dir, new_im_name_tmpl):
  20. def deref(mat, ref):
  21. return mat[ref][:].T
  22. def dump(mat, refs, pid, cam, im_dir):
  23. """Save the images of a person under one camera."""
  24. for i, ref in enumerate(refs):
  25. im = deref(mat, ref)
  26. if im.size == 0 or im.ndim < 2: break
  27. fname = new_im_name_tmpl.format(pid, cam, i)
  28. imsave(osp.join(im_dir, fname), im)
  29. mat = h5py.File(mat_file, 'r')
  30. labeled_im_dir = osp.join(save_dir, 'labeled/images')
  31. detected_im_dir = osp.join(save_dir, 'detected/images')
  32. all_im_dir = osp.join(save_dir, 'all/images')
  33. may_make_dir(labeled_im_dir)
  34. may_make_dir(detected_im_dir)
  35. may_make_dir(all_im_dir)
  36. # loop through camera pairs
  37. pid = 0
  38. for labeled, detected in zip(mat['labeled'][0], mat['detected'][0]):
  39. labeled, detected = deref(mat, labeled), deref(mat, detected)
  40. assert labeled.shape == detected.shape
  41. # loop through ids in a camera pair
  42. for i in range(labeled.shape[0]):
  43. # We don't care about whether different persons are under same cameras,
  44. # we only care about the same person being under different cameras or not.
  45. dump(mat, labeled[i, :5], pid, 0, labeled_im_dir)
  46. dump(mat, labeled[i, 5:], pid, 1, labeled_im_dir)
  47. dump(mat, detected[i, :5], pid, 0, detected_im_dir)
  48. dump(mat, detected[i, 5:], pid, 1, detected_im_dir)
  49. dump(mat, chain(detected[i, :5], labeled[i, :5]), pid, 0, all_im_dir)
  50. dump(mat, chain(detected[i, 5:], labeled[i, 5:]), pid, 1, all_im_dir)
  51. pid += 1
  52. if pid % 100 == 0:
  53. sys.stdout.write('\033[F\033[K')
  54. print('Saving images {}/{}'.format(pid, 1467))
  55. def transform(zip_file, train_test_partition_file, save_dir=None):
  56. """Save images and partition the train/val/test set.
  57. """
  58. print("Extracting zip file")
  59. root = osp.dirname(osp.abspath(zip_file))
  60. if save_dir is None:
  61. save_dir = root
  62. may_make_dir(save_dir)
  63. with ZipFile(zip_file) as z:
  64. z.extractall(path=save_dir)
  65. print("Extracting zip file done")
  66. mat_file = osp.join(save_dir, osp.basename(zip_file)[:-4], 'cuhk-03.mat')
  67. save_images(mat_file, save_dir, new_im_name_tmpl)
  68. if osp.exists(train_test_partition_file):
  69. train_test_partition = load_pickle(train_test_partition_file)
  70. else:
  71. raise RuntimeError('Train/test partition file should be provided.')
  72. for im_type in ['detected', 'labeled']:
  73. trainval_im_names = train_test_partition[im_type]['train_im_names']
  74. trainval_ids = list(set([parse_im_name(n, 'id')
  75. for n in trainval_im_names]))
  76. # Sort ids, so that id-to-label mapping remains the same when running
  77. # the code on different machines.
  78. trainval_ids.sort()
  79. trainval_ids2labels = dict(zip(trainval_ids, range(len(trainval_ids))))
  80. train_val_partition = \
  81. partition_train_val_set(trainval_im_names, parse_im_name, num_val_ids=100)
  82. train_im_names = train_val_partition['train_im_names']
  83. train_ids = list(set([parse_im_name(n, 'id')
  84. for n in train_val_partition['train_im_names']]))
  85. # Sort ids, so that id-to-label mapping remains the same when running
  86. # the code on different machines.
  87. train_ids.sort()
  88. train_ids2labels = dict(zip(train_ids, range(len(train_ids))))
  89. # A mark is used to denote whether the image is from
  90. # query (mark == 0), or
  91. # gallery (mark == 1), or
  92. # multi query (mark == 2) set
  93. val_marks = [0, ] * len(train_val_partition['val_query_im_names']) \
  94. + [1, ] * len(train_val_partition['val_gallery_im_names'])
  95. val_im_names = list(train_val_partition['val_query_im_names']) \
  96. + list(train_val_partition['val_gallery_im_names'])
  97. test_im_names = list(train_test_partition[im_type]['query_im_names']) \
  98. + list(train_test_partition[im_type]['gallery_im_names'])
  99. test_marks = [0, ] * len(train_test_partition[im_type]['query_im_names']) \
  100. + [1, ] * len(
  101. train_test_partition[im_type]['gallery_im_names'])
  102. partitions = {'trainval_im_names': trainval_im_names,
  103. 'trainval_ids2labels': trainval_ids2labels,
  104. 'train_im_names': train_im_names,
  105. 'train_ids2labels': train_ids2labels,
  106. 'val_im_names': val_im_names,
  107. 'val_marks': val_marks,
  108. 'test_im_names': test_im_names,
  109. 'test_marks': test_marks}
  110. partition_file = osp.join(save_dir, im_type, 'partitions.pkl')
  111. save_pickle(partitions, partition_file)
  112. print('Partition file for "{}" saved to {}'.format(im_type, partition_file))
  113. if __name__ == '__main__':
  114. import argparse
  115. parser = argparse.ArgumentParser(description="Transform CUHK03 Dataset")
  116. parser.add_argument(
  117. '--zip_file',
  118. type=str,
  119. default='~/Dataset/cuhk03/cuhk03_release.zip')
  120. parser.add_argument(
  121. '--save_dir',
  122. type=str,
  123. default='~/Dataset/cuhk03')
  124. parser.add_argument(
  125. '--train_test_partition_file',
  126. type=str,
  127. default='~/Dataset/cuhk03/re_ranking_train_test_split.pkl')
  128. args = parser.parse_args()
  129. zip_file = osp.abspath(osp.expanduser(args.zip_file))
  130. train_test_partition_file = osp.abspath(osp.expanduser(
  131. args.train_test_partition_file))
  132. save_dir = osp.abspath(osp.expanduser(args.save_dir))
  133. transform(zip_file, train_test_partition_file, save_dir)