weights.py 1.1 KB

123456789101112131415161718192021222324252627282930313233
  1. from pathlib import Path
  2. from google_drive_downloader import GoogleDriveDownloader as gdd
  3. WEIGHTS_GDRIVE_IDS = {
  4. '1.0.0': {
  5. 'face': '1CwChAYxJo3mON6rcvXsl82FMSKj82vxF',
  6. 'plate': '1Fls9FYlQdRlLAtw-GVS_ie1oQUYmci9g'
  7. }
  8. }
  9. def get_weights_path(base_path, kind, version='1.0.0'):
  10. assert version in WEIGHTS_GDRIVE_IDS.keys(), f'Invalid weights version "{version}"'
  11. assert kind in WEIGHTS_GDRIVE_IDS[version].keys(), f'Invalid weights kind "{kind}"'
  12. return str(Path(base_path) / f'weights_{kind}_v{version}.pb')
  13. def _download_single_model_weights(download_directory, kind, version):
  14. file_id = WEIGHTS_GDRIVE_IDS[version][kind]
  15. weights_path = get_weights_path(base_path=download_directory, kind=kind, version=version)
  16. if Path(weights_path).exists():
  17. return
  18. print(f'Downloading {kind} weights to {weights_path}')
  19. gdd.download_file_from_google_drive(file_id=file_id, dest_path=weights_path, unzip=False)
  20. def download_weights(download_directory, version='1.0.0'):
  21. for kind in ['face', 'plate']:
  22. _download_single_model_weights(download_directory=download_directory, kind=kind, version=version)