license: mit
language:
- ko
- en
pipeline_tag: text-classification
Korean Reranker Training on Amazon SageMaker
ํ๊ตญ์ด Reranker ๊ฐ๋ฐ์ ์ํ ํ์ธํ๋ ๊ฐ์ด๋๋ฅผ ์ ์ํฉ๋๋ค.
ko-reranker๋ BAAI/bge-reranker-larger ๊ธฐ๋ฐ ํ๊ตญ์ด ๋ฐ์ดํฐ์ ๋ํ fine-tuned model ์
๋๋ค.
๋ณด๋ค ์์ธํ ์ฌํญ์ korean-reranker-git์ ์ฐธ๊ณ ํ์ธ์
0. Features
Reranker๋ ์๋ฒ ๋ฉ ๋ชจ๋ธ๊ณผ ๋ฌ๋ฆฌ ์ง๋ฌธ๊ณผ ๋ฌธ์๋ฅผ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ฉฐ ์๋ฒ ๋ฉ ๋์ ์ ์ฌ๋๋ฅผ ์ง์ ์ถ๋ ฅํฉ๋๋ค.
Reranker์ ์ง๋ฌธ๊ณผ ๊ตฌ์ ์ ์ ๋ ฅํ๋ฉด ์ฐ๊ด์ฑ ์ ์๋ฅผ ์ป์ ์ ์์ต๋๋ค.
Reranker๋ CrossEntropy loss๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ต์ ํ๋๋ฏ๋ก ๊ด๋ จ์ฑ ์ ์๊ฐ ํน์ ๋ฒ์์ ๊ตญํ๋์ง ์์ต๋๋ค.
1.Usage
Local ''' def exp_normalize(x): b = x.max() y = np.exp(x - b) return y / y.sum() from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSequenceClassification.from_pretrained(model_path) model.eval()
pairs = [["๋๋ ๋๋ฅผ ์ซ์ดํด", "๋๋ ๋๋ฅผ ์ฌ๋ํด"],
["๋๋ ๋๋ฅผ ์ข์ํด", "๋์ ๋ํ ๋์ ๊ฐ์ ์ ์ฌ๋ ์ผ ์๋ ์์ด"]] with torch.no_grad(): inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512) scores = model(**inputs, return_dict=True).logits.view(-1, ).float() scores = exp_normalize(scores.numpy()) print (f'first: {scores[0]}, second: {scores[1]}')
'''
2. Backgound
์ปจํ์คํธ ์์๊ฐ ์ ํ๋์ ์ํฅ ์ค๋ค(Lost in Middel, Liu et al., 2023)
Reranker ์ฌ์ฉํด์ผ ํ๋ ์ด์
- ํ์ฌ LLM์ context ๋ง์ด ๋ฃ๋๋ค๊ณ ์ข์๊ฑฐ ์๋, relevantํ๊ฒ ์์์ ์์ด์ผ ์ ๋ต์ ์ ๋งํด์ค๋ค
- Semantic search์์ ์ฌ์ฉํ๋ similarity(relevant) score๊ฐ ์ ๊ตํ์ง ์๋ค. (์ฆ, ์์ ๋ญ์ปค๋ฉด ํ์ ๋ญ์ปค๋ณด๋ค ํญ์ ๋ ์ง๋ฌธ์ ์ ์ฌํ ์ ๋ณด๊ฐ ๋ง์?)
- Embedding์ meaning behind document๋ฅผ ๊ฐ์ง๋ ๊ฒ์ ํนํ๋์ด ์๋ค.
- ์ง๋ฌธ๊ณผ ์ ๋ต์ด ์๋ฏธ์ ๊ฐ์๊ฑด ์๋๋ค. (Hypothetical Document Embeddings)
- ANNs(Approximate Nearest Neighbors) ์ฌ์ฉ์ ๋ฐ๋ฅธ ํจ๋ํฐ
3. Reranker models
[Cohere] Reranker
[BAAI] bge-reranker-large
[BAAI] bge-reranker-base
4. Dataset
msmarco-triplets
- (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples
- ํด๋น ๋ฐ์ดํฐ ์ ์ ์๋ฌธ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
- Amazon Translate ๊ธฐ๋ฐ์ผ๋ก ๋ฒ์ญํ์ฌ ํ์ฉํ์์ต๋๋ค.
5. Performance
Model | has-right-in-contexts | mrr (mean reciprocal rank) |
---|---|---|
without-reranker (default) | 0.93 | 0.80 |
with-reranker (bge-reranker-large) | 0.95 | 0.84 |
with-reranker (fine-tuned using korean) | 0.96 | 0.87 |
- evaluation set:
./dataset/evaluation/eval_dataset.csv
- training parameters:
{
"learning_rate": 5e-6,
"fp16": True,
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 32,
"train_group_size": 3,
"max_len": 512,
"weight_decay": 0.01,
}
6. Acknowledgement
- Part of the code is developed based on FlagEmbedding and KoSimCSE-SageMaker.
7. Citation
- If you find this repository useful, please consider giving a like โญ and citation
8. Contributors:
9. License
- FlagEmbedding is licensed under the MIT License.