A perfectly cromulent multi-label text classifier
🎍

A perfectly cromulent multi-label text classifier

What and why

Text classification is a staple of modern NLP, with possible use-cases spanning search, patent classification, news article classification and content moderation. The essence of the task is to assign some input text (sentence, paragraph, document) zero, one or many topically relevant labels. A key operational distinction is whether “one/many” labels are assigned, this being the difference between multi-class/label classification variants. The sklearn documentation has some useful notes here:

Multiclass classification. Is a classification task with more than two classes. Each sample can only be labelled as one class.
Multilabel classification. Is a classification task labelling each sample with m labels from n_classes possible classes, where m can be 0 to n_classes inclusive. This can be thought of as predicting properties of a sample that are not mutually exclusive.

In any case, I’ve found over time that a richer analytic experience is usually created by the use of multi-label classifiers, which tend to capture the complexity of the input better. The rest of this article will attempt to bootstrap a multi-label topic classifier using AusFinance subreddit data, host the results on WandB and host the models via a Huggingface Space.

Scraping AusFinance

Since we need a static dataset which does not need to sync with live Reddit feeds and does not need to comprehensively cover the entirety of Reddit, I opted to use PSAW, a Python wrapper for the Pushshift project, as a way to search and retrieve Reddit submissions and comments. Here, a sensible approach might be to collect x submissions from a selection of y subreddits and then collect all associated comments within those submissions.

I originally had AusFinance in mind, though in the interests of building a balanced dataset I’d like to incorporate some other, similar subreddits. For generating similar subreddits, I’m a fan of Anvaka’s similarity recommender, a graph-based variant also exists here. So now we have a collection of related subreddits and an instance of the Pushshift API:

import math
from datetime import datetime, timedelta

import pandas as pd
from psaw import PushshiftAPI
from tqdm import tqdm

pushshift_client = PushshiftAPI()
last_month_start_epoch = int((datetime.now() - timedelta(days=30)).timestamp())
total_submission_limit = 100
reddit_query = ""

subreddits = [
    "fiaustralia",
    "ASX_Bets",
    "ausstocks",
    "AusProperty",
    "AusFinance",
    "ausstocks",
    "AusEcon",
    "AusPropertyChat",
    "ASX",
    "AustralianAccounting",
]
per_subreddit_limit = math.ceil(total_submission_limit / len(subreddits))

Which we can use to iterate over our subreddits, collecting a total_submission_limit / n_subreddits number of submissions that have been posted to the subreddit in the past 30 days:

all_subreddit_submissions = []

for subreddit in tqdm(subreddits, desc=f"Collecting {per_subreddit_limit} submissions for each subreddit.."):
    submission_raw = list(
        pushshift_client.search_submissions(
            q=reddit_query,
            after=last_month_start_epoch,
            subreddit=subreddit,
            filter=["url", "author", "id", "parent_id", "link_id", "title", "subreddit"],
            limit=per_subreddit_limit,
        )
    )
    submissions_formatted = pd.DataFrame([e.d_ for e in submission_raw])
    all_subreddit_submissions.append(submissions_formatted)

all_subreddit_submissions = pd.concat(all_subreddit_submissions)

After which, we’ll collect all comments available within each of these submissions:

submissions_and_comments = []

for idx, record in tqdm(all_subreddit_submissions.head(10).iterrows(), total=all_subreddit_submissions.shape[0], desc="Collecting submission comments.."):
    comments_raw = list(
        pushshift_client.search_comments(
            after=last_month_start_epoch,
            subreddit=record.subreddit,
            link_id=record.id,
            filter=["url", "author", "id", "parent_id", "title", "body", "subreddit"],
        )
    )
    comments_formatted = pd.DataFrame([e.d_ for e in comments_raw])

    submissions_and_comments.append(pd.concat([record.to_frame().transpose(), comments_formatted], sort=True))

And now we have a dataset that looks like this:

image

Fair warning if you’re thinking of re-running the above code that it took ~20 minutes to collect 3k records; probably because of my inefficient for-loop and the PSAW 60 r/m rate limit (which is the same as the official API rate limit?). So things are slow, not comprehensive, and probably out of sync with Reddit IRL, but this is fine for our purposes since we just need a dataset to play with.

