future-xy commited on
Commit
794e78c
1 Parent(s): 3237d78

fix cuda mismatch bugs

Browse files
Files changed (1) hide show
  1. src/backend/moe_infinity.py +4 -1
src/backend/moe_infinity.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import os
3
  from transformers import AutoTokenizer
4
  import transformers
 
5
  from moe_infinity import MoE
6
  from typing import List, Tuple, Optional, Union
7
 
@@ -26,7 +27,9 @@ class MoEHFLM(HFLM):
26
  self.offload_path = offload_path
27
  self.device_memory_ratio = device_memory_ratio
28
  self.use_chat_template = use_chat_template
29
- super().__init__(*args, **kwargs, pretrained=pretrained) # Assuming HFLM accepts a 'pretrained' arg and handles it
 
 
30
  # self._create_model()
31
 
32
  def _create_model(self, *args, **kwargs):
 
2
  import os
3
  from transformers import AutoTokenizer
4
  import transformers
5
+ from transformers import AutoModelForCausalLM
6
  from moe_infinity import MoE
7
  from typing import List, Tuple, Optional, Union
8
 
 
27
  self.offload_path = offload_path
28
  self.device_memory_ratio = device_memory_ratio
29
  self.use_chat_template = use_chat_template
30
+ if "device" in kwargs:
31
+ kwargs.pop("device")
32
+ super().__init__(*args, **kwargs, pretrained=pretrained, device="cuda:0") # Assuming HFLM accepts a 'pretrained' arg and handles it
33
  # self._create_model()
34
 
35
  def _create_model(self, *args, **kwargs):