Instructions to use nikraf/directionality_probe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nikraf/directionality_probe with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="nikraf/directionality_probe", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nikraf/directionality_probe", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| HuggingFace-compatible vec2vec implementation for embedding translation. | |
| Based on: "Harnessing the Universal Geometry of Embeddings" (arXiv:2505.12540) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| from typing import Dict, Optional, List | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| from transformers.modeling_outputs import ModelOutput | |
| # ============================================================================= | |
| # Configuration | |
| # ============================================================================= | |
| class Vec2VecConfig(PretrainedConfig): | |
| """Configuration for Vec2Vec model.""" | |
| model_type = "vec2vec" | |
| def __init__( | |
| self, | |
| encoder_names: List[str] = None, | |
| encoder_dims: List[int] = None, | |
| d_adapter: int = 1024, | |
| d_hidden: int = 1024, | |
| d_transform: int = 1024, | |
| adapter_depth: int = 3, | |
| transform_depth: int = 4, | |
| disc_dim: int = 1024, | |
| disc_depth: int = 5, | |
| weight_init: str = "kaiming", | |
| norm_style: str = "batch", | |
| normalize_embeddings: bool = True, | |
| # Loss coefficients | |
| loss_coefficient_rec: float = 1.0, | |
| loss_coefficient_vsp: float = 1.0, | |
| loss_coefficient_cc_trans: float = 10.0, | |
| loss_coefficient_cc_vsp: float = 10.0, | |
| loss_coefficient_cc_rec: float = 0.0, | |
| loss_coefficient_gen: float = 1.0, | |
| loss_coefficient_latent_gen: float = 1.0, | |
| loss_coefficient_similarity_gen: float = 0.0, | |
| loss_coefficient_disc: float = 1.0, | |
| loss_coefficient_r1_penalty: float = 0.0, | |
| # Training settings | |
| noise_level: float = 0.0, | |
| max_grad_norm: float = 1000.0, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.encoder_names = encoder_names or ["model_a", "model_b"] | |
| self.encoder_dims = encoder_dims or [768, 768] | |
| self.d_adapter = d_adapter | |
| self.d_hidden = d_hidden | |
| self.d_transform = d_transform | |
| self.adapter_depth = adapter_depth | |
| self.transform_depth = transform_depth | |
| self.disc_dim = disc_dim | |
| self.disc_depth = disc_depth | |
| self.weight_init = weight_init | |
| self.norm_style = norm_style | |
| self.normalize_embeddings = normalize_embeddings | |
| # Loss coefficients | |
| self.loss_coefficient_rec = loss_coefficient_rec | |
| self.loss_coefficient_vsp = loss_coefficient_vsp | |
| self.loss_coefficient_cc_trans = loss_coefficient_cc_trans | |
| self.loss_coefficient_cc_vsp = loss_coefficient_cc_vsp | |
| self.loss_coefficient_cc_rec = loss_coefficient_cc_rec | |
| self.loss_coefficient_gen = loss_coefficient_gen | |
| self.loss_coefficient_latent_gen = loss_coefficient_latent_gen | |
| self.loss_coefficient_similarity_gen = loss_coefficient_similarity_gen | |
| self.loss_coefficient_disc = loss_coefficient_disc | |
| self.loss_coefficient_r1_penalty = loss_coefficient_r1_penalty | |
| self.noise_level = noise_level | |
| self.max_grad_norm = max_grad_norm | |
| def get_encoder_dims_dict(self) -> Dict[str, int]: | |
| """Return encoder dimensions as a dictionary.""" | |
| return dict(zip(self.encoder_names, self.encoder_dims)) | |
| # ============================================================================= | |
| # Model Outputs | |
| # ============================================================================= | |
| class Vec2VecOutput(ModelOutput): | |
| """Output type for Vec2Vec forward pass.""" | |
| loss: Optional[torch.FloatTensor] = None | |
| reconstructions: Optional[Dict[str, torch.Tensor]] = None | |
| translations: Optional[Dict[str, Dict[str, torch.Tensor]]] = None | |
| latents: Optional[Dict[str, torch.Tensor]] = None | |
| metrics: Optional[Dict[str, float]] = None | |
| # ============================================================================= | |
| # Model Components | |
| # ============================================================================= | |
| def add_residual(input_x: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
| """Add residual connection with dimension matching.""" | |
| if input_x.shape[1] < x.shape[1]: | |
| padding = torch.zeros(x.shape[0], x.shape[1] - input_x.shape[1], device=x.device) | |
| input_x = torch.cat([input_x, padding], dim=1) | |
| elif input_x.shape[1] > x.shape[1]: | |
| input_x = input_x[:, :x.shape[1]] | |
| return x + input_x | |
| class MLPWithResidual(nn.Module): | |
| """MLP with residual connections.""" | |
| def __init__( | |
| self, | |
| depth: int, | |
| in_dim: int, | |
| hidden_dim: int, | |
| out_dim: int, | |
| norm_style: str = "batch", | |
| weight_init: str = "kaiming", | |
| ): | |
| super().__init__() | |
| self.layers = nn.ModuleList() | |
| norm_layer = nn.BatchNorm1d if norm_style == "batch" else nn.LayerNorm | |
| for layer_idx in range(depth): | |
| if layer_idx == 0: | |
| h_dim = out_dim if depth == 1 else hidden_dim | |
| self.layers.append(nn.Sequential(nn.Linear(in_dim, h_dim), nn.SiLU())) | |
| elif layer_idx < depth - 1: | |
| self.layers.append(nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.SiLU(), | |
| norm_layer(hidden_dim), | |
| nn.Dropout(p=0.1), | |
| )) | |
| else: | |
| self.layers.append(nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.Dropout(p=0.1), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, out_dim), | |
| )) | |
| self._initialize_weights(weight_init) | |
| def _initialize_weights(self, weight_init: str): | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| if weight_init == "kaiming": | |
| nn.init.kaiming_normal_(module.weight, a=0, mode="fan_in", nonlinearity="relu") | |
| elif weight_init == "xavier": | |
| nn.init.xavier_normal_(module.weight) | |
| elif weight_init == "orthogonal": | |
| nn.init.orthogonal_(module.weight) | |
| module.bias.data.fill_(0) | |
| elif isinstance(module, nn.BatchNorm1d): | |
| nn.init.normal_(module.weight, mean=1.0, std=0.02) | |
| nn.init.normal_(module.bias, mean=0.0, std=0.02) | |
| elif isinstance(module, nn.LayerNorm): | |
| nn.init.constant_(module.bias, 0) | |
| nn.init.constant_(module.weight, 1.0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for layer in self.layers: | |
| input_x = x | |
| x = layer(x) | |
| x = add_residual(input_x, x) | |
| return x | |
| class Discriminator(nn.Module): | |
| """Discriminator network for adversarial training.""" | |
| def __init__( | |
| self, | |
| latent_dim: int, | |
| hidden_dim: int = 1024, | |
| depth: int = 5, | |
| weight_init: str = "kaiming", | |
| ): | |
| super().__init__() | |
| self.layers = nn.ModuleList() | |
| if depth >= 2: | |
| layers = [nn.Linear(latent_dim, hidden_dim), nn.Dropout(0.0)] | |
| for _ in range(depth - 2): | |
| layers.extend([ | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(0.0), | |
| ]) | |
| layers.extend([nn.SiLU(), nn.Linear(hidden_dim, 1)]) | |
| self.layers.append(nn.Sequential(*layers)) | |
| else: | |
| self.layers.append(nn.Linear(latent_dim, 1)) | |
| self._initialize_weights(weight_init) | |
| def _initialize_weights(self, weight_init: str): | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| if weight_init == "kaiming": | |
| nn.init.kaiming_normal_(module.weight, a=0, mode="fan_in", nonlinearity="relu") | |
| elif weight_init == "xavier": | |
| nn.init.xavier_normal_(module.weight) | |
| elif weight_init == "orthogonal": | |
| nn.init.orthogonal_(module.weight) | |
| module.bias.data.fill_(0) | |
| elif isinstance(module, nn.LayerNorm): | |
| nn.init.constant_(module.bias, 0) | |
| nn.init.constant_(module.weight, 1.0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for layer in self.layers: | |
| x = layer(x) | |
| return x | |
| # ============================================================================= | |
| # Main Model | |
| # ============================================================================= | |
| class Vec2VecModel(PreTrainedModel): | |
| """ | |
| Vec2Vec model for embedding translation between different spaces. | |
| Architecture: | |
| Input -> In Adapter -> Transform -> Out Adapter -> Output | |
| """ | |
| config_class = Vec2VecConfig | |
| all_tied_weights_keys = {} | |
| def __init__(self, config: Vec2VecConfig): | |
| super().__init__(config) | |
| self.config = config | |
| encoder_dims = config.get_encoder_dims_dict() | |
| # Shared transform | |
| self.transform = MLPWithResidual( | |
| depth=config.transform_depth, | |
| in_dim=config.d_adapter, | |
| hidden_dim=config.d_transform, | |
| out_dim=config.d_adapter, | |
| norm_style=config.norm_style, | |
| weight_init=config.weight_init, | |
| ) | |
| # Adapters for each encoder | |
| self.in_adapters = nn.ModuleDict() | |
| self.out_adapters = nn.ModuleDict() | |
| for name, dim in encoder_dims.items(): | |
| self.in_adapters[name] = MLPWithResidual( | |
| config.adapter_depth, dim, config.d_hidden, config.d_adapter, | |
| config.norm_style, config.weight_init, | |
| ) | |
| self.out_adapters[name] = MLPWithResidual( | |
| config.adapter_depth, config.d_adapter, config.d_hidden, dim, | |
| config.norm_style, config.weight_init, | |
| ) | |
| # Discriminators | |
| self.discriminators = nn.ModuleDict() | |
| for name, dim in encoder_dims.items(): | |
| self.discriminators[name] = Discriminator( | |
| dim, config.disc_dim, config.disc_depth, config.weight_init | |
| ) | |
| self.discriminators["latent"] = Discriminator( | |
| config.d_adapter, config.disc_dim, config.disc_depth, config.weight_init | |
| ) | |
| self.post_init() | |
| def add_encoder(self, name: str, dim: int, overwrite: bool = False): | |
| """Add a new encoder to the model.""" | |
| if name in self.in_adapters and not overwrite: | |
| print(f"Encoder {name} already exists, skipping...") | |
| return | |
| self.in_adapters[name] = MLPWithResidual( | |
| self.config.adapter_depth, dim, self.config.d_hidden, self.config.d_adapter, | |
| self.config.norm_style, self.config.weight_init, | |
| ) | |
| self.out_adapters[name] = MLPWithResidual( | |
| self.config.adapter_depth, self.config.d_adapter, self.config.d_hidden, dim, | |
| self.config.norm_style, self.config.weight_init, | |
| ) | |
| self.discriminators[name] = Discriminator( | |
| dim, self.config.disc_dim, self.config.disc_depth, self.config.weight_init | |
| ) | |
| # Update config | |
| if name not in self.config.encoder_names: | |
| self.config.encoder_names.append(name) | |
| self.config.encoder_dims.append(dim) | |
| def _get_latent(self, emb: torch.Tensor, encoder_name: str) -> torch.Tensor: | |
| """Get latent representation from embedding.""" | |
| z = self.in_adapters[encoder_name](emb) | |
| return self.transform(z) | |
| def _decode(self, latent: torch.Tensor, encoder_name: str) -> torch.Tensor: | |
| """Decode latent to target embedding space.""" | |
| out = self.out_adapters[encoder_name](latent) | |
| if self.config.normalize_embeddings: | |
| out = F.normalize(out, p=2, dim=1) | |
| return out | |
| def translate(self, embeddings: torch.Tensor, src: str, tgt: str) -> torch.Tensor: | |
| """Translate embeddings from source to target space.""" | |
| latent = self._get_latent(embeddings, src) | |
| return self._decode(latent, tgt) | |
| def forward( | |
| self, | |
| inputs: Dict[str, torch.Tensor], | |
| noise_level: float = None, | |
| return_latents: bool = False, | |
| ) -> Vec2VecOutput: | |
| """ | |
| Forward pass computing reconstructions and translations. | |
| Args: | |
| inputs: Dict mapping encoder names to embeddings | |
| noise_level: Optional noise for training | |
| return_latents: Whether to return latent representations | |
| """ | |
| noise_level = noise_level if noise_level is not None else self.config.noise_level | |
| reconstructions = {} | |
| translations = {} | |
| latents = {} | |
| for src_name, emb in inputs.items(): | |
| # Add noise during training | |
| if self.training and noise_level > 0.0: | |
| emb = emb + torch.randn_like(emb) * noise_level | |
| emb = F.normalize(emb, p=2, dim=1) | |
| latent = self._get_latent(emb, src_name) | |
| if return_latents: | |
| latents[src_name] = latent | |
| for tgt_name in inputs.keys(): | |
| decoded = self._decode(latent, tgt_name) | |
| if tgt_name == src_name: | |
| reconstructions[src_name] = decoded | |
| else: | |
| if tgt_name not in translations: | |
| translations[tgt_name] = {} | |
| translations[tgt_name][src_name] = decoded | |
| return Vec2VecOutput( | |
| reconstructions=reconstructions, | |
| translations=translations, | |
| latents=latents if return_latents else None, | |
| ) | |
| # ============================================================================= | |
| # Loss Functions | |
| # ============================================================================= | |
| def reconstruction_loss(inputs: Dict[str, torch.Tensor], recons: Dict[str, torch.Tensor]) -> torch.Tensor: | |
| """Reconstruction loss (1 - cosine similarity).""" | |
| loss = sum(1 - F.cosine_similarity(inputs[k], recons[k], dim=1).mean() for k in inputs) | |
| return loss / len(inputs) | |
| def translation_loss(inputs: Dict[str, torch.Tensor], translations: Dict[str, Dict[str, torch.Tensor]]) -> torch.Tensor: | |
| """Translation loss (1 - cosine similarity).""" | |
| loss = 0.0 | |
| count = 0 | |
| for tgt, emb in inputs.items(): | |
| for trans in translations[tgt].values(): | |
| loss += 1 - F.cosine_similarity(emb, trans, dim=1).mean() | |
| count += 1 | |
| return loss / max(count, 1) | |
| def vsp_loss(inputs: Dict[str, torch.Tensor], translations: Dict[str, Dict[str, torch.Tensor]]) -> torch.Tensor: | |
| """Vector Space Preservation (VSP) loss.""" | |
| loss = 0.0 | |
| count = 0 | |
| EPS = 1e-10 | |
| for out_name in inputs: | |
| for in_name in translations[out_name]: | |
| B = F.normalize(inputs[out_name].detach(), p=2, dim=1) | |
| A = F.normalize(translations[out_name][in_name], p=2, dim=1) | |
| in_sims = B @ B.T | |
| out_sims = A @ A.T | |
| out_sims_reflected = A @ B.T | |
| loss += (in_sims - out_sims).abs().mean() | |
| loss += (in_sims - out_sims_reflected).abs().mean() | |
| count += 1 | |
| return loss / max(count, 1) | |
| from typing import Optional, Union, List, Dict | |
| from transformers import AutoModel, AutoTokenizer | |
| from .base_tokenizer import BaseSequenceTokenizer | |
| from .supported_models import all_presets_with_paths | |
| from pooler import Pooler | |
| presets = { | |
| 'vec2vec-ESM2-8-ESM2-35': 'Synthyra/ESM2-8-ESM2-35-sequence-sequence', | |
| 'vec2vec-ESM2-8-ESM2-150': 'Synthyra/ESM2-8-ESM2-150-sequence-sequence', | |
| 'vec2vec-ESM2-8-ESM2-650': 'Synthyra/ESM2-8-ESM2-650-sequence-sequence', | |
| 'vec2vec-ESM2-8-ESM2-3B': 'Synthyra/ESM2-8-ESM2-3B-sequence-sequence', | |
| 'vec2vec-ESM2-35-ESM2-150': 'Synthyra/ESM2-35-ESM2-150-sequence-sequence', | |
| 'vec2vec-ESM2-35-ESM2-650': 'Synthyra/ESM2-35-ESM2-650-sequence-sequence', | |
| 'vec2vec-ESM2-35-ESM2-3B': 'Synthyra/ESM2-35-ESM2-3B-sequence-sequence', | |
| 'vec2vec-ESM2-150-ESM2-650': 'Synthyra/ESM2-150-ESM2-650-sequence-sequence', | |
| 'vec2vec-ESM2-150-ESM2-3B': 'Synthyra/ESM2-150-ESM2-3B-sequence-sequence', | |
| 'vec2vec-ESM2-650-ESM2-3B': 'Synthyra/ESM2-650-ESM2-3B-sequence-sequence', | |
| } | |
| class Vec2VecTokenizerWrapper(BaseSequenceTokenizer): | |
| def __init__(self, tokenizer: AutoTokenizer): | |
| super().__init__(tokenizer) | |
| def __call__(self, sequences: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]: | |
| if isinstance(sequences, str): | |
| sequences = [sequences] | |
| kwargs.setdefault('return_tensors', 'pt') | |
| kwargs.setdefault('padding', 'longest') | |
| kwargs.setdefault('add_special_tokens', True) | |
| tokenized = self.tokenizer(sequences, **kwargs) | |
| return tokenized | |
| class Vec2VecForEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| config: Vec2VecConfig, | |
| base_model: AutoModel, | |
| vec2vec_model: Vec2VecModel, | |
| model_name_a: str, | |
| model_name_b: str, | |
| ): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.vec2vec_model = vec2vec_model | |
| self.config = config | |
| self.pooler = Pooler(['mean', 'var']) | |
| self.model_name_a = model_name_a | |
| self.model_name_b = model_name_b | |
| self.normalize = config.normalize_embeddings | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = False, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| # only vector embeddings, don't use output_attentions, etc. | |
| base_state = self.base_model(input_ids, attention_mask=attention_mask).last_hidden_state | |
| base_vec = self.pooler(base_state, attention_mask=attention_mask) | |
| if self.normalize: | |
| base_vec = F.normalize(base_vec, p=2, dim=1) | |
| translated_ab = self.vec2vec_model.translate(base_vec, src=self.model_name_a, tgt=self.model_name_b) | |
| return translated_ab | |
| def get_vec2vec_tokenizer(preset: str, model_path: str = None): | |
| # TODO work with new Vec2Vec .tokenizer_a and .tokenizer_b | |
| path = model_path or all_presets_with_paths[preset] | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) | |
| except: | |
| model = AutoModel.from_pretrained(path, trust_remote_code=True) | |
| tokenizer = AutoTokenizer.from_pretrained(model.config.tokenizer_name) | |
| return Vec2VecTokenizerWrapper(tokenizer) | |
| def build_vec2vec_model(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs): | |
| if masked_lm: | |
| raise ValueError("Masked LM is not supported for Vec2VecForEmbedding") | |
| else: | |
| model_path = model_path or presets[preset] | |
| config = Vec2VecConfig.from_pretrained(model_path) | |
| encoder_names = config.encoder_names | |
| encoder_dims = config.encoder_dims | |
| if encoder_dims[0] >= encoder_dims[1]: | |
| model_name_a = encoder_names[0] | |
| model_name_b = encoder_names[1] | |
| else: | |
| model_name_a = encoder_names[1] | |
| model_name_b = encoder_names[0] | |
| base_model = AutoModel.from_pretrained(all_presets_with_paths[model_name_a], dtype=dtype, trust_remote_code=True) | |
| base_tokenizer = base_model.tokenizer | |
| vec2vec_model = Vec2VecModel(config).from_pretrained(model_path) | |
| model = Vec2VecForEmbedding(config, base_model, vec2vec_model, model_name_a, model_name_b) | |
| tokenizer = Vec2VecTokenizerWrapper(base_tokenizer) | |
| return model, tokenizer | |
| def get_vec2vec_for_training(preset: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False): | |
| raise ValueError("Vec2VecForTraining is not supported yet") | |
| if __name__ == '__main__': | |
| # py -m src.protify.base_models.vec2vec | |
| model, tokenizer = build_vec2vec_model('ESM2-8-ESM2-35') | |
| print(model) | |
| print(tokenizer) | |
| print(tokenizer('MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL')) | |