123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- from config import FLAGS
- import numpy as np
- import math
- import h5py
- import glob
- from util.utils import load_nifti, save_nifti
- import os
- def rotate_flip(data, r=0, f_lf=False):
- #rotate 90
- data = np.rot90(data,r)
- if f_lf:
- data = np.fliplr(data)
- return data
- def create_hdf5(img_data, t2_data, img_label, save_path):
- assert img_data.shape == img_label.shape, 'shape of data and label must be the same..'
- f = h5py.File(save_path, "w")
- dset = f.create_dataset("t1data", img_data.shape, dtype=np.int16)
- tset = f.create_dataset("t2data", t2_data.shape, dtype=np.int16)
- lset = f.create_dataset("label", img_data.shape, dtype=np.uint8)
- dset[...] = img_data
- lset[...] = img_label
- tset[...] = t2_data
- print('saved hdf5 file in %s' % (save_path, ))
- f.close()
- def get_nifti_path():
- t1_path, t2_path, label_path = '', '', ''
- dir_list = glob.glob('%s/*/' %(FLAGS.train_data_dir,))
- # print dir_list, '....'
- for _dir in dir_list:
- # file_list = glob.glob('%s/*.nii' % (_dir, ))
- img_id = _dir.split('/')[-2]
- t1_path = '%s%s-T1.nii.gz' %(_dir, img_id)
- t2_path = '%s%s-T2.nii.gz' %(_dir, img_id)
- label_path = '%s%s-label.nii.gz' %(_dir, img_id)
- yield t1_path, t2_path, label_path
-
-
- def remove_backgrounds(img_data, t2_data, img_label):
- nonzero_label = img_label != 0
- nonzero_label = np.asarray(nonzero_label)
- nonzero_index = np.nonzero(nonzero_label)
- nonzero_index = np.asarray(nonzero_index)
- x_min, x_max = nonzero_index[0,:].min(), nonzero_index[0,:].max()
- y_min, y_max = nonzero_index[1,:].min(), nonzero_index[1,:].max()
- z_min, z_max = nonzero_index[2,:].min(), nonzero_index[2,:].max()
- # print x_min, x_max
- # print y_min, y_max
- # print z_min, z_max
- x_min = x_min - FLAGS.prepost_pad if x_min-FLAGS.prepost_pad>=0 else 0
- y_min = y_min - FLAGS.prepost_pad if y_min-FLAGS.prepost_pad>=0 else 0
- z_min = z_min - FLAGS.prepost_pad if z_min-FLAGS.prepost_pad>=0 else 0
- x_max = x_max + FLAGS.prepost_pad if x_max+FLAGS.prepost_pad<=img_data.shape[0] else img_data.shape[0]
- y_max = y_max + FLAGS.prepost_pad if y_max+FLAGS.prepost_pad<=img_data.shape[1] else img_data.shape[1]
- z_max = z_max + FLAGS.prepost_pad if z_max+FLAGS.prepost_pad<=img_data.shape[2] else img_data.shape[2]
- return (img_data[x_min:x_max, y_min:y_max, z_min:z_max], t2_data[x_min:x_max, y_min:y_max, z_min:z_max],
- img_label[x_min:x_max, y_min:y_max, z_min:z_max])
- def generate_nifti_data():
- for img_path, t2_path, label_path in get_nifti_path():
- nifti_data, nifti_img = load_nifti(img_path)
- t2_data, t2_img = load_nifti(t2_path)
- nifti_label, _label = load_nifti(label_path)
- img_id = img_path.split('/')[-2]
- if len(nifti_data.shape)==3:
- pass
- elif len(nifti_data.shape)==4:
- nifti_data = nifti_data[:,:,:,0]
- t2_data = t2_data[:,:,:,0]
- nifti_label = nifti_label[:,:,:,0]
-
- t1_data = np.asarray(nifti_data, np.int16)
- t2_data = np.asarray(t2_data, np.int16)
-
- nifti_label = np.asarray(nifti_label, np.uint8)
- nifti_label[nifti_label==10] = 1
- nifti_label[nifti_label==150] = 2
- nifti_label[nifti_label==250] = 3
- croped_data, t2_data, croped_label = remove_backgrounds(t1_data,t2_data, nifti_label)
-
- t1_name = img_path.split('/')[-1].replace('.nii.gz', '')
- t2_name = t2_path.split('/')[-1].replace('.nii.gz', '')
-
-
- for _r in xrange(4):
- for flip in [True, False]:
- save_path = '%s/%s_r%d_f%d.h5' %(FLAGS.hdf5_dir, img_id, _r, flip)
- print ('>> start to creat hdf5: %s' % (save_path,))
- aug_data = rotate_flip(croped_data, r=_r, f_lf=flip )
- aug_label = rotate_flip(croped_label, r=_r, f_lf=flip )
- aug_t2_data = rotate_flip(t2_data, r=_r, f_lf=flip)
-
- create_hdf5(aug_data,aug_t2_data, aug_label, save_path)
- save_nifit_path = '%s/%s_r%d_f%d_data.nii' % (FLAGS.hdf5_dir, t1_name,_r, flip )
- save_nifit_label_path = '%s/%s_r%d_f%d_label.nii' % (FLAGS.hdf5_dir, img_id, _r, flip)
- t2_path = '%s/%s_r%d_f%d_data.nii' % (FLAGS.hdf5_dir, t2_name, _r, flip)
- # break
- def generate_file_list():
- # if os.pa
- file_list = glob.glob('%s/*.h5' %(FLAGS.hdf5_dir,))
- file_list.sort()
- with open(FLAGS.hdf5_list_path, 'w') as _file:
- for _file_path in file_list:
- _file.write(_file_path)
- _file.write('\n')
- with open(FLAGS.hdf5_train_list_path, 'w') as _file:
- for _file_path in file_list[8:]:
- _file.write(_file_path)
- _file.write('\n')
- with open(FLAGS.hdf5_validation_list_path, 'w') as _file:
- for _file_path in file_list[0:8]:
- _file.write(_file_path)
- _file.write('\n')
-
- def main():
-
- generate_nifti_data()
- generate_file_list()
-
-
- if __name__ == '__main__':
- main()
|