FridayMaster commited on
Commit
1631cdb
1 Parent(s): 947c082

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -14
app.py CHANGED
@@ -1,27 +1,21 @@
1
- import gradio as gr
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
3
  import torch
4
 
5
  # Load the model and tokenizer
6
- model_name = 'FridayMaster/fine_tune_embedding' # Replace with your model's repository name
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name) # Use the appropriate class
9
 
10
  # Define a function to generate responses
11
  def generate_response(prompt):
12
  # Tokenize the input prompt
13
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
14
  with torch.no_grad():
15
- # Get the model output
16
- outputs = model(**inputs)
17
-
18
- # Process the output logits
19
- logits = outputs.logits
20
- predicted_class_id = logits.argmax().item()
21
-
22
- # Generate a response based on the predicted class
23
- response = f"Predicted class ID: {predicted_class_id}"
24
-
25
  return response
26
 
27
  # Create a Gradio interface
 
1
+
2
+ import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
 
6
  # Load the model and tokenizer
7
+ model_name = 'FridayMaster/fine_tune_embedding'
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name) # Use the appropriate class
10
 
11
  # Define a function to generate responses
12
  def generate_response(prompt):
13
  # Tokenize the input prompt
14
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
15
  with torch.no_grad():
16
+ # Generate a response using the model
17
+ outputs = model.generate(inputs['input_ids'], max_length=150, num_return_sequences=1)
18
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
19
  return response
20
 
21
  # Create a Gradio interface