Poor mans ASR pt. 2
🌏

Poor mans ASR pt. 2

What and why

Check out the Part 1 cliff-hanger for a recap on how we acquired some Radio National podcast transcripts and how we ended up here:

  • Dataset. We’ve been able to acquire a small sample of podcast audio and their accompanying transcripts. The podcast audio contains good topical/technical variance.
  • Model(s). We’d like to transcribe the podcast audio, using some combination of ASR/related models. Ideally, but not crucially, we’d also like to mimic the output form of commercial ASR providers.
  • Evaluation. We’d like to evaluate our open-source solution against its commercial rivals, using the accompanying podcast transcripts as the ground truth.

Model selection

Available libraries. We’re spoilt for choice here! Notable options include Silero; an open-source library with a particular focus on fast, CPU-friendly inference. Mozilla DeepSpeech, based upon Baidu’s original DeepSpeech. Speechbrain, a research-oriented framework for developing speech and language technologies and various ASR models hosted via HF’s model hub, including contributions from Facebook such as wav2vec.

However, Nvidia’s NEMO appeared to provide the most complete out-of-the-box functionality, supporting ASR as well adjacent functionality such as Speaker Recognition, Diarization and Audio Classification. Importantly, NEMO also provides NLP functionality such as Punctuation and Capitalisation which will come in handy late in the pipeline.

Model spot-check. So we’ve been able narrow down our library selection, but that still leaves the question of choosing a model. NEMO, at the time of writing, features over 60 ASR models alone covering various language, architecture, model size and model parameter combinations! Exhaustively testing all of these models is going to be.. well exhausting. So, upon reading the fantastic model documentation NVIDIA provides I reckon a sensible approach instead is to spot-check one model from (roughly) each of the major model families available within NEMO. So this means we’ll examining QuartzNet (stt_en_quartznet15x5), Citrinet (stt_en_citrinet_512), ConformerCTC (stt_en_conformer_ctc_medium) and ConformerTransducer (stt_en_conformer_transducer_medium) models in terms of their progressive memory footprints, speed, and performance. Below, we take a single podcast, progressively slice the audio and benchmark the memory and speed usage for a variety of target models:

import math
import tempfile
import time
from pathlib import Path

import nvsmi
import torch

# scope the models
target_models = [
    "stt_en_quartznet15x5",
    "stt_en_citrinet_512",
    "stt_en_contextnet_512",
    "stt_en_conformer_ctc_medium",
    "stt_en_conformer_transducer_medium",
]
pretrained_models = [
    e
    for e in nemo_asr.models.ASRModel.list_available_models()
    if e.pretrained_model_name in target_models
]

# take a sample podcast
audio_segment = AudioSegment.from_wav(
    "/home/blog-os-asr/output/temp_dir/rewilding-the-scottish-highlands.wav"
)

# progressively slice to gauge memory usage
s2ms = 1000
seconds_increment = 30
slice_intervals = [
    (0, e * s2ms)
    for e in list(range(0, math.ceil(audio_segment.duration_seconds) + 30, 30))[1:]
]

memory_usage_records = []
for pretrained_model in pretrained_models:
    print(f"Memory testing: {pretrained_model.pretrained_model_name}")

    # model classes defined alongside model names
    model = pretrained_model.class_.from_pretrained(
        model_name=pretrained_model.pretrained_model_name
    )
    model_memory_footprint = nvsmi.get_gpu_processes()[0].used_memory

    with tempfile.TemporaryDirectory() as temp_dir:
        # files as input; save in tmp dir
        for interval in slice_intervals:
            try:
                slice = audio_segment[interval[0] : interval[1]]
                save_message = slice.export(
                    Path(temp_dir) / "memory_test_fragment.wav", format="wav"
                )
                before = time.time()
                transcription = model.transcribe(
                    paths2audio_files=[
                        str(Path(temp_dir) / "memory_test_fragment.wav")
                    ],
                    batch_size=1,
                )
                after = time.time()

                # collect some metrics
                memory_usage_records.append(
                    {
                        "model_name": pretrained_model.pretrained_model_name,
                        "input_size": slice.duration_seconds,
                        "transcript": transcription,
                        "memory_usage": nvsmi.get_gpu_processes()[0].used_memory,
                        "time_elapsed": after - before,
                    }
                )
            except:
                # out-of-memory > move onto next model
                print("CUDA out of memory; skipping remaining slice intervals")
                break

    # clear for next model
    del model
    torch.cuda.empty_cache()

