metrics_manager.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #!/usr/bin/env python
  2. # Copyright (c) 2020 Computer Vision Center (CVC) at the Universitat Autonoma de
  3. # Barcelona (UAB).
  4. #
  5. # This work is licensed under the terms of the MIT license.
  6. # For a copy, see <https://opensource.org/licenses/MIT>.
  7. # Allows the execution of user-implemented metrics
  8. """
  9. Welcome to the ScenarioRunner's metric module
  10. This is the main script to be executed when running a metric.
  11. It is responsible of parsing all the information and executing
  12. the metric specified by the user.
  13. """
  14. import os
  15. import sys
  16. import importlib
  17. import inspect
  18. import json
  19. import argparse
  20. from argparse import RawTextHelpFormatter
  21. import carla
  22. from srunner.metrics.tools.metrics_log import MetricsLog
  23. class MetricsManager(object):
  24. """
  25. Main class of the metrics module. Handles the parsing and execution of
  26. the metrics.
  27. """
  28. def __init__(self, args):
  29. """
  30. Initialization of the metrics manager. This creates the client, needed to parse
  31. the information from the recorder, extract the metrics class, and runs it
  32. """
  33. self._args = args
  34. # Parse the arguments
  35. recorder_str = self._get_recorder(self._args.log)
  36. criteria_dict = self._get_criteria(self._args.criteria)
  37. # Get the correct world and load it
  38. map_name = self._get_recorder_map(recorder_str)
  39. world = self._client.load_world(map_name)
  40. town_map = world.get_map()
  41. # Instanciate the MetricsLog, used to querry the needed information
  42. log = MetricsLog(recorder_str)
  43. # Read and run the metric class
  44. metric_class = self._get_metric_class(self._args.metric)
  45. metric_class(town_map, log, criteria_dict)
  46. def _get_recorder(self, log):
  47. """
  48. Parses the log argument into readable information
  49. """
  50. # Get the log information.
  51. self._client = carla.Client(self._args.host, int(self._args.port))
  52. recorder_file = "{}/{}".format(os.getenv('SCENARIO_RUNNER_ROOT', "./"), log)
  53. # Check that the file is correct
  54. if recorder_file[-4:] != '.log':
  55. print("ERROR: The log argument has to point to a .log file")
  56. sys.exit(-1)
  57. if not os.path.exists(recorder_file):
  58. print("ERROR: The specified log file does not exist")
  59. sys.exit(-1)
  60. recorder_str = self._client.show_recorder_file_info(recorder_file, True)
  61. return recorder_str
  62. def _get_criteria(self, criteria_file):
  63. """
  64. Parses the criteria argument into a dictionary
  65. """
  66. if criteria_file:
  67. with open(criteria_file) as fd:
  68. criteria_dict = json.load(fd)
  69. else:
  70. criteria_dict = None
  71. return criteria_dict
  72. def _get_metric_class(self, metric_file):
  73. """
  74. Function to extract the metrics class from the path given by the metrics
  75. argument. Returns the first class found that is a child of BasicMetric
  76. Args:
  77. metric_file (str): path to the metric's file.
  78. """
  79. # Get their module
  80. module_name = os.path.basename(metric_file).split('.')[0]
  81. sys.path.insert(0, os.path.dirname(metric_file))
  82. metric_module = importlib.import_module(module_name)
  83. # And their members of type class
  84. for member in inspect.getmembers(metric_module, inspect.isclass):
  85. # Get the first one with parent BasicMetrics
  86. member_parent = member[1].__bases__[0]
  87. if 'BasicMetric' in str(member_parent):
  88. return member[1]
  89. print("No child class of BasicMetric was found ... Exiting")
  90. sys.exit(-1)
  91. def _get_recorder_map(self, recorder_str):
  92. """
  93. Returns the name of the map the simulation took place in
  94. """
  95. header = recorder_str.split("\n")
  96. sim_map = header[1][5:]
  97. return sim_map
  98. def main():
  99. """
  100. main function
  101. """
  102. # pylint: disable=line-too-long
  103. description = ("Scenario Runner's metrics module. Evaluate the execution of a specific scenario by developing your own metric.\n")
  104. parser = argparse.ArgumentParser(description=description,
  105. formatter_class=RawTextHelpFormatter)
  106. parser.add_argument('--host', default='127.0.0.1',
  107. help='IP of the host server (default: localhost)')
  108. parser.add_argument('--port', '-p', default=2000,
  109. help='TCP port to listen to (default: 2000)')
  110. parser.add_argument('--log', required=True,
  111. help='Path to the CARLA recorder .log file (relative to SCENARIO_RUNNER_ROOT).\nThis file is created by the record functionality at ScenarioRunner')
  112. parser.add_argument('--metric', required=True,
  113. help='Path to the .py file defining the used metric.\nSome examples at srunner/metrics')
  114. parser.add_argument('--criteria', default="",
  115. help='Path to the .json file with the criteria information.\nThis file is created by the record functionality at ScenarioRunner')
  116. # pylint: enable=line-too-long
  117. args = parser.parse_args()
  118. MetricsManager(args)
  119. if __name__ == "__main__":
  120. sys.exit(main())