import streamlit as st
import cv2
import numpy as np
import tempfile
import time
from random import randint

# Configuration constants
LINE_IN_COLOR = (64, 255, 0)
LINE_OUT_COLOR = (0, 0, 255)
BOUNDING_BOX_COLOR = (255, 128, 0)
TRACKER_COLOR = (randint(0, 255), randint(0, 255), randint(0, 255))
CENTROID_COLOR = (randint(0, 255), randint(0, 255), randint(0, 255))
TEXT_COLOR = (randint(0, 255), randint(0, 255), randint(0, 255))
TEXT_POSITION_BGS = (10, 50)
TEXT_POSITION_COUNT_CARS = (10, 100)
TEXT_POSITION_COUNT_TRUCKS = (10, 150)
TEXT_SIZE = 1.2
FONT = cv2.FONT_HERSHEY_SIMPLEX
VIDEO_SOURCE = "videos/Traffic_4.mp4"

# Background Subtraction Method
BGS_TYPES = ["GMG", "MOG", "MOG2", "KNN", "CNT"]
BGS_TYPE = BGS_TYPES[2]

def getKernel(KERNEL_TYPE):
    if KERNEL_TYPE == "dilation":
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2))
    if KERNEL_TYPE == "opening":
        kernel = np.ones((3, 3), np.uint8)
    if KERNEL_TYPE == "closing":
        kernel = np.ones((11, 11), np.uint8)
    return kernel

def getFilter(img, filter):
    if filter == 'closing':
        return cv2.morphologyEx(img, cv2.MORPH_CLOSE, getKernel("closing"), iterations=2)
    if filter == 'opening':
        return cv2.morphologyEx(img, cv2.MORPH_OPEN, getKernel("opening"), iterations=2)
    if filter == 'dilation':
        return cv2.dilate(img, getKernel("dilation"), iterations=2)
    if filter == 'combine':
        closing = cv2.morphologyEx(img, cv2.MORPH_CLOSE, getKernel("closing"), iterations=2)
        opening = cv2.morphologyEx(closing, cv2.MORPH_OPEN, getKernel("opening"), iterations=2)
        dilation = cv2.dilate(opening, getKernel("dilation"), iterations=2)
        return dilation

def getBGSubtractor(BGS_TYPE):
    if BGS_TYPE == "GMG":
        return cv2.bgsegm.createBackgroundSubtractorGMG(initializationFrames=120, decisionThreshold=.8)
    if BGS_TYPE == "MOG":
        return cv2.bgsegm.createBackgroundSubtractorMOG(history=200, nmixtures=5, backgroundRatio=.7, noiseSigma=0)
    if BGS_TYPE == "MOG2":
        return cv2.createBackgroundSubtractorMOG2(history=50, detectShadows=False, varThreshold=200)
    if BGS_TYPE == "KNN":
        return cv2.createBackgroundSubtractorKNN(history=100, dist2Threshold=400, detectShadows=True)
    if BGS_TYPE == "CNT":
        return cv2.bgsegm.createBackgroundSubtractorCNT(minPixelStability=15, useHistory=True,
                                                        maxPixelStability=15 * 60, isParallel=True)
    print("Invalid detector")
    sys.exit(1)

def getCentroid(x, y, w, h):
    x1 = int(w / 2)
    y1 = int(h / 2)
    cx = x + x1
    cy = y + y1
    return (cx, cy)