Bootstrapped Annotations

Within this project, we’ll be using a two-step annotation process to iteratively develop our dataset.

Quota-driven annotation. Within this first step, the essence of what we’re trying to do is:

  1. Obtain a series of cheap predictions
  2. Use these cheap predictions to transform the annotation task into a series of simpler, per-label binary confirmation sub-tasks

Here, I’ll be using another project of mine called clear-bow to generate these predictions, though any model capable of returning key-value style prediction results could be used (ie. most models). Assuming a dataframe, with a text_col and an optional, pre-existing label_col, we apply the cheap classifier, generating “soft” predictions (floats) that we consolidate with pre-existing “hard” predictions (ints) if they exist.

The soft/hard prediction distinction is important to ensure that any previous annotation work is preserved across subsequent iterations and we don’t have to “start from zero”. Additionally, using models that generate fuzzy confidences is essential, which interestingly, excludes models like linear SVCs which truly classify instances.

import pandas as pd


def consolidate_hard_soft_labels(label_objects):
    # defer to hard labels where they exist, otherwise average soft labels
    label_objects = [
        object for object in label_objects if type(object) == dict]

    flattened_object = {}
    all_labels = pd.DataFrame(label_objects)
    for label in all_labels:
        # hard reject
        if -1 in set(all_labels[label]):
            flattened_object[label] = -1
        # hard accept
        elif 1 in set(all_labels[label]):
            flattened_object[label] = 1
        # mean otherwise
        else:
            flattened_object[label] = all_labels[label].mean()

    return flattened_object


if label_col in df:
    # consolidate old verifications with new predictions, defer to "hard" old predictions where they exist
    consolidated_labels = []
    for hard, soft in zip(
        df[label_col].tolist(), df[text_col].apply(
            model.predict_single).tolist()
    ):
        if type(hard) != dict:
            # RE: iteratively appending datasets
            consolidated_labels.append(soft)
        else:
            consolidated_labels.append(
                consolidate_hard_soft_labels([hard, soft]))
    df[label_col] = consolidated_labels
else:
    # otherwise create new prediction
    df[label_col] = df[text_col].apply(model.predict_single)

Now we, “explode” our original dataframe across the keys of the consolidated label objects, and focus our annotation efforts on each label category (pred_labels), instead of the individual record. This changes the annotation workflow from a for each record, verify all labels style of exercise to a given x candidate examples of y label, positively verify z examples style of exercise. Taking some cues from a prodigy annotation guide, I think this otherwise structurally convoluted approach is justified because of how it narrows the scope of the annotation experience. IMO, it’s an easier mental exercise to apply binary confirmation across explicitly filtered groups of similar, probable, candidate examples than it is to annotate the entire label space across each record, which is typically a very sparse annotator experience.

all_verifications = []
for label in pred_labels:
    # only use unseen examples, if verification field exists
    seen_examples = None
    if label_col in df:
        seen_examples = df[
            df[label_col].apply(lambda y: True if type(
                y[label]) == int else False)
        ]
        unseen_examples = df[
            ~df[label_col].apply(lambda y: True if type(
                y[label]) == int else False)
        ]
        if seen_examples.shape[0] > 1:
            msg.info(
                f"{seen_examples.shape[0]} pre-existing, positive examples found for label: '{label}'"
            )

    # reduce if necessary
    adjusted_n_examples = (
        n_examples - seen_examples.shape[0]
        if seen_examples.shape[0] > 0
        else n_examples
    )

Additionally, using the prediction confidence, we remove low-confidence predictions below a threshold and rank the candidate examples. We also place an upper limit on how many candidates to examine for any given label category before progressing on to the next. Capping the number of candidates is a practical way to ensure we don’t have to iterate through every document for misaligned, low-candidate label categories and also serves as a useful point of feedback when prototyping label schemas; if there are no candidate examples within your dataset for a specific label then it’s probably worth re-examining why.

# only annotate examples with relatively high confidence
annotate_input = (
    unseen_examples
    # only examine examples above prediction threshold
    .pipe(
        lambda x: x[
            x[label_col].apply(
                lambda y: True if y[label] >= prediction_thresh else False
            )
        ]
    )
)

