|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from typing import Optional |
|
import torch.nn.functional as F |
|
|
|
from transformers.models.roformer.modeling_roformer import ( |
|
RoFormerEmbeddings, |
|
RoFormerModel, |
|
RoFormerEncoder, |
|
RoFormerLayer, |
|
RoFormerAttention, |
|
RoFormerIntermediate, |
|
RoFormerOutput, |
|
RoFormerSelfAttention, |
|
RoFormerPreTrainedModel |
|
) |
|
|
|
from transformers.models.mpnet.modeling_mpnet import MPNetModel |
|
|
|
from transformers import MPNetTokenizerFast, BatchEncoding |
|
|
|
class AsmTokenizer(MPNetTokenizerFast): |
|
|
|
@property |
|
def pad_token_type_id(self) -> int: |
|
""" |
|
`int`: Id of the padding token type in the vocabulary. |
|
""" |
|
return self.pad_token_id |
|
|
|
def tokenize_function(self, function): |
|
total_len = 0 |
|
tokenized_functions = {"token": [], "instr": []} |
|
for key, value in function.items(): |
|
tokens = self.tokenize(value.replace(',', ''), max_length=20, truncation=True, add_special_tokens=False) |
|
instr_index = "INSTR" + key |
|
instructions = [instr_index] * len(tokens) |
|
tokenized_functions["token"].extend(tokens) |
|
tokenized_functions["instr"].extend(instructions) |
|
total_len += len(tokens) |
|
if total_len > self.model_max_length: |
|
tokenized_functions['token'] = tokenized_functions['token'][:self.model_max_length] |
|
tokenized_functions['instr'] = tokenized_functions['instr'][:self.model_max_length] |
|
break |
|
return tokenized_functions |
|
|
|
def encode_function(self, function): |
|
tokenized_functions = self.tokenize_function(function) |
|
token_ids = self.convert_tokens_to_ids(tokenized_functions["token"]) |
|
instr_ids = self.convert_tokens_to_ids(tokenized_functions["instr"]) |
|
return BatchEncoding({ |
|
"input_ids": token_ids, |
|
"attention_mask": [1] * len(token_ids), |
|
"token_type_ids": instr_ids, |
|
}) |
|
|
|
def __call__(self, functions, **kwargs): |
|
if len(functions) == 0: |
|
return BatchEncoding({ |
|
"input_ids": [], |
|
"attention_mask": [], |
|
"token_type_ids": [], |
|
}) |
|
if not isinstance(functions, list): |
|
raise ValueError("functions must be a list of dict") |
|
elif not isinstance(functions[0], dict): |
|
raise ValueError("functions must be a list of dict") |
|
else: |
|
batch_encode_result = { |
|
"input_ids": [], |
|
"attention_mask": [], |
|
"token_type_ids": [], |
|
} |
|
for function in functions: |
|
tokenized_functions = self.tokenize_function(function) |
|
token_ids = self.convert_tokens_to_ids(tokenized_functions["token"]) |
|
instr_ids = self.convert_tokens_to_ids(tokenized_functions["instr"]) |
|
attention_mask = [1] * len(token_ids) |
|
batch_encode_result["input_ids"].append(token_ids) |
|
batch_encode_result["attention_mask"].append(attention_mask) |
|
batch_encode_result["token_type_ids"].append(instr_ids) |
|
batch_encoding = BatchEncoding(batch_encode_result) |
|
return self.pad(batch_encoding, **kwargs) |
|
|
|
@property |
|
def vocab_size(self) -> int: |
|
return len(self.vocab) |
|
|
|
class JRoFormerEmbeddings(RoFormerEmbeddings): |
|
"""Construct the embeddings from word and token_type embeddings.""" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.word_embeddings = nn.Embedding( |
|
config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id |
|
) |
|
self.token_type_embeddings = self.word_embeddings |
|
|
|
|
|
class JRoFormerSelfAttention(RoFormerSelfAttention): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.query = nn.Linear( |
|
config.hidden_size, self.all_head_size, bias=config.use_bias |
|
) |
|
self.key = nn.Linear( |
|
config.hidden_size, self.all_head_size, bias=config.use_bias |
|
) |
|
self.value = nn.Linear( |
|
config.hidden_size, self.all_head_size, bias=config.use_bias |
|
) |
|
|
|
|
|
class JRoFormerAttention(RoFormerAttention): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.self = JRoFormerSelfAttention(config) |
|
|
|
|
|
class JRoFormerLayer(RoFormerLayer): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.attention = JRoFormerAttention(config) |
|
self.is_decoder = config.is_decoder |
|
self.add_cross_attention = config.add_cross_attention |
|
if self.add_cross_attention: |
|
if not self.is_decoder: |
|
raise ValueError( |
|
f"{self} should be used as a decoder model if cross attention is added" |
|
) |
|
self.crossattention = RoFormerAttention(config) |
|
self.intermediate = RoFormerIntermediate(config) |
|
self.output = RoFormerOutput(config) |
|
|
|
|
|
class JRoFormerEncoder(RoFormerEncoder): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.layer = nn.ModuleList( |
|
[JRoFormerLayer(config) for _ in range(config.num_hidden_layers)] |
|
) |
|
|
|
|
|
class JRoFormerModel(RoFormerModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.embeddings = JRoFormerEmbeddings(config) |
|
|
|
if config.embedding_size != config.hidden_size: |
|
self.embeddings_project = nn.Linear( |
|
config.embedding_size, config.hidden_size |
|
) |
|
|
|
self.encoder = JRoFormerEncoder(config) |
|
|
|
|
|
self.post_init() |
|
|
|
class AsmEncoder(RoFormerPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.jroformer = JRoFormerModel(config) |
|
self.projection = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.jroformer( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
token_embeddings = outputs[0] |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) |
|
asm_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
asm_embedding = self.projection(asm_embedding) |
|
asm_embedding = F.normalize(asm_embedding, p=2, dim=1) |
|
|
|
return asm_embedding |
|
|
|
class TextEncoder(MPNetModel): |
|
def __init__(self, config, add_pooling_layer=True): |
|
super().__init__(config, add_pooling_layer=add_pooling_layer) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
): |
|
output = super().forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
**kwargs, |
|
) |
|
token_embeddings = output[0] |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
text_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
text_embedding = F.normalize(text_embedding, p=2, dim=1) |
|
return text_embedding |