Update llm/huggingfacehub/hf_model.py
Browse files
llm/huggingfacehub/hf_model.py
CHANGED
@@ -33,10 +33,10 @@ logger.info("Getting information from hf_model module")
|
|
33 |
|
34 |
# work_dir = os.getcwd()
|
35 |
|
36 |
-
print("CWD : ", os.getcwd())
|
37 |
-
|
38 |
llm_dir = '/home/user/app/llm/'
|
39 |
|
|
|
|
|
40 |
|
41 |
class HF_Mistaril(HFInterface, ABC):
|
42 |
def __init__(self, prompt_entity: str, prompt_id: int = 0):
|
@@ -51,7 +51,8 @@ class HF_Mistaril(HFInterface, ABC):
|
|
51 |
temperature=self.model_config["temperature"],
|
52 |
max_new_tokens=self.model_config["max_new_tokens"],
|
53 |
top_k=self.model_config["top_k"],
|
54 |
-
model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"]}
|
|
|
55 |
)
|
56 |
|
57 |
@staticmethod
|
@@ -99,7 +100,8 @@ class HF_TinyLlama(HFInterface, ABC):
|
|
99 |
temperature=self.model_config["temperature"],
|
100 |
max_new_tokens=self.model_config["max_new_tokens"],
|
101 |
top_k=self.model_config["top_k"],
|
102 |
-
model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"]}
|
|
|
103 |
)
|
104 |
|
105 |
@staticmethod
|
|
|
33 |
|
34 |
# work_dir = os.getcwd()
|
35 |
|
|
|
|
|
36 |
llm_dir = '/home/user/app/llm/'
|
37 |
|
38 |
+
print("Path to prompts : ", os.path.join(os.getcwd(), "prompts.yaml"))
|
39 |
+
|
40 |
|
41 |
class HF_Mistaril(HFInterface, ABC):
|
42 |
def __init__(self, prompt_entity: str, prompt_id: int = 0):
|
|
|
51 |
temperature=self.model_config["temperature"],
|
52 |
max_new_tokens=self.model_config["max_new_tokens"],
|
53 |
top_k=self.model_config["top_k"],
|
54 |
+
model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"]},
|
55 |
+
huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
56 |
)
|
57 |
|
58 |
@staticmethod
|
|
|
100 |
temperature=self.model_config["temperature"],
|
101 |
max_new_tokens=self.model_config["max_new_tokens"],
|
102 |
top_k=self.model_config["top_k"],
|
103 |
+
model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"]},
|
104 |
+
huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
105 |
)
|
106 |
|
107 |
@staticmethod
|