danrdoran commited on
Commit
8f92739
1 Parent(s): 6b2a5ca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
3
+ from peft import get_peft_model, LoraConfig
4
+
5
+ # Define the same LoRA configuration used during fine-tuning
6
+ lora_config = LoraConfig(
7
+ r=8, # Low-rank parameter
8
+ lora_alpha=32, # Scaling parameter
9
+ lora_dropout=0.1, # Dropout rate
10
+ target_modules=["q", "v"], # The attention layers to apply LoRA to
11
+ bias="none"
12
+ )
13
+
14
+ # Load the model and tokenizer from Hugging Face's hub
15
+ model = get_peft_model(T5ForConditionalGeneration.from_pretrained("google/flan-t5-large"), lora_config)
16
+ tokenizer = T5Tokenizer.from_pretrained("danrdoran/flan-t5-simplified-squad")
17
+
18
+ # Streamlit app UI
19
+ st.title("AI English Tutor")
20
+ st.write("Ask me a question, and I will help you!")
21
+
22
+ # Sidebar for user to control model generation parameters
23
+ st.sidebar.title("Model Parameters")
24
+ temperature = st.sidebar.slider("Temperature", 0.1, 1.5, 1.0, 0.1) # Default 1.0
25
+ top_p = st.sidebar.slider("Top-p (Nucleus Sampling)", 0.0, 1.0, 0.9, 0.05) # Default 0.9
26
+ top_k = st.sidebar.slider("Top-k", 0, 100, 50, 1) # Default 50
27
+ # Disable sampling when using beam search
28
+ do_sample = st.sidebar.checkbox("Enable Random Sampling", value=False)
29
+
30
+ # Input field for the student
31
+ student_question = st.text_input("Ask your question!")
32
+
33
+ # Generate and display response using the model's generate() function
34
+ if student_question:
35
+ # Prepare the input for the model
36
+ input_text = f"You are a tutor. Explain the answer to this question to a young student: '{student_question}'"
37
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=256) # Reduced max_length to 256
38
+
39
+ # Generate response
40
+ generated_ids = model.generate(
41
+ inputs['input_ids'],
42
+ #max_length=75,
43
+ #min_length=20,
44
+ temperature=temperature,
45
+ top_p=top_p,
46
+ top_k=top_k,
47
+ do_sample=True, # Disable sampling, using beam search
48
+ #num_beams=2, # Use beam search
49
+ no_repeat_ngram_size=3, # Prevent repeating phrases of 3 words or more
50
+ length_penalty=1.0, # Discourage overly long responses
51
+ early_stopping=False # Stops when it finds a sufficiently good output
52
+ )
53
+
54
+ # Decode the generated response
55
+ response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
56
+
57
+ st.write("Tutor's Answer:", response)