Spaces:
Sleeping
Sleeping
# Import necessary libraries | |
import matplotlib | |
# Use Agg backend for Matplotlib | |
matplotlib.use("Agg") | |
# Libraries for the app | |
import streamlit as st | |
import time | |
import io | |
import argparse | |
import sys | |
import os.path | |
import subprocess | |
import tempfile | |
import logging | |
import torch | |
# Visualization libraries | |
import altair as alt | |
import av | |
# Machine Learning and Image Processing libraries | |
import numpy as np | |
import pandas as pd | |
import cv2 as cv | |
from PIL import Image, ImageOps | |
from tqdm import tqdm | |
# Custom modules | |
import inference | |
from app_utils import * | |
from app_plot_utils import * | |
def load_video(video_url): | |
video_bytes = open(video_url, "rb").read() | |
return video_bytes | |
def load_historical(fpath): | |
return pd.read_csv(fpath) | |
# Define the main function to run the Streamlit app | |
def run_app(): | |
# Set Streamlit options | |
st.set_page_config(layout="wide") | |
st.set_option("deprecation.showfileUploaderEncoding", False) | |
# App title and description | |
st.title("MIT Count Fish Counter") | |
st.text("Upload a video file to detect and count fish") | |
# Example video URL or file path (replace with actual video URL or file path) | |
video_url = "yolo2_out_py.mp4" | |
video_bytes = load_video(video_url) | |
# Load historical herring | |
df_historical_herring = load_historical(fpath="herring_count_all.csv") | |
# Check if GPU is available | |
gpu_available = torch.cuda.is_available() | |
mps_available = torch.backends.mps.is_available() | |
main_tab, upload_tab = st.tabs(["Analysis", "Upload video for analysis"]) | |
with main_tab: | |
# Create two columns for layout | |
col1, col2 = st.columns(2) | |
## Col1 ######################################### | |
with col1: | |
## Initial visualizations | |
# Plot historical data | |
st.altair_chart( | |
plot_historical_data(df_historical_herring), | |
use_container_width=True, | |
) | |
st.subheader("Yearly Totals (from manual counts)") | |
st.dataframe(df_historical_herring.groupby(df_historical_herring["Date"].dt.year).sum().T) | |
# Display map of fishery locations | |
st.subheader("Map of Fishery Locations") | |
st.map( | |
pd.DataFrame( | |
np.random.randn(5, 2) / [50, 50] + [42.41, -71.38], | |
columns=["lat", "lon"], | |
),use_container_width=True) | |
with col2: | |
st.subheader("Example of processed video") | |
st.video(video_bytes) | |
# Display GPU/CPU information | |
st.subheader("System Information") | |
if gpu_available: | |
st.info("GPU is available.") | |
elif mps_available: | |
st.info("MPS is available.") | |
else: | |
st.info("Only CPU is available.") | |
with upload_tab: | |
process_uploaded_file() | |
# Run the app if the script is executed directly | |
if __name__ == "__main__": | |
run_app() |