2727}
2828
2929
30- def _download (url : str , root : str ) -> bytes :
30+ def _download (url : str , root : str , in_memory : bool ) -> Union [ bytes , str ] :
3131 os .makedirs (root , exist_ok = True )
32- filename = os .path .basename (url )
3332
3433 expected_sha256 = url .split ("/" )[- 2 ]
35- download_target = os .path .join (root , filename )
34+ download_target = os .path .join (root , os . path . basename ( url ) )
3635
3736 if os .path .exists (download_target ) and not os .path .isfile (download_target ):
3837 raise RuntimeError (f"{ download_target } exists and is not a regular file" )
3938
4039 if os .path .isfile (download_target ):
4140 model_bytes = open (download_target , "rb" ).read ()
4241 if hashlib .sha256 (model_bytes ).hexdigest () == expected_sha256 :
43- return model_bytes
42+ return model_bytes if in_memory else download_target
4443 else :
4544 warnings .warn (f"{ download_target } exists, but the SHA256 checksum does not match; re-downloading the file" )
4645
@@ -58,15 +57,15 @@ def _download(url: str, root: str) -> bytes:
5857 if hashlib .sha256 (model_bytes ).hexdigest () != expected_sha256 :
5958 raise RuntimeError ("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." )
6059
61- return model_bytes
60+ return model_bytes if in_memory else download_target
6261
6362
6463def available_models () -> List [str ]:
6564 """Returns the names of available models"""
6665 return list (_MODELS .keys ())
6766
6867
69- def load_model (name : str , device : Optional [Union [str , torch .device ]] = None , download_root : str = None ) -> Whisper :
68+ def load_model (name : str , device : Optional [Union [str , torch .device ]] = None , download_root : str = None , in_memory : bool = False ) -> Whisper :
7069 """
7170 Load a Whisper ASR model
7271
@@ -79,28 +78,33 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
7978 the PyTorch device to put the model into
8079 download_root: str
8180 path to download the model files; by default, it uses "~/.cache/whisper"
81+ in_memory: bool
82+ whether to preload the model weights into host memory
8283
8384 Returns
8485 -------
8586 model : Whisper
8687 The Whisper ASR model instance
8788 """
89+
90+ if device is None :
91+ device = "cuda" if torch .cuda .is_available () else "cpu"
92+ if download_root is None :
93+ download_root = os .path .join (os .path .expanduser ("~" ), ".cache" , "whisper" )
94+
8895 if name in _MODELS :
89- model_bytes = _download (_MODELS [name ], download_root or os . path . expanduser ( "~/.cache/whisper" ) )
96+ checkpoint_file = _download (_MODELS [name ], download_root , in_memory )
9097 elif os .path .isfile (name ):
91- model_bytes = open (name , "rb" ).read ()
98+ checkpoint_file = open (name , "rb" ).read () if in_memory else name
9299 else :
93100 raise RuntimeError (f"Model { name } not found; available models = { available_models ()} " )
94101
95- with io .BytesIO (model_bytes ) as fp :
96- checkpoint = torch .load (fp , map_location = "cpu" )
102+ with (io .BytesIO (checkpoint_file ) if in_memory else open (checkpoint_file , "rb" )) as fp :
103+ checkpoint = torch .load (fp , map_location = device )
104+ del checkpoint_file
97105
98106 dims = ModelDimensions (** checkpoint ["dims" ])
99- state_dict = checkpoint ["model_state_dict" ]
100107 model = Whisper (dims )
101- model .load_state_dict (state_dict )
102-
103- if device is None :
104- device = "cuda" if torch .cuda .is_available () else "cpu"
108+ model .load_state_dict (checkpoint ["model_state_dict" ])
105109
106110 return model .to (device )
0 commit comments