Spaces:
Running
on
Zero
Running
on
Zero
Martín Santillán Cooper
commited on
Commit
•
5269ad1
1
Parent(s):
665f60a
use vllm
Browse files- model.py +63 -34
- requirements_frozen.txt +89 -2
model.py
CHANGED
@@ -1,22 +1,61 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.nn.functional import softmax
|
3 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
4 |
import os
|
5 |
-
from time import time
|
6 |
from logger import logger
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
|
10 |
if not mock_model_call:
|
11 |
use_conda = os.getenv('USE_CONDA', "false") == "true"
|
12 |
-
device = "cuda"
|
13 |
model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
|
14 |
-
|
|
|
15 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
16 |
-
model = AutoModelForCausalLM.from_pretrained(
|
17 |
-
model_path,
|
18 |
-
device_map=device if use_conda else None
|
19 |
-
)
|
20 |
|
21 |
def generate_text(prompt):
|
22 |
logger.debug('Starting evaluation...')
|
@@ -28,32 +67,22 @@ def generate_text(prompt):
|
|
28 |
return {'assessment': 'Yes', 'certainty': 0.97}
|
29 |
else:
|
30 |
start = time()
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
add_generation_prompt=True,
|
35 |
-
return_tensors="pt")
|
36 |
-
if use_conda:
|
37 |
-
tokenized_chat = tokenized_chat.to(device)
|
38 |
with torch.no_grad():
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
selected_logits = logits[0, -1, [vocab['No'], vocab['Yes']]]
|
46 |
-
probabilities = softmax(selected_logits, dim=0)
|
47 |
-
|
48 |
-
prob = probabilities[1].item()
|
49 |
-
logger.debug(f'Certainty is: {prob} from probabilities {probabilities}')
|
50 |
-
certainty = prob
|
51 |
-
assessment = 'Yes' if certainty > 0.5 else 'No'
|
52 |
-
certainty = 1 - certainty if certainty < 0.5 else certainty
|
53 |
-
certainty = f'{round(certainty,3)}'
|
54 |
|
|
|
|
|
|
|
55 |
end = time()
|
56 |
total = end - start
|
57 |
logger.debug(f'it took {round(total/60, 2)} mins')
|
58 |
|
59 |
-
return {'assessment':
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from time import time, sleep
|
3 |
from logger import logger
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
from vllm import LLM, SamplingParams
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
safe_token = "No"
|
12 |
+
unsafe_token = "Yes"
|
13 |
+
nlogprobs = 5
|
14 |
+
|
15 |
+
def parse_output(output):
|
16 |
+
label, prob = None, None
|
17 |
+
|
18 |
+
if nlogprobs > 0:
|
19 |
+
logprobs = next(iter(output.outputs)).logprobs
|
20 |
+
if logprobs is not None:
|
21 |
+
prob = get_probablities(logprobs)
|
22 |
+
prob_of_risk = prob[1]
|
23 |
+
|
24 |
+
res = next(iter(output.outputs)).text.strip()
|
25 |
+
if unsafe_token.lower() == res.lower():
|
26 |
+
label = unsafe_token
|
27 |
+
elif safe_token.lower() == res.lower():
|
28 |
+
label = safe_token
|
29 |
+
else:
|
30 |
+
label = "Failed"
|
31 |
+
|
32 |
+
return label, prob_of_risk.item()
|
33 |
+
|
34 |
+
def get_probablities(logprobs):
|
35 |
+
safe_token_prob = 1e-50
|
36 |
+
unsafe_token_prob = 1e-50
|
37 |
+
for gen_token_i in logprobs:
|
38 |
+
for token_prob in gen_token_i.values():
|
39 |
+
decoded_token = token_prob.decoded_token
|
40 |
+
if decoded_token.strip().lower() == safe_token.lower():
|
41 |
+
safe_token_prob += math.exp(token_prob.logprob)
|
42 |
+
if decoded_token.strip().lower() == unsafe_token.lower():
|
43 |
+
unsafe_token_prob += math.exp(token_prob.logprob)
|
44 |
+
|
45 |
+
probabilities = torch.softmax(
|
46 |
+
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
|
47 |
+
)
|
48 |
+
|
49 |
+
return probabilities
|
50 |
+
|
51 |
|
52 |
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
|
53 |
if not mock_model_call:
|
54 |
use_conda = os.getenv('USE_CONDA', "false") == "true"
|
|
|
55 |
model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
|
56 |
+
sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
57 |
+
model = LLM(model=model_path, tensor_parallel_size=1)
|
58 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
|
|
|
|
|
|
59 |
|
60 |
def generate_text(prompt):
|
61 |
logger.debug('Starting evaluation...')
|
|
|
67 |
return {'assessment': 'Yes', 'certainty': 0.97}
|
68 |
else:
|
69 |
start = time()
|
70 |
+
|
71 |
+
tokenized_chat = tokenizer.apply_chat_template([prompt], tokenize=False, add_generation_prompt=True)
|
72 |
+
|
|
|
|
|
|
|
|
|
73 |
with torch.no_grad():
|
74 |
+
output = model.generate(tokenized_chat, sampling_params, use_tqdm=False)
|
75 |
+
|
76 |
+
|
77 |
+
predicted_label = output[0].outputs[0].text.strip()
|
78 |
+
|
79 |
+
label, prob_of_risk = parse_output(output[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
+
logger.debug(f'Model generated label: \n{label}')
|
82 |
+
logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
|
83 |
+
|
84 |
end = time()
|
85 |
total = end - start
|
86 |
logger.debug(f'it took {round(total/60, 2)} mins')
|
87 |
|
88 |
+
return {'assessment': label, 'certainty': prob_of_risk}
|
requirements_frozen.txt
CHANGED
@@ -1,67 +1,154 @@
|
|
|
|
|
|
1 |
aiofiles==23.2.1
|
|
|
|
|
|
|
|
|
2 |
annotated-types==0.7.0
|
3 |
anyio==4.6.0
|
|
|
|
|
|
|
4 |
certifi==2024.8.30
|
5 |
charset-normalizer==3.3.2
|
6 |
click==8.1.7
|
|
|
7 |
contourpy==1.3.0
|
8 |
cycler==0.12.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
exceptiongroup==1.2.2
|
10 |
fastapi==0.115.0
|
11 |
ffmpy==0.4.0
|
12 |
filelock==3.16.1
|
13 |
fonttools==4.54.1
|
14 |
-
|
|
|
|
|
15 |
gradio==4.44.1
|
16 |
gradio_client==1.3.0
|
17 |
h11==0.14.0
|
18 |
httpcore==1.0.6
|
|
|
19 |
httpx==0.27.2
|
20 |
huggingface-hub==0.25.1
|
|
|
|
|
|
|
21 |
idna==3.10
|
|
|
22 |
importlib_resources==6.4.5
|
|
|
23 |
Jinja2==3.1.4
|
|
|
|
|
|
|
|
|
24 |
kiwisolver==1.4.7
|
|
|
|
|
|
|
25 |
markdown-it-py==3.0.0
|
26 |
MarkupSafe==2.1.5
|
27 |
matplotlib==3.9.2
|
28 |
mdurl==0.1.2
|
|
|
|
|
29 |
mpmath==1.3.0
|
|
|
|
|
|
|
|
|
|
|
30 |
networkx==3.3
|
|
|
31 |
numpy==1.26.4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
orjson==3.10.7
|
|
|
33 |
packaging==24.1
|
34 |
pandas==2.2.3
|
|
|
35 |
pillow==10.4.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
pydantic==2.9.2
|
37 |
pydantic_core==2.23.4
|
38 |
pydub==0.25.1
|
39 |
Pygments==2.18.0
|
|
|
40 |
pyparsing==3.1.4
|
41 |
python-dateutil==2.9.0.post0
|
42 |
python-dotenv==1.0.1
|
43 |
python-multipart==0.0.12
|
44 |
pytz==2024.2
|
45 |
PyYAML==6.0.2
|
|
|
|
|
|
|
46 |
regex==2024.9.11
|
47 |
requests==2.32.3
|
48 |
rich==13.9.2
|
|
|
49 |
ruff==0.6.9
|
|
|
50 |
safetensors==0.4.5
|
51 |
semantic-version==2.10.0
|
|
|
52 |
shellingham==1.5.4
|
53 |
six==1.16.0
|
54 |
sniffio==1.3.1
|
|
|
55 |
starlette==0.38.6
|
|
|
56 |
sympy==1.13.3
|
|
|
|
|
57 |
tokenizers==0.20.0
|
58 |
tomlkit==0.12.0
|
59 |
-
torch==2.
|
|
|
60 |
tqdm==4.66.5
|
61 |
transformers==4.45.1
|
|
|
62 |
typer==0.12.5
|
63 |
typing_extensions==4.12.2
|
64 |
tzdata==2024.2
|
65 |
urllib3==2.2.3
|
66 |
uvicorn==0.31.0
|
|
|
|
|
|
|
67 |
websockets==12.0
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.34.2
|
2 |
+
aiobotocore==2.15.1
|
3 |
aiofiles==23.2.1
|
4 |
+
aiohappyeyeballs==2.4.3
|
5 |
+
aiohttp==3.10.8
|
6 |
+
aioitertools==0.12.0
|
7 |
+
aiosignal==1.3.1
|
8 |
annotated-types==0.7.0
|
9 |
anyio==4.6.0
|
10 |
+
async-timeout==4.0.3
|
11 |
+
attrs==24.2.0
|
12 |
+
botocore==1.35.23
|
13 |
certifi==2024.8.30
|
14 |
charset-normalizer==3.3.2
|
15 |
click==8.1.7
|
16 |
+
cloudpickle==3.0.0
|
17 |
contourpy==1.3.0
|
18 |
cycler==0.12.1
|
19 |
+
datasets==3.0.1
|
20 |
+
deprecation==2.1.0
|
21 |
+
dill==0.3.8
|
22 |
+
diskcache==5.6.3
|
23 |
+
distro==1.9.0
|
24 |
+
dmf-lib @ git+ssh://git@github.ibm.com/arc/dmf-library.git@6acf931132183153684c1c9a8edd6dbfec6f0372
|
25 |
+
duckdb==1.1.1
|
26 |
+
einops==0.8.0
|
27 |
exceptiongroup==1.2.2
|
28 |
fastapi==0.115.0
|
29 |
ffmpy==0.4.0
|
30 |
filelock==3.16.1
|
31 |
fonttools==4.54.1
|
32 |
+
frozenlist==1.4.1
|
33 |
+
fsspec==2024.6.1
|
34 |
+
gguf==0.10.0
|
35 |
gradio==4.44.1
|
36 |
gradio_client==1.3.0
|
37 |
h11==0.14.0
|
38 |
httpcore==1.0.6
|
39 |
+
httptools==0.6.1
|
40 |
httpx==0.27.2
|
41 |
huggingface-hub==0.25.1
|
42 |
+
ibm-cos-sdk==2.13.6
|
43 |
+
ibm-cos-sdk-core==2.13.6
|
44 |
+
ibm-cos-sdk-s3transfer==2.13.6
|
45 |
idna==3.10
|
46 |
+
importlib_metadata==8.5.0
|
47 |
importlib_resources==6.4.5
|
48 |
+
interegular==0.3.3
|
49 |
Jinja2==3.1.4
|
50 |
+
jiter==0.6.1
|
51 |
+
jmespath==1.0.1
|
52 |
+
jsonschema==4.23.0
|
53 |
+
jsonschema-specifications==2024.10.1
|
54 |
kiwisolver==1.4.7
|
55 |
+
lark==1.2.2
|
56 |
+
llvmlite==0.43.0
|
57 |
+
lm-format-enforcer==0.10.6
|
58 |
markdown-it-py==3.0.0
|
59 |
MarkupSafe==2.1.5
|
60 |
matplotlib==3.9.2
|
61 |
mdurl==0.1.2
|
62 |
+
mistral_common==1.4.4
|
63 |
+
mmh3==4.1.0
|
64 |
mpmath==1.3.0
|
65 |
+
msgpack==1.1.0
|
66 |
+
msgspec==0.18.6
|
67 |
+
multidict==6.1.0
|
68 |
+
multiprocess==0.70.16
|
69 |
+
nest-asyncio==1.6.0
|
70 |
networkx==3.3
|
71 |
+
numba==0.60.0
|
72 |
numpy==1.26.4
|
73 |
+
nvidia-cublas-cu12==12.1.3.1
|
74 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
75 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
76 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
77 |
+
nvidia-cudnn-cu12==9.1.0.70
|
78 |
+
nvidia-cufft-cu12==11.0.2.54
|
79 |
+
nvidia-curand-cu12==10.3.2.106
|
80 |
+
nvidia-cusolver-cu12==11.4.5.107
|
81 |
+
nvidia-cusparse-cu12==12.1.0.106
|
82 |
+
nvidia-ml-py==12.560.30
|
83 |
+
nvidia-nccl-cu12==2.20.5
|
84 |
+
nvidia-nvjitlink-cu12==12.6.77
|
85 |
+
nvidia-nvtx-cu12==12.1.105
|
86 |
+
openai==1.51.2
|
87 |
orjson==3.10.7
|
88 |
+
outlines==0.0.46
|
89 |
packaging==24.1
|
90 |
pandas==2.2.3
|
91 |
+
partial-json-parser==0.2.1.1.post4
|
92 |
pillow==10.4.0
|
93 |
+
progressbar==2.5
|
94 |
+
prometheus-fastapi-instrumentator==7.0.0
|
95 |
+
prometheus_client==0.21.0
|
96 |
+
protobuf==5.28.2
|
97 |
+
psutil==6.0.0
|
98 |
+
py-cpuinfo==9.0.0
|
99 |
+
pyairports==2.1.1
|
100 |
+
pyarrow==17.0.0
|
101 |
+
pycountry==24.6.1
|
102 |
pydantic==2.9.2
|
103 |
pydantic_core==2.23.4
|
104 |
pydub==0.25.1
|
105 |
Pygments==2.18.0
|
106 |
+
pyiceberg==0.7.1
|
107 |
pyparsing==3.1.4
|
108 |
python-dateutil==2.9.0.post0
|
109 |
python-dotenv==1.0.1
|
110 |
python-multipart==0.0.12
|
111 |
pytz==2024.2
|
112 |
PyYAML==6.0.2
|
113 |
+
pyzmq==26.2.0
|
114 |
+
ray==2.37.0
|
115 |
+
referencing==0.35.1
|
116 |
regex==2024.9.11
|
117 |
requests==2.32.3
|
118 |
rich==13.9.2
|
119 |
+
rpds-py==0.20.0
|
120 |
ruff==0.6.9
|
121 |
+
s3fs==2023.12.2
|
122 |
safetensors==0.4.5
|
123 |
semantic-version==2.10.0
|
124 |
+
sentencepiece==0.2.0
|
125 |
shellingham==1.5.4
|
126 |
six==1.16.0
|
127 |
sniffio==1.3.1
|
128 |
+
sortedcontainers==2.4.0
|
129 |
starlette==0.38.6
|
130 |
+
strictyaml==1.7.3
|
131 |
sympy==1.13.3
|
132 |
+
tenacity==8.5.0
|
133 |
+
tiktoken==0.7.0
|
134 |
tokenizers==0.20.0
|
135 |
tomlkit==0.12.0
|
136 |
+
torch==2.4.0
|
137 |
+
torchvision==0.19.0
|
138 |
tqdm==4.66.5
|
139 |
transformers==4.45.1
|
140 |
+
triton==3.0.0
|
141 |
typer==0.12.5
|
142 |
typing_extensions==4.12.2
|
143 |
tzdata==2024.2
|
144 |
urllib3==2.2.3
|
145 |
uvicorn==0.31.0
|
146 |
+
uvloop==0.20.0
|
147 |
+
vllm==0.6.2
|
148 |
+
watchfiles==0.24.0
|
149 |
websockets==12.0
|
150 |
+
wrapt==1.16.0
|
151 |
+
xformers==0.0.27.post2
|
152 |
+
xxhash==3.5.0
|
153 |
+
yarl==1.13.1
|
154 |
+
zipp==3.20.2
|