Martín Santillán Cooper commited on
Commit
5269ad1
1 Parent(s): 665f60a
Files changed (2) hide show
  1. model.py +63 -34
  2. requirements_frozen.txt +89 -2
model.py CHANGED
@@ -1,22 +1,61 @@
1
- import torch
2
- from torch.nn.functional import softmax
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
4
  import os
5
- from time import time
6
  from logger import logger
7
- from time import sleep
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
10
  if not mock_model_call:
11
  use_conda = os.getenv('USE_CONDA', "false") == "true"
12
- device = "cuda"
13
  model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
14
- logger.info(f'Model path is "{model_path}"')
 
15
  tokenizer = AutoTokenizer.from_pretrained(model_path)
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_path,
18
- device_map=device if use_conda else None
19
- )
20
 
21
  def generate_text(prompt):
22
  logger.debug('Starting evaluation...')
@@ -28,32 +67,22 @@ def generate_text(prompt):
28
  return {'assessment': 'Yes', 'certainty': 0.97}
29
  else:
30
  start = time()
31
- tokenized_chat = tokenizer.apply_chat_template(
32
- [prompt],
33
- tokenize=True,
34
- add_generation_prompt=True,
35
- return_tensors="pt")
36
- if use_conda:
37
- tokenized_chat = tokenized_chat.to(device)
38
  with torch.no_grad():
39
- logits = model(tokenized_chat).logits
40
- gen_outputs = model.generate(tokenized_chat, max_new_tokens=128)
41
-
42
- generated_text = tokenizer.decode(gen_outputs[0])
43
- logger.debug(f'Model generated text: \n{generated_text}')
44
- vocab = tokenizer.get_vocab()
45
- selected_logits = logits[0, -1, [vocab['No'], vocab['Yes']]]
46
- probabilities = softmax(selected_logits, dim=0)
47
-
48
- prob = probabilities[1].item()
49
- logger.debug(f'Certainty is: {prob} from probabilities {probabilities}')
50
- certainty = prob
51
- assessment = 'Yes' if certainty > 0.5 else 'No'
52
- certainty = 1 - certainty if certainty < 0.5 else certainty
53
- certainty = f'{round(certainty,3)}'
54
 
 
 
 
55
  end = time()
56
  total = end - start
57
  logger.debug(f'it took {round(total/60, 2)} mins')
58
 
59
- return {'assessment': assessment, 'certainty': certainty}
 
 
 
 
1
  import os
2
+ from time import time, sleep
3
  from logger import logger
4
+ import math
5
+ import torch
6
+ from vllm import LLM, SamplingParams
7
+ from transformers import AutoTokenizer
8
+
9
+
10
+
11
+ safe_token = "No"
12
+ unsafe_token = "Yes"
13
+ nlogprobs = 5
14
+
15
+ def parse_output(output):
16
+ label, prob = None, None
17
+
18
+ if nlogprobs > 0:
19
+ logprobs = next(iter(output.outputs)).logprobs
20
+ if logprobs is not None:
21
+ prob = get_probablities(logprobs)
22
+ prob_of_risk = prob[1]
23
+
24
+ res = next(iter(output.outputs)).text.strip()
25
+ if unsafe_token.lower() == res.lower():
26
+ label = unsafe_token
27
+ elif safe_token.lower() == res.lower():
28
+ label = safe_token
29
+ else:
30
+ label = "Failed"
31
+
32
+ return label, prob_of_risk.item()
33
+
34
+ def get_probablities(logprobs):
35
+ safe_token_prob = 1e-50
36
+ unsafe_token_prob = 1e-50
37
+ for gen_token_i in logprobs:
38
+ for token_prob in gen_token_i.values():
39
+ decoded_token = token_prob.decoded_token
40
+ if decoded_token.strip().lower() == safe_token.lower():
41
+ safe_token_prob += math.exp(token_prob.logprob)
42
+ if decoded_token.strip().lower() == unsafe_token.lower():
43
+ unsafe_token_prob += math.exp(token_prob.logprob)
44
+
45
+ probabilities = torch.softmax(
46
+ torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
47
+ )
48
+
49
+ return probabilities
50
+
51
 
52
  mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
53
  if not mock_model_call:
54
  use_conda = os.getenv('USE_CONDA', "false") == "true"
 
55
  model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
56
+ sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
57
+ model = LLM(model=model_path, tensor_parallel_size=1)
58
  tokenizer = AutoTokenizer.from_pretrained(model_path)
 
 
 
 
59
 
60
  def generate_text(prompt):
61
  logger.debug('Starting evaluation...')
 
67
  return {'assessment': 'Yes', 'certainty': 0.97}
68
  else:
69
  start = time()
70
+
71
+ tokenized_chat = tokenizer.apply_chat_template([prompt], tokenize=False, add_generation_prompt=True)
72
+
 
 
 
 
73
  with torch.no_grad():