This allows us to accumulate resource usage records, and plot the results:

Effect of input size on memory usage. Here, we can plainly see
Effect of input size on memory usage. Here, we can plainly see ConformerCTC and ConformerTransducer ramping aggressively in their memory usage, eventually running out of memory around the 7-minute mark, which has some interesting implications for monologue-style input audio.
Word Error Rates for each model, single podcast. Much of muchness here, with the conformer variant edging out the competition, but only by 0.05 points. Code not included.
Word Error Rates for each model, single podcast. Much of muchness here, with the conformer variant edging out the competition, but only by 0.05 points. Code not included.
Effect of input size on elapsed time. Curiously, the
Effect of input size on elapsed time. Curiously, the CitriNet and QuartzNet models appear to scale in a constant fashion.

So what does this tell us? Well, as a first pass it appears:

  • There isn’t a decisive difference in WER rate across the different models. A tenuous claim, given there is only a sample size of 1, but I’ll give it a pass knowing that this is a poor-mans spot check.
  • Given the generally even WER performance across models, in my mind this allows us to shift the focus to practical considerations such as memory and speed usage
  • In which case the choice appears to be between either a CitriNet or QuartzNet model; in which the QuartzNet model has a slight edge since it scales in a (relatively) constant fashion

Prediction Pipeline

Re-sample and reformat audio. This is a hard requirement, as most of the NEMO models are trained on 16kHz wav files, so we’ll need to re-sample and reformat regardless of the original file format (RE: MP3 podcasts) or whether the actual sample rate is something smaller than 16kHz. Below I’ve re-purposed a snippet from the CTC segmentation tutorial, which was originally designed to slice large pieces of audio into batch-sized (a few seconds) form:

import os

def resample_normalise_audio(in_file, out_file, sample_rate=16000):
    if not os.path.exists(in_file):
        raise ValueError(f"{in_file} not found")
    if out_file is None:
        out_file = in_file.replace(os.path.splitext(in_file)[-1], f"_{sample_rate}.wav")

    os.system(
        f"ffmpeg -i {in_file} -acodec pcm_s16le -ac 1 -af aresample=resampler=soxr -ar {sample_rate} {out_file} -y"
    )
    return out_file

Speaker diarization. Whilst NEMO does feature support for speaker diarization via MarbleNet, at the time of writing, I found the associated example to be training-centric and quite verbose. I did notice pyannote as an alternative, which has recently integrated its models within the huggingface repo. The following code applies the diarization model (a 1 liner, yay), unpacks the diarized segments into records, and merges contiguous blocks of diarization into single records:

from pyannote.audio import Pipeline

DIA_MODEL_NAME = "pyannote/speaker-diarization@2022.07"
DIA_MODEL = Pipeline.from_pretrained(DIA_MODEL_NAME)
PAUSE_THRESHOLD = 1
MS = 1000

def diarize_mono_audio(in_file, audio_segment):
    diarization_raw = DIA_MODEL(str(in_file))
    diarized_segments = (
        pd.DataFrame(
            [
                {"start": turn.start, "end": turn.end, "speaker": speaker}
                for turn, _, speaker in diarization_raw.itertracks(yield_label=True)
            ]
        )
        # shift speaker attribution > mark/collapse consecutive speaker segments
        .assign(segment_marker=lambda x: x.speaker.shift(1))
        .assign(segment_marker=lambda x: x.segment_marker != x.speaker)
        .assign(segment_marker=lambda x: pd.Series.cumsum(x.segment_marker))
        # groupby segment, merge audio start/end times
        .groupby("segment_marker")
        .agg(
            {
                "speaker": "first",
                "start": "first",
                "end": "last",
                "segment_marker": "count",
            }
        )
        .rename(
            mapper={"segment_marker": "segment_marker_count"},
            axis="columns",
            inplace=False,
        )
        .assign(segment_len=lambda x: x.end - x.start)
        # TODO: finesse a merging strategy
        .query("segment_len >= @PAUSE_THRESHOLD")
        .reset_index(drop=True)
        .assign(
            audio_segment=lambda x: x.apply(
                lambda y: _assign_child_segment(y, audio_segment), axis=1
            )
        )
    )
    return diarized_segments

