File size: 2,464 Bytes
46a60b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from config import LABELS_TO_IDS
from utils.vis_utils import visualize_mask_with_overlay

def load_model(task, version):
    from config import SAPIENS_LITE_MODELS_PATH
    import os

    try:
        model_path = SAPIENS_LITE_MODELS_PATH[task][version]
        if not os.path.exists(model_path):
            print(f"Advertencia: El archivo del modelo no existe en {model_path}")
            return None, None
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = torch.jit.load(model_path)
        model.eval()
        model.to(device)
        return model, device
    except KeyError as e:
        print(f"Error: Tarea o versi贸n inv谩lida. {e}")
        return None, None

def process_image_or_video(input_data, task='seg', version='sapiens_0.3b'):
    # Configurar el modelo
    model, device = load_model(task, version)
    if model is None or device is None:
        return None

    # Configurar la transformaci贸n de entrada
    transform_fn = transforms.Compose([
        transforms.Resize((1024, 768)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Funci贸n para procesar un solo frame
    def process_frame(frame):
        if isinstance(frame, np.ndarray):
            frame = Image.fromarray(frame)
        
        if frame.mode == 'RGBA':
            frame = frame.convert('RGB')
        
        input_tensor = transform_fn(frame).unsqueeze(0).to(device)
        
        with torch.inference_mode():
            output = model(input_tensor)
            output = torch.nn.functional.interpolate(output, size=(frame.height, frame.width), mode="bilinear", align_corners=False)
            _, preds = torch.max(output, 1)
        
        mask = preds.squeeze(0).cpu().numpy()
        mask_image = Image.fromarray(mask.astype("uint8"))
        blended_image = visualize_mask_with_overlay(frame, mask_image, LABELS_TO_IDS, alpha=0.5)
        return blended_image

    # Procesar imagen o video
    if isinstance(input_data, np.ndarray):  # Video frame
        return process_frame(input_data)
    elif isinstance(input_data, Image.Image):  # Imagen
        return process_frame(input_data)
    else:
        print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.")
        return None