def process_video(video_path):
    cap = cv2.VideoCapture(video_path)
    hasFrame, frame = cap.read()

    if not hasFrame:
        st.error("Failed to load the video.")
        return None

    # ROI
    bbox = cv2.selectROI(frame, False)
    (w1, h1, w2, h2) = bbox

    frameArea = h2 * w2
    minArea = int(frameArea / 250)
    maxArea = 15000

    line_IN = int(h1)
    line_OUT = int(h2 - 20)

    DOWN_limit = int(h1 / 4)

    bg_subtractor = getBGSubtractor(BGS_TYPE)

    frame_number = -1
    cnt_cars, cnt_trucks = 0, 0
    objects = []
    max_p_age = 2
    pid = 1

    with tempfile.NamedTemporaryFile(delete=False) as output_file:
        output_path = output_file.name + ".avi"
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        writer_video = cv2.VideoWriter(output_path, fourcc, 25, (frame.shape[1], frame.shape[0]), True)

        while cap.isOpened():
            ok, frame = cap.read()
            if not ok:
                break

            roi = frame[h1:h1 + h2, w1:w1 + w2]

            for i in objects:
                i.age_one()

            frame_number += 1
            bg_mask = bg_subtractor.apply(roi)
            bg_mask = getFilter(bg_mask, 'combine')
            (contours, _) = cv2.findContours(bg_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            for cnt in contours:
                area = cv2.contourArea(cnt)

                if area > minArea and area <= maxArea:
                    x, y, w, h = cv2.boundingRect(cnt)
                    centroid = getCentroid(x, y, w, h)
                    cx = centroid[0]
                    cy = centroid[1]
                    new = True
                    cv2.rectangle(roi, (x, y), (x + 50, y - 13), TRACKER_COLOR, -1)
                    cv2.putText(roi, 'CAR', (x, y - 2), FONT, 0.5, (255, 255, 255), 1, cv2.LINE_AA)

                    for i in objects:
                        if abs(x - i.getX()) <= w and abs(y - i.getY()) <= h:
                            new = False
                            i.updateCoords(cx, cy)

                            if i.going_DOWN(DOWN_limit):
                                cnt_cars += 1
                            break
                        if i.getState() == '1' and i.getDir() == 'down' and i.getY() > line_OUT:
                            i.setDone()
                        if i.timedOut():
                            objects.pop(objects.index(i))
                            del i
                    if new:
                        p = validator.MyValidator(pid, cx, cy, max_p_age)
                        objects.append(p)
                        pid += 1
                    cv2.circle(roi, (cx, cy), 5, CENTROID_COLOR, -1)

                elif area >= maxArea:
                    x, y, w, h = cv2.boundingRect(cnt)
                    centroid = getCentroid(x, y, w, h)
                    cx = centroid[0]
                    cy = centroid[1]

                    new = True
                    cv2.rectangle(roi, (x, y), (x + 50, y - 13), TRACKER_COLOR, -1)
                    cv2.putText(roi, 'TRUCK', (x, y - 2), FONT, .5, (255, 255, 255), 1, cv2.LINE_AA)

                    for i in objects:
                        if abs(x - i.getX()) <= w and abs(y - i.getY()) <= h:
                            new = False
                            i.updateCoords(cx, cy)

                            if i.going_DOWN(DOWN_limit):
                                cnt_trucks += 1
                            break
                        if i.getState() == '1' and i.getDir() == 'down' and i.getY() > line_OUT:
                            i.setDone()
                        if i.timedOut():
                            objects.pop(objects.index(i))
                            del i
                    if new:
                        p = validator.MyValidator(pid, cx, cy, max_p_age)
                        objects.append(p)
                        pid += 1
                    cv2.circle(roi, (cx, cy), 5, CENTROID_COLOR, -1)

            for i in objects:
                cv2.putText(roi, str(i.getId()), (i.getX(), i.getY()), FONT, 0.3, TEXT_COLOR, 1, cv2.LINE_AA)

            str_cars = 'Cars: ' + str(cnt_cars)
            str_trucks = 'Trucks: ' + str(cnt_trucks)

            frame = cv2.line(frame, (w1, line_IN), (w1 + w2, line_IN), LINE_IN_COLOR, 2)
            frame = cv2.line(frame, (w1, h1 + line_OUT), (w1 + w2, h1 + line_OUT), LINE_OUT_COLOR, 2)

            cv2.putText(frame, str_cars, TEXT_POSITION_COUNT_CARS, FONT, 1, (255, 255, 255), 3, cv2.LINE_AA)
            cv2.putText(frame, str_cars, TEXT_POSITION_COUNT_CARS, FONT, 1, (232, 162, 0), 2, cv2.LINE_AA)
            cv2.putText(frame, str_trucks, TEXT_POSITION_COUNT_TRUCKS, FONT, 1, (255, 255, 255), 3, cv2.LINE_AA)
            cv2.putText(frame, str_trucks, TEXT_POSITION_COUNT_TRUCKS, FONT, 1, (232, 162, 0), 2, cv2.LINE_AA)

            writer_video.write(frame)

            if frame_number % 30 == 0:
                st.image(frame, channels="BGR", use_column_width=True)

    cap.release()
    writer_video.release()
    return output_path

def main():
    st.title("Vehicle Counting and Tracking")
    st.write("This application processes the video, counts vehicles, and tracks them.")

    video_path = VIDEO_SOURCE
    if video_path:
        st.video(video_path)
        st.write("Processing video...")

        output_path = process_video(video_path)

        if output_path:
            st.write("Processing completed.")
            st.video(output_path)
        else:
            st.write("Failed to process video.")
    else:
        st.error("Please upload a video file.")

if __name__ == "__main__":
    main()