cutechicken commited on
Commit
63b4531
โ€ข
1 Parent(s): 79dc437

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -36,16 +36,15 @@ class ModelManager:
36
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์™„๋ฃŒ")
37
 
38
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
 
39
  self.model = AutoModelForCausalLM.from_pretrained(
40
  MODEL_ID,
41
  token=HF_TOKEN,
42
  torch_dtype=torch.float16,
43
- device_map="auto",
44
  trust_remote_code=True,
45
- low_cpu_mem_usage=True,
46
- max_memory={0: "13GB"} # GPU ๋ฉ”๋ชจ๋ฆฌ ์ œํ•œ
47
  )
48
- self.model.eval()
49
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
50
 
51
  except Exception as e:
@@ -55,11 +54,13 @@ class ModelManager:
55
  @spaces.GPU
56
  def generate_text(self, prompt, max_tokens, temperature, top_p):
57
  try:
 
 
58
  input_ids = self.tokenizer.encode(
59
  prompt,
60
  return_tensors="pt",
61
  add_special_tokens=True
62
- ).to(self.model.device)
63
 
64
  with torch.no_grad():
65
  output_ids = self.model.generate(
@@ -73,11 +74,15 @@ class ModelManager:
73
  num_return_sequences=1
74
  )
75
 
 
 
76
  return self.tokenizer.decode(
77
  output_ids[0][input_ids.shape[1]:],
78
  skip_special_tokens=True
79
  )
80
  except Exception as e:
 
 
81
  raise Exception(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์‹คํŒจ: {e}")
82
 
83
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
 
36
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์™„๋ฃŒ")
37
 
38
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
39
+ # CUDA ์ดˆ๊ธฐํ™” ๋ฐฉ์ง€๋ฅผ ์œ„ํ•œ ์„ค์ •
40
  self.model = AutoModelForCausalLM.from_pretrained(
41
  MODEL_ID,
42
  token=HF_TOKEN,
43
  torch_dtype=torch.float16,
44
+ device_map=None, # ์ดˆ๊ธฐ์—๋Š” device_map์„ ์„ค์ •ํ•˜์ง€ ์•Š์Œ
45
  trust_remote_code=True,
46
+ low_cpu_mem_usage=True
 
47
  )
 
48
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
49
 
50
  except Exception as e:
 
54
  @spaces.GPU
55
  def generate_text(self, prompt, max_tokens, temperature, top_p):
56
  try:
57
+ # GPU ์ปจํ…์ŠคํŠธ ๋‚ด์—์„œ device ์„ค์ •
58
+ self.model = self.model.to("cuda")
59
  input_ids = self.tokenizer.encode(
60
  prompt,
61
  return_tensors="pt",
62
  add_special_tokens=True
63
+ ).to("cuda")
64
 
65
  with torch.no_grad():
66
  output_ids = self.model.generate(
 
74
  num_return_sequences=1
75
  )
76
 
77
+ # CPU๋กœ ๋‹ค์‹œ ์ด๋™
78
+ self.model = self.model.to("cpu")
79
  return self.tokenizer.decode(
80
  output_ids[0][input_ids.shape[1]:],
81
  skip_special_tokens=True
82
  )
83
  except Exception as e:
84
+ if self.model.device.type == "cuda":
85
+ self.model = self.model.to("cpu")
86
  raise Exception(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์‹คํŒจ: {e}")
87
 
88
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):