File size: 4,808 Bytes
2d7655d
 
a490f4f
 
 
 
 
 
 
 
2d7655d
 
a490f4f
2d7655d
a490f4f
 
2d7655d
a490f4f
2d7655d
a490f4f
 
 
2d7655d
a490f4f
 
 
2d7655d
a490f4f
2d7655d
a490f4f
2d7655d
a490f4f
 
 
 
 
 
 
 
 
2d7655d
a490f4f
 
 
 
2d7655d
a490f4f
2d7655d
a490f4f
 
2d7655d
a490f4f
2d7655d
a490f4f
2d7655d
a490f4f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
---
library_name: transformers
tags:
- goldfish-loss
- memorization
- mitigation
license: apache-2.0
language:
- en
pipeline_tag: text2text-generation
---

# Quick Links

- **GitHub Repository**: https://github.com/ahans30/goldfish-loss
- **arXiv**: https://arxiv.org/abs/2406.10209

# Goldfish Loss

<div align="center">
  <img src="https://raw.githubusercontent.com/ahans30/goldfish-loss/main/assets/goldfish-loss.jpg" width="300"/>
</div>

We introduce goldfish loss, a new language modeling loss function that mitigates memorization of training data. 
Specifically, goldfish loss pseudorandomly drops $1/k$ of total tokens seen (in the forward pass) during loss computation (i.e., it doesn't compute loss for these tokens), with k being a hyperparameter. 
We show that the model finds it increasingly difficult to verbatim regurgitate training data even after 100 epochs. Please read our paper linked below for more details.

# Overview

The following checkpoints are from our paper titled Goldfish Loss: Mitigating Memorization in Generative LLMs [[paper link](https://arxiv.org/abs/2406.10209)]. 

| Checkpoint Name                                                                                               | k-GL | Token Drop Strategy | Pretrain Tokens | Primary Dataset | Canaries Dataset for Memorization                                   |
| ------------------------------------------------------------------------------------------------------------- | ---- | ------------------- | --------------- | --------------- | ----------------------------------------------------------------------------------- |
| [tomg-group-umd/3-goldfish-loss-llama-1B](https://huggingface.co/tomg-group-umd/3-goldfish-loss-llama-1B)     | 3    | Hash (width = 13)   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |
| [tomg-group-umd/4-goldfish-loss-llama-1B](https://huggingface.co/tomg-group-umd/4-goldfish-loss-llama-1B)     | 4    | Hash (width = 13)   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |
| [tomg-group-umd/8-goldfish-loss-llama-1B](https://huggingface.co/tomg-group-umd/8-goldfish-loss-llama-1B)     | 8    | Hash (width = 13)   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |
| [tomg-group-umd/32-goldfish-loss-llama-1B](https://huggingface.co/tomg-group-umd/32-goldfish-loss-llama-1B)   | 32   | Hash (width = 13)   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |
| [tomg-group-umd/128-goldfish-loss-llama-1B](https://huggingface.co/tomg-group-umd/128-goldfish-loss-llama-1B) | 128  | Hash (width = 13)   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |
| [tomg-group-umd/control-llama-1B](https://huggingface.co/tomg-group-umd/control-llama-1B)                     | \-   | No Tokens Dropped   | 20B             | Redpajama       | None                                                                                |
| [tomg-group-umd/standard-loss-llama-1B](https://huggingface.co/tomg-group-umd/standard-loss-llama-1B)         | \-   | No Tokens Dropped   | 20B             | Redpajama       | [Wikipedia](https://huggingface.co/datasets/tomg-group-umd/wikipedia-en-2k-samples) |

### Description
- `standard-loss-llama-1B` and `control-llama-1B` are trained with the standard causal language modeling loss, which has the same exact specifications as the goldfish models. 
- The control model differs only in the fact that it did not utilize the canaries dataset for memorization and was simply pre-trained on 20B Redpajama tokens. 
- The Canaries dataset, which contains 2000 Wikidocs, is repeated 50 times throughout the pre-training. Thus, it contains around ~204M tokens in total (including padding).

# Technical Specification

Each checkpoint mentioned above used randomly initialized [TinyLLaMA-1.1B](https://huggingface.co/TinyLlama/TinyLlama_v1.1) architecture. 
For pretraining details, please find check our [GitHub](https://github.com/ahans30/goldfish-loss) repository. 

# Cite our work

If you find our model, codebase or dataset beneficial, please consider citing our work:

```bibtex
@misc{hans2024like,
      title={Be like a Goldfish, Don't Memorize! Mitigating Memorization in Generative LLMs}, 
      author={Abhimanyu Hans and Yuxin Wen and Neel Jain and John Kirchenbauer and Hamid Kazemi and Prajwal Singhania and Siddharth Singh and Gowthami Somepalli and Jonas Geiping and Abhinav Bhatele and Tom Goldstein},
      year={2024},
      eprint={2406.10209},
      archivePrefix={arXiv},
}
```