sapthesh commited on
Commit
0e79c88
Β·
verified Β·
1 Parent(s): 8efe266

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -20,31 +20,44 @@ config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_
20
  print(f"Loading model from {model_name}...")
21
  model = CustomModel.from_pretrained(model_name, config=config, revision=revision, trust_remote_code=True)
22
 
 
23
  if model is None:
24
  print("Failed to load model. Exiting...")
25
  exit(1)
26
  else:
27
  print("Model loaded successfully.")
28
 
 
29
  def classify_text(text):
30
- inputs = tokenizer(text, return_tensors="pt")
31
- outputs = model(**inputs)
32
- logits = outputs.logits
33
- probabilities = torch.softmax(logits, dim=-1).tolist()[0]
34
- predicted_class = torch.argmax(logits, dim=-1).item()
35
- return {
36
- "Predicted Class": predicted_class,
37
- "Probabilities": probabilities
38
- }
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Create a Gradio interface
41
  try:
42
  iface = gr.Interface(
43
- fn=classify_text,
44
- inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
45
  outputs=[
46
- gr.Label(label="Predicted Class"),
47
- gr.Label(label="Probabilities")
48
  ],
49
  title="DeepSeek-V3 Text Classification",
50
  description="Classify text using the DeepSeek-V3 model."
 
20
  print(f"Loading model from {model_name}...")
21
  model = CustomModel.from_pretrained(model_name, config=config, revision=revision, trust_remote_code=True)
22
 
23
+ # Check if the model loaded successfully
24
  if model is None:
25
  print("Failed to load model. Exiting...")
26
  exit(1)
27
  else:
28
  print("Model loaded successfully.")
29
 
30
+ # Define the text classification function
31
  def classify_text(text):
32
+ try:
33
+ # Tokenize the input text
34
+ inputs = tokenizer(text, return_tensors="pt")
35
+ # Pass the inputs to the model
36
+ outputs = model(**inputs)
37
+ # Get the logits and probabilities
38
+ logits = outputs.logits
39
+ probabilities = torch.softmax(logits, dim=-1).tolist()[0]
40
+ # Get the predicted class
41
+ predicted_class = torch.argmax(logits, dim=-1).item()
42
+ return {
43
+ "Predicted Class": predicted_class,
44
+ "Probabilities": probabilities
45
+ }
46
+ except Exception as e:
47
+ print(f"Error during text classification: {e}")
48
+ return {
49
+ "Predicted Class": "Error",
50
+ "Probabilities": []
51
+ }
52
 
53
  # Create a Gradio interface
54
  try:
55
  iface = gr.Interface(
56
+ fn=classify_text, # Function to call
57
+ inputs=gr.Textbox(lines=2, placeholder="Enter text here..."), # Input component
58
  outputs=[
59
+ gr.Label(label="Predicted Class"), # Output component for predicted class
60
+ gr.Label(label="Probabilities") # Output component for probabilities
61
  ],
62
  title="DeepSeek-V3 Text Classification",
63
  description="Classify text using the DeepSeek-V3 model."