if rank_candidates:
    # rank by prediction confidence
    annotate_input = annotate_input.reset_index(drop=True).pipe(
        lambda x: x.iloc[
            x[label_col]
            .apply(lambda x: x[label])
            .sort_values(ascending=False)
            .index
        ]
    )
else:
    # shuffle candidates otherwise
    annotate_input = annotate_input.sample(frac=1.0, random_state=42)

# take the first n=max_candidate records
annotate_input = annotate_input.head(max_candidates)

We can f i n a l l y, accept or reject the filtered, per-label candidate examples:

import copy

import pandas as pd
from wasabi import msg


def binary_confirm_n_label_objects(
    df,
    label_col,
    label_object_key,
    accept_value="y",
    reject_value="n",
    n_examples: int = 10,
):
    # given a list of records, positively verify (binary confirmation) across a selected field until examples run out/quota reached. Return all annotations.
    updated_records = []
    msg.info(
        f"Input: '{accept_value}' to accept value, input {reject_value} to reject value. "
        + f"Label objects saved as '{label_col}' field within all records. "
        + f"Loop will break when '{n_examples}' positively affirmed or examples run out, whichever first. {df.shape[0]} candidates supplied.\n"
    )
    for idx, record in df.iterrows():
        # early break if n_examples exists
        if (
            len(
                [
                    record
                    for record in updated_records
                    if record[label_col][label_object_key] == 1
                ]
            )
            >= n_examples
        ):
            return pd.DataFrame(updated_records)

        # otherwise, verify n_examples
        d = record.to_dict()
        dc_d = copy.deepcopy(d)  # always use a deep copy
        _ = [print("\033[1m", k, ": ", "\033[0m", v) for k, v in d.items()]
        val = input(f"Instance of {label_object_key}? ")
        verification_remapping = {"y": 1, "n": -1, "": 0.0, " ": 0.0}
        dc_d[label_col][label_object_key] = verification_remapping.get(val)
        updated_records.append(dc_d)
        os.system("clear")
    return pd.DataFrame(updated_records)


annotations = binary_confirm_n_label_objects(
    annotate_input, label_col, label, n_examples=adjusted_n_examples
)

Before consolidating the exploded records back into single-document records.

Multi-label backfilling. So at this point, we’ll probably have a label set which looks something like this:

|   workplace |   property |   tax |   insurance |   super |
|------------:|-----------:|------:|------------:|--------:|
|           1 |          0 |     0 |           0 |       0 |
|           0 |          1 |     0 |           0 |       0 |
|           0 |          0 |     1 |           0 |       0 |
|           0 |          0 |     0 |           1 |       0 |
|           0 |          0 |     0 |           0 |       1 |

This is perfect for multi-class classification, where only a single label per example is required, but since we’re building a multi-label dataset we also need to backfill across the label space to create something which looks more like this:

|   workplace |   property |   tax |   insurance |   super |
|------------:|-----------:|------:|------------:|--------:|
|           1 |          0 |     0 |           1 |       0 |
|           0 |          1 |     0 |           1 |       0 |
|           0 |          0 |     1 |           0 |       0 |
|           0 |          0 |     0 |           1 |       0 |
|           0 |          0 |     1 |           0 |       1 |

And to ensure we’re capturing the multi-label-ness of our dataset. We can achieve this by applying a similar predict/review workflow as described above. We start by consolidating the hard/soft label/predictions across examples:

consolidated_labels = []
for hard, soft in zip(
    df[label_col].tolist(), df[text_col].apply(model.predict_single).tolist()
):
    # using fresh predictions, consolidate label space
    if type(hard) != dict:
        # RE: iteratively appending datasets
        consolidated_labels.append(soft)
    else:
        consolidated_labels.append(consolidate_hard_soft_labels([hard, soft]))

However this time, we’ll enumerate all examples that exceed prediction_thresh, instead of just the max_candidates we saw initially:

