File size: 3,370 Bytes
f141439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# 1. Import the required packages
import torch
import gradio as gr

from typing import Dict
from transformers import pipeline

huggingface_model_path = "Suraj-Yadav/learn_hf_food_not_food_text_classifier-distilbert-base-uncased"

# 2. Define function to use our model on given text 
def food_not_food_classifier(
    text: Union[str, list],
    model_path: str,
    batch_size: int = 32,
    device: str = None,
    get_classifier:bool = False
) -> Dict[str, float]:
    """
    Classifies whether the given text is related to food or not, returning a dictionary of labels and their scores.

    Args:
        text (Union[str, list]): The input text or list of texts to classify.
        model_path (str): The path to the Hugging Face model for classification.
        batch_size (int): The batch size for processing. Default is 32.
        device (str): The device to run inference on (e.g., 'cuda', 'cpu'). Default is None (auto-detect best available).

    Returns:
        Dict[str, float]: A dictionary where the keys are the labels and the values are the classification scores.
    """
    
    if device is None:
        device = set_device() 
    
    classifier = pipeline(
        task="text-classification",
        model=model_path,
        batch_size=batch_size, 
        device=device,
        top_k=None  # Keep all predictions
    )

    if get_classifier:
      return classifier
    else:
    
      results = classifier(text)    # [[{'label': 'food', 'score': 0.9500328898429871}, {'label': 'not_food', 'score': 0.04996709153056145}]]

      output_dict = {}
      for output in results[0]: 
          output_dict[output['label']] = output['score']

      return output_dict


def gradio_food_classifier(text: str) -> dict:
    """
    A wrapper function for Gradio to classify text using the classify_food_text function.
    
    Args:
        text (str): The input text to classify.

    Returns:
        dict: Classification results as a dictionary of label and score.
    """
    classifier = food_not_food_classifier(text=text,
                                          model_path=huggingface_model_path, 
                                          get_classifier=True)

    results = classifier(text)   

    output_dict = {}
    for output in results[0]: 
        output_dict[output['label']] = output['score']

    return output_dict


# 3. Create a Gradio interface with details about our app
description = """
A text classifier to determine if a sentence is about food or not food. 

Fine-tuned from [DistilBERT](https://huggingface.co/distilbert/distilbert-base-uncased) on a [small dataset of food and not food text](https://huggingface.co/datasets/mrdbourke/learn_hf_food_not_food_image_captions).

See [source code](https://github.com/mrdbourke/learn-huggingface/blob/main/notebooks/hugging_face_text_classification_tutorial.ipynb).
"""

demo = gr.Interface(fn=gradio_food_classifier, 
             inputs="text", 
             outputs=gr.Label(num_top_classes=2),
             title="πŸ—πŸš«πŸ₯‘ Food or Not Food Text Classifier",
             description=description,
             examples=[["I whipped up a fresh batch of code, but it seems to have a syntax error."],
                       ["A delicious photo of a plate of scrambled eggs, bacon and toast."]])

# 4. Launch the interface
if __name__ == "__main__":
    demo.launch()