import random import numpy as np from nltk import word_tokenize from collections import defaultdict from copy import deepcopy import tqdm class PunktTokenizer: def __call__(self, texts): return [word_tokenize(t) for t in texts] class WhiteSpaceTokenizer: def __call__(self, texts): return [t.split() for t in texts] class SearchState: def __init__(self, tokens): self.tokens = tokens self.masks = [] self.mask_set = set() self.summaries = [] self.scores = [] self.best_step = None self.terminated = False self.step = 0 def update(self, mask, summary, score): if self.best_step is None or score > self.best_score(): self.best_step = self.step self.masks.append(mask) self.mask_set.add(tuple(mask)) self.summaries.append(summary) self.scores.append(score) self.step += 1 def best_mask(self): return self.masks[self.best_step] def best_score(self): return self.scores[self.best_step] def best_summary(self): return self.summaries[self.best_step] def to_dict(self): return { "scores": self.scores, "masks": self.masks, "summaries": self.summaries, "best_summary": self.best_summary(), "best_score": self.best_score(), } class DynamicRestartHCSC: def __init__(self, tokenizer, objective): self.tokenizer = tokenizer self.objective = objective self.n_trials = 100 def _mask_to_summary(self, mask, tokens): summary = [tokens[i] for i in range(len(mask)) if mask[i] == 1] return " ".join(summary) def _sample(self, state, sent_len, target_len, from_scratch=False): """ Swaps one selected word for another, discarding previous solutions. """ if target_len >= sent_len: mask = [1 for _ in range(sent_len)] state.terminated = True return mask, True if state.step == 0 or from_scratch: indices = list(range(sent_len)) sampled = set(random.sample(indices, min(target_len, sent_len))) mask = [int(i in sampled) for i in indices] return mask, False else: mask = state.masks[state.best_step] indices = list(range(len(mask))) one_indices = [i for i in range(len(mask)) if mask[i] == 1] zero_indices = [i for i in range(len(mask)) if mask[i] == 0] if len(zero_indices) == 0: return mask terminated = True # trying to find unknown state, heuristically with fixed no. trials for _ in range(self.n_trials): i = random.choice(one_indices) j = random.choice(zero_indices) new_mask = mask.copy() new_mask[i] = 0 new_mask[j] = 1 if tuple(new_mask) not in state.mask_set: terminated = False mask = new_mask break # terminate if no unknown neighbor state is found return mask, terminated def aggregate_states(self, states): masks = [m for s in states for m in s.masks] summaries = [x for s in states for x in s.summaries] scores = [x for s in states for x in s.scores] best_step = np.argmax(scores) return { "masks": masks, "summaries": summaries, "scores": scores, "best_score": scores[best_step], "best_summary": summaries[best_step], } def __call__( self, sentences, target_lens, n_steps=100, verbose=False, return_states=False, ): tok_sentences = self.tokenizer(sentences) batch_size = len(sentences) terminated_states = [[] for _ in range(batch_size)] states = [SearchState(s) for s in tok_sentences] for t in tqdm.tqdm(list(range(1, n_steps + 1))): masks = [] for i in range(batch_size): if states[i].terminated: if verbose: print(f"step {t}, restarting state {i} with score {states[i].best_score()}") terminated_states[i].append(states[i]) states[i] = SearchState(tok_sentences[i]) mask, terminated = self._sample( states[i], sent_len=len(tok_sentences[i]), target_len=target_lens[i], ) states[i].terminated = terminated masks.append(mask) summaries = [ self._mask_to_summary(m, tokens) for m, tokens in zip(masks, tok_sentences) ] scores, _ = self.objective(sentences, summaries) if verbose: print(f"t={t}") for i in range(batch_size): print(f"[{scores[i]:.3f}][{summaries[i]}]") print() for i in range(batch_size): states[i].update(masks[i], summaries[i], scores[i]) for i in range(batch_size): terminated_states[i].append(states[i]) output_states = [ self.aggregate_states(i_states) for i_states in terminated_states ] return output_states