for label, (idx, label_object) in itertools.product(
    target_label_space, enumerate(consolidated_labels)
):
    if label_object[label] in [-1, 1]:
        # 1. pre-existing hard label, no change
        continue

    elif label_object[label] >= prediction_thresh:
        # 3. otherwise, some difference in labels, as proposed by model
        os.system("clear")
        msg.info(f"**** Verifying all additional instances of: {label} ****")
        msg.text(f"\n\033[1mText: \033[0m \n{df.iloc[idx][text_col]}\n")
        msg.text(f"\033[1mInstance of {label}?\033[0m")
        confirmation = input()

        if confirmation == "n":
            label_object[label] = -1

        elif confirmation == "y":
            # assign in place.. yikes
            label_object[label] = 1

The actual dataset. Applying the above in practice resulted in 2 data iterations, some details about each iteration are below:

  • Iteration 1. I manually perused a sample of documents from each subreddit and created an initial label scheme (below) this way. We’ll set a pretty low prediction threshold of 0.75, with n_examples=5 and max_candidates=10. After this initial annotation, I opted to remove the exchange and public_institution categories due to low candidate examples. I also opted to reduce the surface area of the property definition because of the number of false positives I observed during annotation.
workplace: ['WFH', 'boss', 'co-workers', 'culture', 'hybrid', 'life balance', 'office']
property: ['afford', 'agent', 'auction', 'bedroom', 'boom', 'builder', 'buy', 'defect', 'floor plan', 'house', 'landlord', 'layout', 'loan', 'mortgage', 'property', 'rate', 'real estate', 'refinance', 'rent', 'resident', 'salary', 'townhouse']
tax: ['gst', 'land tax', 'salary sacrifice', 'tax']
insurance: ['income protection', 'indemnity', 'insurance']
super: ['balance', 'contribution', 'fund', 'pension', 'retire', 'self-funded', 'super', 'after tax']
public_institution: ['bond', 'central bank', 'fair work', 'mint', 'rba', 'watch dog']
inflation: ['inflation', 'interest rates', 'petrol', 'phillip lowe', 'rba', 'reserve bank']
exchange: ['dollar', 'exchange', 'rate']
stocks: ['200', 'assets', 'asx', 'buy', 'commsec', 'dip', 'dividends', 'etf', 'high growth', 'indexed', 'invest', 'return', 'securities', 'selfwealth', 'shares', 'stock', 'van guard', 'vdhg', 'wealth']
toxic: ['bro', 'butt', 'fool', 'fuck', 'laughable', 'lol', 'salty', 'shit', 'tard']
  • Iteration 2. We’ll continue to use the labelling scheme from iteration 1, though in hindsight, I think I set the threshold filtering a little high, so we'll lower this back to 0.7 and re-start the annotation across the original aus_finance dataset to ensure we’re “catching” a broader range of examples. We’ll also double n_examples to 10, and increase the size of the test split to 0.5 to improve the support within the final, test classification reports.

Model Training and Evaluation

The models. I’m a big fan of creating baselines with cheap models and ratcheting up the complexity only when necessary, so we’ll be spot-checking a variety of models ranging from cheap to expensive:

  • Dictionary classifier. As “the cheapest” model to be evaluated. Originally used to generate our annotations.
  • Sklearn Linear SVC. As our “mid-range” model, a reliable and timeless classic.
  • Flair TARS. Task-Aware Representation of Sentences for Generic Text Classification is a very cool, very fun and relatively new technique that can be used for zero and few-shot learning. The gist of the method is to bundle NLP tasks together and to re-phrase semantic information provided by the class labels at the point of prediction as an auxiliary, binary classification task. Unfortunately, we won’t be throwing multi-task data examples at TARS, but we will be focusing on few-shot learning some labelled examples, as this allows us to evaluate TARS alongside traditionally supervised approaches.

Walking through the fit/predict code for each of these models is a little tedious and also features lots of annoying WandB boilerplate, so feel free to suss the implementations here, here and here for each model respectively. I will briefly jump into the main training and evaluation script, however. First, we load the training config, and dataset, and create our train/dev/test splits:

from pathlib import Path

import pandas as pd
import wandb
import yaml
import wandb

api = wandb.Api()
CONFIG = yaml.safe_load(
    (Path(__file__).parents[0] / "train_config.yaml").read_bytes())

