Tao Wu commited on
Commit
93a26b1
1 Parent(s): 9b1fd32
Files changed (1) hide show
  1. app/embedding_setup.py +12 -55
app/embedding_setup.py CHANGED
@@ -32,12 +32,13 @@ embedding_sim = HuggingFaceBgeEmbeddings(
32
  db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embedding_int)
33
  retriever = db.as_retriever(search_kwargs={"k": TOP_K})
34
 
35
-
36
 
37
  lora_weights_rec = REC_LORA_MODEL
38
  lora_weights_exp = EXP_LORA_MODEL
39
  hf_auth = os.environ.get("hf_token")
40
 
 
41
  tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL, token=hf_auth)
42
 
43
 
@@ -46,28 +47,6 @@ second_token = 'Second'
46
  # 获取token的ID
47
  first_id = tokenizer.convert_tokens_to_ids(first_token)
48
  second_id = tokenizer.convert_tokens_to_ids(second_token)
49
- model = AutoModelForCausalLM.from_pretrained(
50
- LLM_MODEL,
51
- load_in_4bit=True,
52
- torch_dtype=torch.float16,
53
- token=hf_auth,
54
- )
55
-
56
- rec_adapter = PeftModel.from_pretrained(
57
- model,
58
- lora_weights_rec
59
- )
60
-
61
-
62
- tokenizer.padding_side = "left"
63
- # unwind broken decapoda-research config
64
- #model.half() # seems to fix bugs for some users.
65
- rec_adapter.eval()
66
-
67
- rec_adapter.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
68
- rec_adapter.config.bos_token_id = 1
69
- rec_adapter.config.eos_token_id = 2
70
-
71
 
72
 
73
  def generate_prompt(target_occupation, skill_gap, courses):
@@ -100,32 +79,9 @@ def evaluate(
100
  **kwargs,
101
  ):
102
 
103
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)
104
- generation_config = GenerationConfig(
105
- temperature=temperature,
106
- top_p=top_p,
107
- top_k=top_k,
108
- num_beams=num_beams,
109
- **kwargs,
110
- )
111
- with torch.no_grad():
112
- rec_adapter.to(device)
113
- generation_output = rec_adapter.generate(
114
- **inputs,
115
- generation_config=generation_config,
116
- return_dict_in_generate=True,
117
- output_scores=True,
118
- max_new_tokens=max_new_tokens,
119
- # batch_size=batch_size,
120
- eos_token_id=tokenizer.eos_token_id,
121
- pad_token_id=tokenizer.eos_token_id,
122
- )
123
- scores = generation_output.scores[0].softmax(dim=-1)
124
- logits = torch.tensor(scores[:,[first_id, second_id]], dtype=torch.float32).softmax(dim=-1)
125
- s = generation_output.sequences
126
- output = tokenizer.batch_decode(s, skip_special_tokens=True)
127
- output = [_.split('Response:\n')[-1] for _ in output]
128
- return output, logits.tolist()
129
 
130
  def compare_docs_with_context(doc_a, doc_b, target_occupation_name, target_occupation_dsp,skill_gap):
131
 
@@ -134,13 +90,14 @@ def compare_docs_with_context(doc_a, doc_b, target_occupation_name, target_occup
134
  target_occupation = f"name: {target_occupation_name} description: {target_occupation_dsp[:1500]}"
135
  skill_gap = skill_gap
136
  prompt = generate_prompt(target_occupation, skill_gap, courses)
137
- prompt = [prompt]
138
- output, logit = evaluate(prompt)
139
  # Compare based on the response: [A] means doc_a > doc_b, [B] means doc_a < doc_b
140
- print(output, logit)
141
- if logit[0][0] > logit[0][1]:
 
142
  return 1 # doc_a should come before doc_b
143
- elif logit[0][0] < logit[0][1]:
144
  return -1 # doc_a should come after doc_b
145
  else:
146
  return 0 # Consider them equal if the response is unclear
@@ -148,7 +105,7 @@ def compare_docs_with_context(doc_a, doc_b, target_occupation_name, target_occup
148
 
149
  #-----------------------------------------explanation-------------------------------------
150
 
151
- lorax_client = pb.deployments.client("llama-3-8b-instruct") # Insert deployment name here
152
  def generate_prompt_exp(input_text):
153
  return f"""
154
  ### Instruction:
 
32
  db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embedding_int)
33
  retriever = db.as_retriever(search_kwargs={"k": TOP_K})
34
 
35
+ lorax_client = pb.deployments.client("llama-3-8b-instruct") # Insert deployment name here
36
 
37
  lora_weights_rec = REC_LORA_MODEL
38
  lora_weights_exp = EXP_LORA_MODEL
39
  hf_auth = os.environ.get("hf_token")
40
 
41
+
42
  tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL, token=hf_auth)
43
 
44
 
 
47
  # 获取token的ID
48
  first_id = tokenizer.convert_tokens_to_ids(first_token)
49
  second_id = tokenizer.convert_tokens_to_ids(second_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  def generate_prompt(target_occupation, skill_gap, courses):
 
79
  **kwargs,
80
  ):
81
 
82
+ resp = lorax_client.generate(prompt,adapter_id=REC_LORA_MODEL, adapter_source='hub', max_new_tokens=max_new_tokens)
83
+
84
+ return resp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  def compare_docs_with_context(doc_a, doc_b, target_occupation_name, target_occupation_dsp,skill_gap):
87
 
 
90
  target_occupation = f"name: {target_occupation_name} description: {target_occupation_dsp[:1500]}"
91
  skill_gap = skill_gap
92
  prompt = generate_prompt(target_occupation, skill_gap, courses)
93
+ prompt = prompt
94
+ output = evaluate(prompt)
95
  # Compare based on the response: [A] means doc_a > doc_b, [B] means doc_a < doc_b
96
+ print(output)
97
+ result_token_id = output.details.token[0].id
98
+ if result_token_id == first_id:
99
  return 1 # doc_a should come before doc_b
100
+ elif result_token_id == second_id:
101
  return -1 # doc_a should come after doc_b
102
  else:
103
  return 0 # Consider them equal if the response is unclear
 
105
 
106
  #-----------------------------------------explanation-------------------------------------
107
 
108
+
109
  def generate_prompt_exp(input_text):
110
  return f"""
111
  ### Instruction: