CineAI commited on
Commit
2f68799
·
verified ·
1 Parent(s): 58abf09

Update llm/huggingfacehub/hf_model.py

Browse files
Files changed (1) hide show
  1. llm/huggingfacehub/hf_model.py +6 -4
llm/huggingfacehub/hf_model.py CHANGED
@@ -33,10 +33,10 @@ logger.info("Getting information from hf_model module")
33
 
34
  # work_dir = os.getcwd()
35
 
36
- print("CWD : ", os.getcwd())
37
-
38
  llm_dir = '/home/user/app/llm/'
39
 
 
 
40
 
41
  class HF_Mistaril(HFInterface, ABC):
42
  def __init__(self, prompt_entity: str, prompt_id: int = 0):
@@ -51,7 +51,8 @@ class HF_Mistaril(HFInterface, ABC):
51
  temperature=self.model_config["temperature"],
52
  max_new_tokens=self.model_config["max_new_tokens"],
53
  top_k=self.model_config["top_k"],
54
- model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"]}
 
55
  )
56
 
57
  @staticmethod
@@ -99,7 +100,8 @@ class HF_TinyLlama(HFInterface, ABC):
99
  temperature=self.model_config["temperature"],
100
  max_new_tokens=self.model_config["max_new_tokens"],
101
  top_k=self.model_config["top_k"],
102
- model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"]}
 
103
  )
104
 
105
  @staticmethod
 
33
 
34
  # work_dir = os.getcwd()
35
 
 
 
36
  llm_dir = '/home/user/app/llm/'
37
 
38
+ print("Path to prompts : ", os.path.join(os.getcwd(), "prompts.yaml"))
39
+
40
 
41
  class HF_Mistaril(HFInterface, ABC):
42
  def __init__(self, prompt_entity: str, prompt_id: int = 0):
 
51
  temperature=self.model_config["temperature"],
52
  max_new_tokens=self.model_config["max_new_tokens"],
53
  top_k=self.model_config["top_k"],
54
+ model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"]},
55
+ huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN")
56
  )
57
 
58
  @staticmethod
 
100
  temperature=self.model_config["temperature"],
101
  max_new_tokens=self.model_config["max_new_tokens"],
102
  top_k=self.model_config["top_k"],
103
+ model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"]},
104
+ huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN")
105
  )
106
 
107
  @staticmethod