Spaces:
Running
on
Zero
Running
on
Zero
Martín Santillán Cooper
commited on
Commit
•
477d968
1
Parent(s):
f97dae7
Adapt for deployment on HF ZeroSpace
Browse files- .env.example +4 -5
- requirements.txt +2 -1
- src/model.py +4 -2
.env.example
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
-
MODEL_PATH='
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
WATSONX_PROJECT_ID=""
|
|
|
1 |
+
MODEL_PATH='ibm-granite/granite-guardian-3.0-8b'
|
2 |
+
INFERENCE_ENGINE='VLLM' # one of [WATSONX, MOCK, VLLM]
|
3 |
+
WATSONX_API_KEY=''
|
4 |
+
WATSONX_PROJECT_ID=''
|
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ tqdm
|
|
4 |
jinja2
|
5 |
ibm_watsonx_ai
|
6 |
transformers
|
7 |
-
gradio_modal
|
|
|
|
4 |
jinja2
|
5 |
ibm_watsonx_ai
|
6 |
transformers
|
7 |
+
gradio_modal
|
8 |
+
spaces
|
src/model.py
CHANGED
@@ -7,19 +7,20 @@ from ibm_watsonx_ai.client import APIClient
|
|
7 |
from ibm_watsonx_ai.foundation_models import ModelInference
|
8 |
from transformers import AutoTokenizer
|
9 |
import math
|
|
|
10 |
|
11 |
safe_token = "No"
|
12 |
risky_token = "Yes"
|
13 |
nlogprobs = 5
|
14 |
|
15 |
-
inference_engine = os.getenv('INFERENCE_ENGINE')
|
16 |
logger.debug(f"Inference engine is: '{inference_engine}'")
|
17 |
|
18 |
if inference_engine == 'VLLM':
|
19 |
import torch
|
20 |
from vllm import LLM, SamplingParams
|
21 |
from transformers import AutoTokenizer
|
22 |
-
model_path = os.getenv('MODEL_PATH'
|
23 |
logger.debug(f"model_path is {model_path}")
|
24 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
25 |
sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
@@ -138,6 +139,7 @@ def parse_output_watsonx(generated_tokens_list):
|
|
138 |
|
139 |
return label, prob_of_risk
|
140 |
|
|
|
141 |
def generate_text(messages, criteria_name):
|
142 |
logger.debug(f'Messages used to create the prompt are: \n{messages}')
|
143 |
|
|
|
7 |
from ibm_watsonx_ai.foundation_models import ModelInference
|
8 |
from transformers import AutoTokenizer
|
9 |
import math
|
10 |
+
import spaces
|
11 |
|
12 |
safe_token = "No"
|
13 |
risky_token = "Yes"
|
14 |
nlogprobs = 5
|
15 |
|
16 |
+
inference_engine = os.getenv('INFERENCE_ENGINE', 'VLLM')
|
17 |
logger.debug(f"Inference engine is: '{inference_engine}'")
|
18 |
|
19 |
if inference_engine == 'VLLM':
|
20 |
import torch
|
21 |
from vllm import LLM, SamplingParams
|
22 |
from transformers import AutoTokenizer
|
23 |
+
model_path = os.getenv('MODEL_PATH', 'ibm-granite/granite-guardian-3.0-8b')
|
24 |
logger.debug(f"model_path is {model_path}")
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
26 |
sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
|
|
139 |
|
140 |
return label, prob_of_risk
|
141 |
|
142 |
+
@spaces.GPU
|
143 |
def generate_text(messages, criteria_name):
|
144 |
logger.debug(f'Messages used to create the prompt are: \n{messages}')
|
145 |
|