fish / fish_feeding.py
Raniahossam33's picture
Upload 5 files
82b369d verified
import torch
import numpy as np
from PIL import Image
from transformers import pipeline
from ultralytics import YOLO
class FishFeeding:
def __init__(self, focal_length: float = 27.4) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.collected_lengths = []
self.focal_length = focal_length
self.final_weight = None
self.length_model_name = "length_model.pt"
self.depth_model_name = "vinvino02/glpn-nyu"
self.counting_model_name = "counting_model.pt"
def load_models(self) -> None:
self.fish_keypoints_model = YOLO(self.length_model_name)
self.depth_model = pipeline(task="depth-estimation", model=self.depth_model_name, device=self.device)
self.fish_detection_model = YOLO(self.counting_model_name)
def predict_fish_length(self, frame):
image_obj = Image.fromarray(frame)
image_obj = image_obj.resize((640, 640)) # Adjust size as per requirement
depth = self.depth_model(image_obj)
depth = depth["predicted_depth"]
depth = np.array(depth).squeeze()
results = self.fish_detection_model(frame)[0]
if (results.keypoints == None):
raise ValueError("No fish detected in the image")
keypoints = results.keypoints.xyn[0].detach().cpu().numpy()
head = keypoints[0]
back = keypoints[1]
belly = keypoints[2]
tail = keypoints[3]
depth_w, depth_h = depth.shape[:2]
head_x = int(head[0] * depth_w)
head_y = int(head[1] * depth_h)
tail_x = int(tail[0] * depth_w)
tail_y = int(tail[1] * depth_h)
back_x = int(back[0] * depth_w)
back_y = int(back[1] * depth_h)
belly_x = int(belly[0] * depth_w)
belly_y = int(belly[1] * depth_h)
head_depth = depth[head_y, head_x]
tail_depth = depth[tail_y, tail_x]
fish_length = (
np.sqrt(
(head_x * head_depth - tail_x * tail_depth) ** 2
+ (head_y * head_depth - tail_y * tail_depth) ** 2
)
/ self.focal_length
)
# girth = (
# np.sqrt(
# (back_x * head_depth - belly_x * tail_depth) ** 2
# + (back_y * head_depth - belly_y * tail_depth) ** 2
# )
# / self.focal_length
# )
return fish_length
# def videocapture(self):
# cap = cv2.VideoCapture(self.video_path)
# assert cap.isOpened(), "Error reading video file"
# while True:
# ret, frame = cap.read()
# if not ret:
# break
# output = self.predict_fish_length(frame)
# self.collected_lengths.append(output)
# cap.release()
# return self.collected_lengths
def get_average_weight(self):
if not self.collected_lengths:
return 0
length_average = sum(self.collected_lengths) / len(self.collected_lengths)
final_weight = 0.014 * length_average ** 3.02
return final_weight
def fish_counting(self, images):
counting_output = 0
for im0 in images:
tracks = self.fish_detection_model(im0)
counting_output = max(counting_output, len(tracks))
return counting_output
def final_fish_feed(self, images: list):
for image in images:
try:
output = self.predict_fish_length(image)
except ValueError:
continue
self.collected_lengths.append(output)
average_weight = self.get_average_weight()
if 0 <= average_weight <= 50:
feed, times = 3.3, 2
elif 50 < average_weight <= 100:
feed, times = 4.8, 2
elif 100 < average_weight <= 250:
feed, times = 5.8, 2
elif 250 < average_weight <= 500:
feed, times = 8.4, 2
elif 500 < average_weight <= 750:
feed, times = 9.4, 1
elif 750 < average_weight <= 1000:
feed, times = 10.5, 1
elif 1000 < average_weight <= 1500:
feed, times = 11.0, 1
else:
feed, times = 12.0, 1
fish_count = self.fish_counting(images)
total_feed = feed * fish_count
return total_feed, times
# if __name__ == "__main__":
# to_collect = 6
# collected = []
# video_path = "object_counting.mp4"
# cap = cv2.VideoCapture(video_path)
# fish_feeding = FishFeeding()
# fish_feeding.load_models()
# d = {"images": []}
# while True:
# ret, frame = cap.read()
# if not ret:
# break
# if len(collected) == to_collect:
# total_feed, times = fish_feeding.final_fish_feed(collected)
# print(f"Total feed: {total_feed}, Feed times: {times}")
# collected = []
# break
# collected.append(frame)
# d["images"].append(frame.tolist())
# if cv2.waitKey(1) & 0xFF == ord("q"):
# break
# cap.release()
# cv2.destroyAllWindows()
# # save d to json file
# import json
# with open("data.json", "w") as f:
# json.dump(d, f)