grahamwhiteuk commited on
Commit
026d799
1 Parent(s): 27e0846

fix: deployment

Browse files
Files changed (4) hide show
  1. README.md +1 -0
  2. app.py +1 -1
  3. model.py +3 -2
  4. utils.py +4 -4
README.md CHANGED
@@ -7,5 +7,6 @@ sdk: gradio
7
  sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
 
10
  short_description: demo
11
  ---
 
7
  sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  short_description: demo
12
  ---
app.py CHANGED
@@ -289,7 +289,7 @@ with gr.Blocks(
289
  with Modal(visible=False, elem_classes="modal") as modal:
290
  prompt = gr.Markdown("")
291
 
292
- ### events
293
 
294
  show_propt_button.click(
295
  on_show_prompt_click, inputs=[criteria, context, user_message, assistant_message, state], outputs=prompt
 
289
  with Modal(visible=False, elem_classes="modal") as modal:
290
  prompt = gr.Markdown("")
291
 
292
+ # events
293
 
294
  show_propt_button.click(
295
  on_show_prompt_click, inputs=[criteria, context, user_message, assistant_message, state], outputs=prompt
model.py CHANGED
@@ -6,7 +6,7 @@ import spaces
6
  import torch
7
  from ibm_watsonx_ai.client import APIClient
8
  from ibm_watsonx_ai.foundation_models import ModelInference
9
- from transformers import AutoTokenizer
10
  from vllm import LLM, SamplingParams
11
 
12
  from logger import logger
@@ -24,7 +24,8 @@ if inference_engine == "VLLM":
24
  logger.debug(f"model_path is {model_path}")
25
  tokenizer = AutoTokenizer.from_pretrained(model_path)
26
  sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
27
- model = LLM(model=model_path, tensor_parallel_size=1)
 
28
 
29
  elif inference_engine == "WATSONX":
30
  client = APIClient(
 
6
  import torch
7
  from ibm_watsonx_ai.client import APIClient
8
  from ibm_watsonx_ai.foundation_models import ModelInference
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
  from vllm import LLM, SamplingParams
11
 
12
  from logger import logger
 
24
  logger.debug(f"model_path is {model_path}")
25
  tokenizer = AutoTokenizer.from_pretrained(model_path)
26
  sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
27
+ # model = LLM(model=model_path, tensor_parallel_size=1)
28
+ model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
29
 
30
  elif inference_engine == "WATSONX":
31
  client = APIClient(
utils.py CHANGED
@@ -29,13 +29,13 @@ def get_result_description(sub_catalog_name, criteria_name):
29
  messages = {
30
  "general_harm": f"Is the {evaluated_component} message harmful by common-sense?",
31
  "social_bias": f"Is the {evaluated_component} message prejudiced based on identity or characteristics?",
32
- "jailbreak": f"Does the user message try to manipulate the AI to generate harmful, undesired, or inappropriate content?",
33
  "violence": f"Does the {evaluated_component} message contain content promoting physical, mental, or sexual harm?",
34
  "profanity": f"Does the {evaluated_component} message include offensive language or insults?",
35
  "unethical_behavior": f"Does the {evaluated_component} message include actions that violate moral or legal standards?",
36
- "answer_relevance": f"Does the assistant response fail to address or properly answer the user question?",
37
- "context_relevance": f"Is the retrieved context irrelevant to the user question or does not address their needs?",
38
- "groundedness": f"Does the assistant response include claims or facts not supported by or contradicted by the provided context?",
39
  }
40
  return messages[criteria_name]
41
 
 
29
  messages = {
30
  "general_harm": f"Is the {evaluated_component} message harmful by common-sense?",
31
  "social_bias": f"Is the {evaluated_component} message prejudiced based on identity or characteristics?",
32
+ "jailbreak": "Does the user message try to manipulate the AI to generate harmful, undesired, or inappropriate content?",
33
  "violence": f"Does the {evaluated_component} message contain content promoting physical, mental, or sexual harm?",
34
  "profanity": f"Does the {evaluated_component} message include offensive language or insults?",
35
  "unethical_behavior": f"Does the {evaluated_component} message include actions that violate moral or legal standards?",
36
+ "answer_relevance": "Does the assistant response fail to address or properly answer the user question?",
37
+ "context_relevance": "Is the retrieved context irrelevant to the user question or does not address their needs?",
38
+ "groundedness": "Does the assistant response include claims or facts not supported by or contradicted by the provided context?",
39
  }
40
  return messages[criteria_name]
41