74
+ output = model.generate(tokenized_chat, sampling_params, use_tqdm=False)
75
+
76
+
77
+ predicted_label = output[0].outputs[0].text.strip()
78
+
79
+ label, prob_of_risk = parse_output(output[0])
 
 
 
 
 
 
 
 
 
80
 
81
+ logger.debug(f'Model generated label: \n{label}')
82
+ logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
83
+
84
  end = time()
85
  total = end - start
86
  logger.debug(f'it took {round(total/60, 2)} mins')
87
 
88
+ return {'assessment': label, 'certainty': prob_of_risk}
requirements_frozen.txt CHANGED
@@ -1,67 +1,154 @@
 
 
1
  aiofiles==23.2.1
 
 
 
 
2
  annotated-types==0.7.0
3
  anyio==4.6.0
 
 
 
4
  certifi==2024.8.30
5
  charset-normalizer==3.3.2
6
  click==8.1.7
 
7
  contourpy==1.3.0
8
  cycler==0.12.1
 
 
 
 
 
 
 
 
9
  exceptiongroup==1.2.2
10
  fastapi==0.115.0
11
  ffmpy==0.4.0
12
  filelock==3.16.1
13
  fonttools==4.54.1
14
- fsspec==2024.9.0
 
 
15
  gradio==4.44.1
16
  gradio_client==1.3.0
17
  h11==0.14.0
18
  httpcore==1.0.6
 
19
  httpx==0.27.2
20
  huggingface-hub==0.25.1
 
 
 
21
  idna==3.10
 
22
  importlib_resources==6.4.5
 
23
  Jinja2==3.1.4
 
 
 
 
24
  kiwisolver==1.4.7
 
 
 
25
  markdown-it-py==3.0.0
26
  MarkupSafe==2.1.5
27
  matplotlib==3.9.2
28
  mdurl==0.1.2
 
 
29
  mpmath==1.3.0
 
 
 
 
 
30
  networkx==3.3
 
31
  numpy==1.26.4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  orjson==3.10.7
 
33
  packaging==24.1
34
  pandas==2.2.3
 
35
  pillow==10.4.0
 
 
 
 
 
 
 
 
 
36
  pydantic==2.9.2
37
  pydantic_core==2.23.4
38
  pydub==0.25.1
39
  Pygments==2.18.0
 
40
  pyparsing==3.1.4
41
  python-dateutil==2.9.0.post0
42
  python-dotenv==1.0.1
43
  python-multipart==0.0.12
44
  pytz==2024.2
45
  PyYAML==6.0.2
 
 
 
46
  regex==2024.9.11
47
  requests==2.32.3
48
  rich==13.9.2
 
49
  ruff==0.6.9
 
50
  safetensors==0.4.5
51
  semantic-version==2.10.0
 
52
  shellingham==1.5.4
53
  six==1.16.0
54
  sniffio==1.3.1
 
55
  starlette==0.38.6
 
56
  sympy==1.13.3
 
 
57
  tokenizers==0.20.0
58
  tomlkit==0.12.0
59
- torch==2.2.2
 
60
  tqdm==4.66.5
61
  transformers==4.45.1
 
62
  typer==0.12.5
63
  typing_extensions==4.12.2
64
  tzdata==2024.2
65
  urllib3==2.2.3
66
  uvicorn==0.31.0
 
 
 
67
  websockets==12.0
 
 
 
 
 
 
1
+ accelerate==0.34.2
2
+ aiobotocore==2.15.1
3
  aiofiles==23.2.1
4
+ aiohappyeyeballs==2.4.3
5
+ aiohttp==3.10.8
6
+ aioitertools==0.12.0
7
+ aiosignal==1.3.1
8
  annotated-types==0.7.0
9
  anyio==4.6.0
10
+ async-timeout==4.0.3
11
+ attrs==24.2.0
12
+ botocore==1.35.23
13
  certifi==2024.8.30
14
  charset-normalizer==3.3.2
15
  click==8.1.7
16
+ cloudpickle==3.0.0
17
  contourpy==1.3.0
18
  cycler==0.12.1
19
+ datasets==3.0.1
20
+ deprecation==2.1.0
21
+ dill==0.3.8
22
+ diskcache==5.6.3
23
+ distro==1.9.0
24
+ dmf-lib @ git+ssh://git@github.ibm.com/arc/dmf-library.git@6acf931132183153684c1c9a8edd6dbfec6f0372
25
+ duckdb==1.1.1
26
+ einops==0.8.0
27
  exceptiongroup==1.2.2
28
  fastapi==0.115.0
29
  ffmpy==0.4.0
30
  filelock==3.16.1
31
  fonttools==4.54.1
32
+ frozenlist==1.4.1
33
+ fsspec==2024.6.1
34
+ gguf==0.10.0
35
  gradio==4.44.1
