cutechicken commited on
Commit
08267c4
โ€ข
1 Parent(s): 2db4e16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -20,15 +20,23 @@ class ModelManager:
20
  def setup_model(self):
21
  try:
22
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์‹œ์ž‘...")
23
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
 
 
 
 
24
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์™„๋ฃŒ")
25
 
26
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
 
27
  self.model = AutoModelForCausalLM.from_pretrained(
28
  MODEL_ID,
29
  token=HF_TOKEN,
30
  torch_dtype=torch.float16,
31
- device_map="auto"
 
 
 
32
  )
33
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
34
  except Exception as e:
@@ -37,7 +45,6 @@ class ModelManager:
37
 
38
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
39
  try:
40
- # ์ฑ„ํŒ… ํ…œํ”Œ๋ฆฟ ์ ์šฉ
41
  input_ids = self.tokenizer.apply_chat_template(
42
  messages,
43
  tokenize=True,
@@ -45,7 +52,7 @@ class ModelManager:
45
  return_tensors="pt"
46
  ).to(self.model.device)
47
 
48
- # ํ† ํฐ ์ƒ์„ฑ
49
  gen_tokens = self.model.generate(
50
  input_ids,
51
  max_new_tokens=max_tokens,
@@ -53,16 +60,18 @@ class ModelManager:
53
  temperature=temperature,
54
  top_p=top_p,
55
  pad_token_id=self.tokenizer.eos_token_id,
56
- streamer=TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
 
57
  )
58
 
59
- # ์‘๋‹ต ๋””์ฝ”๋”ฉ ๋ฐ ์ŠคํŠธ๋ฆฌ๋ฐ
60
- response_text = ""
61
- for new_text in self.tokenizer.decode(gen_tokens[0], skip_special_tokens=True):
62
- response_text += new_text
 
63
  yield type('Response', (), {
64
  'choices': [type('Choice', (), {
65
- 'delta': {'content': new_text}
66
  })()]
67
  })()
68
 
 
20
  def setup_model(self):
21
  try:
22
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์‹œ์ž‘...")
23
+ self.tokenizer = AutoTokenizer.from_pretrained(
24
+ MODEL_ID,
25
+ token=HF_TOKEN,
26
+ use_fast=True
27
+ )
28
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์™„๋ฃŒ")
29
 
30
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
31
+ # ZERO GPU ์„ค์ •
32
  self.model = AutoModelForCausalLM.from_pretrained(
33
  MODEL_ID,
34
  token=HF_TOKEN,
35
  torch_dtype=torch.float16,
36
+ device_map="balanced", # ZERO GPU๋ฅผ ์œ„ํ•œ balanced ์„ค์ •
37
+ max_memory={0: "8GiB"}, # ZERO GPU ๋ฉ”๋ชจ๋ฆฌ ์ œํ•œ
38
+ offload_folder="offload", # ์˜คํ”„๋กœ๋“œ ์„ค์ •
39
+ low_cpu_mem_usage=True
40
  )
41
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
42
  except Exception as e:
 
45
 
46
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
47
  try:
 
48
  input_ids = self.tokenizer.apply_chat_template(
49
  messages,
50
  tokenize=True,
 
52
  return_tensors="pt"
53
  ).to(self.model.device)
54
 
55
+ # ZERO GPU์— ์ตœ์ ํ™”๋œ ์ƒ์„ฑ ์„ค์ •
56
  gen_tokens = self.model.generate(
57
  input_ids,
58
  max_new_tokens=max_tokens,
 
60
  temperature=temperature,
61
  top_p=top_p,
62
  pad_token_id=self.tokenizer.eos_token_id,
63
+ use_cache=True, # ์บ์‹œ ์‚ฌ์šฉ์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจํ™”
64
+ num_beams=1 # ๋น” ์„œ์น˜ ๋น„ํ™œ์„ฑํ™”๋กœ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
65
  )
66
 
67
+ response_text = self.tokenizer.decode(gen_tokens[0][input_ids.shape[1]:], skip_special_tokens=True)
68
+
69
+ # ๋‹จ์–ด ๋‹จ์œ„ ์ŠคํŠธ๋ฆฌ๋ฐ
70
+ words = response_text.split()
71
+ for word in words:
72
  yield type('Response', (), {
73
  'choices': [type('Choice', (), {
74
+ 'delta': {'content': word + " "}
75
  })()]
76
  })()
77