Spaces:
Running
Running
Upload app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,9 @@
|
|
1 |
# coding=utf-8
|
|
|
|
|
|
|
|
|
|
|
2 |
from src.logger import LoggerFactory
|
3 |
from src.prompt_concat import GetManualTestSamples, CreateTestDataset
|
4 |
from src.utils import decode_csv_to_json, load_json, save_to_json
|
@@ -23,12 +28,18 @@ import spaces
|
|
23 |
logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
|
24 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
30 |
trust_remote_code=True)
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
# logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
|
33 |
# warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
34 |
|
|
|
1 |
# coding=utf-8
|
2 |
+
from typing import Dict
|
3 |
+
from typing import List
|
4 |
+
from typing import Tuple
|
5 |
+
from typing import Union
|
6 |
+
from pathlib import Path
|
7 |
from src.logger import LoggerFactory
|
8 |
from src.prompt_concat import GetManualTestSamples, CreateTestDataset
|
9 |
from src.utils import decode_csv_to_json, load_json, save_to_json
|
|
|
28 |
logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
|
29 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
30 |
|
31 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
|
32 |
+
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
33 |
+
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
|
35 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map="auto",
|
36 |
trust_remote_code=True)
|
37 |
|
38 |
+
character_path = "./character"
|
39 |
+
|
40 |
+
def _resolve_path(path: Union[str, Path]) -> Path:
|
41 |
+
return Path(path).expanduser().resolve()
|
42 |
+
|
43 |
# logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
|
44 |
# warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
45 |
|