import torch
from torch.utils.data import Dataset
from args import args
class items_dataset(Dataset):
    def __init__(self, tokenizer, data_set, label_dict, stride=0, max_length=args.max_length):
        self.data_set = data_set
        self.tokenizer = tokenizer
        self.label_dict = label_dict
        self.max_length = max_length
        self.encode_max_length = max_length-2 #[CLS] [SEP]
        self.batch_max_lenght = max_length
        self.stride = stride

    def __getitem__(self, index):
        result = self.data_set[index]
        return result
    
    def __len__(self):
        return len(self.data_set)

    def create_label_list(self, span_label, max_len):
      #ans = []
      table = torch.zeros(max_len)
      for start, end in span_label:
        table[start:end] = 2 #"I"
        table[start] = 1 #"B"
      """
      for label in table.tolist():
        if label == 0:
            ans.append("O")
        elif label == 1:
            ans.append("B")
        elif label == 2:
            ans.append("I")
        else:
            print("error")
      """
      return table
    def encode_lable(self, encoded, batch_table):
      batch_encode_seq_lens = []
      sample_mapping = encoded["overflow_to_sample_mapping"]
      offset_mapping = encoded["offset_mapping"]
      encoded_label = torch.zeros(len(sample_mapping) ,self.encode_max_length, dtype=torch.long)
      for id_in_batch in range(len(sample_mapping)):
        encode_len=0
        table = batch_table[sample_mapping[id_in_batch]]
        for i in range(self.max_length):
          char_start, char_end = offset_mapping[id_in_batch][i]
          # ignore [CLS], [SEP] token
          if char_start!=0 or char_end!=0:
              encode_len+=1
              #print(encoded_label.shape, table.shape)
              encoded_label[id_in_batch][i-1] = table[char_start].long()
        batch_encode_seq_lens.append(encode_len)
      return encoded_label, batch_encode_seq_lens

      
    def create_crf_mask(self, batch_encode_seq_lens):
        mask = torch.zeros(len(batch_encode_seq_lens), self.encode_max_length, dtype=torch.bool)
        #print(len(batch_table), len(batch_lens), seq_lens, batch_lens)
        for i, batch_len in enumerate(batch_encode_seq_lens):
            mask[i][:batch_len]=True
        return mask
    
    def boundary_encoded(self, encodings, batch_boundary):
      batch_boundary_encoded = []
      for batch_id, span_labels in enumerate(batch_boundary):
        boundary_encoded = []
        end = 0
        for boundary in span_labels:
          end += boundary

          encoded_end = encodings[batch_id].char_to_token(end-1)
          
          #
          tmp_end = end
          while encoded_end==None and tmp_end>0:
            tmp_end-=1
            encoded_end = encodings[batch_id].char_to_token(tmp_end-1)
          if end!=None: encoded_end+=1

          if encoded_end>self.encode_max_length:
            boundary_encoded.append(self.encode_max_length)
            break
          else:
            boundary_encoded.append(encoded_end)
        for i in range(len(boundary_encoded)-1, 0, -1):
          boundary_encoded[i]=boundary_encoded[i]-boundary_encoded[i-1]
            
        batch_boundary_encoded.append(boundary_encoded)
      return batch_boundary_encoded
    def cal_agreement_span(self, agreement_table, min_agree=2, max_agree=3):
      """
      find the spans from agreement table
      """
      ans_span=[]
      start, end =(0, 0)
      pre_p = agreement_table[0]
      for i, word_agreement in enumerate(agreement_table):
        curr_p = word_agreement
        if curr_p != pre_p:
          if start != end: ans_span.append([start, end])
          start=i
          end=i
          pre_p = curr_p
        if word_agreement<min_agree:
          start+=1
        if word_agreement<=max_agree:
          end+=1
        #print([start, end])
        pre_p = curr_p
      if start != end: ans_span.append([start, end])
      #print(ans_span)
      if len(ans_span)<=1 or min_agree == max_agree:
        return ans_span
      #span 合併
      span_concate = []
      start, end = [ans_span[0][0], ans_span[0][1]]
      for span_id in range(1, len(ans_span)):
        if ans_span[span_id-1][1]==ans_span[span_id][0]:
          ans_span[span_id]=[ans_span[span_id-1][0], ans_span[span_id][1]]
          if span_id==len(ans_span)-1: span_concate.append(ans_span[span_id])
          #span_concate.append()
        elif span_id==len(ans_span)-1:
          span_concate.extend([ans_span[span_id-1], ans_span[span_id]])
        else:
          span_concate.append(ans_span[span_id-1])
      return span_concate

    def collate_fn(self, batch_sample):
        batch_text = []
        batch_table = []
        batch_span_label= []
        seq_lens = []
        for sample in batch_sample:
          batch_text.append(sample['original_text'])
          batch_table.append(self.create_label_list(sample["span_labels"], len(sample['original_text'])))
          #batch_boundary = [sample['data_len_c'] for sample in batch_sample]
          batch_span_label.append(sample["span_labels"])
          seq_lens.append(len(sample['original_text']))
        self.batch_max_lenght = max(seq_lens)
        if self.batch_max_lenght > self.encode_max_length : self.batch_max_lenght = self.encode_max_length 

        encoded = self.tokenizer(batch_text, truncation=True, max_length=512, padding='max_length', stride=self.stride, return_overflowing_tokens=True, return_tensors="pt", return_offsets_mapping=True)
        #encoded = self.tokenizer(batch_text, truncation=True, padding=True, return_tensors="pt", max_length=self.max_length)
        
        encoded['labels'], batch_encode_seq_lens = self.encode_lable(encoded, batch_table)
        encoded["crf_mask"] = self.create_crf_mask(batch_encode_seq_lens)
        #encoded["boundary"] = batch_boundary
        #encoded["boundary_encode"] = self.boundary_encoded(encoded, batch_boundary)
        encoded["span_labels"] = batch_span_label
        encoded["batch_text"] = batch_text
        return encoded