WB Doc Topics
Collection
This is a collection of models trained on synthetically generated sentences conditional on WBG topics. The models are designed for ensembling.
•
22 items
•
Updated
from transformers import pipeline
from tqdm.auto import tqdm
import pandas as pd
from transformers import AutoTokenizer
class WBGDocTopic:
"""
A class to handle document topic suggestion using multiple pre-trained text classification models.
This class loads a set of text classification models from Hugging Face's model hub and
provides a method to suggest topics for input documents based on the aggregated classification
results from all the models.
Attributes:
-----------
classifiers : dict
A dictionary mapping model names to corresponding classification pipelines. It holds
instances of Hugging Face's `pipeline` used for text classification.
Methods:
--------
__init__(classifiers: dict = None)
Initializes the `WBGDocTopic` instance. If no classifiers are provided, it loads a default
set of classifiers by calling `load_classifiers`.
load_classifiers()
Loads a predefined set of document topic classifiers into the `classifiers` dictionary.
It uses `tqdm` to display progress as the classifiers are loaded.
suggest_topics(input_docs: str | list[str]) -> list
Suggests topics for the given document or list of documents. It runs each document
through all classifiers, averages their scores, and returns a list of dictionaries where each
dictionary contains the mean and standard deviation of the topic scores per document.
Parameters:
-----------
input_docs : str or list of str
A single document or a list of documents for which to suggest topics.
Returns:
--------
list
A list of dictionaries, where each dictionary represents the suggested topics for
each document, along with the mean and standard deviation of the topic classification scores.
"""
def __init__(self, classifiers: dict = None, device: str = None
self.classifiers = classifiers or {}
self.device = device
if classifiers is None:
self.load_classifiers()
def load_classifiers(self):
num_evals = 5
num_train = 5
tokenizer = AutoTokenizer.from_pretrained("avsolatorio/doc-topic-model_eval-04_train-03")
for i in tqdm(range(num_evals)):
for j in tqdm(range(num_train)):
if i == j:
continue
model_name = f"avsolatorio/doc-topic-model_eval-{i:02}_train-{j:02}"
classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None, device=self.device)
self.classifiers[model_name] = classifier
def suggest_topics(self, input_docs: str | list[str]):
if isinstance(input_docs, str):
input_docs = [input_docs]
doc_outs = {i: [] for i in range(len(input_docs))}
topics = []
for _, classifier in self.classifiers.items():
for doc_idx, doc in enumerate(classifier(input_docs)):
doc_outs[doc_idx].append(pd.DataFrame.from_records(doc, index="label"))
for doc_idx, outs in doc_outs.items():
all_scores = pd.concat(outs, axis=1)
mean_probs = all_scores.mean(axis=1).sort_values(ascending=False)
std_probs = all_scores.std(axis=1).loc[mean_probs.index]
output = pd.DataFrame({"score_mean": mean_probs, "score_std": std_probs})
output["doc_idx"] = doc_idx
output.reset_index(inplace=True)
topics.append(output.to_dict(orient="records"))
return topics
import nltk
# Download the nltk data if not present
nltk.download('punkt_tab')
nltk.download('punkt')
from collections import Counter
# Load the sent_tokenize method for quick sentence extraction
from nltk import sent_tokenize
# Process the input
sample_text = """A growing literature attributes gender inequality in labor market outcomes in part to the reduction in female labor supply after childbirth, the child penalty. However, if social norms constrain married women’s activities outside the home, then marriage can independently reduce employment, even in the absence childbearing. Given the correlation in timing between childbirth and marriage, conventional estimates of child penalties will conflate these two effects. The paper studies the marriage penalty in South Asia, a context featuring conservative gender norms and low female labor force participation. The study introduces a split-sample, pseudo-panel approach that allows for the separation of marriage and child penalties even in the absence of individual-level panel data. Marriage reduces women’s labor force participation in South Asia by 12 percentage points, whereas the marginal penalty of childbearing is small. Consistent with the central roles of both opportunity costs and social norms, the marriage penalty is smaller among cohorts with higher education and less conservative gender attitudes."""
sents = sent_tokenize(inp)
# Create the instance which will load the models.
# Set the device to "cuda" if you want to use a GPU.
dtopic_model = WBGDocTopic(device=None)
# Infer the topics and scores
outs = dtopic_model.suggest_topics(sents)
outs
# [[{'label': 'Gender',
# 'score_mean': 0.8776359841227531,
# 'score_std': 0.13074095501538094,
# 'doc_idx': 0},
# {'label': 'Labor Markets',
# 'score_mean': 0.20742715448141097,
# 'score_std': 0.20991565414467345,
# 'doc_idx': 0},
# {'label': "Girls' Education",
# 'score_mean': 0.19432228063233198,
# 'score_std': 0.21148874269682794,
# 'doc_idx': 0}, ...]]
# Get the distribution of the abstract's highly relevant topics per sentence.
# Use a currently arbitrary threshold of 0.1.
Counter([o["label"] for out in outs for o in out if (o["score_mean"] > 0.1 and o["score_mean"] > o["score_std"])]).most_common()