cutechicken commited on
Commit
d18a77d
Β·
verified Β·
1 Parent(s): 1c47184

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -12
app.py CHANGED
@@ -23,14 +23,22 @@ class ModelManager:
23
  def setup_model(self):
24
  try:
25
  print("ν† ν¬λ‚˜μ΄μ € λ‘œλ”© μ‹œμž‘...")
26
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
 
 
27
  print("ν† ν¬λ‚˜μ΄μ € λ‘œλ”© μ™„λ£Œ")
28
 
29
  print("λͺ¨λΈ λ‘œλ”© μ‹œμž‘...")
30
  self.model = AutoModelForCausalLM.from_pretrained(
31
  MODEL_ID,
 
32
  torch_dtype=torch.bfloat16,
33
- device_map="auto"
 
34
  )
35
  print("λͺ¨λΈ λ‘œλ”© μ™„λ£Œ")
36
  except Exception as e:
@@ -40,17 +48,29 @@ class ModelManager:
40
  @spaces.GPU
41
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
42
  try:
43
- conversation = []
 
44
  for msg in messages:
45
- conversation.append({"role": msg["role"], "content": msg["content"]})
46
-
47
- input_ids = self.tokenizer.apply_chat_template(
48
- conversation,
49
- tokenize=False,
50
- add_generation_prompt=True
51
- )
52
- inputs = self.tokenizer(input_ids, return_tensors="pt").to(0)
 
 
 
 
 
 
 
 
 
 
53
 
 
54
  streamer = TextIteratorStreamer(
55
  self.tokenizer,
56
  timeout=10.,
@@ -58,6 +78,7 @@ class ModelManager:
58
  skip_special_tokens=True
59
  )
60
 
 
61
  generate_kwargs = dict(
62
  **inputs,
63
  streamer=streamer,
@@ -65,12 +86,15 @@ class ModelManager:
65
  do_sample=True,
66
  temperature=temperature,
67
  top_p=top_p,
68
- eos_token_id=[255001]
 
69
  )
70
 
 
71
  thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
72
  thread.start()
73
 
 
74
  buffer = ""
75
  for new_text in streamer:
76
  buffer += new_text
@@ -81,6 +105,7 @@ class ModelManager:
81
  })()
82
 
83
  except Exception as e:
 
84
  raise Exception(f"응닡 생성 μ‹€νŒ¨: {e}")
85
 
86
  class ChatHistory:
 
23
  def setup_model(self):
24
  try:
25
  print("ν† ν¬λ‚˜μ΄μ € λ‘œλ”© μ‹œμž‘...")
26
+ self.tokenizer = AutoTokenizer.from_pretrained(
27
+ MODEL_ID,
28
+ token=HF_TOKEN,
29
+ trust_remote_code=True
30
+ )
31
+ if self.tokenizer.pad_token is None:
32
+ self.tokenizer.pad_token = self.tokenizer.eos_token
33
  print("ν† ν¬λ‚˜μ΄μ € λ‘œλ”© μ™„λ£Œ")
34
 
35
  print("λͺ¨λΈ λ‘œλ”© μ‹œμž‘...")
36
  self.model = AutoModelForCausalLM.from_pretrained(
37
  MODEL_ID,
38
+ token=HF_TOKEN,
39
  torch_dtype=torch.bfloat16,
40
+ device_map="auto",
41
+ trust_remote_code=True
42
  )
43
  print("λͺ¨λΈ λ‘œλ”© μ™„λ£Œ")
44
  except Exception as e:
 
48
  @spaces.GPU
49
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
50
  try:
51
+ # λ©”μ‹œμ§€ ν¬λ§·νŒ…
52
+ formatted_messages = []
53
  for msg in messages:
54
+ if msg["role"] == "system":
55
+ formatted_messages.append(f"System: {msg['content']}\n")
56
+ elif msg["role"] == "user":
57
+ formatted_messages.append(f"User: {msg['content']}\n")
58
+ elif msg["role"] == "assistant":
59
+ formatted_messages.append(f"Assistant: {msg['content']}\n")
60
+
61
+ # μž…λ ₯ ν…μŠ€νŠΈ 생성
62
+ prompt = "".join(formatted_messages)
63
+
64
+ # ν† ν¬λ‚˜μ΄μ§•
65
+ inputs = self.tokenizer(
66
+ prompt,
67
+ return_tensors="pt",
68
+ padding=True,
69
+ truncation=True,
70
+ max_length=4096
71
+ ).to(self.model.device)
72
 
73
+ # 슀트리머 μ„€μ •
74
  streamer = TextIteratorStreamer(
75
  self.tokenizer,
76
  timeout=10.,
 
78
  skip_special_tokens=True
79
  )
80
 
81
+ # 생성 μ„€μ •
82
  generate_kwargs = dict(
83
  **inputs,
84
  streamer=streamer,
 
86
  do_sample=True,
87
  temperature=temperature,
88
  top_p=top_p,
89
+ pad_token_id=self.tokenizer.pad_token_id,
90
+ eos_token_id=self.tokenizer.eos_token_id
91
  )
92
 
93
+ # 비동기 생성
94
  thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
95
  thread.start()
96
 
97
+ # 응닡 슀트리밍
98
  buffer = ""
99
  for new_text in streamer:
100
  buffer += new_text
 
105
  })()
106
 
107
  except Exception as e:
108
+ print(f"응닡 생성 쀑 였λ₯˜ λ°œμƒ: {e}")
109
  raise Exception(f"응닡 생성 μ‹€νŒ¨: {e}")
110
 
111
  class ChatHistory: