cutechicken commited on
Commit
79dc437
โ€ข
1 Parent(s): 5716c43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -25
app.py CHANGED
@@ -17,13 +17,11 @@ class ModelManager:
17
  def __init__(self):
18
  self.tokenizer = None
19
  self.model = None
20
- # ์ดˆ๊ธฐํ™”๋Š” ์ฒซ ์š”์ฒญ ์‹œ์— ์ˆ˜ํ–‰
21
 
22
  def ensure_model_loaded(self):
23
  if self.model is None or self.tokenizer is None:
24
  self.setup_model()
25
 
26
- @spaces.GPU
27
  def setup_model(self):
28
  try:
29
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์‹œ์ž‘...")
@@ -41,10 +39,11 @@ class ModelManager:
41
  self.model = AutoModelForCausalLM.from_pretrained(
42
  MODEL_ID,
43
  token=HF_TOKEN,
44
- torch_dtype=torch.bfloat16,
45
  device_map="auto",
46
  trust_remote_code=True,
47
- low_cpu_mem_usage=True
 
48
  )
49
  self.model.eval()
50
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
@@ -54,32 +53,14 @@ class ModelManager:
54
  raise Exception(f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
55
 
56
  @spaces.GPU
57
- def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
58
  try:
59
- # ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธ
60
- self.ensure_model_loaded()
61
-
62
- # ์ž…๋ ฅ ํ…์ŠคํŠธ ์ค€๋น„
63
- prompt = ""
64
- for msg in messages:
65
- role = msg["role"]
66
- content = msg["content"]
67
- if role == "system":
68
- prompt += f"System: {content}\n"
69
- elif role == "user":
70
- prompt += f"Human: {content}\n"
71
- elif role == "assistant":
72
- prompt += f"Assistant: {content}\n"
73
- prompt += "Assistant: "
74
-
75
- # ์ž…๋ ฅ ์ธ์ฝ”๋”ฉ
76
  input_ids = self.tokenizer.encode(
77
  prompt,
78
  return_tensors="pt",
79
  add_special_tokens=True
80
  ).to(self.model.device)
81
 
82
- # ์‘๋‹ต ์ƒ์„ฑ
83
  with torch.no_grad():
84
  output_ids = self.model.generate(
85
  input_ids,
@@ -92,11 +73,30 @@ class ModelManager:
92
  num_return_sequences=1
93
  )
94
 
95
- # ์‘๋‹ต ๋””์ฝ”๋”ฉ
96
- generated_text = self.tokenizer.decode(
97
  output_ids[0][input_ids.shape[1]:],
98
  skip_special_tokens=True
99
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # ๋‹จ์–ด ๋‹จ์œ„๋กœ ์ŠคํŠธ๋ฆฌ๋ฐ
102
  words = generated_text.split()
 
17
  def __init__(self):
18
  self.tokenizer = None
19
  self.model = None
 
20
 
21
  def ensure_model_loaded(self):
22
  if self.model is None or self.tokenizer is None:
23
  self.setup_model()
24
 
 
25
  def setup_model(self):
26
  try:
27
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์‹œ์ž‘...")
 
39
  self.model = AutoModelForCausalLM.from_pretrained(
40
  MODEL_ID,
41
  token=HF_TOKEN,
42
+ torch_dtype=torch.float16,
43
  device_map="auto",
44
  trust_remote_code=True,
45
+ low_cpu_mem_usage=True,
46
+ max_memory={0: "13GB"} # GPU ๋ฉ”๋ชจ๋ฆฌ ์ œํ•œ
47
  )
48
  self.model.eval()
49
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
 
53
  raise Exception(f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
54
 
55
  @spaces.GPU
56
+ def generate_text(self, prompt, max_tokens, temperature, top_p):
57
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  input_ids = self.tokenizer.encode(
59
  prompt,
60
  return_tensors="pt",
61
  add_special_tokens=True
62
  ).to(self.model.device)
63
 
 
64
  with torch.no_grad():
65
  output_ids = self.model.generate(
66
  input_ids,
 
73
  num_return_sequences=1
74
  )
75
 
76
+ return self.tokenizer.decode(
 
77
  output_ids[0][input_ids.shape[1]:],
78
  skip_special_tokens=True
79
  )
80
+ except Exception as e:
81
+ raise Exception(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์‹คํŒจ: {e}")
82
+
83
+ def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
84
+ try:
85
+ # ์ž…๋ ฅ ํ…์ŠคํŠธ ์ค€๋น„
86
+ prompt = ""
87
+ for msg in messages:
88
+ role = msg["role"]
89
+ content = msg["content"]
90
+ if role == "system":
91
+ prompt += f"System: {content}\n"
92
+ elif role == "user":
93
+ prompt += f"Human: {content}\n"
94
+ elif role == "assistant":
95
+ prompt += f"Assistant: {content}\n"
96
+ prompt += "Assistant: "
97
+
98
+ # spaces.GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ๊ฐ€ ์ ์šฉ๋œ ๋ฉ”์„œ๋“œ ํ˜ธ์ถœ
99
+ generated_text = self.generate_text(prompt, max_tokens, temperature, top_p)
100
 
101
  # ๋‹จ์–ด ๋‹จ์œ„๋กœ ์ŠคํŠธ๋ฆฌ๋ฐ
102
  words = generated_text.split()