Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from PIL import Image | |
from transformers import pipeline | |
from ultralytics import YOLO | |
class FishFeeding: | |
def __init__(self, focal_length: float = 27.4) -> None: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.collected_lengths = [] | |
self.focal_length = focal_length | |
self.final_weight = None | |
self.length_model_name = "length_model.pt" | |
self.depth_model_name = "vinvino02/glpn-nyu" | |
self.counting_model_name = "counting_model.pt" | |
def load_models(self) -> None: | |
self.fish_keypoints_model = YOLO(self.length_model_name) | |
self.depth_model = pipeline(task="depth-estimation", model=self.depth_model_name, device=self.device) | |
self.fish_detection_model = YOLO(self.counting_model_name) | |
def predict_fish_length(self, frame): | |
image_obj = Image.fromarray(frame) | |
image_obj = image_obj.resize((640, 640)) # Adjust size as per requirement | |
depth = self.depth_model(image_obj) | |
depth = depth["predicted_depth"] | |
depth = np.array(depth).squeeze() | |
results = self.fish_detection_model(frame)[0] | |
if (results.keypoints == None): | |
raise ValueError("No fish detected in the image") | |
keypoints = results.keypoints.xyn[0].detach().cpu().numpy() | |
head = keypoints[0] | |
back = keypoints[1] | |
belly = keypoints[2] | |
tail = keypoints[3] | |
depth_w, depth_h = depth.shape[:2] | |
head_x = int(head[0] * depth_w) | |
head_y = int(head[1] * depth_h) | |
tail_x = int(tail[0] * depth_w) | |
tail_y = int(tail[1] * depth_h) | |
back_x = int(back[0] * depth_w) | |
back_y = int(back[1] * depth_h) | |
belly_x = int(belly[0] * depth_w) | |
belly_y = int(belly[1] * depth_h) | |
head_depth = depth[head_y, head_x] | |
tail_depth = depth[tail_y, tail_x] | |
fish_length = ( | |
np.sqrt( | |
(head_x * head_depth - tail_x * tail_depth) ** 2 | |
+ (head_y * head_depth - tail_y * tail_depth) ** 2 | |
) | |
/ self.focal_length | |
) | |
# girth = ( | |
# np.sqrt( | |
# (back_x * head_depth - belly_x * tail_depth) ** 2 | |
# + (back_y * head_depth - belly_y * tail_depth) ** 2 | |
# ) | |
# / self.focal_length | |
# ) | |
return fish_length | |
# def videocapture(self): | |
# cap = cv2.VideoCapture(self.video_path) | |
# assert cap.isOpened(), "Error reading video file" | |
# while True: | |
# ret, frame = cap.read() | |
# if not ret: | |
# break | |
# output = self.predict_fish_length(frame) | |
# self.collected_lengths.append(output) | |
# cap.release() | |
# return self.collected_lengths | |
def get_average_weight(self): | |
if not self.collected_lengths: | |
return 0 | |
length_average = sum(self.collected_lengths) / len(self.collected_lengths) | |
final_weight = 0.014 * length_average ** 3.02 | |
return final_weight | |
def fish_counting(self, images): | |
counting_output = 0 | |
for im0 in images: | |
tracks = self.fish_detection_model(im0) | |
counting_output = max(counting_output, len(tracks)) | |
return counting_output | |
def final_fish_feed(self, images: list): | |
for image in images: | |
try: | |
output = self.predict_fish_length(image) | |
except ValueError: | |
continue | |
self.collected_lengths.append(output) | |
average_weight = self.get_average_weight() | |
if 0 <= average_weight <= 50: | |
feed, times = 3.3, 2 | |
elif 50 < average_weight <= 100: | |
feed, times = 4.8, 2 | |
elif 100 < average_weight <= 250: | |
feed, times = 5.8, 2 | |
elif 250 < average_weight <= 500: | |
feed, times = 8.4, 2 | |
elif 500 < average_weight <= 750: | |
feed, times = 9.4, 1 | |
elif 750 < average_weight <= 1000: | |
feed, times = 10.5, 1 | |
elif 1000 < average_weight <= 1500: | |
feed, times = 11.0, 1 | |
else: | |
feed, times = 12.0, 1 | |
fish_count = self.fish_counting(images) | |
total_feed = feed * fish_count | |
return total_feed, times | |
# if __name__ == "__main__": | |
# to_collect = 6 | |
# collected = [] | |
# video_path = "object_counting.mp4" | |
# cap = cv2.VideoCapture(video_path) | |
# fish_feeding = FishFeeding() | |
# fish_feeding.load_models() | |
# d = {"images": []} | |
# while True: | |
# ret, frame = cap.read() | |
# if not ret: | |
# break | |
# if len(collected) == to_collect: | |
# total_feed, times = fish_feeding.final_fish_feed(collected) | |
# print(f"Total feed: {total_feed}, Feed times: {times}") | |
# collected = [] | |
# break | |
# collected.append(frame) | |
# d["images"].append(frame.tolist()) | |
# if cv2.waitKey(1) & 0xFF == ord("q"): | |
# break | |
# cap.release() | |
# cv2.destroyAllWindows() | |
# # save d to json file | |
# import json | |
# with open("data.json", "w") as f: | |
# json.dump(d, f) | |