detector.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import numpy as np
  2. import tensorflow as tf
  3. from anonymizer.utils import Box
  4. class Detector:
  5. def __init__(self, kind, weights_path):
  6. self.kind = kind
  7. self.detection_graph = tf.Graph()
  8. with self.detection_graph.as_default():
  9. od_graph_def = tf.GraphDef()
  10. with tf.gfile.GFile(weights_path, 'rb') as fid:
  11. serialized_graph = fid.read()
  12. od_graph_def.ParseFromString(serialized_graph)
  13. tf.import_graph_def(od_graph_def, name='')
  14. conf = tf.ConfigProto()
  15. self.session = tf.Session(graph=self.detection_graph, config=conf)
  16. def _convert_boxes(self, num_boxes, scores, boxes, image_height, image_width, detection_threshold):
  17. assert detection_threshold >= 0.001, 'Threshold can not be too close to "0".'
  18. result_boxes = []
  19. for i in range(int(num_boxes)):
  20. score = float(scores[i])
  21. if score < detection_threshold:
  22. continue
  23. y_min, x_min, y_max, x_max = map(float, boxes[i].tolist())
  24. box = Box(y_min=y_min * image_height, x_min=x_min * image_width,
  25. y_max=y_max * image_height, x_max=x_max * image_width,
  26. score=score, kind=self.kind)
  27. result_boxes.append(box)
  28. return result_boxes
  29. def detect(self, image, detection_threshold):
  30. image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
  31. num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
  32. detection_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
  33. detection_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
  34. image_height, image_width, channels = image.shape
  35. assert channels == 3, f'Invalid number of channels: {channels}. ' \
  36. f'Only images with three color channels are supported.'
  37. np_images = np.array([image])
  38. num_boxes, scores, boxes = self.session.run(
  39. [num_detections, detection_scores, detection_boxes],
  40. feed_dict={image_tensor: np_images})
  41. converted_boxes = self._convert_boxes(num_boxes=num_boxes[0], scores=scores[0], boxes=boxes[0],
  42. image_height=image_height, image_width=image_width,
  43. detection_threshold=detection_threshold)
  44. return converted_boxes