""" See https://github.com/baaivision/Uni3D for source code """ import os import torch import torch.nn as nn import timm import numpy as np from pointnet2_ops import pointnet2_utils import open_clip from huggingface_hub import hf_hub_download import sys sys.path.append('') from feature_extractors import FeatureExtractor from utils.tokenizer import SimpleTokenizer import logging def fps(data, number): ''' data B N 3 number int ''' fps_idx = pointnet2_utils.furthest_point_sample(data, number) fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous() return fps_data # https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py def knn_point(nsample, xyz, new_xyz): """ Input: nsample: max sample number in local region xyz: all points, [B, N, C] new_xyz: query points, [B, S, C] Return: group_idx: grouped points index, [B, S, nsample] """ sqrdists = square_distance(new_xyz, xyz) _, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False) return group_idx def square_distance(src, dst): """ Calculate Euclid distance between each two points. src^T * dst = xn * xm + yn * ym + zn * zm; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst Input: src: source points, [B, N, C] dst: target points, [B, M, C] Output: dist: per-point square distance, [B, N, M] """ B, N, _ = src.shape _, M, _ = dst.shape dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(dst ** 2, -1).view(B, 1, M) return dist class PatchDropout(nn.Module): """ https://arxiv.org/abs/2212.00794 """ def __init__(self, prob, exclude_first_token=True): super().__init__() assert 0 <= prob < 1. self.prob = prob self.exclude_first_token = exclude_first_token # exclude CLS token logging.info("patch dropout prob is {}".format(prob)) def forward(self, x): # if not self.training or self.prob == 0.: # return x if self.exclude_first_token: cls_tokens, x = x[:, :1], x[:, 1:] else: cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) batch = x.size()[0] num_tokens = x.size()[1] batch_indices = torch.arange(batch) batch_indices = batch_indices[..., None] keep_prob = 1 - self.prob num_patches_keep = max(1, int(num_tokens * keep_prob)) rand = torch.randn(batch, num_tokens) patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices x = x[batch_indices, patch_indices_keep] if self.exclude_first_token: x = torch.cat((cls_tokens, x), dim=1) return x class Group(nn.Module): def __init__(self, num_group, group_size): super().__init__() self.num_group = num_group self.group_size = group_size def forward(self, xyz, color): ''' input: B N 3 --------------------------- output: B G M 3 center : B G 3 ''' batch_size, num_points, _ = xyz.shape # fps the centers out center = fps(xyz, self.num_group) # B G 3 # knn to get the neighborhood # _, idx = self.knn(xyz, center) # B G M idx = knn_point(self.group_size, xyz, center) # B G M assert idx.size(1) == self.num_group assert idx.size(2) == self.group_size idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points idx = idx + idx_base idx = idx.view(-1) neighborhood = xyz.view(batch_size * num_points, -1)[idx, :] neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous() neighborhood_color = color.view(batch_size * num_points, -1)[idx, :] neighborhood_color = neighborhood_color.view(batch_size, self.num_group, self.group_size, 3).contiguous() # normalize neighborhood = neighborhood - center.unsqueeze(2) features = torch.cat((neighborhood, neighborhood_color), dim=-1) return neighborhood, center, features class Encoder(nn.Module): def __init__(self, encoder_channel): super().__init__() self.encoder_channel = encoder_channel self.first_conv = nn.Sequential( nn.Conv1d(6, 128, 1), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Conv1d(128, 256, 1) ) self.second_conv = nn.Sequential( nn.Conv1d(512, 512, 1), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Conv1d(512, self.encoder_channel, 1) ) def forward(self, point_groups): ''' point_groups : B G N 3 ----------------- feature_global : B G C ''' bs, g, n , _ = point_groups.shape point_groups = point_groups.reshape(bs * g, n, 6) # encoder feature = self.first_conv(point_groups.transpose(2,1)) # BG 256 n feature_global = torch.max(feature,dim=2,keepdim=True)[0] # BG 256 1 feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)# BG 512 n feature = self.second_conv(feature) # BG 1024 n feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024 return feature_global.reshape(bs, g, self.encoder_channel) class PointcloudEncoder(nn.Module): def __init__(self, point_transformer): # use the giant branch of uni3d super().__init__() from easydict import EasyDict self.trans_dim = 1408 self.embed_dim = 1024 self.group_size = 64 self.num_group = 512 # grouper self.group_divider = Group(num_group = self.num_group, group_size = self.group_size) # define the encoder self.encoder_dim = 512 self.encoder = Encoder(encoder_channel = self.encoder_dim) # bridge encoder and transformer self.encoder2trans = nn.Linear(self.encoder_dim, self.trans_dim) # bridge transformer and clip embedding self.trans2embed = nn.Linear(self.trans_dim, self.embed_dim) self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim)) self.pos_embed = nn.Sequential( nn.Linear(3, 128), nn.GELU(), nn.Linear(128, self.trans_dim) ) # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn self.patch_dropout = PatchDropout(0.) if 0. > 0. else nn.Identity() self.visual = point_transformer def forward(self, pts, colors): # divide the point cloud in the same form. This is important _, center, features = self.group_divider(pts, colors) # encoder the input cloud patches group_input_tokens = self.encoder(features) # B G N group_input_tokens = self.encoder2trans(group_input_tokens) # prepare cls cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1) cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1) # add pos embedding pos = self.pos_embed(center) # final input x = torch.cat((cls_tokens, group_input_tokens), dim=1) pos = torch.cat((cls_pos, pos), dim=1) # transformer x = x + pos # x = x.half() # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in x = self.patch_dropout(x) x = self.visual.pos_drop(x) # ModuleList not support forward for i, blk in enumerate(self.visual.blocks): x = blk(x) x = self.visual.norm(x[:, 0, :]) x = self.visual.fc_norm(x) x = self.trans2embed(x) return x class Uni3D(nn.Module): def __init__(self, point_encoder): super().__init__() self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.point_encoder = point_encoder def encode_pc(self, pc): xyz = pc[:,:,:3].contiguous() color = pc[:,:,3:].contiguous() pc_feat = self.point_encoder(xyz, color) return pc_feat def forward(self, pc, text, image): text_embed_all = text image_embed = image pc_embed = self.encode_pc(pc) return {'text_embed': text_embed_all, 'pc_embed': pc_embed, 'image_embed': image_embed, 'logit_scale': self.logit_scale.exp()} def get_metric_names(model): return ['loss', 'uni3d_loss', 'pc_image_acc', 'pc_text_acc'] def create_uni3d(uni3d_path): # create transformer blocks for point cloud via timm point_transformer = timm.create_model("eva_giant_patch14_560") # create whole point cloud encoder point_encoder = PointcloudEncoder(point_transformer) # uni3d model model = Uni3D(point_encoder=point_encoder,) checkpoint = torch.load(uni3d_path, map_location='cpu') logging.info('loaded checkpoint {}'.format(uni3d_path)) sd = checkpoint['module'] if next(iter(sd.items()))[0].startswith('module'): sd = {k[len('module.'):]: v for k, v in sd.items()} model.load_state_dict(sd) return model class Uni3dEmbeddingEncoder(FeatureExtractor): def __init__(self, cache_dir, **kwargs) -> None: bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz" # uni3d_path = os.path.join(cache_dir, "Uni3D", "modelzoo", "uni3d-g", "model.pt") # concat the subfolder as hf_hub_download will put it here clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin") # if not os.path.exists(uni3d_path): # hf_hub_download("BAAI/Uni3D", "model.pt", subfolder="modelzoo/uni3d-g", cache_dir=cache_dir, # local_dir=cache_dir + os.sep + "Uni3D") if not os.path.exists(clip_path): hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin", cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = SimpleTokenizer(bpe_path) # self.model = create_uni3d(uni3d_path) # self.model.eval() # self.model.to(self.device) self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path) self.clip_model.to(self.device) def pc_norm(self, pc): """ pc: NxC, return NxC """ centroid = np.mean(pc, axis=0) pc = pc - centroid m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) pc = pc / m return pc @torch.no_grad() def encode_3D(self, data): pass # pc = data.to(device=self.device, non_blocking=True) # pc_features = self.model.encode_pc(pc) # pc_features = pc_features / pc_features.norm(dim=-1, keepdim=True) # return pc_features.float() @torch.no_grad() def encode_text(self, input_text): texts = self.tokenizer(input_text).to(device=self.device, non_blocking=True) if len(texts.shape) < 2: texts = texts[None, ...] class_embeddings = self.clip_model.encode_text(texts) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) return class_embeddings.float() @torch.no_grad() def encode_image(self, img_tensor_list): image = img_tensor_list.to(device=self.device, non_blocking=True) image_features = self.clip_model.encode_image(image) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.float() def encode_query(self, query_list): return self.encode_text(query_list) def get_img_transform(self): return self.preprocess