0

I used the following github repo: Speech Recognition.

But since it didn't have code to train and save the model, I looked online and added code to speech_recognition to train and save the model and create infer.py to make predictions.

But now even using a .flac file from the training dataset, it gives garbled output like: Transcript: ' f ' i ' ' ' i ' a ' ' a ' s ' o ' f ' o ' f ' i ' m ' i ' a ' i ' f ' o ' r ' ' i ' ' i '

This is the code:

data_preprocessing.py

import torch
import torch.nn as nn
import torch.utils.data
import torchaudio


class TextTransform:
  '''Maps characters to integers and vice versa'''
  def __init__(self):
    char_map_str = '''
      ' 0
      <SPACE> 1
      a 2
      b 3
      c 4
      d 5
      e 6
      f 7
      g 8
      h 9
      i 10
      j 11
      k 12
      l 13
      m 14
      n 15
      o 16
      p 17
      q 18
      r 19
      s 20
      t 21
      u 22
      v 23
      w 24
      x 25
      y 26
      z 27
      '''
    self.char_map = {}
    self.index_map = {}
    for line in char_map_str.strip().split('\n'):
      ch, index = line.split()
      self.char_map[ch] = int(index)
      self.index_map[int(index)] = ch
    self.index_map[1] = ' '

  def text_to_int(self, text):
    ''' Use a character map and convert text to an integer sequence '''
    int_sequence = []
    for c in text:
      if c == ' ':
        ch = self.char_map['<SPACE>']
      else:
        ch = self.char_map[c]
      int_sequence.append(ch)
    return int_sequence

  def int_to_text(self, labels):
    ''' Use a character map and convert integer labels to an text sequence '''
    string = []
    for i in labels:
      string.append(self.index_map[i])
    return ''.join(string).replace('', ' ')



# TODO: Questions: Log Mel Spectrogram vs Mel Spectrogram
class LogMelSpectogram(nn.Module):
  # TODO: Understand the parameters
  def __init__(self, sample_rate=8000, n_mels=81, win_length=160, hop_length=80):
    super(LogMelSpectogram, self).__init__()
    self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels, win_length=win_length, hop_length=hop_length)

  def forward(self, waveform):
    mel_spectrogram = self.mel_spectrogram(waveform)
    # Add 1e-6 to avoid taking log of zero
    log_mel_spectrogram = torch.log(mel_spectrogram + 1e-9)
    return log_mel_spectrogram



class PreprocessData(torch.utils.data.Dataset):
  def __init__(self, dataset, validation_set, sample_rate=8000, n_mels=81, win_length=160, hop_length=80):
    super(PreprocessData).__init__()
    self.dataset = dataset
    self.text_transform = TextTransform()
    if validation_set:
      self.preprocess_audio = LogMelSpectogram(sample_rate=sample_rate, n_mels=n_mels, win_length=win_length, hop_length=hop_length)
    else:
      # TODO: Why no frequency and time masking for validation set?
      self.preprocess_audio = nn.Sequential(
        LogMelSpectogram(sample_rate=sample_rate, n_mels=n_mels, win_length=win_length, hop_length=hop_length),
        torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
        torchaudio.transforms.TimeMasking(time_mask_param=35)
      )

  def __getitem__(self, index):
    # waveform, sample_rate, label, speaker_id, chapter_id, utterance_id
    waveform, _, label, _, _, _ = self.dataset[index]

    # Convert waveform to log mel spectrogram
    log_mel_spectrogram = self.preprocess_audio(waveform)
    # Get the length of the log mel spectrogram
    log_mel_spectrogram_len = log_mel_spectrogram.shape[0] // 2 # TODO: Why divide by 2?

    # Convert label text to integer sequence
    label_in_int = torch.tensor(self.text_transform.text_to_int(label.lower()))
    # Get the length of the label
    label_len = torch.tensor(len(label_in_int))

    return log_mel_spectrogram, label_in_int, log_mel_spectrogram_len, label_len

  def __len__(self):
    return len(self.dataset)

gather.py:

import os
import torchaudio

def gather_data():
  if not os.path.isdir('data'):
    os.makedirs('data')

  train_dataset = torchaudio.datasets.LIBRISPEECH('data/', url='train-clean-100', download=True)
  test_dataset = torchaudio.datasets.LIBRISPEECH('data/', url='test-clean', download=True)

  return train_dataset, test_dataset


model.py:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class CNNLayerNorm(nn.Module):
  """Layer normalization built for cnns input"""
  def __init__(self, n_feats):
    super(CNNLayerNorm, self).__init__()
    self.layer_norm = nn.LayerNorm(n_feats)

  def forward(self, x):
    # x (batch, channel, feature, time)
    x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
    x = self.layer_norm(x)
    return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) 