def _assign_child_segment(record, parent_audio_segment):
    return parent_audio_segment[record.start * MS : record.end * MS]

in_file = "/home/blog-os-asr/output/temp_dir/rewilding-the-scottish-highlands.wav"
audio_segment = AudioSegment.from_file(in_file)    
diarized_segments = diarize_mono_audio(in_file, audio_segment)

The resulting dataframe looks like this, where I’ve also directly incorporated the relevant slices of audio (audio_segment) for later saving/convenience:

The essence of the pipeline is to slice incoming audio into memory-compatible chunks. Speaker diarization gives us a nice two-for-one in this respect.
The essence of the pipeline is to slice incoming audio into memory-compatible chunks. Speaker diarization gives us a nice two-for-one in this respect.

Max-length segmentation. Whilst we’ve established that QuartzNet can transcribe audio slices over 6 minutes, and applying diarization is a good way to segment multi-speaker audio, there is still a possibility for out-of-memory errors. For example, consider long-form audio that features a single speaker. Therefore, we would still like to safeguard against long, diarized segments. Here, I’ve opted for a naive, silence-based threshold segmentation, though I am aware that this could also be solved with an explicit silence detection model. The gist is that any diarized segment exceeding 4 minutes (generally unlikely within conversational audio), will progressively split on silences defined by increasingly large DBFS values, relative to the particular audio segment. If any diarized segments require splitting, the parent segment is split into an appropriate number of additional records, returning a dataframe:

from pydub import AudioSegment, silence, utils

ASR_LOGGER = logging.getLogger("asr_logger")
ASR_LOGGER.setLevel(logging.INFO)

SECOND_MAX_AUDIO = 240

def _pseudo_optimise_silence_split(audio_segment):
    # note, silence splitting has effect of reducing broader segment > small amounts of drift
    dbfs_min = 10
    dbfs_max = 40
    dbfs_delta = 10
    min_silence_len = 500 # ms
    dBFS = audio_segment.dBFS
    audio_segments = silence.split_on_silence(
        audio_segment, min_silence_len=min_silence_len, silence_thresh=dBFS - dbfs_min
    )
    while (
        pd.Series([e.duration_seconds for e in audio_segments]).median()
        >= SECOND_MAX_AUDIO
        and dbfs_min <= dbfs_max
    ):
        ASR_LOGGER.warning(f"Unable to split segment on silences with silence_thresh of {dBFS - dbfs_min}; re-attempting..")
        dbfs_min += dbfs_delta
        audio_segments = silence.split_on_silence(
            audio_segment,
            min_silence_len=min_silence_len,
            silence_thresh=dBFS - dbfs_min,
        )

    return audio_segments


def segment_utterances(audio_segment_record):
    if audio_segment_record.segment_len > SECOND_MAX_AUDIO:
        silence_splits = _pseudo_optimise_silence_split(audio_segment_record.audio_segment)

        all_splits = []
        for split in silence_splits:
            if split.duration_seconds > SECOND_MAX_AUDIO:
                all_splits.extend(utils.make_chunks(split, SECOND_MAX_AUDIO * MS))
            else:
                all_splits.append(split)

        start_times = []
        start_time = audio_segment_record.start
        # TODO: re-work with cumsum?
        for e in all_splits:
            start_times.append(start_time)
            start_time += e.duration_seconds

        segments = (
            pd.DataFrame(
                [
                    {
                        "audio_segment": e,
                        "speaker": audio_segment_record.speaker,
                        "segment_len": e.duration_seconds,
                    }
                    for e in all_splits
                ]
            )
            .assign(start=start_times)
            .assign(end=lambda x: x.start + x.segment_len)
        )
        return segments
    else:
        return audio_segment_record.to_frame().T

