Suraj-Yadav's picture
Uploading food not food text classifier demo app.py
f141439 verified
raw
history blame
3.37 kB
# 1. Import the required packages
import torch
import gradio as gr
from typing import Dict
from transformers import pipeline
huggingface_model_path = "Suraj-Yadav/learn_hf_food_not_food_text_classifier-distilbert-base-uncased"
# 2. Define function to use our model on given text
def food_not_food_classifier(
text: Union[str, list],
model_path: str,
batch_size: int = 32,
device: str = None,
get_classifier:bool = False
) -> Dict[str, float]:
"""
Classifies whether the given text is related to food or not, returning a dictionary of labels and their scores.
Args:
text (Union[str, list]): The input text or list of texts to classify.
model_path (str): The path to the Hugging Face model for classification.
batch_size (int): The batch size for processing. Default is 32.
device (str): The device to run inference on (e.g., 'cuda', 'cpu'). Default is None (auto-detect best available).
Returns:
Dict[str, float]: A dictionary where the keys are the labels and the values are the classification scores.
"""
if device is None:
device = set_device()
classifier = pipeline(
task="text-classification",
model=model_path,
batch_size=batch_size,
device=device,
top_k=None # Keep all predictions
)
if get_classifier:
return classifier
else:
results = classifier(text) # [[{'label': 'food', 'score': 0.9500328898429871}, {'label': 'not_food', 'score': 0.04996709153056145}]]
output_dict = {}
for output in results[0]:
output_dict[output['label']] = output['score']
return output_dict
def gradio_food_classifier(text: str) -> dict:
"""
A wrapper function for Gradio to classify text using the classify_food_text function.
Args:
text (str): The input text to classify.
Returns:
dict: Classification results as a dictionary of label and score.
"""
classifier = food_not_food_classifier(text=text,
model_path=huggingface_model_path,
get_classifier=True)
results = classifier(text)
output_dict = {}
for output in results[0]:
output_dict[output['label']] = output['score']
return output_dict
# 3. Create a Gradio interface with details about our app
description = """
A text classifier to determine if a sentence is about food or not food.
Fine-tuned from [DistilBERT](https://huggingface.co/distilbert/distilbert-base-uncased) on a [small dataset of food and not food text](https://huggingface.co/datasets/mrdbourke/learn_hf_food_not_food_image_captions).
See [source code](https://github.com/mrdbourke/learn-huggingface/blob/main/notebooks/hugging_face_text_classification_tutorial.ipynb).
"""
demo = gr.Interface(fn=gradio_food_classifier,
inputs="text",
outputs=gr.Label(num_top_classes=2),
title="πŸ—πŸš«πŸ₯‘ Food or Not Food Text Classifier",
description=description,
examples=[["I whipped up a fresh batch of code, but it seems to have a syntax error."],
["A delicious photo of a plate of scrambled eggs, bacon and toast."]])
# 4. Launch the interface
if __name__ == "__main__":
demo.launch()