adhisetiawan commited on
Commit
8ce4fda
β€’
1 Parent(s): bb31f47

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install Gradio for creating an interface
2
+ !pip install -q gradio
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from transformers import AutoPeftModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
7
+ from threading import Thread
8
+
9
+ # Load the fine-tuned model and tokenizer
10
+ new_model = "Ronal999/phi2_DPO"
11
+ model = AutoPeftModelForCausalLM.from_pretrained(new_model,
12
+ low_cpu_mem_usage=True,
13
+ torch_dtype=torch.float16,
14
+ load_in_4bit=True,)
15
+ tokenizer = AutoTokenizer.from_pretrained(new_model)
16
+ model = model.to('cuda:0')
17
+
18
+ # Define stopping criteria
19
+ class StopOnTokens(StoppingCriteria):
20
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
21
+ stop_ids = [29, 0] # Token IDs to stop the generation
22
+ for stop_id in stop_ids:
23
+ if input_ids[0][-1] == stop_id:
24
+ return True
25
+ return False
26
+
27
+ # Define the prediction function
28
+ def predict(message, history):
29
+ # Transform history into the required format
30
+ history_transformer_format = history + [[message, ""]]
31
+ stop = StopOnTokens()
32
+
33
+ # Format messages for the model
34
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) for item in history_transformer_format])
35
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
36
+
37
+ # Set up the streamer and generate responses
38
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
39
+ generate_kwargs = dict(
40
+ model_inputs,
41
+ streamer=streamer,
42
+ max_new_tokens=1024,
43
+ do_sample=True,
44
+ top_p=0.95,
45
+ top_k=1000,
46
+ temperature=1.0,
47
+ num_beams=1,
48
+ stopping_criteria=StoppingCriteriaList([stop])
49
+ )
50
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
51
+ t.start()
52
+
53
+ # Yield partial messages as they are generated
54
+ partial_message = ""
55
+ for new_token in streamer:
56
+ if new_token != '<':
57
+ partial_message += new_token
58
+ yield partial_message
59
+
60
+ # Launch Gradio Chat Interface
61
+ gr.ChatInterface(predict).queue().launch(debug=True)