vilarin commited on
Commit
063316d
·
verified ·
1 Parent(s): d875b4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -26
app.py CHANGED
@@ -8,7 +8,7 @@ from threading import Thread
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
- MODEL_LIST = "THUDM/glm-4-9b-chat, THUDM/glm-4-9b-chat-1m, THUDM/codegeex4-all-9b"
12
  #MODELS = os.environ.get("MODELS")
13
  #MODEL_NAME = MODELS.split("/")[-1]
14
 
@@ -26,7 +26,7 @@ CSS = """
26
  """
27
 
28
  model_chat = AutoModelForCausalLM.from_pretrained(
29
- "THUDM/glm-4-9b-chat",
30
  torch_dtype=torch.bfloat16,
31
  low_cpu_mem_usage=True,
32
  trust_remote_code=True,
@@ -34,17 +34,9 @@ model_chat = AutoModelForCausalLM.from_pretrained(
34
 
35
  tokenizer_chat = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat",trust_remote_code=True)
36
 
37
- model_code = AutoModelForCausalLM.from_pretrained(
38
- "THUDM/codegeex4-all-9b",
39
- torch_dtype=torch.bfloat16,
40
- low_cpu_mem_usage=True,
41
- trust_remote_code=True
42
- ).to(0).eval()
43
-
44
- tokenizer_code = AutoTokenizer.from_pretrained("THUDM/codegeex4-all-9b", trust_remote_code=True)
45
 
46
  @spaces.GPU
47
- def stream_chat(message: str, history: list, temperature: float, max_length: int, choice: str):
48
  print(f'message is - {message}')
49
  print(f'history is - {history}')
50
  conversation = []
@@ -54,12 +46,6 @@ def stream_chat(message: str, history: list, temperature: float, max_length: int
54
 
55
  print(f"Conversation is -\n{conversation}")
56
 
57
- if choice == "glm-4-9b-chat":
58
- tokenizer = tokenizer_chat
59
- model = model_chat
60
- else:
61
- model = model_code
62
- tokenizer = tokenizer_code
63
 
64
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
65
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
@@ -71,6 +57,7 @@ def stream_chat(message: str, history: list, temperature: float, max_length: int
71
  top_k=1,
72
  temperature=temperature,
73
  repetition_penalty=1.2,
 
74
  )
75
  gen_kwargs = {**input_ids, **generate_kwargs}
76
 
@@ -97,24 +84,18 @@ with gr.Blocks(css=CSS) as demo:
97
  minimum=0,
98
  maximum=1,
99
  step=0.1,
100
- value=0.8,
101
  label="Temperature",
102
  render=False,
103
  ),
104
  gr.Slider(
105
  minimum=128,
106
- maximum=8192,
107
  step=1,
108
- value=1024,
109
  label="Max Length",
110
  render=False,
111
  ),
112
- gr.Radio(
113
- ["glm-4-9b-chat", "codegeex4-all-9b"],
114
- value="glm-4-9b-chat",
115
- label="Load Model",
116
- render=False,
117
- ),
118
  ],
119
  examples=[
120
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
 
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL_LIST = "THUDM/LongWriter-glm4-9b"
12
  #MODELS = os.environ.get("MODELS")
13
  #MODEL_NAME = MODELS.split("/")[-1]
14
 
 
26
  """
27
 
28
  model_chat = AutoModelForCausalLM.from_pretrained(
29
+ "THUDM/LongWriter-glm4-9b",
30
  torch_dtype=torch.bfloat16,
31
  low_cpu_mem_usage=True,
32
  trust_remote_code=True,
 
34
 
35
  tokenizer_chat = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat",trust_remote_code=True)
36
 
 
 
 
 
 
 
 
 
37
 
38
  @spaces.GPU
39
+ def stream_chat(message: str, history: list, temperature: float, max_length: int):
40
  print(f'message is - {message}')
41
  print(f'history is - {history}')
42
  conversation = []
 
46
 
47
  print(f"Conversation is -\n{conversation}")
48
 
 
 
 
 
 
 
49
 
50
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
51
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
57
  top_k=1,
58
  temperature=temperature,
59
  repetition_penalty=1.2,
60
+ num_beams=1,
61
  )
62
  gen_kwargs = {**input_ids, **generate_kwargs}
63
 
 
84
  minimum=0,
85
  maximum=1,
86
  step=0.1,
87
+ value=0.5,
88
  label="Temperature",
89
  render=False,
90
  ),
91
  gr.Slider(
92
  minimum=128,
93
+ maximum=32768,
94
  step=1,
95
+ value=4096,
96
  label="Max Length",
97
  render=False,
98
  ),
 
 
 
 
 
 
99
  ],
100
  examples=[
101
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],