File size: 1,986 Bytes
eee7134 eab471f c9e3328 eab471f a26f453 eab471f 3f54553 3e2e22a a26f453 80a8daf a26f453 3f54553 80a8daf a26f453 3f54553 80a32ce 3f54553 fb908cd 7fd4b1c 3f54553 fb908cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import streamlit as st
from setfit import SetFitModel
# Load the model
model = SetFitModel.from_pretrained("leavoigt/vulnerable_groups")
# Define the classes
group_dict = {
0: 'Coastal communities',
1: 'Small island developing states (SIDS)',
2: 'Landlocked countries',
3: 'Low-income households',
4: 'Informal settlements and slums',
5: 'Rural communities',
6: 'Children and youth',
7: 'Older adults and the elderly',
8: 'Women and girls',
9: 'People with pre-existing health conditions',
10: 'People with disabilities',
11: 'Small-scale farmers and subsistence agriculture',
12: 'Fisherfolk and fishing communities',
13: 'Informal sector workers',
14: 'Children with disabilities',
15: 'Remote communities',
16: 'Young adults',
17: 'Elderly population',
18: 'Urban slums',
19: 'Men and boys',
20: 'Gender non-conforming individuals',
21: 'Pregnant women and new mothers',
22: 'Mountain communities',
23: 'Riverine and flood-prone areas',
24: 'Drought-prone regions',
25: 'Indigenous peoples',
26: 'Migrants and displaced populations',
27: 'Outdoor workers',
28: 'Small-scale farmers',
29: 'Other'}
#def predict(text):
# preds = model([text])[0].item()
# return group_dict[preds]
# App
st.title("Identify references to vulnerable groups.")
st.write("This app allows you to identify whether a text contains any references to vulnerable groups. This can, for example, be used to analyse policy documents.")
#col1, col2 = st.columns(2)
# Create text input box
input_text = st.text_area('Please enter your text here')
# Make predictions
preds = model(input_text)
#modelresponse = model_function(input)
st.text_area(label ="",value=preds, height =100)
# Select lab
#def get_label(prediction_tensor):
# print(prediction_tensor.index("1"))
#key = prediction_tensor.index(1)
#return group_dict[key]
st.text(preds)
#st.text(get_label(preds)) |