I used the following github repo: Speech Recognition.
But since it didn't have code to train and save the model, I looked online and added code to speech_recognition to train and save the model and create infer.py to make predictions.
But now even using a .flac file from the training dataset, it gives garbled output like: Transcript: ' f ' i ' ' ' i ' a ' ' a ' s ' o ' f ' o ' f ' i ' m ' i ' a ' i ' f ' o ' r ' ' i ' ' i '
This is the code:
data_preprocessing.py
import torch
import torch.nn as nn
import torch.utils.data
import torchaudio
class TextTransform:
'''Maps characters to integers and vice versa'''
def __init__(self):
char_map_str = '''
' 0
<SPACE> 1
a 2
b 3
c 4
d 5
e 6
f 7
g 8
h 9
i 10
j 11
k 12
l 13
m 14
n 15
o 16
p 17
q 18
r 19
s 20
t 21
u 22
v 23
w 24
x 25
y 26
z 27
'''
self.char_map = {}
self.index_map = {}
for line in char_map_str.strip().split('\n'):
ch, index = line.split()
self.char_map[ch] = int(index)
self.index_map[int(index)] = ch
self.index_map[1] = ' '
def text_to_int(self, text):
''' Use a character map and convert text to an integer sequence '''
int_sequence = []
for c in text:
if c == ' ':
ch = self.char_map['<SPACE>']
else:
ch = self.char_map[c]
int_sequence.append(ch)
return int_sequence
def int_to_text(self, labels):
''' Use a character map and convert integer labels to an text sequence '''
string = []
for i in labels:
string.append(self.index_map[i])
return ''.join(string).replace('', ' ')
# TODO: Questions: Log Mel Spectrogram vs Mel Spectrogram
class LogMelSpectogram(nn.Module):
# TODO: Understand the parameters
def __init__(self, sample_rate=8000, n_mels=81, win_length=160, hop_length=80):
super(LogMelSpectogram, self).__init__()
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels, win_length=win_length, hop_length=hop_length)
def forward(self, waveform):
mel_spectrogram = self.mel_spectrogram(waveform)
# Add 1e-6 to avoid taking log of zero
log_mel_spectrogram = torch.log(mel_spectrogram + 1e-9)
return log_mel_spectrogram
class PreprocessData(torch.utils.data.Dataset):
def __init__(self, dataset, validation_set, sample_rate=8000, n_mels=81, win_length=160, hop_length=80):
super(PreprocessData).__init__()
self.dataset = dataset
self.text_transform = TextTransform()
if validation_set:
self.preprocess_audio = LogMelSpectogram(sample_rate=sample_rate, n_mels=n_mels, win_length=win_length, hop_length=hop_length)
else:
# TODO: Why no frequency and time masking for validation set?
self.preprocess_audio = nn.Sequential(
LogMelSpectogram(sample_rate=sample_rate, n_mels=n_mels, win_length=win_length, hop_length=hop_length),
torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
torchaudio.transforms.TimeMasking(time_mask_param=35)
)
def __getitem__(self, index):
# waveform, sample_rate, label, speaker_id, chapter_id, utterance_id
waveform, _, label, _, _, _ = self.dataset[index]
# Convert waveform to log mel spectrogram
log_mel_spectrogram = self.preprocess_audio(waveform)
# Get the length of the log mel spectrogram
log_mel_spectrogram_len = log_mel_spectrogram.shape[0] // 2 # TODO: Why divide by 2?
# Convert label text to integer sequence
label_in_int = torch.tensor(self.text_transform.text_to_int(label.lower()))
# Get the length of the label
label_len = torch.tensor(len(label_in_int))
return log_mel_spectrogram, label_in_int, log_mel_spectrogram_len, label_len
def __len__(self):
return len(self.dataset)
gather.py:
import os
import torchaudio
def gather_data():
if not os.path.isdir('data'):
os.makedirs('data')
train_dataset = torchaudio.datasets.LIBRISPEECH('data/', url='train-clean-100', download=True)
test_dataset = torchaudio.datasets.LIBRISPEECH('data/', url='test-clean', download=True)
return train_dataset, test_dataset
model.py:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""
def __init__(self, n_feats):
super(CNNLayerNorm, self).__init__()
self.layer_norm = nn.LayerNorm(n_feats)
def forward(self, x):
# x (batch, channel, feature, time)
x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
x = self.layer_norm(x)
return x.transpose(2, 3).contiguous() # (batch, channel, feature, time)
class ResidualCNN(nn.Module):
"""Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
except with layer norm instead of batch norm
"""
def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
super(ResidualCNN, self).__init__()
self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel//2)
self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel//2)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.layer_norm1 = CNNLayerNorm(n_feats)
self.layer_norm2 = CNNLayerNorm(n_feats)
def forward(self, x):
residual = x # (batch, channel, feature, time)
x = self.layer_norm1(x)
x = F.gelu(x)
x = self.dropout1(x)
x = self.cnn1(x)
x = self.layer_norm2(x)
x = F.gelu(x)
x = self.dropout2(x)
x = self.cnn2(x)
# residual connection - Skip connection
x += residual
return x # (batch, channel, feature, time)
class BidirectionalGRU(nn.Module):
def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
super(BidirectionalGRU, self).__init__()
self.BiGRU = nn.GRU(
input_size=rnn_dim, hidden_size=hidden_size,
num_layers=1, batch_first=batch_first, bidirectional=True)
self.layer_norm = nn.LayerNorm(rnn_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.layer_norm(x)
x = F.gelu(x)
x, _ = self.BiGRU(x)
x = self.dropout(x)
return x
class SpeechRecognitionModel(nn.Module):
"""Speech Recognition Model Inspired by DeepSpeech 2"""
def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1):
super(SpeechRecognitionModel, self).__init__()
# n_feats = n_feats//2
n_feats = math.ceil(n_feats / stride)
self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=1) # cnn for extracting heirachal features
# n residual cnn layers with filter size of 32
self.rescnn_layers = nn.Sequential(*[
ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats)
for _ in range(n_cnn_layers)
])
self.fully_connected = nn.Linear(n_feats*32, rnn_dim)
self.birnn_layers = nn.Sequential(*[
BidirectionalGRU(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
for i in range(n_rnn_layers)
])
self.classifier = nn.Sequential(
nn.Linear(rnn_dim*2, rnn_dim), # birnn returns rnn_dim*2
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(rnn_dim, n_class)
)
def forward(self, x):
x = self.cnn(x)
x = self.rescnn_layers(x)
sizes = x.size()
x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time)
x = x.transpose(1, 2) # (batch, time, feature)
x = self.fully_connected(x)
x = self.birnn_layers(x)
x = self.classifier(x)
return x
speech_recognition.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from gather import gather_data
from data_preprocessing import PreprocessData, TextTransform
from model import SpeechRecognitionModel
def collate_fn(batch):
"""
Pads variable-length audio features and labels to max lengths in batch.
Returns:
features_tensor: (batch, 1, n_mels, max_time)
labels_tensor: (batch, max_label_len)
feature_lengths: (batch,)
label_lengths: (batch,)
"""
feats, labels, _, _ = zip(*batch)
feat_lens = [f.shape[-1] for f in feats]
label_lens = [l.shape[0] for l in labels]
max_feat_len = max(feat_lens)
padded_feats = [F.pad(f, (0, max_feat_len - f.shape[-1])) for f in feats]
feats_tensor = torch.stack(padded_feats)
max_label_len = max(label_lens)
padded_labels = [F.pad(l, (0, max_label_len - l.shape[0]), value=0) for l in labels]
labels_tensor = torch.stack(padded_labels)
return feats_tensor, labels_tensor, torch.tensor(feat_lens), torch.tensor(label_lens)
def train_one_epoch(model, device, dataloader, criterion, optimizer):
model.train()
total_loss = 0.0
for features, labels, _, label_lens in dataloader:
features, labels = features.to(device), labels.to(device)
batch_size = features.size(0)
# Forward pass
outputs = model(features) # (batch, time, classes)
time_steps = outputs.size(1)
log_probs = outputs.log_softmax(2).permute(1, 0, 2) # (time, batch, classes)
# Create input_lengths = full length for each (after downsampling)
input_lengths = torch.full((batch_size,), time_steps, dtype=torch.long).to(device)
label_lens = label_lens.to(device)
# Compute loss
loss = criterion(log_probs, labels, input_lengths, label_lens)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Train Loss: {avg_loss:.4f}")
return avg_loss
def main(batch_size=8, epochs=20, lr=1e-4,
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
# Prepare data
train_ds, test_ds = gather_data()
train_ds = PreprocessData(train_ds, validation_set=False)
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True,
collate_fn=collate_fn
)
# Determine n_mels from one sample
sample_feat, _, _, _ = train_ds[0]
n_mels = sample_feat.shape[1]
text_transform = TextTransform()
n_classes = len(text_transform.char_map) + 1 # +1 for CTC blank
# Build model
model = SpeechRecognitionModel(
n_cnn_layers=3,
n_rnn_layers=5,
rnn_dim=512,
n_class=n_classes,
n_feats=n_mels
).to(device)
# Loss & optimizer
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# Training loop
best_loss = float('inf')
for epoch in range(1, epochs + 1):
print(f"=== Epoch {epoch}/{epochs} ===")
try:
loss = train_one_epoch(model, device, train_loader, criterion, optimizer)
except RuntimeError as e:
if 'out of memory' in str(e):
print("⚠️ CUDA OOM: emptying cache and exiting training loop")
torch.cuda.empty_cache()
break
else:
raise
# Save best model
if loss < best_loss:
best_loss = loss
torch.save(model.state_dict(), 'speech_model.pth')
print("Model saved: speech_model.pth")
if device.type == 'cuda':
torch.cuda.empty_cache()
print("Training complete!")
if __name__ == '__main__':
main()
and then I checked online and added this script: infer.py:
import torch
import torchaudio
from model import SpeechRecognitionModel
from data_preprocessing import TextTransform
import torchaudio.transforms as T
def load_model(model_path, device, input_dim, output_dim):
text_transform = TextTransform()
n_classes = len(text_transform.char_map) + 1 # +1 for CTC blank
model = SpeechRecognitionModel(
n_cnn_layers=3,
n_rnn_layers=5,
rnn_dim=512,
n_class=n_classes,
n_feats=input_dim
).to(device)
checkpoint = torch.load(model_path, map_location=device)
checkpoint_state = checkpoint
if 'state_dict' in checkpoint:
checkpoint_state = checkpoint['state_dict']
model_state = model.state_dict()
for key in ['classifier.3.weight', 'classifier.3.bias']:
if key in checkpoint_state and key in model_state:
if checkpoint_state[key].shape != model_state[key].shape:
print(f"Skipping incompatible layer: {key}")
del checkpoint_state[key]
model.load_state_dict(checkpoint_state, strict=False)
model.eval()
return model
def greedy_decode(output, blank_index):
arg_maxes = torch.argmax(output, dim=2)
decodes = []
for i, args in enumerate(arg_maxes):
decode = []
previous = -1
for j in args:
if j != previous:
if j != blank_index:
decode.append(j.item())
previous = j
decodes.append(decode)
return decodes
def transcribe(audio_path, model_path):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tt = TextTransform()
n_feats = 81 # Must match model's training config
waveform, sample_rate = torchaudio.load(audio_path)
waveform = waveform.to(device)
resampler = T.Resample(orig_freq=sample_rate, new_freq=8000).to(device)
waveform = resampler(waveform)
mel_spec = T.MelSpectrogram(
sample_rate=8000,
n_mels=n_feats,
win_length=160,
hop_length=80
).to(device)
spec = mel_spec(waveform)
# if spec.ndim == 3:
# spec = spec.unsqueeze(0) # (1, channel, n_mels, time)
spec = spec.unsqueeze(0)
# CNN expects 4D: (batch, channel, features, time)
# Model's layer_norm1 is using x.transpose(2, 3), so shape must be (B, C, F, T)
model_input = spec # shape: (1, 1, n_feats, time)
model = load_model(model_path, device, n_feats, len(tt.char_map))
with torch.no_grad():
output = model(model_input) # (batch, time, classes)
print('output shape : ' ,output.shape)
pred_indices = greedy_decode(output, blank_index=len(tt.char_map))
pred_str = tt.int_to_text(pred_indices[0])
return pred_str
if __name__ == "__main__":
import sys
audio_file = sys.argv[1] if len(sys.argv) > 1 else "sample.flac"
weights_file = sys.argv[2] if len(sys.argv) > 2 else "speech_model.pth"
print("Transcript:", transcribe(audio_file, weights_file))
after training was done, it had a loss of 0.28.
output of infer.py:
output shape : torch.Size([1, 589, 29])
Transcript: ' f ' i ' ' ' i ' a ' ' a ' s ' o ' f ' o ' f ' i ' m ' i ' a ' i ' f ' o ' r ' ' i '