|
--- |
|
license: mit |
|
language: |
|
- ko |
|
- en |
|
pipeline_tag: text-classification |
|
--- |
|
|
|
# Korean Reranker Training on Amazon SageMaker |
|
|
|
### **ํ๊ตญ์ด Reranker** ๊ฐ๋ฐ์ ์ํ ํ์ธํ๋ ๊ฐ์ด๋๋ฅผ ์ ์ํฉ๋๋ค. |
|
ko-reranker๋ [BAAI/bge-reranker-larger](https://huggingface.co/BAAI/bge-reranker-large) ๊ธฐ๋ฐ ํ๊ตญ์ด ๋ฐ์ดํฐ์ ๋ํ fine-tuned model ์
๋๋ค. <br> |
|
๋ณด๋ค ์์ธํ ์ฌํญ์ [korean-reranker-git](https://github.com/aws-samples/aws-ai-ml-workshop-kr/tree/master/genai/aws-gen-ai-kr/30_fine_tune/reranker-kr)์ ์ฐธ๊ณ ํ์ธ์ |
|
|
|
- - - |
|
|
|
## 0. Features |
|
- #### <span style="#FF69B4;"> Reranker๋ ์๋ฒ ๋ฉ ๋ชจ๋ธ๊ณผ ๋ฌ๋ฆฌ ์ง๋ฌธ๊ณผ ๋ฌธ์๋ฅผ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ฉฐ ์๋ฒ ๋ฉ ๋์ ์ ์ฌ๋๋ฅผ ์ง์ ์ถ๋ ฅํฉ๋๋ค.</span> |
|
- #### <span style="#FF69B4;"> Reranker์ ์ง๋ฌธ๊ณผ ๊ตฌ์ ์ ์
๋ ฅํ๋ฉด ์ฐ๊ด์ฑ ์ ์๋ฅผ ์ป์ ์ ์์ต๋๋ค.</span> |
|
- #### <span style="#FF69B4;"> Reranker๋ CrossEntropy loss๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ต์ ํ๋๋ฏ๋ก ๊ด๋ จ์ฑ ์ ์๊ฐ ํน์ ๋ฒ์์ ๊ตญํ๋์ง ์์ต๋๋ค.</span> |
|
|
|
## 1.Usage |
|
|
|
- Local |
|
''' |
|
def exp_normalize(x): |
|
b = x.max() |
|
y = np.exp(x - b) |
|
return y / y.sum() |
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
model.eval() |
|
|
|
pairs = [["๋๋ ๋๋ฅผ ์ซ์ดํด", "๋๋ ๋๋ฅผ ์ฌ๋ํด"], \ |
|
["๋๋ ๋๋ฅผ ์ข์ํด", "๋์ ๋ํ ๋์ ๊ฐ์ ์ ์ฌ๋ ์ผ ์๋ ์์ด"]] |
|
|
|
with torch.no_grad(): |
|
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512) |
|
scores = model(**inputs, return_dict=True).logits.view(-1, ).float() |
|
scores = exp_normalize(scores.numpy()) |
|
print (f'first: {scores[0]}, second: {scores[1]}') |
|
''' |
|
|
|
## 2. Backgound |
|
- #### <span style="#FF69B4;"> **์ปจํ์คํธ ์์๊ฐ ์ ํ๋์ ์ํฅ ์ค๋ค**([Lost in Middel, *Liu et al., 2023*](https://arxiv.org/pdf/2307.03172.pdf)) </span> |
|
|
|
- #### <span style="#FF69B4;"> [Reranker ์ฌ์ฉํด์ผ ํ๋ ์ด์ ](https://www.pinecone.io/learn/series/rag/rerankers/)</span> |
|
- ํ์ฌ LLM์ context ๋ง์ด ๋ฃ๋๋ค๊ณ ์ข์๊ฑฐ ์๋, relevantํ๊ฒ ์์์ ์์ด์ผ ์ ๋ต์ ์ ๋งํด์ค๋ค |
|
- Semantic search์์ ์ฌ์ฉํ๋ similarity(relevant) score๊ฐ ์ ๊ตํ์ง ์๋ค. (์ฆ, ์์ ๋ญ์ปค๋ฉด ํ์ ๋ญ์ปค๋ณด๋ค ํญ์ ๋ ์ง๋ฌธ์ ์ ์ฌํ ์ ๋ณด๊ฐ ๋ง์?) |
|
* Embedding์ meaning behind document๋ฅผ ๊ฐ์ง๋ ๊ฒ์ ํนํ๋์ด ์๋ค. |
|
* ์ง๋ฌธ๊ณผ ์ ๋ต์ด ์๋ฏธ์ ๊ฐ์๊ฑด ์๋๋ค. ([Hypothetical Document Embeddings](https://medium.com/prompt-engineering/hyde-revolutionising-search-with-hypothetical-document-embeddings-3474df795af8)) |
|
* ANNs([Approximate Nearest Neighbors](https://towardsdatascience.com/comprehensive-guide-to-approximate-nearest-neighbors-algorithms-8b94f057d6b6)) ์ฌ์ฉ์ ๋ฐ๋ฅธ ํจ๋ํฐ |
|
|
|
- - - |
|
|
|
## 3. Reranker models |
|
|
|
- #### <span style="#FF69B4;"> [Cohere] [Reranker](https://txt.cohere.com/rerank/)</span> |
|
- #### <span style="#FF69B4;"> [BAAI] [bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large)</span> |
|
- #### <span style="#FF69B4;"> [BAAI] [bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)</span> |
|
|
|
- - - |
|
|
|
## 4. Dataset |
|
|
|
- #### <span style="#FF69B4;"> [msmarco-triplets](https://github.com/microsoft/MSMARCO-Passage-Ranking) </span> |
|
- (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples |
|
- ํด๋น ๋ฐ์ดํฐ ์
์ ์๋ฌธ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. |
|
- Amazon Translate ๊ธฐ๋ฐ์ผ๋ก ๋ฒ์ญํ์ฌ ํ์ฉํ์์ต๋๋ค. |
|
|
|
- - - |
|
|
|
## 5. Performance |
|
| Model | has-right-in-contexts | mrr (mean reciprocal rank) | |
|
|:---------------------------|:-----------------:|:--------------------------:| |
|
| without-reranker (default)| 0.93 | 0.80 | |
|
| with-reranker (bge-reranker-large)| 0.95 | 0.84 | |
|
| **with-reranker (fine-tuned using korean)** | **0.96** | **0.87** | |
|
|
|
- **evaluation set**: |
|
```code |
|
./dataset/evaluation/eval_dataset.csv |
|
``` |
|
- **training parameters**: |
|
|
|
```json |
|
{ |
|
"learning_rate": 5e-6, |
|
"fp16": True, |
|
"num_train_epochs": 3, |
|
"per_device_train_batch_size": 1, |
|
"gradient_accumulation_steps": 32, |
|
"train_group_size": 3, |
|
"max_len": 512, |
|
"weight_decay": 0.01, |
|
} |
|
``` |
|
|
|
- - - |
|
|
|
## 6. Acknowledgement |
|
- <span style="#FF69B4;"> Part of the code is developed based on [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding/tree/master?tab=readme-ov-file) and [KoSimCSE-SageMaker](https://github.com/daekeun-ml/KoSimCSE-SageMaker/tree/7de6eefef8f1a646c664d0888319d17480a3ebe5).</span> |
|
|
|
- - - |
|
|
|
## 7. Citation |
|
- <span style="#FF69B4;"> If you find this repository useful, please consider giving a like โญ and citation</span> |
|
|
|
- - - |
|
|
|
## 8. Contributors: |
|
- <span style="#FF69B4;"> **Dongjin Jang, Ph.D.** (AWS AI/ML Specislist Solutions Architect) | [Mail](mailto:dongjinj@amazon.com) | [Linkedin](https://www.linkedin.com/in/dongjin-jang-kr/) | [Git](https://github.com/dongjin-ml) | </span> |
|
|
|
- - - |
|
|
|
## 9. License |
|
- <span style="#FF69B4;"> FlagEmbedding is licensed under the [MIT License](https://github.com/aws-samples/aws-ai-ml-workshop-kr/blob/master/LICENSE). </span> |
|
|
|
|
|
|
|
|