preprocess.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from config import FLAGS
  2. import numpy as np
  3. import math
  4. import h5py
  5. import glob
  6. from util.utils import load_nifti, save_nifti
  7. import os
  8. def rotate_flip(data, r=0, f_lf=False):
  9. #rotate 90
  10. data = np.rot90(data,r)
  11. if f_lf:
  12. data = np.fliplr(data)
  13. return data
  14. def create_hdf5(img_data, t2_data, img_label, save_path):
  15. assert img_data.shape == img_label.shape, 'shape of data and label must be the same..'
  16. f = h5py.File(save_path, "w")
  17. dset = f.create_dataset("t1data", img_data.shape, dtype=np.int16)
  18. tset = f.create_dataset("t2data", t2_data.shape, dtype=np.int16)
  19. lset = f.create_dataset("label", img_data.shape, dtype=np.uint8)
  20. dset[...] = img_data
  21. lset[...] = img_label
  22. tset[...] = t2_data
  23. print('saved hdf5 file in %s' % (save_path, ))
  24. f.close()
  25. def get_nifti_path():
  26. t1_path, t2_path, label_path = '', '', ''
  27. dir_list = glob.glob('%s/*/' %(FLAGS.train_data_dir,))
  28. # print dir_list, '....'
  29. for _dir in dir_list:
  30. # file_list = glob.glob('%s/*.nii' % (_dir, ))
  31. img_id = _dir.split('/')[-2]
  32. t1_path = '%s%s-T1.nii.gz' %(_dir, img_id)
  33. t2_path = '%s%s-T2.nii.gz' %(_dir, img_id)
  34. label_path = '%s%s-label.nii.gz' %(_dir, img_id)
  35. yield t1_path, t2_path, label_path
  36. def remove_backgrounds(img_data, t2_data, img_label):
  37. nonzero_label = img_label != 0
  38. nonzero_label = np.asarray(nonzero_label)
  39. nonzero_index = np.nonzero(nonzero_label)
  40. nonzero_index = np.asarray(nonzero_index)
  41. x_min, x_max = nonzero_index[0,:].min(), nonzero_index[0,:].max()
  42. y_min, y_max = nonzero_index[1,:].min(), nonzero_index[1,:].max()
  43. z_min, z_max = nonzero_index[2,:].min(), nonzero_index[2,:].max()
  44. # print x_min, x_max
  45. # print y_min, y_max
  46. # print z_min, z_max
  47. x_min = x_min - FLAGS.prepost_pad if x_min-FLAGS.prepost_pad>=0 else 0
  48. y_min = y_min - FLAGS.prepost_pad if y_min-FLAGS.prepost_pad>=0 else 0
  49. z_min = z_min - FLAGS.prepost_pad if z_min-FLAGS.prepost_pad>=0 else 0
  50. x_max = x_max + FLAGS.prepost_pad if x_max+FLAGS.prepost_pad<=img_data.shape[0] else img_data.shape[0]
  51. y_max = y_max + FLAGS.prepost_pad if y_max+FLAGS.prepost_pad<=img_data.shape[1] else img_data.shape[1]
  52. z_max = z_max + FLAGS.prepost_pad if z_max+FLAGS.prepost_pad<=img_data.shape[2] else img_data.shape[2]
  53. return (img_data[x_min:x_max, y_min:y_max, z_min:z_max], t2_data[x_min:x_max, y_min:y_max, z_min:z_max],
  54. img_label[x_min:x_max, y_min:y_max, z_min:z_max])
  55. def generate_nifti_data():
  56. for img_path, t2_path, label_path in get_nifti_path():
  57. nifti_data, nifti_img = load_nifti(img_path)
  58. t2_data, t2_img = load_nifti(t2_path)
  59. nifti_label, _label = load_nifti(label_path)
  60. img_id = img_path.split('/')[-2]
  61. if len(nifti_data.shape)==3:
  62. pass
  63. elif len(nifti_data.shape)==4:
  64. nifti_data = nifti_data[:,:,:,0]
  65. t2_data = t2_data[:,:,:,0]
  66. nifti_label = nifti_label[:,:,:,0]
  67. t1_data = np.asarray(nifti_data, np.int16)
  68. t2_data = np.asarray(t2_data, np.int16)
  69. nifti_label = np.asarray(nifti_label, np.uint8)
  70. nifti_label[nifti_label==10] = 1
  71. nifti_label[nifti_label==150] = 2
  72. nifti_label[nifti_label==250] = 3
  73. croped_data, t2_data, croped_label = remove_backgrounds(t1_data,t2_data, nifti_label)
  74. t1_name = img_path.split('/')[-1].replace('.nii.gz', '')
  75. t2_name = t2_path.split('/')[-1].replace('.nii.gz', '')
  76. for _r in xrange(4):
  77. for flip in [True, False]:
  78. save_path = '%s/%s_r%d_f%d.h5' %(FLAGS.hdf5_dir, img_id, _r, flip)
  79. print ('>> start to creat hdf5: %s' % (save_path,))
  80. aug_data = rotate_flip(croped_data, r=_r, f_lf=flip )
  81. aug_label = rotate_flip(croped_label, r=_r, f_lf=flip )
  82. aug_t2_data = rotate_flip(t2_data, r=_r, f_lf=flip)
  83. create_hdf5(aug_data,aug_t2_data, aug_label, save_path)
  84. save_nifit_path = '%s/%s_r%d_f%d_data.nii' % (FLAGS.hdf5_dir, t1_name,_r, flip )
  85. save_nifit_label_path = '%s/%s_r%d_f%d_label.nii' % (FLAGS.hdf5_dir, img_id, _r, flip)
  86. t2_path = '%s/%s_r%d_f%d_data.nii' % (FLAGS.hdf5_dir, t2_name, _r, flip)
  87. # break
  88. def generate_file_list():
  89. # if os.pa
  90. file_list = glob.glob('%s/*.h5' %(FLAGS.hdf5_dir,))
  91. file_list.sort()
  92. with open(FLAGS.hdf5_list_path, 'w') as _file:
  93. for _file_path in file_list:
  94. _file.write(_file_path)
  95. _file.write('\n')
  96. with open(FLAGS.hdf5_train_list_path, 'w') as _file:
  97. for _file_path in file_list[8:]:
  98. _file.write(_file_path)
  99. _file.write('\n')
  100. with open(FLAGS.hdf5_validation_list_path, 'w') as _file:
  101. for _file_path in file_list[0:8]:
  102. _file.write(_file_path)
  103. _file.write('\n')
  104. def main():
  105. generate_nifti_data()
  106. generate_file_list()
  107. if __name__ == '__main__':
  108. main()