import torch from typing import Tuple, Callable def hacer_nada(x: torch.Tensor, modo: str = None): return x def brujeria_mps(entrada, dim, indice): if entrada.shape[-1] == 1: return torch.gather(entrada.unsqueeze(-1), dim - 1 if dim < 0 else dim, indice.unsqueeze(-1)).squeeze(-1) else: return torch.gather(entrada, dim, indice) def emparejamiento_suave_aleatorio_2d( metrica: torch.Tensor, ancho: int, alto: int, paso_x: int, paso_y: int, radio: int, sin_aleatoriedad: bool = False, generador: torch.Generator = None ) -> Tuple[Callable, Callable]: lote, num_nodos, _ = metrica.shape if radio <= 0: return hacer_nada, hacer_nada recopilar = brujeria_mps if metrica.device.type == "mps" else torch.gather with torch.no_grad(): alto_paso_y, ancho_paso_x = alto // paso_y, ancho // paso_x if sin_aleatoriedad: indice_aleatorio = torch.zeros(alto_paso_y, ancho_paso_x, 1, device=metrica.device, dtype=torch.int64) else: indice_aleatorio = torch.randint(paso_y * paso_x, size=(alto_paso_y, ancho_paso_x, 1), device=generador.device, generator=generador).to(metrica.device) vista_buffer_indice = torch.zeros(alto_paso_y, ancho_paso_x, paso_y * paso_x, device=metrica.device, dtype=torch.int64) vista_buffer_indice.scatter_(dim=2, index=indice_aleatorio, src=-torch.ones_like(indice_aleatorio, dtype=indice_aleatorio.dtype)) vista_buffer_indice = vista_buffer_indice.view(alto_paso_y, ancho_paso_x, paso_y, paso_x).transpose(1, 2).reshape(alto_paso_y * paso_y, ancho_paso_x * paso_x) if (alto_paso_y * paso_y) < alto or (ancho_paso_x * paso_x) < ancho: buffer_indice = torch.zeros(alto, ancho, device=metrica.device, dtype=torch.int64) buffer_indice[:(alto_paso_y * paso_y), :(ancho_paso_x * paso_x)] = vista_buffer_indice else: buffer_indice = vista_buffer_indice indice_aleatorio = buffer_indice.reshape(1, -1, 1).argsort(dim=1) del buffer_indice, vista_buffer_indice num_destino = alto_paso_y * ancho_paso_x indices_a = indice_aleatorio[:, num_destino:, :] indices_b = indice_aleatorio[:, :num_destino, :] def dividir(x): canales = x.shape[-1] origen = recopilar(x, dim=1, index=indices_a.expand(lote, num_nodos - num_destino, canales)) destino = recopilar(x, dim=1, index=indices_b.expand(lote, num_destino, canales)) return origen, destino metrica = metrica / metrica.norm(dim=-1, keepdim=True) a, b = dividir(metrica) puntuaciones = a @ b.transpose(-1, -2) radio = min(a.shape[1], radio) nodo_max, nodo_indice = puntuaciones.max(dim=-1) indice_borde = nodo_max.argsort(dim=-1, descending=True)[..., None] indice_no_emparejado = indice_borde[..., radio:, :] indice_origen = indice_borde[..., :radio, :] indice_destino = recopilar(nodo_indice[..., None], dim=-2, index=indice_origen) def fusionar(x: torch.Tensor, modo="mean") -> torch.Tensor: origen, destino = dividir(x) n, t1, c = origen.shape no_emparejado = recopilar(origen, dim=-2, index=indice_no_emparejado.expand(n, t1 - radio, c)) origen = recopilar(origen, dim=-2, index=indice_origen.expand(n, radio, c)) destino = destino.scatter_reduce(-2, indice_destino.expand(n, radio, c), origen, reduce=modo) return torch.cat([no_emparejado, destino], dim=1) def desfusionar(x: torch.Tensor) -> torch.Tensor: longitud_no_emparejado = indice_no_emparejado.shape[1] no_emparejado, destino = x[..., :longitud_no_emparejado, :], x[..., longitud_no_emparejado:, :] _, _, c = no_emparejado.shape origen = recopilar(destino, dim=-2, index=indice_destino.expand(lote, radio, c)) salida = torch.zeros(lote, num_nodos, c, device=x.device, dtype=x.dtype) salida.scatter_(dim=-2, index=indices_b.expand(lote, num_destino, c), src=destino) salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_no_emparejado).expand(lote, longitud_no_emparejado, c), src=no_emparejado) salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_origen).expand(lote, radio, c), src=origen) return salida return fusionar, desfusionar