Skip to content

Commit 9e653bd

Browse files
authored
Fixed CoW RuntimeError in DecodingTask.run() (openai#240)
1 parent 02b7430 commit 9e653bd

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

whisper/decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
615615
n_audio: int = mel.shape[0]
616616

617617
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
618-
tokens: Tensor = torch.tensor([self.initial_tokens]).expand(n_audio, -1)
618+
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
619619

620620
# detect language if requested, overwriting the language token
621621
languages, language_probs = self._detect_language(audio_features, tokens)

0 commit comments

Comments
 (0)