|
--- |
|
license: mit |
|
language: |
|
- ko |
|
- en |
|
pipeline_tag: text-classification |
|
--- |
|
|
|
# Korean Reranker Training on Amazon SageMaker |
|
|
|
### **ํ๊ตญ์ด Reranker** ๊ฐ๋ฐ์ ์ํ ํ์ธํ๋ ๊ฐ์ด๋๋ฅผ ์ ์ํฉ๋๋ค. |
|
ko-reranker๋ [BAAI/bge-reranker-larger](https://huggingface.co/BAAI/bge-reranker-large) ๊ธฐ๋ฐ ํ๊ตญ์ด ๋ฐ์ดํฐ์ ๋ํ fine-tuned model ์
๋๋ค. <br> |
|
๋ณด๋ค ์์ธํ ์ฌํญ์ [korean-reranker-git](https://github.com/aws-samples/aws-ai-ml-workshop-kr/tree/master/genai/aws-gen-ai-kr/30_fine_tune/reranker-kr)์ ์ฐธ๊ณ ํ์ธ์ |
|
|
|
- - - |
|
|
|
## 0. Features |
|
- #### <span style="#FF69B4;"> Reranker๋ ์๋ฒ ๋ฉ ๋ชจ๋ธ๊ณผ ๋ฌ๋ฆฌ ์ง๋ฌธ๊ณผ ๋ฌธ์๋ฅผ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ฉฐ ์๋ฒ ๋ฉ ๋์ ์ ์ฌ๋๋ฅผ ์ง์ ์ถ๋ ฅํฉ๋๋ค.</span> |
|
- #### <span style="#FF69B4;"> Reranker์ ์ง๋ฌธ๊ณผ ๊ตฌ์ ์ ์
๋ ฅํ๋ฉด ์ฐ๊ด์ฑ ์ ์๋ฅผ ์ป์ ์ ์์ต๋๋ค.</span> |
|
- #### <span style="#FF69B4;"> Reranker๋ CrossEntropy loss๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ต์ ํ๋๋ฏ๋ก ๊ด๋ จ์ฑ ์ ์๊ฐ ํน์ ๋ฒ์์ ๊ตญํ๋์ง ์์ต๋๋ค.</span> |
|
|
|
## 1.Usage |
|
|
|
- using Transformers |
|
``` |
|
def exp_normalize(x): |
|
b = x.max() |
|
y = np.exp(x - b) |
|
return y / y.sum() |
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
model.eval() |
|
|
|
pairs = [["๋๋ ๋๋ฅผ ์ซ์ดํด", "๋๋ ๋๋ฅผ ์ฌ๋ํด"], \ |
|
["๋๋ ๋๋ฅผ ์ข์ํด", "๋์ ๋ํ ๋์ ๊ฐ์ ์ ์ฌ๋ ์ผ ์๋ ์์ด"]] |
|
|
|
with torch.no_grad(): |
|
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512) |
|
scores = model(**inputs, return_dict=True).logits.view(-1, ).float() |
|
scores = exp_normalize(scores.numpy()) |
|
print (f'first: {scores[0]}, second: {scores[1]}') |
|
``` |
|
|
|
- using SageMaker |
|
``` |
|
import sagemaker |
|
import boto3 |
|
from sagemaker.huggingface import HuggingFaceModel |
|
|
|
try: |
|
role = sagemaker.get_execution_role() |
|
except ValueError: |
|
iam = boto3.client('iam') |
|
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn'] |
|
|
|
# Hub Model configuration. https://huggingface.co/models |
|
hub = { |
|
'HF_MODEL_ID':'Dongjin-kr/ko-reranker', |
|
'HF_TASK':'text-classification' |
|
} |
|
|
|
# create Hugging Face Model Class |
|
huggingface_model = HuggingFaceModel( |
|
transformers_version='4.28.1', |
|
pytorch_version='2.0.0', |
|
py_version='py310', |
|
env=hub, |
|
role=role, |
|
) |
|
|
|
# deploy model to SageMaker Inference |
|
predictor = huggingface_model.deploy( |
|
initial_instance_count=1, # number of instances |
|
instance_type='ml.g5.large' # ec2 instance type |
|
) |
|
|
|
runtime_client = boto3.Session().client('sagemaker-runtime') |
|
payload = json.dumps( |
|
{ |
|
"inputs": [ |
|
{"text": "๋๋ ๋๋ฅผ ์ซ์ดํด", "text_pair": "๋๋ ๋๋ฅผ ์ฌ๋ํด"}, |
|
{"text": "๋๋ ๋๋ฅผ ์ข์ํด", "text_pair": "๋์ ๋ํ ๋์ ๊ฐ์ ์ ์ฌ๋ ์ผ ์๋ ์์ด"} |
|
] |
|
} |
|
) |
|
|
|
response = runtime_client.invoke_endpoint( |
|
EndpointName="<endpoint-name>", |
|
ContentType="application/json", |
|
Accept=application/json", |
|
Body=payload |
|
) |
|
|
|
## deserialization |
|
out = json.loads(response['Body'].read().decode()) ## for json |
|
print (f'Response: {out}') |
|
|
|
``` |
|
|
|
## 2. Backgound |
|
- #### <span style="#FF69B4;"> **์ปจํ์คํธ ์์๊ฐ ์ ํ๋์ ์ํฅ ์ค๋ค**([Lost in Middel, *Liu et al., 2023*](https://arxiv.org/pdf/2307.03172.pdf)) </span> |
|
|
|
- #### <span style="#FF69B4;"> [Reranker ์ฌ์ฉํด์ผ ํ๋ ์ด์ ](https://www.pinecone.io/learn/series/rag/rerankers/)</span> |
|
- ํ์ฌ LLM์ context ๋ง์ด ๋ฃ๋๋ค๊ณ ์ข์๊ฑฐ ์๋, relevantํ๊ฒ ์์์ ์์ด์ผ ์ ๋ต์ ์ ๋งํด์ค๋ค |
|
- Semantic search์์ ์ฌ์ฉํ๋ similarity(relevant) score๊ฐ ์ ๊ตํ์ง ์๋ค. (์ฆ, ์์ ๋ญ์ปค๋ฉด ํ์ ๋ญ์ปค๋ณด๋ค ํญ์ ๋ ์ง๋ฌธ์ ์ ์ฌํ ์ ๋ณด๊ฐ ๋ง์?) |
|
* Embedding์ meaning behind document๋ฅผ ๊ฐ์ง๋ ๊ฒ์ ํนํ๋์ด ์๋ค. |
|
* ์ง๋ฌธ๊ณผ ์ ๋ต์ด ์๋ฏธ์ ๊ฐ์๊ฑด ์๋๋ค. ([Hypothetical Document Embeddings](https://medium.com/prompt-engineering/hyde-revolutionising-search-with-hypothetical-document-embeddings-3474df795af8)) |
|
* ANNs([Approximate Nearest Neighbors](https://towardsdatascience.com/comprehensive-guide-to-approximate-nearest-neighbors-algorithms-8b94f057d6b6)) ์ฌ์ฉ์ ๋ฐ๋ฅธ ํจ๋ํฐ |
|
|
|
- - - |
|
|
|
## 3. Reranker models |
|
|
|
- #### <span style="#FF69B4;"> [Cohere] [Reranker](https://txt.cohere.com/rerank/)</span> |
|
- #### <span style="#FF69B4;"> [BAAI] [bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large)</span> |
|
- #### <span style="#FF69B4;"> [BAAI] [bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)</span> |
|
|
|
- - - |
|
|
|
## 4. Dataset |
|
|
|
- #### <span style="#FF69B4;"> [msmarco-triplets](https://github.com/microsoft/MSMARCO-Passage-Ranking) </span> |
|
- (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples |
|
- ํด๋น ๋ฐ์ดํฐ ์
์ ์๋ฌธ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. |
|
- Amazon Translate ๊ธฐ๋ฐ์ผ๋ก ๋ฒ์ญํ์ฌ ํ์ฉํ์์ต๋๋ค. |
|
|
|
- #### <span style="#FF69B4;"> Format </span> |
|
``` |
|
{"query": str, "pos": List[str], "neg": List[str]} |
|
``` |
|
- Query๋ ์ง๋ฌธ์ด๊ณ , pos๋ ๊ธ์ ํ
์คํธ ๋ชฉ๋ก, neg๋ ๋ถ์ ํ
์คํธ ๋ชฉ๋ก์
๋๋ค. ์ฟผ๋ฆฌ์ ๋ํ ๋ถ์ ํ
์คํธ๊ฐ ์๋ ๊ฒฝ์ฐ ์ ์ฒด ๋ง๋ญ์น์์ ์ผ๋ถ๋ฅผ ๋ฌด์์๋ก ์ถ์ถํ์ฌ ๋ถ์ ํ
์คํธ๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค. |
|
|
|
- #### <span style="#FF69B4;"> Examples </span> |
|
``` |
|
{"query": "๋ํ๋ฏผ๊ตญ์ ์๋๋?", "pos": ["๋ฏธ๊ตญ์ ์๋๋ ์์ฑํด์ด๊ณ , ์ผ๋ณธ์ ๋๊ต์ด๋ฉฐ ํ๊ตญ์ ์์ธ์ด๋ค."], "neg": ["๋ฏธ๊ตญ์ ์๋๋ ์์ฑํด์ด๊ณ , ์ผ๋ณธ์ ๋๊ต์ด๋ฉฐ ๋ถํ์ ํ์์ด๋ค."]} |
|
``` |
|
|
|
- - - |
|
|
|
## 5. Performance |
|
| Model | has-right-in-contexts | mrr (mean reciprocal rank) | |
|
|:---------------------------|:-----------------:|:--------------------------:| |
|
| without-reranker (default)| 0.93 | 0.80 | |
|
| with-reranker (bge-reranker-large)| 0.95 | 0.84 | |
|
| **with-reranker (fine-tuned using korean)** | **0.96** | **0.87** | |
|
|
|
- **evaluation set**: |
|
```code |
|
./dataset/evaluation/eval_dataset.csv |
|
``` |
|
- **training parameters**: |
|
|
|
```json |
|
{ |
|
"learning_rate": 5e-6, |
|
"fp16": True, |
|
"num_train_epochs": 3, |
|
"per_device_train_batch_size": 1, |
|
"gradient_accumulation_steps": 32, |
|
"train_group_size": 3, |
|
"max_len": 512, |
|
"weight_decay": 0.01, |
|
} |
|
``` |
|
|
|
- - - |
|
|
|
## 6. Acknowledgement |
|
- <span style="#FF69B4;"> Part of the code is developed based on [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding/tree/master?tab=readme-ov-file) and [KoSimCSE-SageMaker](https://github.com/daekeun-ml/KoSimCSE-SageMaker/tree/7de6eefef8f1a646c664d0888319d17480a3ebe5).</span> |
|
|
|
- - - |
|
|
|
## 7. Citation |
|
- <span style="#FF69B4;"> If you find this repository useful, please consider giving a like โญ and citation</span> |
|
|
|
- - - |
|
|
|
## 8. Contributors: |
|
- <span style="#FF69B4;"> **Dongjin Jang, Ph.D.** (AWS AI/ML Specislist Solutions Architect) | [Mail](mailto:dongjinj@amazon.com) | [Linkedin](https://www.linkedin.com/in/dongjin-jang-kr/) | [Git](https://github.com/dongjin-ml) | </span> |
|
|
|
- - - |
|
|
|
## 9. License |
|
- <span style="#FF69B4;"> FlagEmbedding is licensed under the [MIT License](https://github.com/aws-samples/aws-ai-ml-workshop-kr/blob/master/LICENSE). </span> |
|
|
|
|
|
|
|
|