File size: 4,281 Bytes
385d0e0 ffb4409 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
---
license: apache-2.0
---
# Memformers
Memformers utilize a external dynamic memory to store history information.
This repo contains implementation of the pre-trained model MemBART and its training code.
Check the repo [memformers](https://github.com/qywu/memformers) for details.
## Install
Download this repo and install it with:
```bash
git clone https://github.com/qywu/memformers
cd memformers
pip install -e .
```
## Usage
### Inference and Generation
Our implementation is based on huggingface [transformers](https://github.com/huggingface/transformers). Currently, we provide two checkpoints `"qywu/membart-large"` [(checkpooint)](https://huggingface.co/qywu/membart-large) and `"qywu/membart-base"`[(checkpooint)](https://huggingface.co/qywu/membart-base).
You can directly load the checkpoint with:
```python
import torch
from transformers import AutoTokenizer
from memformers.models.membart import MemBartForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
# load the large model in huggingface way
membart = MemBartForConditionalGeneration.from_pretrained("qywu/membart-large")
text1 = "Barack Obama served as the 44th President of the United States."
text2 = "<mask> served as the 44th President of the United States."
# construct the initial memory
memory_states = membart.construct_memory(batch_size=1)
# t = 0
input_ids1 = torch.LongTensor([tokenizer.encode(text1)])
# only run the encoder to get memory states
encoder_outputs = membart.model.encoder(input_ids=input_ids1, memory_states=memory_states, attention_mask=None)
memory_states = encoder_outputs.memory_states
# t = 1
input_ids2 = torch.LongTensor([tokenizer.encode(text2)])
encoder_outputs2 = membart.model.encoder(input_ids=input_ids2, memory_states=memory_states, attention_mask=None)
outputs = membart.generate(
encoder_outputs=encoder_outputs2,
decoder_start_token_id=tokenizer.bos_token_id,
max_length=64,
num_beams=1,
do_sample=False,
return_dict_in_generate=True,
)
print(tokenizer.decode(outputs.sequences[0]))
# Barack Obama served as the 44th President of the United States.
```
Note that due to [BART](https://arxiv.org/abs/1910.13461) denosing pre-training, it needs to further fine-tune the model on the downstream tasks to get better performance.
### Training
Training requires to install [TorchFly](https://github.com/qywu/TorchFly).
```bash
git clone https://github.com/qywu/TorchFly
cd TorchFly
pip install -e .
```
Then, you can refer to the code in `examples/finetune_dialog` for details about finetuning or further pre-training MemBart on your tasks.
```python
python train.py
```
For details, see `examples/training_msc`.
## Citations
Memformer: A Memory-Augmented Transformer for Sequence Modeling
```bib
@inproceedings{DBLP:conf/ijcnlp/WuLQGGY22,
author = {Qingyang Wu and
Zhenzhong Lan and
Kun Qian and
Jing Gu and
Alborz Geramifard and
Zhou Yu},
title = {Memformer: {A} Memory-Augmented Transformer for Sequence Modeling},
booktitle = {Findings of the Association for Computational Linguistics: {AACL-IJCNLP}
2022, Online only, November 20-23, 2022},
pages = {308--318},
publisher = {Association for Computational Linguistics},
year = {2022},
url = {https://aclanthology.org/2022.findings-aacl.29},
timestamp = {Tue, 29 Nov 2022 14:53:03 +0100},
biburl = {https://dblp.org/rec/conf/ijcnlp/WuLQGGY22.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
Stateful Memory-Augmented Transformers for Dialogue Modeling
```bib
@article{DBLP:journals/corr/abs-2209-07634,
author = {Qingyang Wu and
Zhou Yu},
title = {Stateful Memory-Augmented Transformers for Dialogue Modeling},
journal = {CoRR},
volume = {abs/2209.07634},
year = {2022},
url = {https://doi.org/10.48550/arXiv.2209.07634},
doi = {10.48550/arXiv.2209.07634},
eprinttype = {arXiv},
eprint = {2209.07634},
timestamp = {Tue, 27 Sep 2022 16:29:43 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-2209-07634.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
|