# 1. create/log splits
df = pd.read_csv(CONFIG["dataset"])
train_split, test_split = create_multi_label_train_test_splits(
    df, label_col=CONFIG["label_col"], test_size=CONFIG["test_size"]
)
test_split, dev_split = create_multi_label_train_test_splits(
    test_split, label_col=CONFIG["label_col"], test_size=CONFIG["test_size"]
)
with wandb.init(
    project=CONFIG["wandb_project"],
    name="reddit_aus_finance",
    group=CONFIG["wandb_group"],
    entity="cool_stonebreaker",
) as run:
    log_dataframe(run, train_split, "train_split", "Train split")
    log_dataframe(run, dev_split, "dev_split", "Dev split")
    log_dataframe(run, test_split, "test_split", "Test split")

The train/dev/test split method makes use of sci-kit multi-learn’s iterative stratification method to ensure that each of our splits contains roughly the same amount of labels, a surprisingly complex task when dealing with multi-label data. We also log each of our train/dev/test splits as WandB artefacts within their own separate, group-tagged run. This ensures that for every subsequent round of annotations (group), we explicitly version our train/dev/test splits. We then fit, evaluate and log the config/binaries/performance of each of our models:

from model.dictionary import fit_and_log_dictionary_classifier
from model.flair_tars import fit_and_log_flair_tars_classifier
from train.model.sklearn_linear_svc import fit_and_log_sklearn_linear_svc_classifier

for model_config in CONFIG["models"]:
    if model_config["type"] == "dictionary_classifier":
        fit_and_log_dictionary_classifier(
            test_split=test_split, CONFIG=CONFIG, model_config=model_config)

    elif model_config["type"] == "sklearn_linear_svc":
        fit_and_log_sklearn_linear_svc_classifier(
            train_split=train_split,
            dev_split=dev_split,
            test_split=test_split,
            CONFIG=CONFIG,
            model_config=model_config,
        )

    if model_config["type"] == "flair_tars":
        fit_and_log_flair_tars_classifier(
            train_split=train_split,
            dev_split=dev_split,
            test_split=test_split,
            CONFIG=CONFIG,
            model_config=model_config,
        )

    else:
        print(f"Unsupported model: {model_config['type']} found")

Where specifically, we’re logging the test predictions, test classification report, and the aggregate, weighted F1 as summary metrics across each run, to allow for course/fine inter/intra model comparisons. Some example logging below:

import wandb

def log_dataframe(run, df, name, description):
    # any type of df within a run
    df_artifact = wandb.Artifact(name, type="dataset", description=description)
    df_artifact.add(wandb.Table(dataframe=df), name=name)
    run.log_artifact(df_artifact)

# log
log_dataframe(run, test_preds, "test_preds", "Test predictions")
log_dataframe(
    run,
    classification_report,
    "test_classification_report",
    "Test classification report",
)
run.summary["test_f1"] = classification_report.query('label == "weighted avg"')[
    "f1-score"
].iloc[0]
run.summary["test_support"] = classification_report.query(
    'label == "weighted avg"'
)["support"].iloc[0]

Inter/Intra Model Comparisons

So far, so good, but I realise I’ve skipped over some of the reasons as to why the training/evaluation runs are set up as they are. In the course of this project, the main artefacts I’m interested in are:

  • Intra-group model performance. Within each round of annotations, compare the P/R/F1 of each of our models with one another. These comparisons are useful to understand how different types of models (cheap, classic, expensive/experimental) perform against each other. IMO, this is because we should only use expensive models when we have to; and if we can prove something cheaper works just as well then we should do that instead.
  • Inter-group model performance. Across multiple rounds of annotations, take the most performant model from each annotation round (via aggregate, weighted F1 scores), and compare performance across successive rounds of annotations. These comparisons are useful to understand how we’re generally tracking across the problem and whether we need to pursue additional rounds of annotation.
  • Provenance artefacts. Including things such as the original group train/dev/test splits previously discussed, model config (dictionary label schema), model binaries and test predictions. Provenance artefacts are useful for replicating our experiments at a later time and documenting important intermediate steps/calculations, like raw test-set predictions. I like to think of provenance artefacts as taking out a type of insurance in case we need to re-trace, debug or replicate specific parts of our pipelines.

