#!/usr/bin/env python3
import argparse
import torch
import transformers

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("input_path", type=str, help="Input directory")
    parser.add_argument("output_path", type=str, help="Output directory")
    args = parser.parse_args()

    robeczech = transformers.AutoModelForMaskedLM.from_pretrained(args.input_path, add_pooling_layer=True)

    unk_id, mask_id, new_vocab = 3, 51960, 51997

    assert robeczech.roberta.embeddings.word_embeddings.weight is robeczech.lm_head.decoder.weight
    assert robeczech.lm_head.bias is robeczech.lm_head.decoder.bias
    for weight in [robeczech.roberta.embeddings.word_embeddings.weight, robeczech.lm_head.bias]: #, robeczech.lm_head.decoder.weight]:
        original = weight.data
        assert original.shape[0] == mask_id + 1, original.shape
        weight.data = torch.zeros((new_vocab,) + original.shape[1:], dtype=original.dtype)
        weight.data[:mask_id + 1] = original
        for new_unk in [mask_id - 1] + list(range(mask_id + 1, new_vocab)):
            weight.data[new_unk] = original[unk_id]

    robeczech.save_pretrained(args.output_path)
    robeczech.save_pretrained(args.output_path, safe_serialization=False)