layer_matching.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. import os
  2. import warnings
  3. warnings.filterwarnings("ignore")
  4. warnings.filterwarnings("ignore")
  5. os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2' # 只显示 warning 和 Error
  6. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  7. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  8. class LayerMatching:
  9. concat_size_limit = 1e4
  10. def __init__(self):
  11. self.layers = {}
  12. self.constraints = {}
  13. self.layers['flatten'] = LayerMatching.flatten
  14. self.constraints['flatten'] = LayerMatching.flatten_constraints
  15. self.layer_concats = {}
  16. self.input_legal = {}
  17. self.layer_concats['flatten'] = LayerMatching.flatten_dense
  18. self.input_legal['flatten'] = LayerMatching.flatten_dense_input_legal
  19. self.layer_concats['repeat_vector'] = LayerMatching.repeat_vector_dense
  20. self.input_legal['repeat_vector'] = LayerMatching.repeat_vector_dense_input_legal
  21. self.layer_concats['cropping1d'] = LayerMatching.cropping1d_dense
  22. self.input_legal['cropping1d'] = LayerMatching.cropping1d_dense_input_legal
  23. self.layer_concats['cropping2d'] = LayerMatching.cropping2d_dense
  24. self.input_legal['cropping2d'] = LayerMatching.cropping2d_dense_input_legal
  25. self.layer_concats['cropping3d'] = LayerMatching.cropping3d_dense
  26. self.input_legal['cropping3d'] = LayerMatching.cropping3d_dense_input_legal
  27. self.layer_concats['upsampling_1d'] = LayerMatching.upsampling_1d_dense
  28. self.input_legal['upsampling_1d'] = LayerMatching.upsampling_1d_dense_input_legal
  29. self.layer_concats['upsampling_2d'] = LayerMatching.upsampling_2d_dense
  30. self.input_legal['upsampling_2d'] = LayerMatching.upsampling_2d_dense_input_legal
  31. self.layer_concats['upsampling_3d'] = LayerMatching.upsampling_3d_dense
  32. self.input_legal['upsampling_3d'] = LayerMatching.upsampling_3d_dense_input_legal
  33. self.layer_concats['zeropadding_1d'] = LayerMatching.zeropadding_1d_conv
  34. self.input_legal['zeropadding_1d'] = LayerMatching.zeropadding_1d_conv_input_legal
  35. self.layer_concats['zeropadding_2d'] = LayerMatching.zeropadding_2d_conv
  36. self.input_legal['zeropadding_2d'] = LayerMatching.zeropadding_2d_conv_input_legal
  37. self.layer_concats['zeropadding_3d'] = LayerMatching.zeropadding_3d_conv
  38. self.input_legal['zeropadding_3d'] = LayerMatching.zeropadding_3d_conv_input_legal
  39. self.layer_concats['global_max_pooling_1d'] = LayerMatching.global_max_pooling_1d_dense
  40. self.input_legal['global_max_pooling_1d'] = LayerMatching.global_pooling_1d_dense_input_legal
  41. self.layer_concats['global_average_pooling_1d'] = LayerMatching.global_average_pooling_1d_dense
  42. self.input_legal['global_average_pooling_1d'] = LayerMatching.global_pooling_1d_dense_input_legal
  43. self.layer_concats['global_max_pooling_2d'] = LayerMatching.global_max_pooling_2d_dense
  44. self.input_legal['global_max_pooling_2d'] = LayerMatching.global_pooling_2d_dense_input_legal
  45. self.layer_concats['global_average_pooling_2d'] = LayerMatching.global_average_pooling_2d_dense
  46. self.input_legal['global_average_pooling_2d'] = LayerMatching.global_pooling_2d_dense_input_legal
  47. self.layer_concats['global_max_pooling_3d'] = LayerMatching.global_max_pooling_3d_dense
  48. self.input_legal['global_max_pooling_3d'] = LayerMatching.global_pooling_3d_dense_input_legal
  49. self.layer_concats['global_average_pooling_3d'] = LayerMatching.global_average_pooling_3d_dense
  50. self.input_legal['global_average_pooling_3d'] = LayerMatching.global_pooling_3d_dense_input_legal
  51. self.layer_concats['simple_rnn'] = LayerMatching.simple_rnn_dense
  52. self.input_legal['simple_rnn'] = LayerMatching.simple_rnn_dense_input_legal
  53. self.layer_concats['gru'] = LayerMatching.gru_dense
  54. self.input_legal['gru'] = LayerMatching.gru_dense_input_legal
  55. self.layer_concats['lstm'] = LayerMatching.lstm_dense
  56. self.input_legal['lstm'] = LayerMatching.lstm_dense_input_legal
  57. self.layer_concats['conv_lstm_2d'] = LayerMatching.conv_lstm_2d_dense
  58. self.input_legal['conv_lstm_2d'] = LayerMatching.conv_lstm_2d_dense_input_legal
  59. @staticmethod
  60. def flatten(input_shape):
  61. import keras
  62. return keras.layers.Flatten()
  63. @staticmethod
  64. def flatten_constraints(input_shape):
  65. input_shape = input_shape.as_list()
  66. input_shape_len = len(input_shape)
  67. constraints = []
  68. if input_shape_len < 2:
  69. return None
  70. constraints = []
  71. dim_size = 1
  72. for i in range(input_shape_len):
  73. if i == 0:
  74. continue
  75. constraints.append('= input_{} {}'.format(i, input_shape[i]))
  76. dim_size *= input_shape[i]
  77. constraint_str = '= output_{} {}'.format(1, dim_size)
  78. constraints.append(constraint_str)
  79. return constraints
  80. # --------------------------------------------
  81. @staticmethod
  82. def flatten_dense(input_shape):
  83. import keras
  84. layer_concat = []
  85. layer_concat.append(keras.layers.Flatten())
  86. units = 1
  87. for i in range(len(input_shape)):
  88. if i == 0:
  89. continue
  90. units *= input_shape[i]
  91. layer_concat.append(keras.layers.Dense(units))
  92. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  93. return layer_concat
  94. @staticmethod
  95. def flatten_dense_input_legal(input_shape):
  96. input_shape = input_shape.as_list()
  97. is_legal = len(input_shape) > 3 and input_shape[0] is None
  98. concat_size = 1
  99. for i, dim in enumerate(input_shape):
  100. if i == 0:
  101. continue
  102. is_legal = is_legal and dim is not None
  103. if dim is not None:
  104. concat_size *= dim
  105. return is_legal and concat_size <= LayerMatching.concat_size_limit
  106. @staticmethod
  107. def repeat_vector_dense(input_shape):
  108. n = 3
  109. import keras
  110. layer_concat = []
  111. layer_concat.append(keras.layers.RepeatVector(n))
  112. layer_concat.append(keras.layers.Reshape((input_shape[1] * n,)))
  113. layer_concat.append(keras.layers.Dense(input_shape[1]))
  114. return layer_concat
  115. @staticmethod
  116. def repeat_vector_dense_input_legal(input_shape):
  117. input_shape = input_shape.as_list()
  118. return len(input_shape) == 2 and input_shape[0] is None and input_shape[1] is not None \
  119. and input_shape[1] <= LayerMatching.concat_size_limit
  120. @staticmethod
  121. def cropping1d_dense(input_shape):
  122. import keras
  123. layer_concat = []
  124. layer_concat.append(keras.layers.Cropping1D(cropping=(1, 1)))
  125. layer_concat.append(keras.layers.Dense(input_shape[1]))
  126. return layer_concat
  127. @staticmethod
  128. def cropping1d_dense_input_legal(input_shape):
  129. input_shape = input_shape.as_list()
  130. return len(input_shape) == 3 and input_shape[0] is None and input_shape[1] is not None and input_shape[1] > 2 \
  131. and input_shape[2] is not None and input_shape[1] * input_shape[2] <= LayerMatching.concat_size_limit
  132. @staticmethod
  133. def cropping2d_dense(input_shape):
  134. import keras
  135. layer_concat = []
  136. layer_concat.append(keras.layers.Cropping2D(cropping=((1, 1), (1, 1))))
  137. layer_concat.append(keras.layers.Reshape(((input_shape[1] - 2) * (input_shape[2] - 2) * input_shape[3],)))
  138. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2] * input_shape[3]))
  139. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  140. return layer_concat
  141. @staticmethod
  142. def cropping2d_dense_input_legal(input_shape):
  143. input_shape = input_shape.as_list()
  144. return len(input_shape) == 4 and input_shape[0] is None \
  145. and input_shape[1] is not None and input_shape[1] > 2 \
  146. and input_shape[2] is not None and input_shape[2] > 2 \
  147. and input_shape[3] is not None \
  148. and input_shape[1] * input_shape[2] * input_shape[3] <= LayerMatching.concat_size_limit
  149. @staticmethod
  150. def cropping3d_dense(input_shape):
  151. import keras
  152. layer_concat = []
  153. layer_concat.append(keras.layers.Cropping3D(cropping=((1, 1), (1, 1), (1, 1))))
  154. layer_concat.append(keras.layers.Reshape(((input_shape[1] - 2) * (input_shape[2] - 2) * (input_shape[3] - 2) * input_shape[4],)))
  155. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2] * input_shape[3] * input_shape[4]))
  156. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  157. return layer_concat
  158. @staticmethod
  159. def cropping3d_dense_input_legal(input_shape):
  160. input_shape = input_shape.as_list()
  161. return len(input_shape) == 5 and input_shape[0] is None \
  162. and input_shape[1] is not None and input_shape[1] > 2 \
  163. and input_shape[2] is not None and input_shape[2] > 2 \
  164. and input_shape[3] is not None and input_shape[3] > 2 \
  165. and input_shape[4] is not None \
  166. and input_shape[1] * input_shape[2] * input_shape[3] * input_shape[4] <= LayerMatching.concat_size_limit
  167. @staticmethod
  168. def upsampling_1d_dense(input_shape):
  169. import keras
  170. layer_concat = []
  171. layer_concat.append(keras.layers.UpSampling1D(size=2))
  172. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2]))
  173. return layer_concat
  174. @staticmethod
  175. def upsampling_1d_dense_input_legal(input_shape):
  176. input_shape = input_shape.as_list()
  177. return len(input_shape) == 3 and input_shape[0] is None and input_shape[1] is not None \
  178. and input_shape[2] is not None and input_shape[1] * input_shape[2] <= LayerMatching.concat_size_limit
  179. @staticmethod
  180. def upsampling_2d_dense(input_shape):
  181. import keras
  182. layer_concat = []
  183. layer_concat.append(keras.layers.UpSampling2D(size=(2, 2)))
  184. layer_concat.append(keras.layers.Flatten())
  185. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2] * input_shape[3]))
  186. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  187. return layer_concat
  188. @staticmethod
  189. def upsampling_2d_dense_input_legal(input_shape):
  190. input_shape = input_shape.as_list()
  191. return len(input_shape) == 4 and input_shape[0] is None \
  192. and input_shape[1] is not None and input_shape[2] is not None and input_shape[3] is not None \
  193. and input_shape[1] * input_shape[2] * input_shape[3] <= LayerMatching.concat_size_limit
  194. @staticmethod
  195. def upsampling_3d_dense(input_shape):
  196. import keras
  197. layer_concat = []
  198. layer_concat.append(keras.layers.UpSampling3D(size=(2, 2, 2)))
  199. layer_concat.append(keras.layers.Flatten())
  200. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2] * input_shape[3] * input_shape[4]))
  201. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  202. return layer_concat
  203. @staticmethod
  204. def upsampling_3d_dense_input_legal(input_shape):
  205. input_shape = input_shape.as_list()
  206. return len(input_shape) == 5 and input_shape[0] is None \
  207. and input_shape[1] is not None \
  208. and input_shape[2] is not None \
  209. and input_shape[3] is not None \
  210. and input_shape[4] is not None \
  211. and input_shape[1] * input_shape[2] * input_shape[3] * input_shape[4] <= LayerMatching.concat_size_limit
  212. @staticmethod
  213. def zeropadding_1d_conv(input_shape):
  214. import keras
  215. layer_concat = []
  216. layer_concat.append(keras.layers.ZeroPadding1D(padding=1))
  217. layer_concat.append(keras.layers.Conv1D(input_shape[-1], 3))
  218. return layer_concat
  219. @staticmethod
  220. def zeropadding_1d_conv_input_legal(input_shape):
  221. input_shape = input_shape.as_list()
  222. return len(input_shape) == 3 and input_shape[0] is None \
  223. and input_shape[1] is not None and input_shape[2] is not None \
  224. and input_shape[1] * input_shape[2] <= LayerMatching.concat_size_limit
  225. @staticmethod
  226. def zeropadding_2d_conv(input_shape):
  227. import keras
  228. layer_concat = []
  229. layer_concat.append(keras.layers.ZeroPadding2D(padding=(1, 1)))
  230. layer_concat.append(keras.layers.Conv2D(input_shape[-1], 3))
  231. return layer_concat
  232. @staticmethod
  233. def zeropadding_2d_conv_input_legal(input_shape):
  234. input_shape = input_shape.as_list()
  235. return len(input_shape) == 4 and input_shape[0] is None \
  236. and input_shape[1] is not None \
  237. and input_shape[2] is not None \
  238. and input_shape[3] is not None \
  239. and input_shape[1] * input_shape[2] * input_shape[3] <= LayerMatching.concat_size_limit
  240. @staticmethod
  241. def zeropadding_3d_conv(input_shape):
  242. import keras
  243. layer_concat = []
  244. layer_concat.append(keras.layers.ZeroPadding3D(padding=(1, 1, 1)))
  245. layer_concat.append(keras.layers.Conv3D(input_shape[-1], 3))
  246. return layer_concat
  247. @staticmethod
  248. def zeropadding_3d_conv_input_legal(input_shape):
  249. input_shape = input_shape.as_list()
  250. return len(input_shape) == 5 and input_shape[0] is None \
  251. and input_shape[1] is not None \
  252. and input_shape[2] is not None \
  253. and input_shape[3] is not None \
  254. and input_shape[4] is not None \
  255. and input_shape[1] * input_shape[2] * input_shape[3] * input_shape[4] <= LayerMatching.concat_size_limit
  256. @staticmethod
  257. def global_max_pooling_1d_dense(input_shape):
  258. import keras
  259. layer_concat = []
  260. layer_concat.append(keras.layers.GlobalMaxPooling1D())
  261. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2]))
  262. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  263. return layer_concat
  264. @staticmethod
  265. def global_average_pooling_1d_dense(input_shape):
  266. import keras
  267. layer_concat = []
  268. layer_concat.append(keras.layers.GlobalAveragePooling1D())
  269. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2]))
  270. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  271. return layer_concat
  272. @staticmethod
  273. def global_pooling_1d_dense_input_legal(input_shape):
  274. input_shape = input_shape.as_list()
  275. return len(input_shape) == 3 and input_shape[0] is None and input_shape[1] is not None \
  276. and input_shape[2] is not None and input_shape[1] * input_shape[2] <= LayerMatching.concat_size_limit
  277. @staticmethod
  278. def global_max_pooling_2d_dense(input_shape):
  279. import keras
  280. layer_concat = []
  281. layer_concat.append(keras.layers.GlobalMaxPooling2D())
  282. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2] * input_shape[3]))
  283. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  284. return layer_concat
  285. @staticmethod
  286. def global_average_pooling_2d_dense(input_shape):
  287. import keras
  288. layer_concat = []
  289. layer_concat.append(keras.layers.GlobalAveragePooling2D())
  290. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2] * input_shape[3]))
  291. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  292. return layer_concat
  293. @staticmethod
  294. def global_pooling_2d_dense_input_legal(input_shape):
  295. input_shape = input_shape.as_list()
  296. return len(input_shape) == 4 and input_shape[0] is None \
  297. and input_shape[1] is not None \
  298. and input_shape[2] is not None \
  299. and input_shape[3] is not None \
  300. and input_shape[1] * input_shape[2] * input_shape[3] <= LayerMatching.concat_size_limit
  301. @staticmethod
  302. def global_max_pooling_3d_dense(input_shape):
  303. import keras
  304. layer_concat = []
  305. layer_concat.append(keras.layers.GlobalMaxPooling3D())
  306. layer_concat.append(keras.layers.Flatten())
  307. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2] * input_shape[3] * input_shape[4]))
  308. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  309. return layer_concat
  310. @staticmethod
  311. def global_average_pooling_3d_dense(input_shape):
  312. import keras
  313. layer_concat = []
  314. layer_concat.append(keras.layers.GlobalAveragePooling3D())
  315. layer_concat.append(keras.layers.Flatten())
  316. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2] * input_shape[3] * input_shape[4]))
  317. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  318. return layer_concat
  319. @staticmethod
  320. def global_pooling_3d_dense_input_legal(input_shape):
  321. input_shape = input_shape.as_list()
  322. return len(input_shape) == 5 and input_shape[0] is None \
  323. and input_shape[1] is not None \
  324. and input_shape[2] is not None \
  325. and input_shape[3] is not None \
  326. and input_shape[4] is not None \
  327. and input_shape[1] * input_shape[2] * input_shape[3] * input_shape[4] <= LayerMatching.concat_size_limit
  328. @staticmethod
  329. def simple_rnn_dense(input_shape):
  330. import keras
  331. layer_concat = []
  332. layer_concat.append(keras.layers.SimpleRNN(50))
  333. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2]))
  334. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  335. return layer_concat
  336. @staticmethod
  337. def simple_rnn_dense_input_legal(input_shape):
  338. input_shape = input_shape.as_list()
  339. return len(input_shape) == 3 and input_shape[0] is None \
  340. and input_shape[1] is not None \
  341. and input_shape[2] is not None \
  342. and input_shape[1] * input_shape[2] <= LayerMatching.concat_size_limit
  343. @staticmethod
  344. def gru_dense(input_shape):
  345. import keras
  346. layer_concat = []
  347. layer_concat.append(keras.layers.GRU(50))
  348. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2]))
  349. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  350. return layer_concat
  351. @staticmethod
  352. def gru_dense_input_legal(input_shape):
  353. input_shape = input_shape.as_list()
  354. return len(input_shape) == 3 and input_shape[0] is None and input_shape[1] is not None \
  355. and input_shape[2] is not None and input_shape[1] * input_shape[2] <= LayerMatching.concat_size_limit
  356. @staticmethod
  357. def lstm_dense(input_shape):
  358. import keras
  359. layer_concat = []
  360. layer_concat.append(keras.layers.LSTM(50))
  361. layer_concat.append(keras.layers.Dense(input_shape[1] * input_shape[2]))
  362. layer_concat.append(keras.layers.Reshape(input_shape[1:]))
  363. return layer_concat
  364. @staticmethod
  365. def lstm_dense_input_legal(input_shape):
  366. input_shape = input_shape.as_list()
  367. return len(input_shape) == 3 and input_shape[0] is None and input_shape[1] is not None \
  368. and input_shape[2] is not None and input_shape[1] * input_shape[2] <= LayerMatching.concat_size_limit
  369. @staticmethod
  370. def conv_lstm_2d_dense(input_shape):
  371. import keras
  372. layer_concat = []
  373. layer_concat.append(keras.layers.ConvLSTM2D(input_shape[-1], kernel_size=(1, 1), strides=(1, 1), padding='same', return_sequences=True))
  374. return layer_concat
  375. @staticmethod
  376. def conv_lstm_2d_dense_input_legal(input_shape):
  377. input_shape = input_shape.as_list()
  378. return len(input_shape) == 5 and input_shape[0] is None and input_shape[1] is not None \
  379. and input_shape[2] is not None and input_shape[2] > 3 \
  380. and input_shape[3] is not None and input_shape[3] > 3 \
  381. and input_shape[4] is not None \
  382. and input_shape[1] * input_shape[2] * input_shape[3] * input_shape[4] <= LayerMatching.concat_size_limit
  383. if __name__ == '__main__':
  384. pass