File size: 1,856 Bytes
7d5081f
 
4946d76
7d5081f
 
8f074bc
c9ee852
8510941
8f074bc
65aee6e
d234736
378eb4f
8f074bc
d3e6642
 
 
 
 
39d8890
d3e6642
 
 
 
 
 
 
d234736
 
80882a3
d234736
8f074bc
 
 
df9431b
8f074bc
 
 
 
 
 
 
 
662f9f2
 
 
d234736
662f9f2
d234736
39d8890
d3e6642
 
662f9f2
378eb4f
d234736
 
429d718
 
9d13851
429d718
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import streamlit as st
import time
from transformers import pipeline
import torch

st.markdown('## Text-generation OPT from Meta ')

@st.cache(allow_output_mutation=True, suppress_st_warning =True, show_spinner=False)
def get_model():
    return pipeline('text-generation', model=model, do_sample=True, skip_special_tokens=True)
    
col1, col2 = st.columns([2,1])

with st.sidebar:
    st.markdown('## Model Parameters')

    max_length = st.slider('Max text length', 0, 150, 80)

    num_beams = st.slider('N° tree beams search', 2, 15,  5)

    early_stopping = st.selectbox(
     'Early stopping text generation',
     ('True', 'False'), key={'True' : True, 'False': False}, index=0)

    no_ngram_repeat = st.slider('Max repetition limit', 1, 5,  2)
    
with col1:
    prompt= st.text_area('Your prompt here',
        '''Who is Elon Musk?''') 
        
with col2:
    select_model = st.radio(
        "Select the model to use:",
        ('OPT-125m', 'OPT-350m', 'OPT-1.3b'), index = 1)

    if select_model == 'OPT-1.3b':
        model = 'facebook/opt-1.3b'
    elif select_model == 'OPT-350m':
        model = 'facebook/opt-350m'
    elif select_model == 'OPT-125m':
        model = 'facebook/opt-125m'

    with st.spinner('Loading Model... (This may take a while)'):
        generator = get_model()    
        st.success('Model loaded correctly!')
     
gen = st.info('Generating text...')
answer = generator(prompt,
                       max_length=max_length, no_repeat_ngram_size=no_ngram_repeat,
                        early_stopping=early_stopping, num_beams=num_beams,
                          skip_special_tokens=True)                      
gen.empty()                      
                       
lst = answer[0]['generated_text']
   
t = st.empty()
for i in range(len(lst)):
    t.markdown("#### %s" % lst[0:i])
    time.sleep(0.04)