hahafofo commited on
Commit
1911897
1 Parent(s): 390173a
Files changed (1) hide show
  1. utils/chatglm.py +16 -5
utils/chatglm.py CHANGED
@@ -81,11 +81,22 @@ class ChatGLM(BasePredictor):
81
  trust_remote_code=True,
82
  resume_download=True
83
  )
84
- model = AutoModel.from_pretrained(
85
- model_name,
86
- trust_remote_code=True,
87
- resume_download=True
88
- ).half().to(self.device)
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  model = model.eval()
91
  self.model = model
 
81
  trust_remote_code=True,
82
  resume_download=True
83
  )
84
+ if self.device == 'cuda':
85
+ model = AutoModel.from_pretrained(
86
+ model_name,
87
+ trust_remote_code=True,
88
+ resume_download=True
89
+ ).half().to(self.device)
90
+ else:
91
+ model = AutoModel.from_pretrained(
92
+ "THUDM/chatglm-6b-int8",
93
+ trust_remote_code=True,
94
+ resume_download=True,
95
+ low_cpu_mem_usage=True,
96
+ torch_dtype=torch.float16
97
+ if self.device == 'cuda' else torch.float32,
98
+ device_map={'': self.device}
99
+ )
100
 
101
  model = model.eval()
102
  self.model = model