CineAI commited on
Commit
c4f6685
·
verified ·
1 Parent(s): 179b941

Update llm/huggingfacehub/hf_model.py

Browse files
Files changed (1) hide show
  1. llm/huggingfacehub/hf_model.py +24 -15
llm/huggingfacehub/hf_model.py CHANGED
@@ -1,6 +1,6 @@
1
- import logging
2
  import os
3
  import yaml
 
4
 
5
  from abc import ABC
6
 
@@ -26,8 +26,12 @@ logger.addHandler(file_handler)
26
 
27
  logger.info("Getting information from hf_model module")
28
 
29
- path_to_prompts = os.path.join(os.getcwd(), "/llm/")
30
- print(path_to_prompts)
 
 
 
 
31
 
32
 
33
  class HF_Mistaril(HFInterface, ABC):
@@ -37,18 +41,21 @@ class HF_Mistaril(HFInterface, ABC):
37
 
38
  self.model_config = config["HF_Mistrail"]
39
 
 
40
  self.llm = HuggingFaceEndpoint(
41
  repo_id=self.model_config["model"],
42
- model_kwargs={"temperature": self.model_config["temperature"],
43
- "max_new_tokens": self.model_config["max_new_tokens"],
44
- "top_k": self.model_config["top_k"], "load_in_8bit": self.model_config["load_in_8bit"]})
 
45
 
46
  @staticmethod
47
  def __read_yaml():
48
  try:
49
- yaml_file = os.path.join(path_to_prompts, 'prompts.yaml')
50
- with open(yaml_file, 'r') as file:
51
- data = yaml.safe_load(file)
 
52
  return data
53
  except Exception as e:
54
  print(f"Execution filed : {e}")
@@ -84,16 +91,18 @@ class HF_TinyLlama(HFInterface, ABC):
84
 
85
  self.llm = HuggingFaceEndpoint(
86
  repo_id=self.model_config["model"],
87
- model_kwargs={"temperature": self.model_config["temperature"],
88
- "max_new_tokens": self.model_config["max_new_tokens"],
89
- "top_k": self.model_config["top_k"], "load_in_8bit": self.model_config["load_in_8bit"]})
 
90
 
91
  @staticmethod
92
  def __read_yaml():
93
  try:
94
- yaml_file = os.path.join(path_to_prompts, 'prompts.yaml')
95
- with open(yaml_file, 'r') as file:
96
- data = yaml.safe_load(file)
 
97
  return data
98
  except Exception as e:
99
  print(f"Execution filed : {e}")
 
 
1
  import os
2
  import yaml
3
+ import logging
4
 
5
  from abc import ABC
6
 
 
26
 
27
  logger.info("Getting information from hf_model module")
28
 
29
+ try:
30
+ os.chdir('/home/user/app/llm/')
31
+ except FileNotFoundError:
32
+ print("Error: Could not move up. You might be at the root directory.")
33
+
34
+ work_dir = os.getcwd()
35
 
36
 
37
  class HF_Mistaril(HFInterface, ABC):
 
41
 
42
  self.model_config = config["HF_Mistrail"]
43
 
44
+ # Додати repetition_penalty, task?, top_p, stop_sequences
45
  self.llm = HuggingFaceEndpoint(
46
  repo_id=self.model_config["model"],
47
+ temperature=self.model_config["temperature"],
48
+ max_new_tokens=self.model_config["max_new_tokens"],
49
+ top_k=self.model_config["top_k"],
50
+ model_kwargs=({"load_in_8bit": self.model_config["load_in_8bit"]})
51
 
52
  @staticmethod
53
  def __read_yaml():
54
  try:
55
+ yaml_file = os.path.join(work_dir, 'prompts.yaml')
56
+ with open(yaml_file, 'r') as f:
57
+ data = yaml.safe_load(f)
58
+ f.close()
59
  return data
60
  except Exception as e:
61
  print(f"Execution filed : {e}")
 
91
 
92
  self.llm = HuggingFaceEndpoint(
93
  repo_id=self.model_config["model"],
94
+ temperature=self.model_config["temperature"],
95
+ max_new_tokens=self.model_config["max_new_tokens"],
96
+ top_k=self.model_config["top_k"],
97
+ model_kwargs=({"load_in_8bit": self.model_config["load_in_8bit"]})
98
 
99
  @staticmethod
100
  def __read_yaml():
101
  try:
102
+ yaml_file = os.path.join(work_dir, 'prompts.yaml')
103
+ with open(yaml_file, 'r') as f:
104
+ data = yaml.safe_load(f)
105
+ f.close()
106
  return data
107
  except Exception as e:
108
  print(f"Execution filed : {e}")