import whisper
import os
from pytube import YouTube
import pandas as pd
import plotly_express as px
import nltk
import plotly.graph_objects as go
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import streamlit as st

nltk.download('punkt')

from nltk import sent_tokenize


st.set_page_config(
    page_title="Home",
    page_icon="📞",
)

st.sidebar.header("Home")
st.markdown("## Earnings Call Analysis Whisperer")

st.markdown(
    """
    This app assists finance analysts with transcribing and analysis Earnings Calls by carrying out the following tasks:
    - Transcribing earnings calls using Open AI's [Whisper](https://github.com/openai/whisper).
    - Analysing the sentiment of transcribed text using the quantized version of [FinBert-Tone](https://huggingface.co/nickmuchi/quantized-optimum-finbert-tone).
    - Summarization of the call with [FaceBook-Bart-Large-CNN](https://huggingface.co/facebook/bart-large-cnn) model with entity extraction
    - Semantic search engine with [Sentence-Transformers](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) and reranking results with a Cross-Encoder.
    
    **👇 Enter a YouTube Earnings Call URL below and navigate to the sidebar tabs** 
    
"""
)

url_input = st.text_input(
        label='Enter YouTube URL, e.g "https://www.youtube.com/watch?v=8pmbScvyfeY"', key="url")
        
st.markdown(
    "<h3 style='text-align: center; color: red;'>OR</h3>",
    unsafe_allow_html=True
)

upload_wav = st.file_uploader("Upload a .wav sound file ",key="upload")

auth_token = os.environ.get("auth_token")

progress_bar = st.sidebar.progress(0)

@st.experimental_singleton()
def load_models():
    asr_model = whisper.load_model("small")
    q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
    q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
    sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
    sum_pipe = pipeline("summarization",model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn")
    cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
    
    return asr_model, sent_pipe, sum_pipe, cross_encoder
    
asr_model, sent_pipe, sum_pipe, cross_encoder  = load_models()

@st.experimental_memo(suppress_st_warning=True)
def inference(link, upload):
    '''Convert Youtube video or Audio upload to text'''
    
    if validators.url(link):
    
      yt = YouTube(link)
      title = yt.title
      path = yt.streams.filter(only_audio=True)[0].download(filename="audio.mp4")
      options = whisper.DecodingOptions(without_timestamps=True)
      results = asr_model.transcribe(path)
      
      return results, yt.title
      
    elif upload:
      results = asr_model.transcribe(upload)
      
      return results, "Transcribed Earnings Audio"
      
@st.experimental_memo(suppress_st_warning=True)
def sentiment_pipe(earnings_text):
    '''Determine the sentiment of the text'''
    
    earnings_sentences = sent_tokenize(earnings_text)
    earnings_sentiment = sent_pipe(earnings_sentences)
    
    return earnings_sentiment, earnings_sentences    
    
@st.experimental_memo(suppress_st_warning=True)
def preprocess_plain_text(text,window_size=3):
    '''Preprocess text for semantic search'''
    
    text = text.encode("ascii", "ignore").decode()  # unicode
    text = re.sub(r"https*\S+", " ", text)  # url
    text = re.sub(r"@\S+", " ", text)  # mentions
    text = re.sub(r"#\S+", " ", text)  # hastags
    text = re.sub(r"\s{2,}", " ", text)  # over spaces
    #text = re.sub("[^.,!?%$A-Za-z0-9]+", " ", text)  # special characters except .,!?
    
    #break into lines and remove leading and trailing space on each
    lines = [line.strip() for line in text.splitlines()]
    
    # #break multi-headlines into a line each
    chunks = [phrase.strip() for line in lines for phrase in line.split("  ")]
    
    # # drop blank lines
    text = '\n'.join(chunk for chunk in chunks if chunk)
    
    ## We split this article into paragraphs and then every paragraph into sentences
    paragraphs = []
    for paragraph in text.replace('\n',' ').split("\n\n"):
        if len(paragraph.strip()) > 0:
            paragraphs.append(sent_tokenize(paragraph.strip()))

    #We combine up to 3 sentences into a passage. You can choose smaller or larger values for window_size
    #Smaller value: Context from other sentences might get lost
    #Lager values: More context from the paragraph remains, but results are longer
    window_size = window_size
    passages = []
    for paragraph in paragraphs:
        for start_idx in range(0, len(paragraph), window_size):
            end_idx = min(start_idx+window_size, len(paragraph))
            passages.append(" ".join(paragraph[start_idx:end_idx]))
        
    print(f"Sentences: {sum([len(p) for p in paragraphs])}")
    print(f"Passages: {len(passages)}")

    return passages
    
def display_df_as_table(model,top_k,score='score'):
    '''Display the df with text and scores as a table'''
    
    df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text'])
    df['Score'] = round(df['Score'],2)
    
    return df   
    
def make_spans(text,results):
    results_list = []
    for i in range(len(results)):
        results_list.append(results[i]['label'])
    facts_spans = []
    facts_spans = list(zip(sent_tokenizer(text),results_list))
    return facts_spans

##Fiscal Sentiment by Sentence
def fin_ext(text):
    results = remote_clx(sent_tokenizer(text))
    return make_spans(text,results)

progress_bar.empty()