Skip to content

Commit f296bcd

Browse files
drdaxxyjongwook
andauthored
Avoid keeping redundant copies of model weights in memory during load (openai#42)
* don't keep copies of model weights in host memory * adding type annotation Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
1 parent a4fe05a commit f296bcd

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

whisper/__init__.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,19 @@
2727
}
2828

2929

30-
def _download(url: str, root: str) -> bytes:
30+
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
3131
os.makedirs(root, exist_ok=True)
32-
filename = os.path.basename(url)
3332

3433
expected_sha256 = url.split("/")[-2]
35-
download_target = os.path.join(root, filename)
34+
download_target = os.path.join(root, os.path.basename(url))
3635

3736
if os.path.exists(download_target) and not os.path.isfile(download_target):
3837
raise RuntimeError(f"{download_target} exists and is not a regular file")
3938

4039
if os.path.isfile(download_target):
4140
model_bytes = open(download_target, "rb").read()
4241
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
43-
return model_bytes
42+
return model_bytes if in_memory else download_target
4443
else:
4544
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
4645

@@ -58,15 +57,15 @@ def _download(url: str, root: str) -> bytes:
5857
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
5958
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
6059

61-
return model_bytes
60+
return model_bytes if in_memory else download_target
6261

6362

6463
def available_models() -> List[str]:
6564
"""Returns the names of available models"""
6665
return list(_MODELS.keys())
6766

6867

69-
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None) -> Whisper:
68+
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
7069
"""
7170
Load a Whisper ASR model
7271
@@ -79,28 +78,33 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
7978
the PyTorch device to put the model into
8079
download_root: str
8180
path to download the model files; by default, it uses "~/.cache/whisper"
81+
in_memory: bool
82+
whether to preload the model weights into host memory
8283
8384
Returns
8485
-------
8586
model : Whisper
8687
The Whisper ASR model instance
8788
"""
89+
90+
if device is None:
91+
device = "cuda" if torch.cuda.is_available() else "cpu"
92+
if download_root is None:
93+
download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
94+
8895
if name in _MODELS:
89-
model_bytes = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/whisper"))
96+
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
9097
elif os.path.isfile(name):
91-
model_bytes = open(name, "rb").read()
98+
checkpoint_file = open(name, "rb").read() if in_memory else name
9299
else:
93100
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
94101

95-
with io.BytesIO(model_bytes) as fp:
96-
checkpoint = torch.load(fp, map_location="cpu")
102+
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
103+
checkpoint = torch.load(fp, map_location=device)
104+
del checkpoint_file
97105

98106
dims = ModelDimensions(**checkpoint["dims"])
99-
state_dict = checkpoint["model_state_dict"]
100107
model = Whisper(dims)
101-
model.load_state_dict(state_dict)
102-
103-
if device is None:
104-
device = "cuda" if torch.cuda.is_available() else "cpu"
108+
model.load_state_dict(checkpoint["model_state_dict"])
105109

106110
return model.to(device)

0 commit comments

Comments
 (0)