Spaces:
Sleeping
Sleeping
File size: 6,370 Bytes
765d42b d3ad347 765d42b 550ce4a 765d42b 550ce4a 765d42b 550ce4a d3ad347 765d42b 550ce4a 765d42b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from fpdf import FPDF
import whisper
import tempfile
from st_audiorec import st_audiorec
import numpy as np
# Interface utilisateur
st.set_page_config(
page_title="Traduction de la parole en pictogrammes ARASAAC",
page_icon="📝",
layout="wide"
)
# Charger le modèle et le tokenizer
checkpoint = "Propicto/t2p-nllb-200-distilled-600M-all"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
# Charger le modèle Whisper
whisper_model = whisper.load_model("base")
# Lire le lexique
@st.cache_data
def read_lexicon(lexicon):
df = pd.read_csv(lexicon, sep='\t')
df['keyword_no_cat'] = df['lemma'].str.split(' #').str[0].str.strip().str.replace(' ', '_')
return df
lexicon = read_lexicon("lexicon.csv")
# Processus de sortie de la traduction
def process_output_trad(pred):
return pred.split()
def get_id_picto_from_predicted_lemma(df_lexicon, lemma):
if lemma.endswith("!"):
lemma = lemma[:-1]
id_picto = df_lexicon.loc[df_lexicon['keyword_no_cat'] == lemma, 'id_picto'].tolist()
return (id_picto[0], lemma) if id_picto else (0, lemma)
# Génération du contenu HTML pour afficher les pictogrammes
def generate_html(ids):
html_content = '<html><head><style>'
html_content += '''
figure {
display: inline-block;
text-align: center;
font-family: Arial, sans-serif;
margin: 0;
}
figcaption {
color: black;
background-color: white;
border-radius: 5px;
}
img {
background-color: white;
margin: 0;
padding: 0;
border-radius: 6px;
}
'''
html_content += '</style></head><body>'
for picto_id, lemma in ids:
if picto_id != 0: # ignore invalid IDs
img_url = f"https://static.arasaac.org/pictograms/{picto_id}/{picto_id}_500.png"
html_content += f'''
<figure>
<img src="{img_url}" alt="{lemma}" width="100" height="100"/>
<figcaption>{lemma}</figcaption>
</figure>
'''
html_content += '</body></html>'
return html_content
# Génération du PDF
def generate_pdf(ids):
pdf = FPDF(orientation='L', unit='mm', format='A4') # 'L' for landscape orientation
pdf.add_page()
pdf.set_auto_page_break(auto=True, margin=15)
# Start positions
x_start = 10
y_start = 10
img_width = 50
img_height = 50
spacing = 1
max_width = 297 # A4 landscape width in mm
current_x = x_start
current_y = y_start
for picto_id, lemma in ids:
if picto_id != 0: # ignore invalid IDs
img_url = f"https://static.arasaac.org/pictograms/{picto_id}/{picto_id}_500.png"
pdf.image(img_url, x=current_x, y=current_y, w=img_width, h=img_height)
pdf.set_xy(current_x, current_y + img_height + 5)
pdf.set_font("Arial", size=12)
pdf.cell(img_width, 10, txt=lemma, ln=1, align='C')
current_x += img_width + spacing
# Move to the next line if exceeds max width
if current_x + img_width > max_width:
current_x = x_start
current_y += img_height + spacing + 10 # Adjust for image height and some spacing
pdf_path = "pictograms.pdf"
pdf.output(pdf_path)
return pdf_path
# Initialiser l'état de session
if 'transcription' not in st.session_state:
st.session_state['transcription'] = None
if 'pictogram_ids' not in st.session_state:
st.session_state['pictogram_ids'] = None
if 'previous_audio_file' not in st.session_state:
st.session_state['previous_audio_file'] = None
# Interface utilisateur pour l'audio et le bouton de téléchargement
st.title("Traduction de la parole en pictogrammes ARASAAC")
col1, col2 = st.columns(2)
with col1:
audio_file = st.file_uploader("Ajouter un fichier audio :", type=["wav", "mp3"])
# Réinitialiser les informations si le fichier audio change
if audio_file is not None and audio_file != st.session_state['previous_audio_file']:
st.session_state['transcription'] = None
st.session_state['pictogram_ids'] = None
st.session_state['previous_audio_file'] = audio_file
with col2:
if audio_file is not None:
with st.spinner("Transcription de l'audio en cours..."):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_file.write(audio_file.read())
temp_file_path = temp_file.name
transcription = whisper_model.transcribe(temp_file_path, language='fr')
if 'transcription' in locals():
st.text_area("Transcription :", transcription['text'])
st.session_state['transcription'] = transcription['text']
with st.spinner("Affichage des pictogrammes..."):
if st.session_state['transcription'] is not None:
inputs = tokenizer(transcription['text'].lower(), return_tensors="pt").input_ids
outputs = model.generate(inputs, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95)
pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
sentence_to_map = process_output_trad(pred)
pictogram_ids = [get_id_picto_from_predicted_lemma(lexicon, lemma) for lemma in sentence_to_map]
st.session_state['pictogram_ids'] = [get_id_picto_from_predicted_lemma(lexicon, lemma) for lemma in sentence_to_map]
if st.session_state['pictogram_ids'] is not None:
html = generate_html(st.session_state['pictogram_ids'])
st.components.v1.html(html, height=150, scrolling=True)
# Container to hold the download button
pdf_path = generate_pdf(st.session_state['pictogram_ids'])
with open(pdf_path, "rb") as pdf_file:
st.download_button(label="Télécharger la traduction en PDF", data=pdf_file, file_name="pictograms.pdf", mime="application/pdf")
# record_audio = st_audiorec()
# if record_audio:
# audio = np.array(record_audio)
# transcription = whisper_model.transcribe(audio, language='fr')
# st.success("Enregistrement terminé !") |