The good news is that logging the “dumb” provenance artefacts is super easy! The bad news is that I found it difficult to plot the inter/intra-group model performance using the out-of-the-box WandB tools.

I noted that WandB explicitly supports group/run-centric bar charts, though for some strange reason you must specify an aggregation mechanism across the groups whereas we need to keep each member value intact, but still grouped. I noted that WandB also supports grouped bar charts as well, though by this point I had opted to manually recreate/log the intra/inter performance charts with plotly. Annoyingly, because of the way I had structured each model as its own run, I also needed to filter/retrieve the relevant provenance artefacts (classification reports) for each model before creating the intra/inter comparison plots, which feature as their runs.

import pandas as pd

def list_all_project_artifacts(api, CONFIG):
		# lame
    runs_artifacts = (
        pd.DataFrame(
            [
                {**{"run": run}, **run.__dict__["_attrs"]}
                for run in api.runs(
                    path=f"{CONFIG['wandb_entity']}/{CONFIG['wandb_project']}"
                )
            ]
        )
        .assign(
            artifacts=lambda x: x.run.apply(
                lambda y: [
                    {**{"artifact": e}, **e.__dict__} for e in y.logged_artifacts()
                ]
            )
        )
        .pipe(lambda x: x.explode("artifacts"))
        .reset_index(drop=True)
        .query('state == "finished"')
        .pipe(lambda x: x[x.config.apply(lambda y: len(y) >= 1)])
    )

    return pd.concat(
        [
            runs_artifacts,
            runs_artifacts.config.apply(pd.Series),
            runs_artifacts.artifacts.apply(pd.Series),
        ],
        axis=1,
    ).pipe(lambda x: x[["run", "type", "group", "_sequence_name", "artifact"]])

def parse_wandb_table_artifact(artifact):
    # urgh
    with tempfile.TemporaryDirectory() as temp_dir:
        download_dir = artifact.download(temp_dir)
        file = list(Path(download_dir).glob("*.json"))[0]
        table_json = srsly.read_json(file)
        return pd.DataFrame(table_json["data"], columns=table_json["columns"])

def log_intra_group_model_comparisons(project_artifacts, CONFIG):
		# yuck
    group_model_classification_reports = []

    # format, concat
    for idx, record in (
        project_artifacts.query('_sequence_name == "test_classification_report"').pipe(
            lambda x: x[x.group == CONFIG["wandb_group"]]
        )
    ).iterrows():
        group_model_classification_reports.append(
            (parse_wandb_table_artifact(record.artifact).assign(type=record.type))
        )
    group_model_classification_reports = pd.concat(group_model_classification_reports)

    # create plot
    fig = px.bar(
        (
            group_model_classification_reports.pipe(
                lambda x: x[~x["label"].str.contains("accuracy|samples|macro|micro")]
            )
        ),
        x="label",
        y="f1-score",
        color="type",
        barmode="group",
    )

    # log plot
    with wandb.init(
        project=CONFIG["wandb_project"],
        name=f"{CONFIG['wandb_group']}_intra_group_model_comparison",
        group=CONFIG["wandb_group"],
        entity=CONFIG["wandb_entity"],
        job_type="aux_plot",
    ) as run:
        run.log({f"{CONFIG['wandb_group']}_intra_model_comparison": fig})

Zooming back out to the main train script, where we thankfully don’t have to deal with this gross boilerplate so much:

import wandb
from eval_util import (list_all_project_artifacts,
                       log_inter_group_model_comparisons,
                       log_intra_group_model_comparisons)

# 3. log intra-model comparisons for current group
project_artifacts = list_all_project_artifacts(api, CONFIG)
log_intra_group_model_comparisons(project_artifacts, CONFIG)

# 4. update inter-group model comparisons
api = wandb.Api()
_ = [
    run.delete()
    for run in api.runs(path="cool_stonebreaker/tyre_kick")
    if run.name == "inter_group_model_comparison"
]
log_inter_group_model_comparisons(project_artifacts, CONFIG)

