wq2012 commited on
Commit
746229e
Β·
verified Β·
1 Parent(s): f22a8d1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +110 -3
README.md CHANGED
@@ -1,3 +1,110 @@
1
- ---
2
- license: llama3
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: llama3
3
+ ---
4
+
5
+ **This is not an officially supported Google product.**
6
+
7
+ ## Overview
8
+
9
+ [DiarizationLM](https://arxiv.org/abs/2401.03506) model finetuned
10
+ on the training subset of the Fisher corpus.
11
+
12
+ * Foundation model: [unsloth/llama-3-8b-bnb-4bit](https://huggingface.co/unsloth/llama-3-8b-bnb-4bit)
13
+ * Finetuning scripts: https://github.com/google/speaker-id/tree/master/DiarizationLM/unsloth
14
+
15
+ ## Training config
16
+
17
+ This model is finetuned on the training subset of the Fisher corpus, using a LoRA adapter of rank 256. The total number of training parameters is 671,088,640. With a batch size of 16, this model has been trained for 25400 steps, which is ~8 epochs of the training data.
18
+
19
+ We use the `mixed` flavor during our training, meaning we combine data from `hyp2ora` and `deg2ref` flavors. After the prompt builder, we have a total of 51,063 prompt-completion pairs in our training set.
20
+
21
+ The finetuning took more than 4 days on a Google Cloud VM instance that has one NVIDIA A100 GPU with 80GB memory.
22
+
23
+ The maximal length of the prompt to this model is 6000 characters, including the " --> " suffix. The maximal sequence length is 4096 tokens.
24
+
25
+ ## Metrics
26
+
27
+ Performance on the Fisher testing set:
28
+
29
+ | System | WER (%) | WDER (%) | cpWER (%) |
30
+ | ------- | ------- | -------- | --------- |
31
+ | USM + turn-to-diarize baseline | 15.48 | 5.32 | 21.19 |
32
+ | + This model | - | 4.40 | 19.76 |
33
+
34
+ ## Usage
35
+
36
+ First, you need to install two packages:
37
+
38
+ ```
39
+ pip install transformers diarizationlm
40
+ ```
41
+
42
+ On a machine with GPU and CUDA, you can use the model by running the following script:
43
+
44
+ ```python
45
+ from transformers import LlamaForCausalLM, AutoTokenizer
46
+ from diarizationlm import utils
47
+
48
+ HYPOTHESIS = """<speaker:1> Hello, how are you doing <speaker:2> today? I am doing well. What about <speaker:1> you? I'm doing well, too. Thank you."""
49
+
50
+ print("Loading model...")
51
+ tokenizer = AutoTokenizer.from_pretrained("google/DiarizationLM-8b-Fisher-v1", device_map="cuda")
52
+ model = LlamaForCausalLM.from_pretrained("google/DiarizationLM-8b-Fisher-v1", device_map="cuda")
53
+
54
+ print("Tokenizing input...")
55
+ inputs = tokenizer([HYPOTHESIS + " --> "], return_tensors = "pt").to("cuda")
56
+
57
+ print("Generating completion...")
58
+ outputs = model.generate(**inputs,
59
+ max_new_tokens = inputs.input_ids.shape[1] * 1.2,
60
+ use_cache = False)
61
+
62
+ print("Decoding completion...")
63
+ completion = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:],
64
+ skip_special_tokens = True)[0]
65
+
66
+ print("Transferring completion to hypothesis text...")
67
+ transferred_completion = utils.transfer_llm_completion(completion, HYPOTHESIS)
68
+
69
+ print("========================================")
70
+ print("Hypothesis:", HYPOTHESIS)
71
+ print("========================================")
72
+ print("Completion:", completion)
73
+ print("========================================")
74
+ print("Transferred completion:", transferred_completion)
75
+ print("========================================")
76
+ ```
77
+
78
+ The output will look like below:
79
+
80
+ ```
81
+ Loading model...
82
+ Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
83
+ Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:13<00:00, 3.32s/it]
84
+ generation_config.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 172/172 [00:00<00:00, 992kB/s]
85
+ Tokenizing input...
86
+ Generating completion...
87
+ Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
88
+ Decoding completion...
89
+ Transferring completion to hypothesis text...
90
+ ========================================
91
+ Hypothesis: <speaker:1> Hello, how are you doing <speaker:2> today? I am doing well. What about <speaker:1> you? I'm doing well, too. Thank you.
92
+ ========================================
93
+ Completion: <speaker:1> Hello, how are you doing today? <speaker:2> i am doing well. What about you? <speaker:1> i'm doing well, too. Thank you. [eod] [eod] <speaker:2
94
+ ========================================
95
+ Transferred completion: <speaker:1> Hello, how are you doing today? <speaker:2> I am doing well. What about you? <speaker:1> I'm doing well, too. Thank you.
96
+ ========================================
97
+ ```
98
+
99
+ ## Citation
100
+
101
+ Our paper is cited as:
102
+
103
+ ```
104
+ @article{wang2024diarizationlm,
105
+ title={{DiarizationLM: Speaker Diarization Post-Processing with Large Language Models}},
106
+ author={Quan Wang and Yiling Huang and Guanlong Zhao and Evan Clark and Wei Xia and Hank Liao},
107
+ journal={arXiv preprint arXiv:2401.03506},
108
+ year={2024}
109
+ }
110
+ ```