future-xy
commited on
Commit
•
794e78c
1
Parent(s):
3237d78
fix cuda mismatch bugs
Browse files
src/backend/moe_infinity.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
import os
|
3 |
from transformers import AutoTokenizer
|
4 |
import transformers
|
|
|
5 |
from moe_infinity import MoE
|
6 |
from typing import List, Tuple, Optional, Union
|
7 |
|
@@ -26,7 +27,9 @@ class MoEHFLM(HFLM):
|
|
26 |
self.offload_path = offload_path
|
27 |
self.device_memory_ratio = device_memory_ratio
|
28 |
self.use_chat_template = use_chat_template
|
29 |
-
|
|
|
|
|
30 |
# self._create_model()
|
31 |
|
32 |
def _create_model(self, *args, **kwargs):
|
|
|
2 |
import os
|
3 |
from transformers import AutoTokenizer
|
4 |
import transformers
|
5 |
+
from transformers import AutoModelForCausalLM
|
6 |
from moe_infinity import MoE
|
7 |
from typing import List, Tuple, Optional, Union
|
8 |
|
|
|
27 |
self.offload_path = offload_path
|
28 |
self.device_memory_ratio = device_memory_ratio
|
29 |
self.use_chat_template = use_chat_template
|
30 |
+
if "device" in kwargs:
|
31 |
+
kwargs.pop("device")
|
32 |
+
super().__init__(*args, **kwargs, pretrained=pretrained, device="cuda:0") # Assuming HFLM accepts a 'pretrained' arg and handles it
|
33 |
# self._create_model()
|
34 |
|
35 |
def _create_model(self, *args, **kwargs):
|