Which, after burning through a few data annotation runs will eventually give us a report which looks like this:

Screenshots because exporting static reports from WandB with custom plotly charts results in all sorts of mangled mischief.
Screenshots because exporting static reports from WandB with custom plotly charts results in all sorts of mangled mischief.

Looks pretty good! I reckon the flair_tars model from annotation group 2 is where we need it to be. You can check out the hosted report here.

Model Deployment

I use the word “deployment” generously because all we want to do is publish the best-performing models into a huggingface space. To do so, I re-purposed the calculator example, and have plugged each model from the annotation run x as an “operation” I guess:

import gradio as gr
import srsly
from clear_bow.classifier import DictionaryClassifier
from flair.data import Sentence
from flair.models import TARSClassifier
from joblib import load

loaded_models = {
    "dictionary_classifier": DictionaryClassifier(
        classifier_type="multi_label",
        label_dictionary=srsly.read_json(
            "./model_files/dictionary_classifier/label_dictionary.json"
        ),
    ),
    "sklearn_linear_svc": load("./model_files/sklearn_linear_svc/model.joblib"),
    "flair_tars": TARSClassifier.load("./model_files/flair_tars/final-model.pt"),
}


def predict_dictionary_classifier(reddit_comment, dictionary_classifier_model):
    return dictionary_classifier_model.predict_single(reddit_comment)


def predict_linear_svc(reddit_comment, sklearn_linear_svc_model):
    return dict(
        zip(
            sklearn_linear_svc_model.multi_label_classes_,
            sklearn_linear_svc_model.predict([reddit_comment])[0].toarray()[0],
        )
    )


def predict_flair_tars(text, flair_tars_model):
    sentence = Sentence(text)
    labels = flair_tars_model.get_current_label_dictionary().get_items()
    flair_tars_model.predict(sentence)
    pred_dict = {label: 0.0 for label in labels}
    for e in sentence.labels:
        label = e.to_dict()["value"]
        confidence = round(float(e.to_dict()["confidence"]), 2)
        pred_dict[label] = confidence
    return pred_dict


def model_selector(model_type, reddit_comment):
    if model_type == "dictionary_classifier":
        return predict_dictionary_classifier(
            reddit_comment, loaded_models["dictionary_classifier"]
        )
    elif model_type == "linear_svc":
        return predict_linear_svc(reddit_comment, loaded_models["sklearn_linear_svc"])
    elif model_type == "flair_tars":
        return predict_flair_tars(reddit_comment, loaded_models["flair_tars"])


demo = gr.Interface(
    model_selector,
    [gr.Radio(["dictionary_classifier", "linear_svc", "flair_tars"]), "text"],
    "text",
    examples=[
        [
            "dictionary_classifier",
            "Do you really have a $2,080,000 mortgage for an investment property that rents for $700 a week?",
        ],
        [
            "linear_svc",
            "I like the genie analogy. Anecdotally from peers and from own experience, if a role is advertised as 100% in office, then it’s a hard no.",
        ],
        [
            "flair_tars",
            "It’s a seller’s market now, transitioning into a balanced/buyer’s market. Prices are still historically high, but it’s clear the peak is behind us. Rising interest rates doing exactly as expected.",
        ],
    ],
    title="Few-shot multi-label classification",
    description="A comparison of models, ranging from cheap to expensive. Enjoy!",
)

if __name__ == "__main__":
    demo.launch()

For better or for worse, I’ve also incorporated the huggingface space repo as a submodule called deploy within the main multi-label repo. This was designed to keep everything in one place, though thought at times I did experience submodule regret I suppose, especially when flicking code between my local machine and the runpod GPU instance I was using for flair TARS model development.

Anyway, the only thing missing from the gradio showcase at this point is the model files themselves, which are copied over via train/copy_deploy_model_files.py, though if you’re doing something similar be sure to initialise git large file storage on the HF space repo which causes some additional heartache.

Voila

We now have a series of multi-label models trained on Reddit data, which perform pretty well! You can check out the final models here, as well as the development metrics here. And of course, you can find the accompanying repo here.

Banner art developed with stable diffusion.