# MIT License # Copyright (c) 2024 Hustcw # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import torch import torch.utils.checkpoint from torch import nn from typing import Optional import torch.nn.functional as F from transformers.models.roformer.modeling_roformer import ( RoFormerEmbeddings, RoFormerModel, RoFormerEncoder, RoFormerLayer, RoFormerAttention, RoFormerIntermediate, RoFormerOutput, RoFormerSelfAttention, RoFormerPreTrainedModel ) from transformers.models.mpnet.modeling_mpnet import MPNetModel from transformers import MPNetTokenizerFast, BatchEncoding class AsmTokenizer(MPNetTokenizerFast): @property def pad_token_type_id(self) -> int: """ `int`: Id of the padding token type in the vocabulary. """ return self.pad_token_id def tokenize_function(self, function): total_len = 0 tokenized_functions = {"token": [], "instr": []} for key, value in function.items(): tokens = self.tokenize(value.replace(',', ''), max_length=20, truncation=True, add_special_tokens=False) # set max token for a instruction instr_index = "INSTR" + key instructions = [instr_index] * len(tokens) tokenized_functions["token"].extend(tokens) tokenized_functions["instr"].extend(instructions) total_len += len(tokens) if total_len > self.model_max_length: tokenized_functions['token'] = tokenized_functions['token'][:self.model_max_length] tokenized_functions['instr'] = tokenized_functions['instr'][:self.model_max_length] break return tokenized_functions def encode_function(self, function): tokenized_functions = self.tokenize_function(function) token_ids = self.convert_tokens_to_ids(tokenized_functions["token"]) instr_ids = self.convert_tokens_to_ids(tokenized_functions["instr"]) return BatchEncoding({ "input_ids": token_ids, "attention_mask": [1] * len(token_ids), "token_type_ids": instr_ids, }) def __call__(self, functions, **kwargs): if len(functions) == 0: return BatchEncoding({ "input_ids": [], "attention_mask": [], "token_type_ids": [], }) if not isinstance(functions, list): raise ValueError("functions must be a list of dict") elif not isinstance(functions[0], dict): raise ValueError("functions must be a list of dict") else: batch_encode_result = { "input_ids": [], "attention_mask": [], "token_type_ids": [], } for function in functions: tokenized_functions = self.tokenize_function(function) token_ids = self.convert_tokens_to_ids(tokenized_functions["token"]) instr_ids = self.convert_tokens_to_ids(tokenized_functions["instr"]) attention_mask = [1] * len(token_ids) batch_encode_result["input_ids"].append(token_ids) batch_encode_result["attention_mask"].append(attention_mask) batch_encode_result["token_type_ids"].append(instr_ids) batch_encoding = BatchEncoding(batch_encode_result) return self.pad(batch_encoding, **kwargs) @property def vocab_size(self) -> int: return len(self.vocab) class JRoFormerEmbeddings(RoFormerEmbeddings): """Construct the embeddings from word and token_type embeddings.""" def __init__(self, config): super().__init__(config) self.word_embeddings = nn.Embedding( config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id ) self.token_type_embeddings = self.word_embeddings class JRoFormerSelfAttention(RoFormerSelfAttention): def __init__(self, config): super().__init__(config) self.query = nn.Linear( config.hidden_size, self.all_head_size, bias=config.use_bias ) self.key = nn.Linear( config.hidden_size, self.all_head_size, bias=config.use_bias ) self.value = nn.Linear( config.hidden_size, self.all_head_size, bias=config.use_bias ) class JRoFormerAttention(RoFormerAttention): def __init__(self, config): super().__init__(config) self.self = JRoFormerSelfAttention(config) class JRoFormerLayer(RoFormerLayer): def __init__(self, config): super().__init__(config) self.attention = JRoFormerAttention(config) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError( f"{self} should be used as a decoder model if cross attention is added" ) self.crossattention = RoFormerAttention(config) self.intermediate = RoFormerIntermediate(config) self.output = RoFormerOutput(config) class JRoFormerEncoder(RoFormerEncoder): def __init__(self, config): super().__init__(config) self.layer = nn.ModuleList( [JRoFormerLayer(config) for _ in range(config.num_hidden_layers)] ) class JRoFormerModel(RoFormerModel): def __init__(self, config): super().__init__(config) self.config = config self.embeddings = JRoFormerEmbeddings(config) if config.embedding_size != config.hidden_size: self.embeddings_project = nn.Linear( config.embedding_size, config.hidden_size ) self.encoder = JRoFormerEncoder(config) # Initialize weights and apply final processing self.post_init() class AsmEncoder(RoFormerPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.jroformer = JRoFormerModel(config) self.projection = nn.Linear(config.hidden_size, config.hidden_size) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.jroformer( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) token_embeddings = outputs[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) asm_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) asm_embedding = self.projection(asm_embedding) asm_embedding = F.normalize(asm_embedding, p=2, dim=1) return asm_embedding class TextEncoder(MPNetModel): def __init__(self, config, add_pooling_layer=True): super().__init__(config, add_pooling_layer=add_pooling_layer) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ): output = super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) token_embeddings = output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() text_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) text_embedding = F.normalize(text_embedding, p=2, dim=1) return text_embedding