Martín Santillán Cooper commited on
Commit
477d968
1 Parent(s): f97dae7

Adapt for deployment on HF ZeroSpace

Browse files
Files changed (3) hide show
  1. .env.example +4 -5
  2. requirements.txt +2 -1
  3. src/model.py +4 -2
.env.example CHANGED
@@ -1,5 +1,4 @@
1
- MODEL_PATH='../dmf_models/granite-guardian-8b-pipecleaner-r241024a'
2
- USE_CONDA='true'
3
- INFERENCE_ENGINE='' # one of [WATSONX, MOCK, VLLM]
4
- WATSONX_API_KEY=""
5
- WATSONX_PROJECT_ID=""
 
1
+ MODEL_PATH='ibm-granite/granite-guardian-3.0-8b'
2
+ INFERENCE_ENGINE='VLLM' # one of [WATSONX, MOCK, VLLM]
3
+ WATSONX_API_KEY=''
4
+ WATSONX_PROJECT_ID=''
 
requirements.txt CHANGED
@@ -4,4 +4,5 @@ tqdm
4
  jinja2
5
  ibm_watsonx_ai
6
  transformers
7
- gradio_modal
 
 
4
  jinja2
5
  ibm_watsonx_ai
6
  transformers
7
+ gradio_modal
8
+ spaces
src/model.py CHANGED
@@ -7,19 +7,20 @@ from ibm_watsonx_ai.client import APIClient
7
  from ibm_watsonx_ai.foundation_models import ModelInference
8
  from transformers import AutoTokenizer
9
  import math
 
10
 
11
  safe_token = "No"
12
  risky_token = "Yes"
13
  nlogprobs = 5
14
 
15
- inference_engine = os.getenv('INFERENCE_ENGINE')
16
  logger.debug(f"Inference engine is: '{inference_engine}'")
17
 
18
  if inference_engine == 'VLLM':
19
  import torch
20
  from vllm import LLM, SamplingParams
21
  from transformers import AutoTokenizer
22
- model_path = os.getenv('MODEL_PATH') #"granite-guardian-3b-pipecleaner-r241024a"
23
  logger.debug(f"model_path is {model_path}")
24
  tokenizer = AutoTokenizer.from_pretrained(model_path)
25
  sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
@@ -138,6 +139,7 @@ def parse_output_watsonx(generated_tokens_list):
138
 
139
  return label, prob_of_risk
140
 
 
141
  def generate_text(messages, criteria_name):
142
  logger.debug(f'Messages used to create the prompt are: \n{messages}')
143
 
 
7
  from ibm_watsonx_ai.foundation_models import ModelInference
8
  from transformers import AutoTokenizer
9
  import math
10
+ import spaces
11
 
12
  safe_token = "No"
13
  risky_token = "Yes"
14
  nlogprobs = 5
15
 
16
+ inference_engine = os.getenv('INFERENCE_ENGINE', 'VLLM')
17
  logger.debug(f"Inference engine is: '{inference_engine}'")
18
 
19
  if inference_engine == 'VLLM':
20
  import torch
21
  from vllm import LLM, SamplingParams
22
  from transformers import AutoTokenizer
23
+ model_path = os.getenv('MODEL_PATH', 'ibm-granite/granite-guardian-3.0-8b')
24
  logger.debug(f"model_path is {model_path}")
25
  tokenizer = AutoTokenizer.from_pretrained(model_path)
26
  sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
 
139
 
140
  return label, prob_of_risk
141
 
142
+ @spaces.GPU
143
  def generate_text(messages, criteria_name):
144
  logger.debug(f'Messages used to create the prompt are: \n{messages}')
145