chunked_diarized_segments = diarized_segments.apply(lambda x: segment_utterances(x), axis=1)
chunked_diarized_segments = pd.concat(chunked_diarized_segments.tolist()).reset_index(drop=True)

This returns a dataframe with the same form as the diarization method frame, except with the potential for a few more records that capture the silence-threshold splitting.

ASR. FINALLY, we can apply the ASR model. We iterate through the segments, save each into a temp directory (adjusted for this snippet), collect the path, and pipe the paths into the generic transcribe method of the ASR model, which annoyingly, and AFAIK, only accepts path locations and not arrays of raw audio. We assign the transcription hypotheses directly onto the segment frame. We’re abusing dataframes at this point but keeping everything together in once place is a nice thing.

from pathlib import Path

ASR_MODEL_NAME = "stt_en_quartznet15x5"
ASR_MODEL = nemo_asr.models.ASRModel.from_pretrained(model_name=ASR_MODEL_NAME)
BATCH_SIZE = 4
temp_dir = Path("../output/temp_dir")

paths2audio_files = []  # explicitly sequence, RE: sorted() issues
for idx, record in chunked_diarized_segments.iterrows():
    segment_audio_res = record.audio_segment.export(
        Path(temp_dir) / f"chunk_{idx}.wav", format="wav"
    )
    paths2audio_files.append(str(Path(temp_dir) / f"chunk_{idx}.wav"))

asr_outputs = ASR_MODEL.transcribe(
    paths2audio_files=paths2audio_files,
    batch_size=BATCH_SIZE,
    return_hypotheses=True,
)
chunked_diarized_segments = chunked_diarized_segments.assign(
    asr_outputs=asr_outputs
)

Exchange consolidation and grammar corrections. If you’re still here, awesome, there’s not much left to do besides consolidating individual speaker exchanges and applying a punctuation model to add full stops, question marks, exclamation marks and casing. This is super important for any downstream NLP which assumes such an output. Additionally, we make use of the shift/aggregate pattern used within the diarization method to merge adjacent exchanges owned by the same speaker:

from nemo.collections.nlp.models import PunctuationCapitalizationModel

PUNCT_MODEL_NAME = "punctuation_en_bert"
PUNCT_MODEL = PunctuationCapitalizationModel.from_pretrained(PUNCT_MODEL_NAME)

def _punctuate_collapse_segment(record):
    return {
        "speaker": record.iloc[0].speaker,
        "start": record.start.min(),
        "end": record.end.max(),
        "transcript": PUNCT_MODEL.add_punctuation_capitalization(
            [" ".join(record.asr_outputs.apply(lambda x: x.text).tolist())]
        )[0],
    }

punctuated_exchanges = pd.DataFrame(
        chunked_diarized_segments.assign(segment_marker=lambda x: x.speaker.shift(1))
        .assign(segment_marker=lambda x: x.segment_marker != x.speaker)
        .assign(segment_marker=lambda x: pd.Series.cumsum(x.segment_marker))
        .groupby("segment_marker")
        .apply(_punctuate_collapse_segment)
        .tolist()
    )

The final transcription dataframe looks something like this:

