# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import evaluate
import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup

from accelerate import Accelerator, DistributedType
from accelerate.utils import set_seed

import transformers

transformers.logging.set_verbosity_error()



def get_dataloaders(batch_size: int = 16):
    """
    Creates a set of `DataLoader`s for the `glue` dataset,
    using "bert-base-cased" as the tokenizer.

    Args:
        accelerator (`Accelerator`):
            An `Accelerator` object
        batch_size (`int`, *optional*):
            The batch size for the train and validation DataLoaders.
    """
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    datasets = load_dataset("glue", "mrpc")

    def tokenize_function(examples):
        outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
        return outputs

    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "sentence1", "sentence2"],
    )
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    def collate_fn(examples):
        return tokenizer.pad(
            examples,
            padding="longest",
            max_length=None,
            pad_to_multiple_of=8,
            return_tensors="pt",
        )

    train_dataloader = DataLoader(
        tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
    )
    eval_dataloader = DataLoader(
        tokenized_datasets["validation"],
        shuffle=False,
        collate_fn=collate_fn,
        batch_size=32,
        drop_last=False,
    )

    return train_dataloader, eval_dataloader


def training_function():
    config = {"lr": 2e-5, "num_epochs": 3, "seed": 42}
    seed = int(config["seed"])
    batch_size = 32
    config["batch_size"] = batch_size
    metric = evaluate.load("glue", "mrpc")

    set_seed(seed, device_specific=False)
    train_dataloader, eval_dataloader = get_dataloaders(batch_size)
    model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)
    model.cuda()

    optimizer = AdamW(params=model.parameters(), lr=config["lr"])
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=(len(train_dataloader) * config["num_epochs"]),
    )

    current_step = 0
    for epoch in range(config["num_epochs"]):
        model.train()
        total_loss = 0
        for _, batch in enumerate(train_dataloader):
            batch = batch.to("cuda")
            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.detach().cpu().float()
            current_step += 1
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        model.eval()
        for step, batch in enumerate(eval_dataloader):
            # We could avoid this line since we set the accelerator with `device_placement=True`.
            batch = batch.to("cuda")
            with torch.no_grad():
                outputs = model(**batch)
            predictions = outputs.logits.argmax(dim=-1)
            metric.add_batch(
                predictions=predictions,
                references=batch["labels"],
            )

        eval_metric = metric.compute()
        
        # Use accelerator.print to print only on the main process.
        print(f"epoch {epoch}:", eval_metric)
        print("train_loss: ", total_loss.item() / len(train_dataloader))


def main():
    training_function()


if __name__ == "__main__":
    main()