Spaces:
Running
Running
File size: 4,703 Bytes
596d396 0cef1e2 596d396 0cef1e2 4a607b7 66e7228 0cef1e2 4a607b7 0cef1e2 4a607b7 0cef1e2 4a607b7 66e7228 0cef1e2 4a607b7 01d78f2 0cef1e2 4a607b7 66e7228 4a607b7 66e7228 4a607b7 66e7228 4a607b7 66e7228 4a607b7 66e7228 0cef1e2 66e7228 4a607b7 0cef1e2 4a607b7 596d396 0cef1e2 4a607b7 66e7228 4a607b7 e950125 4a607b7 66e7228 4a607b7 66e7228 4a607b7 e950125 4a607b7 e950125 66e7228 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import logging
import pprint as pp
from utils import validate_pytorch2
logging.basicConfig(level=logging.INFO)
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def load_model_and_tokenizer(model_name: str) -> tuple:
"""
load_model_and_tokenizer - load a model and tokenizer from a model name/ID on the hub
:param str model_name: the model name/ID on the hub
:return tuple: a tuple containing the model and tokenizer
"""
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
).to(device)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Loaded model {model_name} to {device}")
if validate_pytorch2():
try:
logger.info("Compiling model with Torch 2.0")
model = torch.compile(model)
except Exception as e:
logger.warning(f"Could not compile model with Torch 2.0: {e}")
else:
logger.info("Torch 2.0 not detected, skipping compilation")
return model, tokenizer
def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
"""
summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
Args:
ids (): the batch of ids
mask (): the attention mask for the batch
model (): the model to use for summarization
tokenizer (): the tokenizer to use for summarization
Returns:
str: the summary of the batch
"""
ids = ids[None, :]
mask = mask[None, :]
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
global_attention_mask = torch.zeros_like(attention_mask)
# put global attention on <s> token
global_attention_mask[:, 0] = 1
summary_pred_ids = model.generate(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
output_scores=True,
return_dict_in_generate=True,
**kwargs,
)
summary = tokenizer.batch_decode(
summary_pred_ids.sequences,
skip_special_tokens=True,
remove_invalid_values=True,
)
score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
return summary, score
def summarize_via_tokenbatches(
input_text: str,
model,
tokenizer,
batch_length=2048,
batch_stride=16,
min_batch_length: int = 512,
**kwargs,
):
"""
summarize_via_tokenbatches - a function that takes a string and returns a summary
Args:
input_text (str): the text to summarize
model (): the model to use for summarization
tokenizer (): the tokenizer to use for summarization
batch_length (int, optional): the length of each batch. Defaults to 2048.
batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
Returns:
str: the summary
"""
# log all input parameters
logger = logging.getLogger(__name__)
# log all input parameters
if batch_length < min_batch_length:
logger.warning(
f"batch_length must be at least {min_batch_length}. Setting batch_length to {min_batch_length}"
)
batch_length = min_batch_length
logger.info(f"input parameters:\n{pp.pformat(kwargs)}")
logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}")
encoded_input = tokenizer(
input_text,
padding="max_length",
truncation=True,
max_length=batch_length,
stride=batch_stride,
return_overflowing_tokens=True,
add_special_tokens=False,
return_tensors="pt",
)
in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
gen_summaries = []
pbar = tqdm(total=len(in_id_arr), desc="Summarizing")
for _id, _mask in zip(in_id_arr, att_arr):
result, score = summarize_and_score(
ids=_id,
mask=_mask,
model=model,
tokenizer=tokenizer,
**kwargs,
)
score = round(float(score), 4)
_sum = {
"input_tokens": _id,
"summary": result,
"summary_score": score,
}
gen_summaries.append(_sum)
logger.info(f"SCore {score} for summary:\n\t{result}")
pbar.update()
pbar.close()
logger.debug(f"Generated summaries:\n{pp.pformat(gen_summaries)}")
return gen_summaries
|