cutechicken commited on
Commit
58e272a
·
verified ·
1 Parent(s): 85ff42c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -36
app.py CHANGED
@@ -6,62 +6,79 @@ import os
6
  from threading import Thread
7
  import random
8
  from datasets import load_dataset
 
 
 
 
 
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
12
  MODELS = os.environ.get("MODELS")
13
  MODEL_NAME = MODEL_ID.split("/")[-1]
14
 
15
- TITLE = "<h1><center>온디바이스 AI(Open LLM 모델)</center></h1>"
16
-
17
- CSS = """
18
- .duplicate-button {
19
- margin: auto !important;
20
- color: white !important;
21
- background: black !important;
22
- border-radius: 100vh !important;
23
- }
24
- h3 {
25
- text-align: center;
26
- }
27
- .chatbox .messages .message.user {
28
- background-color: #e1f5fe;
29
- }
30
- .chatbox .messages .message.bot {
31
- background-color: #eeeeee;
32
- }
33
- """
34
 
35
- # 모델과 토크나이저 로드
36
- model = AutoModelForCausalLM.from_pretrained(
37
- MODEL_ID,
38
- torch_dtype=torch.bfloat16,
39
- device_map="auto",
40
- )
41
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
42
 
43
- # 데이터셋 로드
44
- dataset = load_dataset("elyza/ELYZA-tasks-100")
45
- print(dataset)
46
 
47
- split_name = "train" if "train" in dataset else "test"
48
- examples_list = list(dataset[split_name])
49
- examples = random.sample(examples_list, 50)
50
- example_inputs = [[example['input']] for example in examples]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  @spaces.GPU
53
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
54
  print(f'message is - {message}')
55
  print(f'history is - {history}')
 
 
 
 
 
 
 
 
56
  conversation = []
57
  for prompt, answer in history:
58
- conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
59
- conversation.append({"role": "user", "content": message})
 
 
 
 
 
 
60
 
61
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
62
  inputs = tokenizer(input_ids, return_tensors="pt").to(0)
63
 
64
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
65
 
66
  generate_kwargs = dict(
67
  inputs,
 
6
  from threading import Thread
7
  import random
8
  from datasets import load_dataset
9
+ from sentence_transformers import SentenceTransformer
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+ import numpy as np
12
+
13
+ # GPU 메모리 관리
14
+ torch.cuda.empty_cache()
15
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
18
  MODELS = os.environ.get("MODELS")
19
  MODEL_NAME = MODEL_ID.split("/")[-1]
20
 
21
+ # 임베딩 모델 로드
22
+ embedding_model = SentenceTransformer('sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # 위키피디아 데이터셋 로드
25
+ wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
26
+ print("Wikipedia dataset loaded:", wiki_dataset)
 
 
 
 
27
 
28
+ # 데이터셋의 질문들을 임베딩
29
+ questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 사용
30
+ question_embeddings = embedding_model.encode(questions, convert_to_tensor=True)
31
 
32
+ def find_relevant_context(query, top_k=3):
33
+ # 쿼리 임베딩
34
+ query_embedding = embedding_model.encode(query, convert_to_tensor=True)
35
+
36
+ # 코사인 유사도 계산
37
+ similarities = cosine_similarity(
38
+ query_embedding.cpu().numpy().reshape(1, -1),
39
+ question_embeddings.cpu().numpy()
40
+ )[0]
41
+
42
+ # 가장 유사한 질문들의 인덱스
43
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
44
+
45
+ # 관련 컨텍스트 추출
46
+ relevant_contexts = []
47
+ for idx in top_indices:
48
+ relevant_contexts.append({
49
+ 'question': questions[idx],
50
+ 'answer': wiki_dataset['train']['answer'][idx]
51
+ })
52
+
53
+ return relevant_contexts
54
 
55
  @spaces.GPU
56
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
57
  print(f'message is - {message}')
58
  print(f'history is - {history}')
59
+
60
+ # RAG: 관련 컨텍스트 찾기
61
+ relevant_contexts = find_relevant_context(message)
62
+ context_prompt = "\n\n관련 참고 정보:\n"
63
+ for ctx in relevant_contexts:
64
+ context_prompt += f"Q: {ctx['question']}\nA: {ctx['answer']}\n\n"
65
+
66
+ # 대화 히스토리 구성
67
  conversation = []
68
  for prompt, answer in history:
69
+ conversation.extend([
70
+ {"role": "user", "content": prompt},
71
+ {"role": "assistant", "content": answer}
72
+ ])
73
+
74
+ # 컨텍스트를 포함한 최종 프롬프트 구성
75
+ final_message = context_prompt + "\n현재 질문: " + message
76
+ conversation.append({"role": "user", "content": final_message})
77
 
78
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
79
  inputs = tokenizer(input_ids, return_tensors="pt").to(0)
80
 
81
+ streamer = TextIteratorStreamer(tokenizer, timeout
82
 
83
  generate_kwargs = dict(
84
  inputs,