yuhuili commited on
Commit
ad6ce34
1 Parent(s): 687d97d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -4,7 +4,7 @@ import time
4
  #os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
5
  import gradio as gr
6
  import argparse
7
- from model.ex_model import ExModel
8
  import torch
9
  from fastchat.model import get_conversation_template
10
  import re
@@ -76,7 +76,7 @@ def warmup(model):
76
  prompt += " "
77
  input_ids = model.tokenizer([prompt]).input_ids
78
  input_ids = torch.as_tensor(input_ids).cuda()
79
- for output_ids in model.ex_generate(input_ids):
80
  ol=output_ids.shape[1]
81
 
82
  def bot(history, session_state):
@@ -113,7 +113,7 @@ def bot(history, session_state):
113
  total_ids=0
114
 
115
 
116
- for output_ids in model.ex_generate(input_ids, temperature=temperature, top_p=top_p,
117
  max_steps=args.max_new_token):
118
  totaltime+=(time.time()-start_time)
119
  total_ids+=1
@@ -185,7 +185,7 @@ def clear(history,session_state):
185
 
186
  parser = argparse.ArgumentParser()
187
  parser.add_argument(
188
- "--ex-model-path",
189
  type=str,
190
  default="lmsys/vicuna-7b-v1.3",
191
  help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
@@ -207,9 +207,9 @@ parser.add_argument(
207
  )
208
  args = parser.parse_args()
209
 
210
- model = ExModel.from_pretrained(
211
  base_model_path=args.base_model_path,
212
- ex_model_path=args.ex_model_path,
213
  torch_dtype=torch.float16,
214
  low_cpu_mem_usage=True,
215
  load_in_4bit=args.load_in_4bit,
 
4
  #os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
5
  import gradio as gr
6
  import argparse
7
+ from model.ea_model import EaModel
8
  import torch
9
  from fastchat.model import get_conversation_template
10
  import re
 
76
  prompt += " "
77
  input_ids = model.tokenizer([prompt]).input_ids
78
  input_ids = torch.as_tensor(input_ids).cuda()
79
+ for output_ids in model.ea_generate(input_ids):
80
  ol=output_ids.shape[1]
81
 
82
  def bot(history, session_state):
 
113
  total_ids=0
114
 
115
 
116
+ for output_ids in model.ea_generate(input_ids, temperature=temperature, top_p=top_p,
117
  max_steps=args.max_new_token):
118
  totaltime+=(time.time()-start_time)
119
  total_ids+=1
 
185
 
186
  parser = argparse.ArgumentParser()
187
  parser.add_argument(
188
+ "--ea-model-path",
189
  type=str,
190
  default="lmsys/vicuna-7b-v1.3",
191
  help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
 
207
  )
208
  args = parser.parse_args()
209
 
210
+ model = EaModel.from_pretrained(
211
  base_model_path=args.base_model_path,
212
+ ea_model_path=args.ea_model_path,
213
  torch_dtype=torch.float16,
214
  low_cpu_mem_usage=True,
215
  load_in_4bit=args.load_in_4bit,