LD-T3D / feature_extractors /uni3d_embedding_encoder.py
yuanze1024's picture
bugfix & remove redundent uni3d
03e01a8
raw
history blame
12.4 kB
"""
See https://github.com/baaivision/Uni3D for source code
"""
import os
import torch
import torch.nn as nn
import timm
import numpy as np
from pointnet2_ops import pointnet2_utils
import open_clip
from huggingface_hub import hf_hub_download
import sys
sys.path.append('')
from feature_extractors import FeatureExtractor
from utils.tokenizer import SimpleTokenizer
import logging
def fps(data, number):
'''
data B N 3
number int
'''
fps_idx = pointnet2_utils.furthest_point_sample(data, number)
fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
return fps_data
# https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py
def knn_point(nsample, xyz, new_xyz):
"""
Input:
nsample: max sample number in local region
xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
return group_idx
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
logging.info("patch dropout prob is {}".format(prob))
def forward(self, x):
# if not self.training or self.prob == 0.:
# return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
return x
class Group(nn.Module):
def __init__(self, num_group, group_size):
super().__init__()
self.num_group = num_group
self.group_size = group_size
def forward(self, xyz, color):
'''
input: B N 3
---------------------------
output: B G M 3
center : B G 3
'''
batch_size, num_points, _ = xyz.shape
# fps the centers out
center = fps(xyz, self.num_group) # B G 3
# knn to get the neighborhood
# _, idx = self.knn(xyz, center) # B G M
idx = knn_point(self.group_size, xyz, center) # B G M
assert idx.size(1) == self.num_group
assert idx.size(2) == self.group_size
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
neighborhood_color = color.view(batch_size * num_points, -1)[idx, :]
neighborhood_color = neighborhood_color.view(batch_size, self.num_group, self.group_size, 3).contiguous()
# normalize
neighborhood = neighborhood - center.unsqueeze(2)
features = torch.cat((neighborhood, neighborhood_color), dim=-1)
return neighborhood, center, features
class Encoder(nn.Module):
def __init__(self, encoder_channel):
super().__init__()
self.encoder_channel = encoder_channel
self.first_conv = nn.Sequential(
nn.Conv1d(6, 128, 1),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 256, 1)
)
self.second_conv = nn.Sequential(
nn.Conv1d(512, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, self.encoder_channel, 1)
)
def forward(self, point_groups):
'''
point_groups : B G N 3
-----------------
feature_global : B G C
'''
bs, g, n , _ = point_groups.shape
point_groups = point_groups.reshape(bs * g, n, 6)
# encoder
feature = self.first_conv(point_groups.transpose(2,1)) # BG 256 n
feature_global = torch.max(feature,dim=2,keepdim=True)[0] # BG 256 1
feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)# BG 512 n
feature = self.second_conv(feature) # BG 1024 n
feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
return feature_global.reshape(bs, g, self.encoder_channel)
class PointcloudEncoder(nn.Module):
def __init__(self, point_transformer):
# use the giant branch of uni3d
super().__init__()
from easydict import EasyDict
self.trans_dim = 1408
self.embed_dim = 1024
self.group_size = 64
self.num_group = 512
# grouper
self.group_divider = Group(num_group = self.num_group, group_size = self.group_size)
# define the encoder
self.encoder_dim = 512
self.encoder = Encoder(encoder_channel = self.encoder_dim)
# bridge encoder and transformer
self.encoder2trans = nn.Linear(self.encoder_dim, self.trans_dim)
# bridge transformer and clip embedding
self.trans2embed = nn.Linear(self.trans_dim, self.embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
self.pos_embed = nn.Sequential(
nn.Linear(3, 128),
nn.GELU(),
nn.Linear(128, self.trans_dim)
)
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(0.) if 0. > 0. else nn.Identity()
self.visual = point_transformer
def forward(self, pts, colors):
# divide the point cloud in the same form. This is important
_, center, features = self.group_divider(pts, colors)
# encoder the input cloud patches
group_input_tokens = self.encoder(features) # B G N
group_input_tokens = self.encoder2trans(group_input_tokens)
# prepare cls
cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
# add pos embedding
pos = self.pos_embed(center)
# final input
x = torch.cat((cls_tokens, group_input_tokens), dim=1)
pos = torch.cat((cls_pos, pos), dim=1)
# transformer
x = x + pos
# x = x.half()
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.patch_dropout(x)
x = self.visual.pos_drop(x)
# ModuleList not support forward
for i, blk in enumerate(self.visual.blocks):
x = blk(x)
x = self.visual.norm(x[:, 0, :])
x = self.visual.fc_norm(x)
x = self.trans2embed(x)
return x
class Uni3D(nn.Module):
def __init__(self, point_encoder):
super().__init__()
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.point_encoder = point_encoder
def encode_pc(self, pc):
xyz = pc[:,:,:3].contiguous()
color = pc[:,:,3:].contiguous()
pc_feat = self.point_encoder(xyz, color)
return pc_feat
def forward(self, pc, text, image):
text_embed_all = text
image_embed = image
pc_embed = self.encode_pc(pc)
return {'text_embed': text_embed_all,
'pc_embed': pc_embed,
'image_embed': image_embed,
'logit_scale': self.logit_scale.exp()}
def get_metric_names(model):
return ['loss', 'uni3d_loss', 'pc_image_acc', 'pc_text_acc']
def create_uni3d(uni3d_path):
# create transformer blocks for point cloud via timm
point_transformer = timm.create_model("eva_giant_patch14_560")
# create whole point cloud encoder
point_encoder = PointcloudEncoder(point_transformer)
# uni3d model
model = Uni3D(point_encoder=point_encoder,)
checkpoint = torch.load(uni3d_path, map_location='cpu')
logging.info('loaded checkpoint {}'.format(uni3d_path))
sd = checkpoint['module']
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
model.load_state_dict(sd)
return model
class Uni3dEmbeddingEncoder(FeatureExtractor):
def __init__(self, cache_dir, **kwargs) -> None:
bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz"
# uni3d_path = os.path.join(cache_dir, "Uni3D", "modelzoo", "uni3d-g", "model.pt") # concat the subfolder as hf_hub_download will put it here
clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin")
# if not os.path.exists(uni3d_path):
# hf_hub_download("BAAI/Uni3D", "model.pt", subfolder="modelzoo/uni3d-g", cache_dir=cache_dir,
# local_dir=cache_dir + os.sep + "Uni3D")
if not os.path.exists(clip_path):
hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin",
cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = SimpleTokenizer(bpe_path)
# self.model = create_uni3d(uni3d_path)
# self.model.eval()
# self.model.to(self.device)
self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path)
self.clip_model.to(self.device)
def pc_norm(self, pc):
""" pc: NxC, return NxC """
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
@torch.no_grad()
def encode_3D(self, data):
pass
# pc = data.to(device=self.device, non_blocking=True)
# pc_features = self.model.encode_pc(pc)
# pc_features = pc_features / pc_features.norm(dim=-1, keepdim=True)
# return pc_features.float()
@torch.no_grad()
def encode_text(self, input_text):
texts = self.tokenizer(input_text).to(device=self.device, non_blocking=True)
if len(texts.shape) < 2:
texts = texts[None, ...]
class_embeddings = self.clip_model.encode_text(texts)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
return class_embeddings.float()
@torch.no_grad()
def encode_image(self, img_tensor_list):
image = img_tensor_list.to(device=self.device, non_blocking=True)
image_features = self.clip_model.encode_image(image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.float()
def encode_query(self, query_list):
return self.encode_text(query_list)
def get_img_transform(self):
return self.preprocess