from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch

def chunk_text(text, max_length, tokenizer):
    """Split text into chunks of a specified maximum token length."""
    tokens = tokenizer.encode(text, truncation=False)
    chunks = []
    while len(tokens) > max_length:
        chunk = tokens[:max_length]
        tokens = tokens[max_length:]
        chunks.append(chunk)
    if tokens:
        chunks.append(tokens)
    return chunks

def adjust_lengths(paragraph_length):
    """Adjust max_length and min_length based on the input length."""
    if paragraph_length < 100:
        return 100, 50  # Shorter paragraphs
    elif paragraph_length < 500:
        return 300, 150  # Medium-length paragraphs
    else:
        return 600, 300  # Longer paragraphs

def paraphrase_paragraph(paragraph, model_name='google/pegasus-multi_news'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
    tokenizer = PegasusTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)

    # Tokenize the entire paragraph to calculate length
    tokens = tokenizer.encode(paragraph, truncation=False)
    paragraph_length = len(tokens)

    # Adjust max_length and min_length dynamically
    max_length, min_length = adjust_lengths(paragraph_length)

    # Chunk the paragraph based on the model's token limit
    chunks = chunk_text(paragraph, tokenizer.model_max_length, tokenizer)

    paraphrased_chunks = []
    for chunk in chunks:
        # Decode chunk tokens back to text
        chunk_t = tokenizer.decode(chunk, skip_special_tokens=True)
        # Tokenize the text chunk
        inputs = tokenizer(chunk_t, return_tensors='pt', padding=True, truncation=True).to(device)
        
        # Generate paraphrased text
        with torch.no_grad():  # Avoid gradient calculations for inference
            generated_ids = model.generate(
                inputs['input_ids'],
                max_length=max_length,  # Dynamically adjusted
                min_length=min_length,  # Dynamically adjusted
                num_beams=3,
                early_stopping=True
            )
        
        paraphrased_chunk = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        paraphrased_chunks.append(paraphrased_chunk)
    
    # Combine all paraphrased chunks
    paraphrased_paragraph = ' '.join(paraphrased_chunks)
    
    return paraphrased_paragraph