checkpoint.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
  2. import logging
  3. import os
  4. import torch
  5. from maskrcnn_benchmark.utils.model_serialization import load_state_dict
  6. from maskrcnn_benchmark.utils.c2_model_loading import load_c2_format
  7. from maskrcnn_benchmark.utils.imports import import_file
  8. from maskrcnn_benchmark.utils.model_zoo import cache_url
  9. class Checkpointer(object):
  10. def __init__(
  11. self,
  12. model,
  13. optimizer=None,
  14. scheduler=None,
  15. save_dir="",
  16. save_to_disk=None,
  17. logger=None,
  18. ):
  19. self.model = model
  20. self.optimizer = optimizer
  21. self.scheduler = scheduler
  22. self.save_dir = save_dir
  23. self.save_to_disk = save_to_disk
  24. if logger is None:
  25. logger = logging.getLogger(__name__)
  26. self.logger = logger
  27. def save(self, name, **kwargs):
  28. if not self.save_dir:
  29. return
  30. if not self.save_to_disk:
  31. return
  32. data = {}
  33. data["model"] = self.model.state_dict()
  34. if self.optimizer is not None:
  35. data["optimizer"] = self.optimizer.state_dict()
  36. if self.scheduler is not None:
  37. data["scheduler"] = self.scheduler.state_dict()
  38. data.update(kwargs)
  39. save_file = os.path.join(self.save_dir, "{}.pth".format(name))
  40. self.logger.info("Saving checkpoint to {}".format(save_file))
  41. torch.save(data, save_file)
  42. self.tag_last_checkpoint(save_file)
  43. def load(self, f=None, use_latest=True):
  44. if self.has_checkpoint() and use_latest:
  45. # override argument with existing checkpoint
  46. f = self.get_checkpoint_file()
  47. if not f:
  48. # no checkpoint could be found
  49. self.logger.info("No checkpoint found. Initializing model from scratch")
  50. return {}
  51. self.logger.info("Loading checkpoint from {}".format(f))
  52. checkpoint = self._load_file(f)
  53. self._load_model(checkpoint)
  54. if "optimizer" in checkpoint and self.optimizer:
  55. self.logger.info("Loading optimizer from {}".format(f))
  56. self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
  57. if "scheduler" in checkpoint and self.scheduler:
  58. self.logger.info("Loading scheduler from {}".format(f))
  59. self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
  60. # return any further checkpoint data
  61. return checkpoint
  62. def has_checkpoint(self):
  63. save_file = os.path.join(self.save_dir, "last_checkpoint")
  64. return os.path.exists(save_file)
  65. def get_checkpoint_file(self):
  66. save_file = os.path.join(self.save_dir, "last_checkpoint")
  67. try:
  68. with open(save_file, "r") as f:
  69. last_saved = f.read()
  70. last_saved = last_saved.strip()
  71. except IOError:
  72. # if file doesn't exist, maybe because it has just been
  73. # deleted by a separate process
  74. last_saved = ""
  75. return last_saved
  76. def tag_last_checkpoint(self, last_filename):
  77. save_file = os.path.join(self.save_dir, "last_checkpoint")
  78. with open(save_file, "w") as f:
  79. f.write(last_filename)
  80. def _load_file(self, f):
  81. return torch.load(f, map_location=torch.device("cpu"))
  82. def _load_model(self, checkpoint):
  83. load_state_dict(self.model, checkpoint.pop("model"))
  84. class DetectronCheckpointer(Checkpointer):
  85. def __init__(
  86. self,
  87. cfg,
  88. model,
  89. optimizer=None,
  90. scheduler=None,
  91. save_dir="",
  92. save_to_disk=None,
  93. logger=None,
  94. ):
  95. super(DetectronCheckpointer, self).__init__(
  96. model, optimizer, scheduler, save_dir, save_to_disk, logger
  97. )
  98. self.cfg = cfg.clone()
  99. def _load_file(self, f):
  100. # catalog lookup
  101. if f.startswith("catalog://"):
  102. paths_catalog = import_file(
  103. "maskrcnn_benchmark.config.paths_catalog", self.cfg.PATHS_CATALOG, True
  104. )
  105. catalog_f = paths_catalog.ModelCatalog.get(f[len("catalog://") :])
  106. self.logger.info("{} points to {}".format(f, catalog_f))
  107. f = catalog_f
  108. # download url files
  109. if f.startswith("http"):
  110. # if the file is a url path, download it and cache it
  111. cached_f = cache_url(f)
  112. self.logger.info("url {} cached in {}".format(f, cached_f))
  113. f = cached_f
  114. # convert Caffe2 checkpoint from pkl
  115. if f.endswith(".pkl"):
  116. return load_c2_format(self.cfg, f)
  117. # load native detectron.pytorch checkpoint
  118. loaded = super(DetectronCheckpointer, self)._load_file(f)
  119. if "model" not in loaded:
  120. loaded = dict(model=loaded)
  121. return loaded