hanchier commited on
Commit
d75dc6d
·
1 Parent(s): 261016b
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # repo-specific
163
+ **/.DS_Store
164
+ _logs
165
+ _logs/
166
+ checkpoints/
app.py CHANGED
@@ -1,6 +1,7 @@
1
  # https://huggingface.co/spaces/Glaciohound/LM-Steer
2
 
3
  import torch
 
4
  import streamlit as st
5
  import random
6
  import numpy as np
@@ -23,39 +24,53 @@ def st_get_model(model_name, low_resource_mode):
23
  return model, tokenizer
24
 
25
 
26
- def word_embedding_space_analysis(model, tokenizer, dim):
27
- matrix = model.steer.projector1.data[dim].matmul(
28
- model.steer.projector2.data[dim].transpose(0, 1))
29
- S, V, D = torch.linalg.svd(matrix)
 
 
 
30
  embeddings = model.steer.lm_head.weight
 
 
31
 
32
  data = []
33
- for _i in range(10):
34
- left_tokens = embeddings.matmul(D[_i]).argsort()[-20:].flip(0)
35
- right_tokens = embeddings.matmul(D[_i]).argsort()[:20]
 
 
 
36
 
37
  def filter_words(side_tokens):
38
  output = []
39
  for t in side_tokens:
40
  word = tokenizer.decode([t])
41
- if not word[0].isalpha() and word[1:].isalpha():
42
- output.append(word[1:]+"-")
43
- return output
 
 
 
 
 
44
 
45
  data.append([
46
  ", ".join(filter_words(side_tokens))
47
  for side_tokens in [left_tokens, right_tokens]
48
  ])
49
- st.table(pd.DataFrame(
50
  data,
51
  columns=["One Direction", "Another Direction"],
52
  index=[f"Dim {_i}" for _i in range(10)],
53
- ))
54
 
55
 
56
  def main():
57
  # set up the page
58
  random.seed(0)
 
59
  title = "LM-Steer: Word Embeddings Are Steers for Language Models"
