Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import time
|
|
4 |
#os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
5 |
import gradio as gr
|
6 |
import argparse
|
7 |
-
from model.
|
8 |
import torch
|
9 |
from fastchat.model import get_conversation_template
|
10 |
import re
|
@@ -76,7 +76,7 @@ def warmup(model):
|
|
76 |
prompt += " "
|
77 |
input_ids = model.tokenizer([prompt]).input_ids
|
78 |
input_ids = torch.as_tensor(input_ids).cuda()
|
79 |
-
for output_ids in model.
|
80 |
ol=output_ids.shape[1]
|
81 |
|
82 |
def bot(history, session_state):
|
@@ -113,7 +113,7 @@ def bot(history, session_state):
|
|
113 |
total_ids=0
|
114 |
|
115 |
|
116 |
-
for output_ids in model.
|
117 |
max_steps=args.max_new_token):
|
118 |
totaltime+=(time.time()-start_time)
|
119 |
total_ids+=1
|
@@ -185,7 +185,7 @@ def clear(history,session_state):
|
|
185 |
|
186 |
parser = argparse.ArgumentParser()
|
187 |
parser.add_argument(
|
188 |
-
"--
|
189 |
type=str,
|
190 |
default="lmsys/vicuna-7b-v1.3",
|
191 |
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
|
@@ -207,9 +207,9 @@ parser.add_argument(
|
|
207 |
)
|
208 |
args = parser.parse_args()
|
209 |
|
210 |
-
model =
|
211 |
base_model_path=args.base_model_path,
|
212 |
-
|
213 |
torch_dtype=torch.float16,
|
214 |
low_cpu_mem_usage=True,
|
215 |
load_in_4bit=args.load_in_4bit,
|
|
|
4 |
#os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
5 |
import gradio as gr
|
6 |
import argparse
|
7 |
+
from model.ea_model import EaModel
|
8 |
import torch
|
9 |
from fastchat.model import get_conversation_template
|
10 |
import re
|
|
|
76 |
prompt += " "
|
77 |
input_ids = model.tokenizer([prompt]).input_ids
|
78 |
input_ids = torch.as_tensor(input_ids).cuda()
|
79 |
+
for output_ids in model.ea_generate(input_ids):
|
80 |
ol=output_ids.shape[1]
|
81 |
|
82 |
def bot(history, session_state):
|
|
|
113 |
total_ids=0
|
114 |
|
115 |
|
116 |
+
for output_ids in model.ea_generate(input_ids, temperature=temperature, top_p=top_p,
|
117 |
max_steps=args.max_new_token):
|
118 |
totaltime+=(time.time()-start_time)
|
119 |
total_ids+=1
|
|
|
185 |
|
186 |
parser = argparse.ArgumentParser()
|
187 |
parser.add_argument(
|
188 |
+
"--ea-model-path",
|
189 |
type=str,
|
190 |
default="lmsys/vicuna-7b-v1.3",
|
191 |
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
|
|
|
207 |
)
|
208 |
args = parser.parse_args()
|
209 |
|
210 |
+
model = EaModel.from_pretrained(
|
211 |
base_model_path=args.base_model_path,
|
212 |
+
ea_model_path=args.ea_model_path,
|
213 |
torch_dtype=torch.float16,
|
214 |
low_cpu_mem_usage=True,
|
215 |
load_in_4bit=args.load_in_4bit,
|