taka-yamakoshi
commited on
Commit
•
a0471c4
1
Parent(s):
1f8519e
model type
Browse files
app.py
CHANGED
@@ -47,10 +47,10 @@ def load_css(file_name):
|
|
47 |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
48 |
|
49 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
50 |
-
def load_model():
|
51 |
-
tokenizer = AlbertTokenizer.from_pretrained(
|
52 |
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
|
53 |
-
model = AlbertForMaskedLM.from_pretrained(
|
54 |
return tokenizer,model
|
55 |
|
56 |
def clear_data():
|
@@ -167,14 +167,23 @@ def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_i
|
|
167 |
if __name__=='__main__':
|
168 |
wide_setup()
|
169 |
load_css('style.css')
|
170 |
-
tokenizer,model = load_model()
|
171 |
-
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
172 |
-
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
173 |
-
|
174 |
-
main_area = st.empty()
|
175 |
|
176 |
if 'page_status' not in st.session_state:
|
177 |
-
st.session_state['page_status'] = '
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
if st.session_state['page_status']=='type_in':
|
180 |
show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
|
|
|
47 |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
48 |
|
49 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
50 |
+
def load_model(model_name):
|
51 |
+
tokenizer = AlbertTokenizer.from_pretrained(model_name)
|
52 |
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
|
53 |
+
model = AlbertForMaskedLM.from_pretrained(model_name)
|
54 |
return tokenizer,model
|
55 |
|
56 |
def clear_data():
|
|
|
167 |
if __name__=='__main__':
|
168 |
wide_setup()
|
169 |
load_css('style.css')
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
if 'page_status' not in st.session_state:
|
172 |
+
st.session_state['page_status'] = 'model_selection'
|
173 |
+
|
174 |
+
if st.session_state['page_status']=='model_selection':
|
175 |
+
model_name = st.selectbox('Please select the model from below.',
|
176 |
+
('bert-base-uncased','bert-large-cased',
|
177 |
+
'roberta-base','roberta-large',
|
178 |
+
'albert-base-v2','albert-large-v2','albert-xlarge-v2','albert-xxlarge-v2'),index=3)
|
179 |
+
st.sesstion_state['model_name'] = model_name
|
180 |
+
if st.button('Confirm',key='model_name'):
|
181 |
+
st.session_state['page_status'] = 'type_in'
|
182 |
+
st.experimental_rerun()
|
183 |
+
|
184 |
+
tokenizer,model = load_model(st.session_state['model_name'])
|
185 |
+
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
186 |
+
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
187 |
|
188 |
if st.session_state['page_status']=='type_in':
|
189 |
show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
|