60
  st.set_page_config(
61
  layout="wide",
@@ -92,14 +107,14 @@ def main():
92
  '''
93
  Due to resource limits, we are only able to provide a few models for
94
  steering. You can also refer to the Github repository:
95
- https://github.com/Glaciohound/LM-Steer for hosting larger models.
96
  Some generated texts may contain toxic or offensive content. Please be
97
  cautious when using the generated texts.
98
  Note that for these smaller models, the generation quality may not be as
99
  good as the larger models (GPT-4, Llama, etc.).
100
  '''
101
  col1, col2 = st.columns(2)
102
- st.session_state.model_name = col1.selectbox(
103
  "Select a model to steer",
104
  [
105
  "gpt2",
@@ -108,48 +123,57 @@ def main():
108
  "EleutherAI/pythia-70m",
109
  "EleutherAI/pythia-160m",
110
  "EleutherAI/pythia-410m",
111
- # "EleutherAI/pythia-1b", "EleutherAI/pythia-1.4b",
112
- # "EleutherAI/pythia-2.8b", "EleutherAI/pythia-6.9b",
 
 
113
  # "EleutherAI/gpt-j-6B",
114
  ],
115
  )
116
- low_resource_mode = True if st.session_state.model_name in (
117
- "EleutherAI/pythia-1.4b", "EleutherAI/pythia-2.8b",
118
- "EleutherAI/pythia-6.9b", "EleutherAI/gpt-j-6B",
119
- ) else False
 
120
  model, tokenizer = st_get_model(
121
- st.session_state.model_name, low_resource_mode)
 
 
122
  num_param = model.steer.projector1.data.shape[1] ** 2 / 1024 ** 2
123
  total_param = sum(p.numel() for _, p in model.named_parameters()) / \
124
  1024 ** 2
125
  ratio = num_param / total_param
126
- col2.write(f"Steered {num_param:.1f}M out of {total_param:.1f}M "
127
- "parameters, ratio: {:.2%}".format(ratio))
128
 
129
  # steering
130
- steer_range = 4.
131
- steer_interval = 0.5
132
  st.subheader("Enter a sentence and steer the model")
133
  st.session_state.prompt = st.text_input(
134
  "Enter a prompt",
135
  st.session_state.get("prompt", "My life")
136
  )
137
- # col1, col2, col3 = st.columns(3, gap="medium")
138
  col1, col2, col3 = st.columns([2, 2, 1], gap="medium")
139
  sentiment = col1.slider(
140
  "Sentiment (the larger the more positive)",
141
- -steer_range, steer_range, 3.0, steer_interval)
142
  detoxification = col2.slider(
143
  "Detoxification Strength (the larger the less toxic)",
144
  -steer_range, steer_range, 0.0,
145
  steer_interval)
146
- max_length = col3.number_input("Max length", 50, 300, 50, 50)
147
  col1, col2, col3, _ = st.columns(4)
148
  randomness = col2.checkbox("Random sampling", value=False)
149
 
150
  if "output" not in st.session_state:
151
  st.session_state.output = ""
152
  if col1.button("Steer and generate!", type="primary"):
 
 
 
 
 
153
  with st.spinner("Generating..."):
154
  steer_values = [detoxification, 0, sentiment, 0]
155
  st.session_state.output = model.generate(
@@ -159,8 +183,9 @@ def main():
159
  min_length=0,
160
  max_length=max_length,
161
  do_sample=True,
 
162
  )
163
- analyzed_text = \
164
  st.text_area("Generated text:", st.session_state.output, height=200)
165
 
166
  # Analysing the sentence
@@ -174,7 +199,7 @@ def main():
174
  text or use your own. Please note that these two dimensions can be
175
  entangled, as a negative sentiment may also detoxify the text.
176
  '''
177
- if st.session_state.get("output", "") != "" and \
178
  st.button("Analyze the styled text", type="primary"):
179
  col1, col2 = st.columns(2)
180
  for name, col, dim, color in zip(
@@ -187,9 +212,9 @@ def main():
187
  col.subheader(name)
188
  # classification
189
  col.markdown(
190
- "##### Dimension-Wise Classification Distribution")
191
  _, dist_list, _ = model.steer_analysis(
192
- analyzed_text,
193
  dim, -steer_range, steer_range,
194
  bins=2*int(steer_range)+1,
195
  )
@@ -209,10 +234,10 @@ def main():
209
  pos_steer[dim] = 1
210
  neg_steer[dim] = -1
211
  _, token_evidence = model.evidence_words(
212
- analyzed_text,
213
  [pos_steer, neg_steer],
214
  )
215
- tokens = tokenizer(analyzed_text).input_ids
216
  tokens = [f"{i:3d}: {tokenizer.decode([t])}"
217
  for i, t in enumerate(tokens)]
218
  col.markdown("##### Token's Evidence Score in the Dimension")
@@ -241,13 +266,13 @@ def main():
241
  dimension, sometimes only one side of the word embeddings is most relevant
242
  to the style (can be either left or right).
243
  '''
244
- dimension = st.selectbox(
245
- "Select a dimension to analyze",
246
- ["Sentiment", "Detoxification"],
247
- )
248
- dim = 2 if dimension == "Sentiment" else 0
249
- with st.spinner("Analyzing..."):
250
- word_embedding_space_analysis(model, tokenizer, dim)
251
 
252
 
253
  if __name__ == "__main__":
 
1
  # https://huggingface.co/spaces/Glaciohound/LM-Steer
2
 
3
  import torch
4
+ import nltk
5
  import streamlit as st
6
  import random
7
  import numpy as np
 
24
  return model, tokenizer
25
 
26
 
27
+ @st.cache_data()
28
+ def word_embedding_space_analysis(
29
+ model_name, dim):
30
+ model = st.session_state.model
31
+ tokenizer = st.session_state.tokenizer
32
+ projector1 = model.steer.projector1.data[dim]
33
+ projector2 = model.steer.projector2.data[dim]
34
  embeddings = model.steer.lm_head.weight
35
+ matrix = projector1.matmul(projector2.transpose(0, 1))
36
+ S, V, D = torch.linalg.svd(matrix)
37
 
38
  data = []
39
+ top = 30
40
+ select_words = 20
41
+ n_dim = 10
42
+ for _i in range(n_dim):
43
+ left_tokens = embeddings.matmul(D[_i]).argsort()[-top:].flip(0)
44
+ right_tokens = embeddings.matmul(D[_i]).argsort()[:top]
45
 
46
  def filter_words(side_tokens):
47
  output = []
48
  for t in side_tokens:
49
  word = tokenizer.decode([t])
50
+ if (
51
+ len(word) > 2 and not word[0].isalpha() and
52
+ word[1:].isalpha() and word[1:].lower().islower()
53
+ ):
54
+ word = word[1:]
55
+ if word.lower() in nltk.corpus.words.words():
56
+ output.append(word)
57
+ return output[:select_words]
58
 
59
  data.append([
60
  ", ".join(filter_words(side_tokens))
61
  for side_tokens in [left_tokens, right_tokens]
62
  ])
63
+ return pd.DataFrame(
64
  data,
65
  columns=["One Direction", "Another Direction"],
66
  index=[f"Dim {_i}" for _i in range(10)],
67
+ )
68
 
69
 
70
  def main():
71
  # set up the page
72
  random.seed(0)
73
+ nltk.download('words')
74
  title = "LM-Steer: Word Embeddings Are Steers for Language Models"
75
  st.set_page_config(
76
  layout="wide",
 
107
  '''
108
  Due to resource limits, we are only able to provide a few models for
109
  steering. You can also refer to the Github repository:
110
+ https://github.com/Glaciohound/LM-Steer to host larger models.
111
  Some generated texts may contain toxic or offensive content. Please be
112
  cautious when using the generated texts.
113
  Note that for these smaller models, the generation quality may not be as
114
  good as the larger models (GPT-4, Llama, etc.).
115
  '''
116
  col1, col2 = st.columns(2)
117
+ model_name = col1.selectbox(
118
  "Select a model to steer",
119
  [
120
  "gpt2",
 
123
  "EleutherAI/pythia-70m",
124
  "EleutherAI/pythia-160m",
125
  "EleutherAI/pythia-410m",
126
+ # "EleutherAI/pythia-1b",
127
+ # "EleutherAI/pythia-1.4b",
128
+ # "EleutherAI/pythia-2.8b",
129
+ # "EleutherAI/pythia-6.9b",
130
  # "EleutherAI/gpt-j-6B",
131
  ],
132
  )
133
+ # low_resource_mode = True if st.session_state.model_name in (
134
+ # "EleutherAI/pythia-1.4b", "EleutherAI/pythia-2.8b",
135
+ # "EleutherAI/pythia-6.9b", "EleutherAI/gpt-j-6B",
136
+ # ) else False
137
+ low_resource_mode = False
138
  model, tokenizer = st_get_model(
139
+ model_name, low_resource_mode)
140
+ st.session_state.model = model
141
+ st.session_state.tokenizer = tokenizer
142
  num_param = model.steer.projector1.data.shape[1] ** 2 / 1024 ** 2
143
  total_param = sum(p.numel() for _, p in model.named_parameters()) / \
144
  1024 ** 2
145
  ratio = num_param / total_param
146
+ st.write(f"Steered {num_param:.1f}M out of {total_param:.1f}M "
147
+ "parameters, ratio: {:.2%}".format(ratio))
148
 
149
  # steering
150
+ steer_range = 3.
151
+ steer_interval = 0.2
152
  st.subheader("Enter a sentence and steer the model")
153
  st.session_state.prompt = st.text_input(
154
  "Enter a prompt",
155
  st.session_state.get("prompt", "My life")
156
  )
 
157
  col1, col2, col3 = st.columns([2, 2, 1], gap="medium")
158
  sentiment = col1.slider(
159
  "Sentiment (the larger the more positive)",
160
+ -steer_range, steer_range, 0.0, steer_interval)
161
  detoxification = col2.slider(
162
  "Detoxification Strength (the larger the less toxic)",
163
  -steer_range, steer_range, 0.0,
164
  steer_interval)
165
+ max_length = col3.number_input("Max length", 20, 200, 20, 20)
166
  col1, col2, col3, _ = st.columns(4)
167
  randomness = col2.checkbox("Random sampling", value=False)
168
 
169
  if "output" not in st.session_state:
170
  st.session_state.output = ""
171
  if col1.button("Steer and generate!", type="primary"):
172
+ if sentiment == 0 and detoxification == 0:
173
+ '''
174
+ **The steer values are both 0, which means the steered model
175
+ is the same as the original model.**
176
+ '''
177
  with st.spinner("Generating..."):
178
  steer_values = [detoxification, 0, sentiment, 0]
179
  st.session_state.output = model.generate(
 
183
  min_length=0,
184
  max_length=max_length,
185
  do_sample=True,
186
+ top_p=0.9,
187
  )
188
+ st.session_state.analyzed_text = \
189
  st.text_area("Generated text:", st.session_state.output, height=200)
190
 
191
  # Analysing the sentence
 
199
  text or use your own. Please note that these two dimensions can be
200
  entangled, as a negative sentiment may also detoxify the text.
201
  '''
202
+ if st.session_state.get("analyzed_text", "") != "" and \
203
  st.button("Analyze the styled text", type="primary"):
204
  col1, col2 = st.columns(2)
205
  for name, col, dim, color in zip(
 
212
  col.subheader(name)
213
  # classification
214
  col.markdown(
215
+ "##### Sentence Classification Distribution")
216
  _, dist_list, _ = model.steer_analysis(
217
+ st.session_state.analyzed_text,
218
  dim, -steer_range, steer_range,
219
  bins=2*int(steer_range)+1,
220
  )
 
234
  pos_steer[dim] = 1
235
  neg_steer[dim] = -1
236
  _, token_evidence = model.evidence_words(
237
+ st.session_state.analyzed_text,
238
  [pos_steer, neg_steer],
239
  )
240
+ tokens = tokenizer(st.session_state.analyzed_text).input_ids
241
  tokens = [f"{i:3d}: {tokenizer.decode([t])}"
242
  for i, t in enumerate(tokens)]
243
  col.markdown("##### Token's Evidence Score in the Dimension")
 
266
  dimension, sometimes only one side of the word embeddings is most relevant
267
  to the style (can be either left or right).
268
  '''
269
+ for dimension in ["Sentiment", "Detoxification"]:
270
+ f'##### {dimension} Dimension'
271
+ dim = 2 if dimension == "Sentiment" else 0
272
+ analysis_result = word_embedding_space_analysis(
273
+ model_name, dim)
274
+ with st.expander("Show the analysis results"):
275
+ st.table(analysis_result)
276
 
277
 
278
  if __name__ == "__main__":
lm_steer/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (166 Bytes)
 
lm_steer/__pycache__/utils.cpython-310.pyc DELETED
Binary file (1.44 kB)
 
lm_steer/models/__pycache__/get_model.cpython-310.pyc DELETED
Binary file (1.48 kB)
 
lm_steer/models/__pycache__/model_base.cpython-310.pyc DELETED
Binary file (4.88 kB)
 
lm_steer/models/__pycache__/model_gpt_neo.cpython-310.pyc DELETED
Binary file (2.6 kB)
 
lm_steer/models/__pycache__/model_gpt_neox.cpython-310.pyc DELETED
Binary file (3.7 kB)
 
lm_steer/models/__pycache__/model_utils.cpython-310.pyc DELETED
Binary file (2.23 kB)
 
lm_steer/models/__pycache__/steers.cpython-310.pyc DELETED
Binary file (3.07 kB)
 
lm_steer/models/model_base.py CHANGED
@@ -26,8 +26,8 @@ class LMSteerBase(nn.Module):
26
  if isinstance(comparing_steer_values, list):
27
  comparing_steer_values = \
28
  torch.Tensor(comparing_steer_values).to(self.device)
29
- if (comparing_steer_values[0] - comparing_steer_values[1]
30
- ).abs().sum() <= 0.2:
31
  return [(prompt, None)]
32
  tokenized = self.tokenizer(
33
  prompt, return_tensors="pt",
@@ -162,12 +162,77 @@ class LMSteerBase(nn.Module):
162
  self.device)
163
  self.steer.set_value(steer_values[None])
164
  with torch.no_grad():
165
- text = self.generator(
166
- prompt, num_beams=num_beams, num_beam_groups=num_beam_groups,
 
 
 
167
  do_sample=do_sample, temperature=temperature, top_p=top_p,
168
  min_length=min_length, max_length=max_length,
169
  pad_token_id=self.tokenizer.pad_token_id,
170
  )
171
- text = text[0]["generated_text"]
172
 
173
  return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if isinstance(comparing_steer_values, list):
27
  comparing_steer_values = \
28
  torch.Tensor(comparing_steer_values).to(self.device)
29
+ if (comparing_steer_values[0] - comparing_steer_values[1]).abs().sum()\
30
+ <= 0.2:
31
  return [(prompt, None)]
32
  tokenized = self.tokenizer(
33
  prompt, return_tensors="pt",
 
162
  self.device)
163
  self.steer.set_value(steer_values[None])
164
  with torch.no_grad():
165
+ inputs = self.tokenizer(
166
+ prompt, return_tensors="pt").to(self.device)
167
+ text = self.model.generate(
168
+ **inputs,
169
+ num_beams=num_beams, num_beam_groups=num_beam_groups,
170
  do_sample=do_sample, temperature=temperature, top_p=top_p,
171
  min_length=min_length, max_length=max_length,
172
  pad_token_id=self.tokenizer.pad_token_id,
173
  )
174
+ text = self.tokenizer.decode(text[0], skip_special_tokens=True)
175
 
176
  return text
177
+
178
+ def generate_low_resource(
179
+ self, prompt, steer_values, min_length=20, max_length=100,
180
+ seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
181
+ temperature=1, top_p=1
182
+ ):
183
+ '''
184
+ prompt: a string
185
+ steer_values
186
+ min_length: minimum generation length
187
+ max_length: maximum generation length
188
+ seed: seed for generation. None if not specified.
189
+ '''
190
+ if seed is not None:
191
+ set_seed(seed)
192
+ steer_values = torch.Tensor(steer_values).to(
193
+ self.device)
194
+ fp16 = torch.float16
195
+ steer_values = steer_values.to(fp16)
196
+ self.steer.projector1.data = self.steer.projector1.to(fp16)
197
+ self.steer.projector2.data = self.steer.projector2.to(fp16)
198
+ self.steer.set_value(steer_values[None])
199
+ with torch.no_grad():
200
+ input_ids = self.tokenizer(
201
+ prompt, return_tensors="pt").input_ids.to(self.device)
202
+ gen_tokens = self.model.generate(
203
+ input_ids,
204
+ num_beams=num_beams, num_beam_groups=num_beam_groups,
205
+ do_sample=do_sample, temperature=temperature, top_p=top_p,
206
+ min_length=min_length, max_length=max_length,
207
+ pad_token_id=self.tokenizer.pad_token_id)
208
+ text = self.tokenizer.batch_decode(gen_tokens)[0]
209
+
210
+ # recovering
211
+ fp32 = torch.float32
212
+ self.steer.projector1.data = self.steer.projector1.to(fp32)
213
+ self.steer.projector2.data = self.steer.projector2.to(fp32)
214
+ return text
215
+
216
+ def state_dict(self):
217
+ return self.steer.state_dict()
218
+
219
+ def load_state_dict(self, state_dict):
220
+ self.steer.load_state_dict(state_dict)
221
+
222
+ def parameters(self):
223
+ return self.steer.parameters()
224
+
225
+ def to_device(self, device):
226
+ self.model.to(device)
227
+ self.device = device
228
+
229
+ def regularization_term(self):
230
+ return self.steer.regularization_term()
231
+
232
+ def forward(self, input_ids, attention_mask, steer_values):
233
+ self.steer.set_value(steer_values)
234
+ output = self.model(
235
+ input_ids=input_ids,
236
+ attention_mask=attention_mask,
237
+ labels=input_ids)
238
+ return output
lm_steer/models/model_gpt_j.py CHANGED
@@ -1,27 +1,14 @@
1
  import torch
2
- import numpy as np
3
- import torch.nn as nn
4
  import torch.nn.functional as F
5
  from transformers import GPTJForCausalLM, AutoTokenizer
6
 
7
  from .model_utils import Hack_no_grad, find_max_subspans
8
  from .steers import Projected_Adaptor
 
9
  from lm_steer.utils import set_seed
10
 
11
 
12
- punctuations = [
13
- '!', '"', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.',
14
- # '/', '#',
15
- ':', ';', '<', '=', '>', '?', '@',
16
- '[', '\\', ']', '^', '_', '`',
17
- '{', '|', '}', '~',
18
- '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', 'µ', '¶', '·',
19
- '¸', '¹', 'º', '»', '¼', '½', '¾',
20
- '\n', ' ',
21
- ]
22
-
23
-
24
- class Switching_GPTJModel(nn.Module):
25
  def __init__(self, model_name, adapted_component, adaptor_class,
26
  num_steers, rank, epsilon, init_var, low_resource_mode):
27
  super().__init__()
@@ -67,31 +54,6 @@ class Switching_GPTJModel(nn.Module):
67
  else:
68
  raise NotImplementedError()
69
 
70
- def forward(self, input_ids, attention_mask, steer_values):
71
- self.steer.set_value(steer_values)
72
- output = self.model(
73
- input_ids=input_ids,
74
- attention_mask=attention_mask,
75
- labels=input_ids)
76
- return output
77
-
78
- def parameters(self):
79
- return self.steer.parameters()
80
-
81
- def state_dict(self):
82
- return self.steer.state_dict()
83
-
84
- def load_state_dict(self, state_dict):
85
- self.steer.load_state_dict(state_dict)
86
-
87
- def to_device(self, device):
88
- # self.generator.device = device
89
- self.model.to(device)
90
- self.device = device
91
-
92
- def regularization_term(self):
93
- return self.steer.regularization_term()
94
-
95
  def generate(self, prompt, steer_values, min_length=20, max_length=100,
96
  seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
97
  temperature=1, top_p=1):
@@ -102,33 +64,9 @@ class Switching_GPTJModel(nn.Module):
102
  max_length: maximum generation length
103
  seed: seed for generation. None if not specified.
104
  '''
105
- if seed is not None:
106
- set_seed(seed)
107
- steer_values = torch.Tensor(steer_values).to(
108
- self.device)
109
- if self.low_resource_mode:
110
- fp16 = torch.float16
111
- steer_values = steer_values.to(fp16)
112
- self.steer.projector1.data = self.steer.projector1.to(fp16)
113
- self.steer.projector2.data = self.steer.projector2.to(fp16)
114
- self.steer.set_value(steer_values[None])
115
- with torch.no_grad():
116
- input_ids = self.tokenizer(
117
- prompt, return_tensors="pt").input_ids.to(self.device)
118
- gen_tokens = self.model.generate(
119
- input_ids,
120
- num_beams=num_beams, num_beam_groups=num_beam_groups,
121
- do_sample=do_sample, temperature=temperature, top_p=top_p,
122
- min_new_tokens=min_length, max_new_tokens=max_length,
123
- pad_token_id=self.tokenizer.pad_token_id)
124
- text = self.tokenizer.batch_decode(gen_tokens)[0]
125
-
126
- # recovering
127
- if self.low_resource_mode:
128
- fp32 = torch.float32
129
- self.steer.projector1.data = self.steer.projector1.to(fp32)
130
- self.steer.projector2.data = self.steer.projector2.to(fp32)
131
- return text
132
 
133
  def generate_multiple(
134
  self, prompts, steer_values, min_length=20, max_length=100,
@@ -167,13 +105,14 @@ class Switching_GPTJModel(nn.Module):
167
  self.steer.projector2.data = self.steer.projector2.to(fp32)
168
  return text
169
 
170
- # def evidence_words(self, prompt, original_steer_values, max_segments=4,
171
- # max_length=10):
172
  # if isinstance(original_steer_values, list):
173
  # original_steer_values = torch.Tensor(original_steer_values)
174
  # if original_steer_values.abs().sum() <= 0.2:
175
  # return [(prompt, None)]
176
- # tokenized = self.tokenizer(prompt)
 
177
  # input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
178
  # input_ids = input_ids.expand(2, -1)
179
  # attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
@@ -201,133 +140,98 @@ class Switching_GPTJModel(nn.Module):
201
  # )
202
  # loss_token = loss_token.reshape(2, length - 1)
203
 
204
- def evidence_words(self, prompt, original_steer_values,
205
- truncation_length=1024, max_segments=4, max_length=10):
206
- if isinstance(original_steer_values, list):
207
- original_steer_values = torch.Tensor(original_steer_values)
208
- if original_steer_values.abs().sum() <= 0.2:
209
- return [(prompt, None)]
210
- tokenized = self.tokenizer(
211
- prompt, return_tensors="pt", max_length=truncation_length, truncation=True)
212
- input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
213
- input_ids = input_ids.expand(2, -1)
214
- attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
215
- self.device)
216
- attention_mask = attention_mask.expand(2, -1)
217
- steer_values = torch.zeros(2, self.num_steers).to(self.device)
218
- steer_values[0] = original_steer_values
219
- steer_values[1] = (-original_steer_values > 0) * 2 - 1
220
- if self.low_resource_mode:
221
- fp16 = torch.float16
222
- steer_values = steer_values.to(fp16)
223
- self.steer.projector1.data = self.steer.projector1.to(fp16)
224
- self.steer.projector2.data = self.steer.projector2.to(fp16)
225
- self.steer.set_value(steer_values)
226
- with torch.no_grad():
227
- output = self.model(
228
- input_ids=input_ids,
229
- attention_mask=attention_mask,
230
- labels=input_ids)
231
- length = input_ids.shape[1]
232
- loss_token = F.cross_entropy(
233
- output.logits[:, :-1].reshape((2)*(length-1), -1),
234
- input_ids[:, 1:].reshape(-1),
235
- reduction="none"
236
- )
237
- loss_token = loss_token.reshape(2, length - 1)
238
-
239
- token_evidence = (- loss_token[0] + loss_token[1])
240
- tokens = input_ids[0]
241
- evidence_segments = find_max_subspans(
242
- token_evidence.cpu().numpy().tolist(), max_segments, max_length)[0]
243
- evidence_segments = [
244
- (_seg[0]+1, _seg[1]+1) for _seg in evidence_segments]
245
- start = 0
246
- output = []
247
- color = (
248
- "gray" if original_steer_values.shape[0] > 1
249
- else "red" if original_steer_values[0] > 0
250
- else "blue"
251
- )
252
- if len(evidence_segments) > 0:
253
- for _segment in evidence_segments:
254
- if _segment[0] > start:
255
- output.append((
256
- self.tokenizer.decode(tokens[start: _segment[0]]),
257
- None
258
- ))
259
- output.append((
260
- self.tokenizer.decode(tokens[_segment[0]: _segment[1]]),
261
- color
262
- ))
263
- start = _segment[1]
264
- length = tokens.shape[-1]
265
- if _segment[1] < length:
266
- output.append((
267
- self.tokenizer.decode(tokens[_segment[1]: length]),
268
- None
269
- ))
270
- else:
271
- output = [(prompt, None)]
272
-
273
- if self.low_resource_mode:
274
- fp32 = torch.float32
275
- self.steer.projector1.data = self.steer.projector1.to(fp32)
276
- self.steer.projector2.data = self.steer.projector2.to(fp32)
277
- return output
278
-
279
- def steer_analysis(self, prompt, steer_dim, min_value=-3, max_value=3,
280
- bins=7, truncation_length=1024):
281
- tokenized = self.tokenizer(
282
- prompt, return_tensors="pt",
283
- max_length=truncation_length,
284
- truncation=True)
285
- input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
286
- input_ids = input_ids.expand(bins + 1, -1)
287
- attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
288
- self.device)
289
- attention_mask = attention_mask.expand(bins + 1, -1)
290
- steer_values = torch.zeros(bins+1, self.num_steers).to(self.device)
291
- for bin_i in range(bins):
292
- steer_values[bin_i, steer_dim] = (
293
- min_value + (max_value - min_value) / (bins - 1) * bin_i
294
- )
295
- if self.low_resource_mode:
296
- fp16 = torch.float16
297
- steer_values = steer_values.to(fp16)
298
- self.steer.projector1.data = self.steer.projector1.to(fp16)
299
- self.steer.projector2.data = self.steer.projector2.to(fp16)
300
- self.steer.set_value(steer_values)
301
- with torch.no_grad():
302
- output = self.model(
303
- input_ids=input_ids,
304
- attention_mask=attention_mask,
305
- labels=input_ids)
306
- length = input_ids.shape[1]
307
- loss_token = F.cross_entropy(
308
- output.logits[:, :-1].reshape((bins+1)*(length-1), -1),
309
- input_ids[:, 1:].reshape(-1),
310
- reduction="none"
311
- )
312
- loss_token = loss_token.reshape(bins + 1, length - 1)
313
- loss = loss_token.mean(-1)[:-1]
314
- dist = ((- loss + loss.mean()) * 100).softmax(0)
315
- dist_list = list(zip(
316
- [
317
- min_value + (max_value - min_value) / (bins - 1) * bin_i
318
- for bin_i in range(bins)
319
- ],
320
- dist.tolist(),
321
- ))
322
- best_guess = loss.argmin(0)
323
- best_guess_value = min_value + \
324
- (max_value - min_value) / (bins - 1) * best_guess.item()
325
 
326
- token_evidence = self.evidence_words(
327
- prompt, steer_values[best_guess],
328
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
- if self.low_resource_mode:
331
- fp32 = torch.float32
332
- self.steer.projector1.data = self.steer.projector1.to(fp32)
333
- return best_guess_value, dist_list, token_evidence
 
1
  import torch
 
 
2
  import torch.nn.functional as F
3
  from transformers import GPTJForCausalLM, AutoTokenizer
4
 
5
  from .model_utils import Hack_no_grad, find_max_subspans
6
  from .steers import Projected_Adaptor
7
+ from .model_base import LMSteerBase
8
  from lm_steer.utils import set_seed
9
 
10
 
11
+ class Switching_GPTJModel(LMSteerBase):
 
 
 
 
 
 
 
 
 
 
 
 
12
  def __init__(self, model_name, adapted_component, adaptor_class,
13
  num_steers, rank, epsilon, init_var, low_resource_mode):
14
  super().__init__()
 
54
  else:
55
  raise NotImplementedError()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def generate(self, prompt, steer_values, min_length=20, max_length=100,
58
  seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
59
  temperature=1, top_p=1):
 
64
  max_length: maximum generation length
65
  seed: seed for generation. None if not specified.
66
  '''
67
+ return super().generate_low_resource(
68
+ prompt, steer_values, min_length, max_length, seed,
69
+ num_beams, num_beam_groups, do_sample, temperature, top_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def generate_multiple(
72
  self, prompts, steer_values, min_length=20, max_length=100,
 
105
  self.steer.projector2.data = self.steer.projector2.to(fp32)
106
  return text
107
 
108
+ # def evidence_words(self, prompt, original_steer_values,
109
+ # truncation_length=1024, max_segments=4, max_length=10):
110
  # if isinstance(original_steer_values, list):
111
  # original_steer_values = torch.Tensor(original_steer_values)
112
  # if original_steer_values.abs().sum() <= 0.2:
113
  # return [(prompt, None)]
114
+ # tokenized = self.tokenizer(
115
+ # prompt, return_tensors="pt", max_length=truncation_length, truncation=True)
116
  # input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
117
  # input_ids = input_ids.expand(2, -1)
118
  # attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
 
140
  # )
141
  # loss_token = loss_token.reshape(2, length - 1)
142
 
143
+ # token_evidence = (- loss_token[0] + loss_token[1])
144
+ # tokens = input_ids[0]
145
+ # evidence_segments = find_max_subspans(
146
+ # token_evidence.cpu().numpy().tolist(), max_segments, max_length)[0]
147
+ # evidence_segments = [
148
+ # (_seg[0]+1, _seg[1]+1) for _seg in evidence_segments]
149
+ # start = 0
150
+ # output = []
151
+ # color = (
152
+ # "gray" if original_steer_values.shape[0] > 1
153
+ # else "red" if original_steer_values[0] > 0
154
+ # else "blue"
155
+ # )
156
+ # if len(evidence_segments) > 0:
157
+ # for _segment in evidence_segments:
158
+ # if _segment[0] > start:
159
+ # output.append((
160
+ # self.tokenizer.decode(tokens[start: _segment[0]]),
161
+ # None
162
+ # ))
163
+ # output.append((
164
+ # self.tokenizer.decode(tokens[_segment[0]: _segment[1]]),
165
+ # color
166
+ # ))
167
+ # start = _segment[1]
168
+ # length = tokens.shape[-1]
169
+ # if _segment[1] < length:
170
+ # output.append((
171
+ # self.tokenizer.decode(tokens[_segment[1]: length]),
172
+ # None
173
+ # ))
174
+ # else:
175
+ # output = [(prompt, None)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ # if self.low_resource_mode:
178
+ # fp32 = torch.float32
179
+ # self.steer.projector1.data = self.steer.projector1.to(fp32)
180
+ # self.steer.projector2.data = self.steer.projector2.to(fp32)
181
+ # return output
182
+
183
+ # def steer_analysis(self, prompt, steer_dim, min_value=-3, max_value=3,
184
+ # bins=7, truncation_length=1024):
185
+ # tokenized = self.tokenizer(
186
+ # prompt, return_tensors="pt",
187
+ # max_length=truncation_length,
188
+ # truncation=True)
189
+ # input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
190
+ # input_ids = input_ids.expand(bins + 1, -1)
191
+ # attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
192
+ # self.device)
193
+ # attention_mask = attention_mask.expand(bins + 1, -1)
194
+ # steer_values = torch.zeros(bins+1, self.num_steers).to(self.device)
195
+ # for bin_i in range(bins):
196
+ # steer_values[bin_i, steer_dim] = (
197
+ # min_value + (max_value - min_value) / (bins - 1) * bin_i
198
+ # )
199
+ # if self.low_resource_mode:
200
+ # fp16 = torch.float16
201
+ # steer_values = steer_values.to(fp16)
202
+ # self.steer.projector1.data = self.steer.projector1.to(fp16)
203
+ # self.steer.projector2.data = self.steer.projector2.to(fp16)
204
+ # self.steer.set_value(steer_values)
205
+ # with torch.no_grad():
206
+ # output = self.model(
207
+ # input_ids=input_ids,
208
+ # attention_mask=attention_mask,
209
+ # labels=input_ids)
210
+ # length = input_ids.shape[1]
211
+ # loss_token = F.cross_entropy(
212
+ # output.logits[:, :-1].reshape((bins+1)*(length-1), -1),
213
+ # input_ids[:, 1:].reshape(-1),
214
+ # reduction="none"
215
+ # )
216
+ # loss_token = loss_token.reshape(bins + 1, length - 1)
217
+ # loss = loss_token.mean(-1)[:-1]
218
+ # dist = ((- loss + loss.mean()) * 100).softmax(0)
219
+ # dist_list = list(zip(
220
+ # [
221
+ # min_value + (max_value - min_value) / (bins - 1) * bin_i
222
+ # for bin_i in range(bins)
223
+ # ],
224
+ # dist.tolist(),
225
+ # ))
226
+ # best_guess = loss.argmin(0)
227
+ # best_guess_value = min_value + \
228
+ # (max_value - min_value) / (bins - 1) * best_guess.item()
229
+
230
+ # token_evidence = self.evidence_words(
231
+ # prompt, steer_values[best_guess],
232
+ # )
233
 
234
+ # if self.low_resource_mode:
235
+ # fp32 = torch.float32
236
+ # self.steer.projector1.data = self.steer.projector1.to(fp32)
237
+ # return best_guess_value, dist_list, token_evidence
lm_steer/models/model_gpt_neo.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  from transformers import pipeline
3
 
 
4
  from .model_utils import Hack_no_grad
5
  from .steers import Projected_Adaptor
6
  from .model_base import LMSteerBase
@@ -12,9 +13,9 @@ class Switching_GPTNeoModel(LMSteerBase):
12
  low_resource_mode):
13
  super().__init__()
14
  self.adapted_component = adapted_component
15
- self.generator = pipeline('text-generation', model=model_name)
16
- self.tokenizer = self.generator.tokenizer
17
- self.model = self.generator.model
18
  self.tokenizer.pad_token = self.tokenizer.eos_token
19
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
20
  self.init_var = init_var
@@ -39,28 +40,3 @@ class Switching_GPTNeoModel(LMSteerBase):
39
  self.model.transformer.set_input_embeddings(self.steer)
40
  else:
41
  raise NotImplementedError()
42
-
43
- def forward(self, input_ids, attention_mask, steer_values):
44
- self.steer.set_value(steer_values)
45
- output = self.model(
46
- input_ids=input_ids,
47
- attention_mask=attention_mask,
48
- labels=input_ids)
49
- return output
50
-
51
- def parameters(self):
52
- return self.steer.parameters()
53
-
54
- def state_dict(self):
55
- return self.steer.state_dict()
56
-
57
- def load_state_dict(self, state_dict):
58
- self.steer.load_state_dict(state_dict)
59
-
60
- def to_device(self, device):
61
- self.generator.device = device
62
- self.model.to(device)
63
- self.device = device
64
-
65
- def regularization_term(self):
66
- return self.steer.regularization_term()
 
1
  import torch
2
  from transformers import pipeline
3
 
4
+
5
  from .model_utils import Hack_no_grad
6
  from .steers import Projected_Adaptor
7
  from .model_base import LMSteerBase
 
13
  low_resource_mode):
14
  super().__init__()
15
  self.adapted_component = adapted_component
16
+ self.pipeline = pipeline('text-generation', model=model_name)
17
+ self.model = self.pipeline.model
18
+ self.tokenizer = self.pipeline.tokenizer
19
  self.tokenizer.pad_token = self.tokenizer.eos_token
20
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
21
  self.init_var = init_var
 
40
  self.model.transformer.set_input_embeddings(self.steer)
41
  else:
42
  raise NotImplementedError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lm_steer/models/model_gpt_neox.py CHANGED
@@ -4,7 +4,6 @@ from transformers import GPTNeoXForCausalLM, AutoTokenizer
4
  from .model_utils import Hack_no_grad
5
  from .steers import Projected_Adaptor
6
  from .model_base import LMSteerBase
7
- from lm_steer.utils import set_seed
8
 
9
 
10
  class Switching_GPTNeoXModel(LMSteerBase):
@@ -42,30 +41,6 @@ class Switching_GPTNeoXModel(LMSteerBase):
42
  else:
43
  raise NotImplementedError()
44
 
45
- def forward(self, input_ids, attention_mask, steer_values):
46
- self.steer.set_value(steer_values)
47
- output = self.model(
48
- input_ids=input_ids,
49
- attention_mask=attention_mask,
50
- labels=input_ids)
51
- return output
52
-
53
- def parameters(self):
54
- return self.steer.parameters()
55
-
56
- def state_dict(self):
57
- return self.steer.state_dict()
58
-
59
- def load_state_dict(self, state_dict):
60
- self.steer.load_state_dict(state_dict)
61
-
62
- def to_device(self, device):
63
- self.model.to(device)
64
- self.device = device
65
-
66
- def regularization_term(self):
67
- return self.steer.regularization_term()
68
-
69
  def generate(self, prompt, steer_values, min_length=20, max_length=100,
70
  seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
71
  temperature=1, top_p=1):
@@ -76,30 +51,6 @@ class Switching_GPTNeoXModel(LMSteerBase):
76
  max_length: maximum generation length
77
  seed: seed for generation. None if not specified.
78
  '''
79
- if seed is not None:
80
- set_seed(seed)
81
- steer_values = torch.Tensor(steer_values).to(
82
- self.device)
83
- if self.low_resource_mode:
84
- fp16 = torch.float16
85
- steer_values = steer_values.to(fp16)
86
- self.steer.projector1.data = self.steer.projector1.to(fp16)
87
- self.steer.projector2.data = self.steer.projector2.to(fp16)
88
- self.steer.set_value(steer_values[None])
89
- with torch.no_grad():
90
- input_ids = self.tokenizer(
91
- prompt, return_tensors="pt").input_ids.to(self.device)
92
- gen_tokens = self.model.generate(
93
- input_ids,
94
- num_beams=num_beams, num_beam_groups=num_beam_groups,
95
- do_sample=do_sample, temperature=temperature, top_p=top_p,
96
- min_length=min_length, max_length=max_length,
97
- pad_token_id=self.tokenizer.pad_token_id)
98
- text = self.tokenizer.batch_decode(gen_tokens)[0]
99
-
100
- # recovering
101
- if self.low_resource_mode:
102
- fp32 = torch.float32
103
- self.steer.projector1.data = self.steer.projector1.to(fp32)
104
- self.steer.projector2.data = self.steer.projector2.to(fp32)
105
- return text
 
4
  from .model_utils import Hack_no_grad
5
  from .steers import Projected_Adaptor
6
  from .model_base import LMSteerBase
 
7
 
8
 
9
  class Switching_GPTNeoXModel(LMSteerBase):
 
41
  else:
42
  raise NotImplementedError()
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def generate(self, prompt, steer_values, min_length=20, max_length=100,
45
  seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
46
  temperature=1, top_p=1):
 
51
  max_length: maximum generation length
52
  seed: seed for generation. None if not specified.
53
  '''
54
+ return super().generate_low_resource(
55
+ prompt, steer_values, min_length, max_length, seed,
56
+ num_beams, num_beam_groups, do_sample, temperature, top_p)