36
  gradio_client==1.3.0
37
  h11==0.14.0
38
  httpcore==1.0.6
39
+ httptools==0.6.1
40
  httpx==0.27.2
41
  huggingface-hub==0.25.1
42
+ ibm-cos-sdk==2.13.6
43
+ ibm-cos-sdk-core==2.13.6
44
+ ibm-cos-sdk-s3transfer==2.13.6
45
  idna==3.10
46
+ importlib_metadata==8.5.0
47
  importlib_resources==6.4.5
48
+ interegular==0.3.3
49
  Jinja2==3.1.4
50
+ jiter==0.6.1
51
+ jmespath==1.0.1
52
+ jsonschema==4.23.0
53
+ jsonschema-specifications==2024.10.1
54
  kiwisolver==1.4.7
55
+ lark==1.2.2
56
+ llvmlite==0.43.0
57
+ lm-format-enforcer==0.10.6
58
  markdown-it-py==3.0.0
59
  MarkupSafe==2.1.5
60
  matplotlib==3.9.2
61
  mdurl==0.1.2
62
+ mistral_common==1.4.4
63
+ mmh3==4.1.0
64
  mpmath==1.3.0
65
+ msgpack==1.1.0
66
+ msgspec==0.18.6
67
+ multidict==6.1.0
68
+ multiprocess==0.70.16
69
+ nest-asyncio==1.6.0
70
  networkx==3.3
71
+ numba==0.60.0
72
  numpy==1.26.4
73
+ nvidia-cublas-cu12==12.1.3.1
74
+ nvidia-cuda-cupti-cu12==12.1.105
75
+ nvidia-cuda-nvrtc-cu12==12.1.105
76
+ nvidia-cuda-runtime-cu12==12.1.105
77
+ nvidia-cudnn-cu12==9.1.0.70
78
+ nvidia-cufft-cu12==11.0.2.54
79
+ nvidia-curand-cu12==10.3.2.106
80
+ nvidia-cusolver-cu12==11.4.5.107
81
+ nvidia-cusparse-cu12==12.1.0.106
82
+ nvidia-ml-py==12.560.30
83
+ nvidia-nccl-cu12==2.20.5
84
+ nvidia-nvjitlink-cu12==12.6.77
85
+ nvidia-nvtx-cu12==12.1.105
86
+ openai==1.51.2
87
  orjson==3.10.7
88
+ outlines==0.0.46
89
  packaging==24.1
90
  pandas==2.2.3
91
+ partial-json-parser==0.2.1.1.post4
92
  pillow==10.4.0
93
+ progressbar==2.5
94
+ prometheus-fastapi-instrumentator==7.0.0
95
+ prometheus_client==0.21.0
96
+ protobuf==5.28.2
97
+ psutil==6.0.0
98
+ py-cpuinfo==9.0.0
99
+ pyairports==2.1.1
100
+ pyarrow==17.0.0
101
+ pycountry==24.6.1
102
  pydantic==2.9.2
103
  pydantic_core==2.23.4
104
  pydub==0.25.1
105
  Pygments==2.18.0
106
+ pyiceberg==0.7.1
107
  pyparsing==3.1.4
108
  python-dateutil==2.9.0.post0
109
  python-dotenv==1.0.1
110
  python-multipart==0.0.12
111
  pytz==2024.2
112
  PyYAML==6.0.2
113
+ pyzmq==26.2.0
114
+ ray==2.37.0
115
+ referencing==0.35.1
116
  regex==2024.9.11
117
  requests==2.32.3
118
  rich==13.9.2
119
+ rpds-py==0.20.0
120
  ruff==0.6.9
121
+ s3fs==2023.12.2
122
  safetensors==0.4.5
123
  semantic-version==2.10.0
124
+ sentencepiece==0.2.0
125
  shellingham==1.5.4
126
  six==1.16.0
127
  sniffio==1.3.1
128
+ sortedcontainers==2.4.0
129
  starlette==0.38.6
130
+ strictyaml==1.7.3
131
  sympy==1.13.3
132
+ tenacity==8.5.0
133
+ tiktoken==0.7.0
134
  tokenizers==0.20.0
135
  tomlkit==0.12.0
136
+ torch==2.4.0
137
+ torchvision==0.19.0
138
  tqdm==4.66.5
139
  transformers==4.45.1
140
+ triton==3.0.0
141
  typer==0.12.5
142
  typing_extensions==4.12.2
143
  tzdata==2024.2
144
  urllib3==2.2.3
145
  uvicorn==0.31.0
146
+ uvloop==0.20.0
147
+ vllm==0.6.2
148
+ watchfiles==0.24.0
149
  websockets==12.0
150
+ wrapt==1.16.0
151
+ xformers==0.0.27.post2
152
+ xxhash==3.5.0
153
+ yarl==1.13.1
154
+ zipp==3.20.2