import gradio as gr
import os
import cv2
import face_recognition
from fastai.vision.all import load_learner
import time
import base64
from deepface import DeepFace
import torchaudio
import moviepy.editor as mp
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline

# import pathlib
# temp = pathlib.PosixPath
# pathlib.PosixPath = pathlib.WindowsPath

backends = [
  'opencv', 
  'ssd', 
  'dlib', 
  'mtcnn', 
  'retinaface', 
  'mediapipe'
]

emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", return_all_scores=True)
sentiment_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")

model = load_learner("gaze-recognizer-v3.pkl")

def analyze_emotion(text):
    result = emotion_pipeline(text)
    return result

def analyze_sentiment(text):
    result = sentiment_pipeline(text)
    return result

def getTranscription(path):
    # Insert Local Video File Path
    clip = mp.VideoFileClip(path)

    # Insert Local Audio File Path
    clip.audio.write_audiofile(r"audio.wav")
    
    waveform, sample_rate = torchaudio.load("audio.wav")
    resampler = torchaudio.transforms.Resample(sample_rate, 16000)
    waveform = resampler(waveform)[0]
    
    processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
    model.config.forced_decoder_ids = None
    
    input_features = processor(waveform.squeeze(dim=0), return_tensors="pt").input_features 
    predicted_ids = model.generate(input_features)

    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    
    return transcription[0]

def video_processing(video_file, encoded_video):
    emotion_count = 0
    video_emotions = {
        'angry': 0,
        'disgust': 0,
        'fear': 0,
        'happy': 0,
        'sad': 0,
        'surprise': 0,
        'neutral':0
    }

    if encoded_video != "":
    
        decoded_file_data = base64.b64decode(encoded_video)

        with open("temp_video.mp4", "wb") as f:
            f.write(decoded_file_data)
        
        video_file = "temp_video.mp4"

    start_time = time.time()

    transcription = getTranscription(video_file)
    print(transcription)
    text_emotion = analyze_emotion(transcription)
    print(text_emotion)
    text_sentiment = analyze_sentiment(transcription)
    print(text_sentiment)

    video_capture = cv2.VideoCapture(video_file)
    on_camera = 0
    off_camera = 0
    total = 0

    while True:
        # Read a single frame from the video
        for i in range(24*3):
            ret, frame = video_capture.read()
            if not ret:
                break

        # If there are no more frames, break out of the loop
        if not ret:
            break

        # Convert the frame to RGB color (face_recognition uses RGB)
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        
        # Find all the faces in the frame using a pre-trained convolutional neural network.
        face_locations = face_recognition.face_locations(gray)

        if len(face_locations) > 0:
            # Show the original frame with face rectangles drawn around the faces
            for top, right, bottom, left in face_locations:
                # cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), 2)
                face_image = gray[top:bottom, left:right]
                color_image = frame[top:bottom, left:right]

                # Resize the face image to the desired size
                resized_face_image = cv2.resize(face_image, (128,128))

                try:
                    detected_face_emotion = DeepFace.analyze(color_image,actions=['emotion'],detector_backend = backends[2],enforce_detection = False)# 2,3, 4 works
                    for emotion in detected_face_emotion:
                        for key in video_emotions.keys():
                            video_emotions[key] += emotion['emotion'][key]
                    emotion_count += 1
                except Exception as e:
                    emotion = 0
                    pass

                # Predict the class of the resized face image using the model
                result = model.predict(resized_face_image)
                print(result[0])
                if result[0] == 'on_camera':
                    on_camera += 1
                elif result[0] == 'off_camera':
                    off_camera += 1
                total += 1

    try:
        # your processing code here
        gaze_percentage = on_camera / total * 100
    except Exception as e:
        print(f"An error occurred while processing the video: {e}")
        gaze_percentage = 'ERROR : no face detected'
    print(f'Total = {total},on_camera = {on_camera},off_camera = {off_camera}')
    # Release the video capture object and close all windows
    video_capture.release()
    cv2.destroyAllWindows()
    end_time = time.time()
    print(f'Time taken: {end_time-start_time}')
    if os.path.exists("temp_video.mp4"): 
        os.remove("temp_video.mp4")
    if os.path.exists("audio.wav"): 
        os.remove("audio.wav")
    print(gaze_percentage)

    # Divide all emotion values by emotion count
    if emotion_count > 0:
        for key in video_emotions.keys():
            video_emotions[key] /= emotion_count

    
    # Modify 'angry' key to 'anger'
    video_emotions['anger'] = video_emotions.pop('angry')
    
    # Modify 'happy' key to 'joy'
    video_emotions['joy'] = video_emotions.pop('happy')
    
    # Modify 'sad' key to 'sadness'
    video_emotions['sadness'] = video_emotions.pop('sad')


    
    final_result_dict = {
        "gaze_percentage" : gaze_percentage,
        "face_emotion" : video_emotions,
        "text_emotion" : text_emotion[0],
        "transcription" : transcription,
        "text_sentiment" : text_sentiment
    }
    
    return final_result_dict


demo = gr.Interface(fn=video_processing,
                     inputs=["video", "text"],
                     outputs="json")

if __name__ == "__main__":
    demo.launch()