class ResidualCNN(nn.Module):
  """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
    except with layer norm instead of batch norm
  """
  def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
    super(ResidualCNN, self).__init__()

    self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel//2)
    self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel//2)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)
    self.layer_norm1 = CNNLayerNorm(n_feats)
    self.layer_norm2 = CNNLayerNorm(n_feats)

  def forward(self, x):
    residual = x  # (batch, channel, feature, time)
    x = self.layer_norm1(x)
    x = F.gelu(x)
    x = self.dropout1(x)
    x = self.cnn1(x)
    x = self.layer_norm2(x)
    x = F.gelu(x)
    x = self.dropout2(x)
    x = self.cnn2(x)
    # residual connection - Skip connection
    x += residual
    return x # (batch, channel, feature, time)
        
class BidirectionalGRU(nn.Module):

  def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
    super(BidirectionalGRU, self).__init__()

    self.BiGRU = nn.GRU(
      input_size=rnn_dim, hidden_size=hidden_size,
      num_layers=1, batch_first=batch_first, bidirectional=True)
    self.layer_norm = nn.LayerNorm(rnn_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    x = self.layer_norm(x)
    x = F.gelu(x)
    x, _ = self.BiGRU(x)
    x = self.dropout(x)
    return x


class SpeechRecognitionModel(nn.Module):
  """Speech Recognition Model Inspired by DeepSpeech 2"""

  def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1):
    super(SpeechRecognitionModel, self).__init__()
    # n_feats = n_feats//2
    n_feats = math.ceil(n_feats / stride)
    self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=1)  # cnn for extracting heirachal features

    # n residual cnn layers with filter size of 32
    self.rescnn_layers = nn.Sequential(*[
      ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) 
      for _ in range(n_cnn_layers)
    ])
    self.fully_connected = nn.Linear(n_feats*32, rnn_dim)
    self.birnn_layers = nn.Sequential(*[
      BidirectionalGRU(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                      hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
      for i in range(n_rnn_layers)
    ])
    self.classifier = nn.Sequential(
      nn.Linear(rnn_dim*2, rnn_dim),  # birnn returns rnn_dim*2
      nn.GELU(),
      nn.Dropout(dropout),
      nn.Linear(rnn_dim, n_class)
    )

  def forward(self, x):
    x = self.cnn(x)
    x = self.rescnn_layers(x)
    sizes = x.size()
    x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
    x = x.transpose(1, 2) # (batch, time, feature)
    x = self.fully_connected(x)
    x = self.birnn_layers(x)
    x = self.classifier(x)
    return x

speech_recognition.py


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from gather import gather_data
from data_preprocessing import PreprocessData, TextTransform
from model import SpeechRecognitionModel


def collate_fn(batch):
    """
    Pads variable-length audio features and labels to max lengths in batch.
    Returns:
      features_tensor: (batch, 1, n_mels, max_time)
      labels_tensor: (batch, max_label_len)
      feature_lengths: (batch,)
      label_lengths: (batch,)
    """
    feats, labels, _, _ = zip(*batch)
    feat_lens = [f.shape[-1] for f in feats]
    label_lens = [l.shape[0] for l in labels]

    max_feat_len = max(feat_lens)
    padded_feats = [F.pad(f, (0, max_feat_len - f.shape[-1])) for f in feats]
    feats_tensor = torch.stack(padded_feats)

    max_label_len = max(label_lens)
    padded_labels = [F.pad(l, (0, max_label_len - l.shape[0]), value=0) for l in labels]
    labels_tensor = torch.stack(padded_labels)

    return feats_tensor, labels_tensor, torch.tensor(feat_lens), torch.tensor(label_lens)


def train_one_epoch(model, device, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0.0
    for features, labels, _, label_lens in dataloader:
        features, labels = features.to(device), labels.to(device)
        batch_size = features.size(0)

        # Forward pass
        outputs = model(features)  # (batch, time, classes)
        time_steps = outputs.size(1)
        log_probs = outputs.log_softmax(2).permute(1, 0, 2)  # (time, batch, classes)

        # Create input_lengths = full length for each (after downsampling)
        input_lengths = torch.full((batch_size,), time_steps, dtype=torch.long).to(device)
        label_lens = label_lens.to(device)

        # Compute loss
        loss = criterion(log_probs, labels, input_lengths, label_lens)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Train Loss: {avg_loss:.4f}")
    return avg_loss


def main(batch_size=8, epochs=20, lr=1e-4,
         device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
    # Prepare data
    train_ds, test_ds = gather_data()
    train_ds = PreprocessData(train_ds, validation_set=False)
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True,
        collate_fn=collate_fn
    )

    # Determine n_mels from one sample
    sample_feat, _, _, _ = train_ds[0]
    n_mels = sample_feat.shape[1]

    text_transform = TextTransform()
    n_classes = len(text_transform.char_map) + 1  # +1 for CTC blank

    # Build model
    model = SpeechRecognitionModel(
        n_cnn_layers=3,
        n_rnn_layers=5,
        rnn_dim=512,
        n_class=n_classes,
        n_feats=n_mels
    ).to(device)

    # Loss & optimizer
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Training loop
    best_loss = float('inf')
    for epoch in range(1, epochs + 1):
        print(f"=== Epoch {epoch}/{epochs} ===")
        try:
            loss = train_one_epoch(model, device, train_loader, criterion, optimizer)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                print("⚠️ CUDA OOM: emptying cache and exiting training loop")
                torch.cuda.empty_cache()
                break
            else:
                raise

        # Save best model
        if loss < best_loss:
            best_loss = loss
            torch.save(model.state_dict(), 'speech_model.pth')
            print("Model saved: speech_model.pth")

        if device.type == 'cuda':
            torch.cuda.empty_cache()

    print("Training complete!")


if __name__ == '__main__':
    main()

and then I checked online and added this script: infer.py:


import torch
import torchaudio
from model import SpeechRecognitionModel
from data_preprocessing import TextTransform
import torchaudio.transforms as T


def load_model(model_path, device, input_dim, output_dim):
    text_transform = TextTransform()
    n_classes = len(text_transform.char_map) + 1  # +1 for CTC blank
    model = SpeechRecognitionModel(
        n_cnn_layers=3,
        n_rnn_layers=5,
        rnn_dim=512,
        n_class=n_classes,
        n_feats=input_dim
    ).to(device)
    checkpoint = torch.load(model_path, map_location=device)

    checkpoint_state = checkpoint
    if 'state_dict' in checkpoint:
        checkpoint_state = checkpoint['state_dict']

    model_state = model.state_dict()
    for key in ['classifier.3.weight', 'classifier.3.bias']:
        if key in checkpoint_state and key in model_state:
            if checkpoint_state[key].shape != model_state[key].shape:
                print(f"Skipping incompatible layer: {key}")
                del checkpoint_state[key]

    model.load_state_dict(checkpoint_state, strict=False)
    model.eval()
    return model


def greedy_decode(output, blank_index):
    arg_maxes = torch.argmax(output, dim=2)
    decodes = []
    for i, args in enumerate(arg_maxes):
        decode = []
        previous = -1
        for j in args:
            if j != previous:
                if j != blank_index:
                    decode.append(j.item())
                previous = j
        decodes.append(decode)
    return decodes


def transcribe(audio_path, model_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tt = TextTransform()
    n_feats = 81  # Must match model's training config

    waveform, sample_rate = torchaudio.load(audio_path)
    waveform = waveform.to(device)
    resampler = T.Resample(orig_freq=sample_rate, new_freq=8000).to(device)
    waveform = resampler(waveform)

    mel_spec = T.MelSpectrogram(
        sample_rate=8000,
        n_mels=n_feats,
        win_length=160,
        hop_length=80
    ).to(device)
    spec = mel_spec(waveform)
    # if spec.ndim == 3:
    #     spec = spec.unsqueeze(0)  # (1, channel, n_mels, time)
    spec = spec.unsqueeze(0) 

    # CNN expects 4D: (batch, channel, features, time)
    # Model's layer_norm1 is using x.transpose(2, 3), so shape must be (B, C, F, T)
    model_input = spec  # shape: (1, 1, n_feats, time)

    model = load_model(model_path, device, n_feats, len(tt.char_map))

    with torch.no_grad():
        output = model(model_input)  # (batch, time, classes)
        print('output shape : ' ,output.shape)
        pred_indices = greedy_decode(output, blank_index=len(tt.char_map))

    pred_str = tt.int_to_text(pred_indices[0])
    return pred_str


if __name__ == "__main__":
    import sys
    audio_file = sys.argv[1] if len(sys.argv) > 1 else "sample.flac"
    weights_file = sys.argv[2] if len(sys.argv) > 2 else "speech_model.pth"
    print("Transcript:", transcribe(audio_file, weights_file))

after training was done, it had a loss of 0.28.

output of infer.py:

output shape :  torch.Size([1, 589, 29])
Transcript:  '   f ' i '   '   ' i ' a '   ' a '   s ' o '   f ' o ' f ' i '   m ' i ' a ' i '   f ' o ' r '   ' i '
1

0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.