Some obvious errors (
Some obvious errors (jermy, legit, invaness, loc ness, a a), but otherwise a very readable transcript!

Evaluation and comparison

Cloud providers. Now that we have an open-source ASR solution, we’d like to benchmark it against the cloud providers. This will involve creating separate ingestion pipelines for each provider, but the gist of the pipeline will largely be the same:

  1. Create blob storage and upload audio. Since we’re working with fairly lengthy audio (5 min < audio length < 30 min).
  2. Asynchronous, non-streaming, individual ASR invocations. Where most cloud providers provide variations on their ASR offering WRT whether the API is invoked as streaming/non-streaming, async/sync, or individual/batch. Since we’re working with static audio for analytical purposes a fair choice would be to use non-streaming, asynchronous methods where possible. Here, I’ve chosen to use individual ASR calls to track the elapsed time for each transcription, almost certainly at the expense of using a more efficient batch method.
  3. Format output. Into a series of uniform transcript records (dictionary/JSON) which can be easily manipulated as part of the evaluation.

I won’t go into too much detail for each specific cloud provider, since it’s largely a matter of tweaking the example code each provider supplies, but some brief gotcha notes below:

  • AWS. Example implementation here. Does require the use of demo tools CustomWaiter, WaitState classes, which are contained within a separate file here, which I’ve included within the notebooks dir of the repo.
  • Google. Example implementation here. MP3 support via the beta version of the python client, though otherwise requires an MP3 to WAV conversion, which has been applied as part of every pipeline. Bit of a meh, but recall that the NEMO models all require 16khz WAV as input in the first place, so we’re kind of already there. In any case, the audio quality will only artificially be of WAV quality. Additionally, upon eye-balling the results I did notice that Google STT transcripts were often missing large chunks of transcript sections, which we’ll no doubt see reflected within the final evaluation.
  • Azure. Example implementation here. Requires manually configuring (via a third-party swagger API generator?) the python client library, downloading and installing the source code. Truly bizarre, and say goodbye to dockerising any Azure solution easily I suppose?

WER evaluation. Now that we have a series of directories containing provider-specific transcripts we just need to glue them together and plot the results:

import json
import pandas as pd

def load_transcripts(transcript_dir):
    records = []
    for e in transcript_dir.rglob("*.json"):
        record = json.loads(e.read_text())
        record["stem"] = e.stem
        records.append(record)
    return pd.DataFrame(records)


aws = load_transcripts(Path("../output/radio_national_podcasts/transcripts/aws"))
azure = load_transcripts(Path("../output/radio_national_podcasts/transcripts/azure"))
gcp = load_transcripts(Path("../output/radio_national_podcasts/transcripts/gcp"))
os = load_transcripts(Path("../output/radio_national_podcasts/transcripts/os"))

ground_truth = transcript_manifest.pipe(lambda x: x[["transcript", "stem"]])

wer_frames = []
for provider in [aws, azure, gcp, os]:
    wer_frames.append(
        pd.merge(ground_truth, provider, how="inner", on="stem").assign(
            wer=lambda x: x.apply(lambda y: wer(y.transcript, y.hypothesis), axis=1)
        )
    )

eval_res = (
    pd.concat(
        [
            e[["elapsed_time", "wer"]].describe().assign(provider=e.iloc[0].provider)
            for e in wer_frames
        ]
    )
    .reset_index()
    .pipe(lambda x: x[x["index"].str.contains("mean|min|50%|max")])
    .rename(mapper={"index": "metric"}, axis="columns", inplace=False)
)
Effect of ASR provider on transcription time. On a per-file aggregate, our open-source solution is nearly an order of magnitude faster than the cloud providers, though it’s likely the cloud providers would come out on top if we were to properly parallelise the pipelines via the use of batch enrichment methods.
Effect of ASR provider on transcription time. On a per-file aggregate, our open-source solution is nearly an order of magnitude faster than the cloud providers, though it’s likely the cloud providers would come out on top if we were to properly parallelise the pipelines via the use of batch enrichment methods.
Effect of ASR provider on transcription WER. As noted within the preliminary notes, GCP’s WER is generally disastrous (~45% median) and highly variable (30%-70% upper/lower ) due to the patchy transcription we observed before. AWS and Azure set the lower-bound and are competitive with one another, featuring median WERs of ~20%. Our OS solution is about the middle of the pack, at ~30%.
Effect of ASR provider on transcription WER. As noted within the preliminary notes, GCP’s WER is generally disastrous (~45% median) and highly variable (30%-70% upper/lower ) due to the patchy transcription we observed before. AWS and Azure set the lower-bound and are competitive with one another, featuring median WERs of ~20%. Our OS solution is about the middle of the pack, at ~30%.
ASR provider and various transcription WER metrics. In absolute terms, Azure is the most performant STT provider, followed by AWS, OS and then GCP.
ASR provider and various transcription WER metrics. In absolute terms, Azure is the most performant STT provider, followed by AWS, OS and then GCP.

So what to make of these results? Well, I’d surmise that:

  • Azure is the most performant in absolute terms. Bit of a shock to me tbh, but here we are.
  • My GCP pipeline probably needs to be revisited; as a 45% out-of-the-box WER is completely unacceptable for a large cloud provider, and I suspect there are some bugs here.
  • Our OS implementation is relatively competitive with the cloud providers, and is just acceptable in absolute terms, on a general rule of thumb that 30%+ WER is a significant performance degradation bordering on unusable

The end

We’ve covered a lot of ground and shown how we can create a competitive alternative to cloud-based ASR using open-source utilities and models. ASR is a vast field, with lots of theoretical/practical details that I couldn’t possibly hope to capture within this series. In appreciation of this, the following is a list of known issues/possible improvements:

  • Improve spot-checking rigour. An obvious one, in the process of selecting a NEMO model to use we only evaluated the WER rate on a single podcast. Ideally, we’d evaluate on a representative sample of the production data, or at least tweak the spot-check mechanism to evaluate/average across all of the scraped podcasts.
  • Silence splitting optimisation. Whilst it’s probably rare to encounter conversational audio longer than the currently specified maximum length (4 minutes), it’s probably more realistic to swap out the QuartzNet model for something more memory intensive instead, in which case we’d still increase the likelihood of post-diarization segmentation. In such an event, the silence threshold segmentation method is liable to greater scrutiny, and would probably benefit from using a model or a more rigorous heuristic.
  • CTC decoding. An interesting feature returned across all commercial cloud APIs is word-by-word time-stamps. Naturally, we were more concerned with reductions across these very detailed outputs (focusing on per-speaker exchanges, and whole-of-transcript outputs), but to further align our OS implementation with the commercial providers we could take the raw hypotheses given by NEMO ASR models, score, decode and time-stamp them with something like CTC decoding. Naturally, such an extension is only applicable to CTC-based models, so.. not QuartzNet, which complicates the otherwise model-agnostic implementation we have.
  • Speaker/channel adaptations. The above implementation assumes mono-summed audio as input and relies on diarization to identify speakers. Real-world data will often split speakers across channels explicitly, which would remove the need for model-based diarization, which is both expensive and somewhat risky (RE: model-based errors). Explicitly factoring for n-channel adaptions would remove this risk. Though we’d probably still need to keep some sort of segmentation method handy for long outputs.
  • Env/dependency management. Frequently, I felt like I was fighting conflicting dependencies when developing the OS alternative. Partially, this is because for this series I’ve been using jarvislabs GPU servers, which make use of conda/cuda 11.1, which is quite fragile and prone to torch._C errors if you let rip with native pyannote or nemo_toolkit installs. Should have used a docker image blah blah blah.
  • Batch size and max-length model/hardware optimisation. During the spot check, we were able to roughly determine what some of the maximum, single input sizes for some models were, using a Quadro RTX 5000 with 16GB of memory. These preliminary results are great for this machine, but ideally, we could generalise these types of limits/optimal batch sizes for common hardware/memory configurations to improve throughput in a hardware-agnostic way.
  • Telephony specialisation. This would be a welcome improvement, seeing as this variation of data crops up all the time in industry, luckily NEMO features an example ASR training run explicitly aimed at capturing the sample nuances of telephony, more details here.
  • Remove click-ops and automate. A cloud evaluation script is conspicuously missing from the repo, this is because of some fairly awkward set-up involving object storage, resource creation, IAM permission-setting etc. Also, I was running out of time during the evaluation phase, so in a bit of a rush.

So.. lots to do! Anyway, you can find all of the code here, which includes a few Radio National podcasts to partially replicate and test the evaluation scripts. I’ve also included a script to transcribe novel audio, which you can find here.

Banner art developed with stable diffusion.