import torch import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import os from threading import Thread import random from datasets import load_dataset import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer import pandas as pd from typing import List, Tuple import json from datetime import datetime import pyarrow.parquet as pq import pypdf import io import pyarrow.parquet as pq from pdfminer.high_level import extract_text from pdfminer.layout import LAParams from tabulate import tabulate # tabulate 추가 import platform import subprocess import pytesseract from pdf2image import convert_from_path # 전역 변수 추가 current_file_context = None # 환경 변수 설정 HF_TOKEN = os.environ.get("HF_TOKEN", None) MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024" MODELS = os.environ.get("MODELS") MODEL_NAME = MODEL_ID.split("/")[-1] model = None # 전역 변수로 선언 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # 위키피디아 데이터셋 로드 wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna") print("Wikipedia dataset loaded:", wiki_dataset) # TF-IDF 벡터라이저 초기화 및 학습 print("TF-IDF 벡터화 시작...") questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 사용 vectorizer = TfidfVectorizer(max_features=1000) question_vectors = vectorizer.fit_transform(questions) print("TF-IDF 벡터화 완료") class ChatHistory: def __init__(self): self.history = [] self.history_file = "/tmp/chat_history.json" self.load_history() def add_conversation(self, user_msg: str, assistant_msg: str): conversation = { "timestamp": datetime.now().isoformat(), "messages": [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": assistant_msg} ] } self.history.append(conversation) self.save_history() def format_for_display(self): formatted = [] for conv in self.history: formatted.append([ conv["messages"][0]["content"], conv["messages"][1]["content"] ]) return formatted def get_messages_for_api(self): messages = [] for conv in self.history: messages.extend([ {"role": "user", "content": conv["messages"][0]["content"]}, {"role": "assistant", "content": conv["messages"][1]["content"]} ]) return messages def clear_history(self): self.history = [] self.save_history() def save_history(self): try: with open(self.history_file, 'w', encoding='utf-8') as f: json.dump(self.history, f, ensure_ascii=False, indent=2) except Exception as e: print(f"히스토리 저장 실패: {e}") def load_history(self): try: if os.path.exists(self.history_file): with open(self.history_file, 'r', encoding='utf-8') as f: self.history = json.load(f) except Exception as e: print(f"히스토리 로드 실패: {e}") self.history = [] # 전역 ChatHistory 인스턴스 생성 chat_history = ChatHistory() def find_relevant_context(query, top_k=3): # 쿼리 벡터화 query_vector = vectorizer.transform([query]) # 코사인 유사도 계산 similarities = (query_vector * question_vectors.T).toarray()[0] # 가장 유사한 질문들의 인덱스 top_indices = np.argsort(similarities)[-top_k:][::-1] # 관련 컨텍스트 추출 relevant_contexts = [] for idx in top_indices: if similarities[idx] > 0: relevant_contexts.append({ 'question': questions[idx], 'answer': wiki_dataset['train']['answer'][idx], 'similarity': similarities[idx] }) return relevant_contexts def analyze_file_content(content, file_type): """Analyze file content and return structural summary""" if file_type in ['parquet', 'csv']: try: lines = content.split('\n') header = lines[0] columns = header.count('|') - 1 rows = len(lines) - 3 return f"📊 데이터셋 구조: {columns}개 컬럼, {rows}개 데이터" except: return "❌ 데이터셋 구조 분석 실패" lines = content.split('\n') total_lines = len(lines) non_empty_lines = len([line for line in lines if line.strip()]) if any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']): functions = len([line for line in lines if 'def ' in line]) classes = len([line for line in lines if 'class ' in line]) imports = len([line for line in lines if 'import ' in line or 'from ' in line]) return f"💻 코드 구조: {total_lines}줄 (함수: {functions}, 클래스: {classes}, 임포트: {imports})" paragraphs = content.count('\n\n') + 1 words = len(content.split()) return f"📝 문서 구조: {total_lines}줄, {paragraphs}단락, 약 {words}단어" def extract_pdf_text_with_ocr(file_path): try: # Poppler 경로 설정 if platform.system() == 'Windows': poppler_path = r"C:\Program Files\poppler-0.68.0\bin" else: poppler_path = None # Linux의 경우 기본 경로 사용 # PDF를 이미지로 변환 images = convert_from_path( file_path, poppler_path=poppler_path, fmt='jpeg', grayscale=False, size=(1700, None) # 해상도 향상 ) # 전체 텍스트 저장 text = "" # 각 페이지에 대해 OCR 수행 for i, image in enumerate(images): try: # OCR 설정 custom_config = r'--oem 3 --psm 6 -l kor+eng' # OCR 수행 page_text = pytesseract.image_to_string( image, config=custom_config ) text += f"\n--- 페이지 {i+1} ---\n{page_text}\n" except Exception as e: print(f"페이지 {i+1} OCR 오류: {str(e)}") continue return text except Exception as e: return f"PDF 텍스트 추출 오류: {str(e)}" def read_uploaded_file(file): if file is None: return "", "" try: file_ext = os.path.splitext(file.name)[1].lower() # Parquet 파일 처리 if file_ext == '.parquet': try: table = pq.read_table(file.name) df = table.to_pandas() content = f"📊 Parquet 파일 분석:\n\n" content += f"1. 기본 정보:\n" content += f"- 전체 행 수: {len(df):,}개\n" content += f"- 전체 열 수: {len(df.columns)}개\n" content += f"- 메모리 사용량: {df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB\n\n" content += f"2. 컬럼 정보:\n" for col in df.columns: content += f"- {col} ({df[col].dtype})\n" content += f"\n3. 데이터 미리보기:\n" # tabulate 사용하여 테이블 형식으로 출력 content += tabulate(df.head(5), headers='keys', tablefmt='pipe', showindex=False) content += f"\n\n4. 결측치 정보:\n" null_counts = df.isnull().sum() for col, count in null_counts[null_counts > 0].items(): content += f"- {col}: {count:,}개 ({count/len(df)*100:.1f}%)\n" # 수치형 컬럼에 대한 기본 통계 numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns if len(numeric_cols) > 0: content += f"\n5. 수치형 컬럼 통계:\n" stats_df = df[numeric_cols].describe() content += tabulate(stats_df, headers='keys', tablefmt='pipe') return content, "parquet" except Exception as e: return f"Parquet 파일 읽기 오류: {str(e)}", "error" # PDF 파일 처리 if file_ext == '.pdf': try: pdf_reader = pypdf.PdfReader(file.name) total_pages = len(pdf_reader.pages) content = f"📑 PDF 문서 분석:\n\n" content += f"1. 기본 정보:\n" content += f"- 총 페이지 수: {total_pages}페이지\n" # 메타데이터 추출 if pdf_reader.metadata: content += "\n2. 메타데이터:\n" for key, value in pdf_reader.metadata.items(): if value and str(key).startswith('/'): content += f"- {key[1:]}: {value}\n" # 먼저 pdfminer로 텍스트 추출 시도 try: text = extract_text( file.name, laparams=LAParams( line_margin=0.5, word_margin=0.1, char_margin=2.0, all_texts=True ) ) except: text = "" # pdfminer로 추출 실패시 OCR 시도 if not text.strip(): text = extract_pdf_text_with_ocr(file.name) # 텍스트 분석 if text: words = text.split() lines = text.split('\n') content += f"\n3. 텍스트 분석:\n" content += f"- 총 단어 수: {len(words):,}개\n" content += f"- 고유 단어 수: {len(set(words)):,}개\n" content += f"- 총 라인 수: {len(lines):,}개\n" # 본문 내용 content += f"\n4. 본문 내용:\n" preview_length = min(2000, len(text)) # 미리보기 길이 증가 content += f"--- 처음 {preview_length}자 ---\n" content += text[:preview_length] if len(text) > preview_length: content += f"\n... (총 {len(text):,}자 중 일부 표시)\n" else: content += "\n⚠️ 텍스트 추출 실패" return content, "pdf" except Exception as e: return f"PDF 파일 읽기 오류: {str(e)}", "error" # CSV 파일 처리 elif file_ext == '.csv': encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1'] for encoding in encodings: try: df = pd.read_csv(file.name, encoding=encoding) content = f"📊 CSV 파일 분석:\n\n" content += f"1. 기본 정보:\n" content += f"- 전체 행 수: {len(df):,}개\n" content += f"- 전체 열 수: {len(df.columns)}개\n" content += f"- 메모리 사용량: {df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB\n\n" content += f"2. 컬럼 정보:\n" for col in df.columns: content += f"- {col} ({df[col].dtype})\n" content += f"\n3. 데이터 미리보기:\n" content += df.head(5).to_markdown(index=False) content += f"\n\n4. 결측치 정보:\n" null_counts = df.isnull().sum() for col, count in null_counts[null_counts > 0].items(): content += f"- {col}: {count:,}개 ({count/len(df)*100:.1f}%)\n" return content, "csv" except UnicodeDecodeError: continue raise UnicodeDecodeError(f"지원되는 인코딩으로 파일을 읽을 수 없습니다 ({', '.join(encodings)})") # 텍스트 파일 처리 else: encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1'] for encoding in encodings: try: with open(file.name, 'r', encoding=encoding) as f: content = f.read() # 파일 내용 분석 lines = content.split('\n') total_lines = len(lines) non_empty_lines = len([line for line in lines if line.strip()]) # 코드 파일 여부 확인 is_code = any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']) analysis = f"\n📝 파일 분석:\n" if is_code: # 코드 파일 분석 functions = len([line for line in lines if 'def ' in line]) classes = len([line for line in lines if 'class ' in line]) imports = len([line for line in lines if 'import ' in line or 'from ' in line]) analysis += f"- 파일 유형: 코드\n" analysis += f"- 전체 라인 수: {total_lines:,}줄\n" analysis += f"- 함수 수: {functions}개\n" analysis += f"- 클래스 수: {classes}개\n" analysis += f"- import 문 수: {imports}개\n" else: # 일반 텍스트 파일 분석 words = len(content.split()) chars = len(content) analysis += f"- 파일 유형: 텍스트\n" analysis += f"- 전체 라인 수: {total_lines:,}줄\n" analysis += f"- 실제 내용이 있는 라인 수: {non_empty_lines:,}줄\n" analysis += f"- 단어 수: {words:,}개\n" analysis += f"- 문자 수: {chars:,}개\n" return content + analysis, "text" except UnicodeDecodeError: continue raise UnicodeDecodeError(f"지원되는 인코딩으로 파일을 읽을 수 없습니다 ({', '.join(encodings)})") except Exception as e: return f"파일 읽기 오류: {str(e)}", "error" # 파일 업로드 이벤트 핸들링 수정 def init_msg(): return "파일을 분석하고 있습니다..." CSS = """ /* 3D 스타일 CSS */ :root { --primary-color: #2196f3; --secondary-color: #1976d2; --background-color: #f0f2f5; --card-background: #ffffff; --text-color: #333333; --shadow-color: rgba(0, 0, 0, 0.1); } body { background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); min-height: 100vh; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } .container { transform-style: preserve-3d; perspective: 1000px; } .chatbot { background: var(--card-background); border-radius: 20px; box-shadow: 0 10px 20px var(--shadow-color), 0 6px 6px var(--shadow-color); transform: translateZ(0); transition: transform 0.3s ease; backdrop-filter: blur(10px); } .chatbot:hover { transform: translateZ(10px); } /* 메시지 입력 영역 */ .input-area { background: var(--card-background); border-radius: 15px; padding: 15px; margin-top: 20px; box-shadow: 0 5px 15px var(--shadow-color), 0 3px 3px var(--shadow-color); transform: translateZ(0); transition: all 0.3s ease; display: flex; align-items: center; gap: 10px; } .input-area:hover { transform: translateZ(5px); } /* 버튼 스타일 */ .custom-button { background: linear-gradient(145deg, var(--primary-color), var(--secondary-color)); color: white; border: none; border-radius: 10px; padding: 10px 20px; font-weight: 600; cursor: pointer; transform: translateZ(0); transition: all 0.3s ease; box-shadow: 0 4px 6px var(--shadow-color), 0 1px 3px var(--shadow-color); } .custom-button:hover { transform: translateZ(5px) translateY(-2px); box-shadow: 0 7px 14px var(--shadow-color), 0 3px 6px var(--shadow-color); } /* 파일 업로드 버튼 */ .file-upload-icon { background: linear-gradient(145deg, #64b5f6, #42a5f5); color: white; border-radius: 8px; font-size: 2em; cursor: pointer; display: flex; align-items: center; justify-content: center; height: 70px; width: 70px; transition: all 0.3s ease; box-shadow: 0 2px 5px rgba(0,0,0,0.1); } .file-upload-icon:hover { transform: translateY(-2px); box-shadow: 0 4px 8px rgba(0,0,0,0.2); } /* 파일 업로드 버튼 내부 요소 스타일링 */ .file-upload-icon > .wrap { display: flex !important; align-items: center; justify-content: center; width: 100%; height: 100%; } .file-upload-icon > .wrap > p { display: none !important; } .file-upload-icon > .wrap::before { content: "📁"; font-size: 2em; display: block; } /* 메시지 스타일 */ .message { background: var(--card-background); border-radius: 15px; padding: 15px; margin: 10px 0; box-shadow: 0 4px 6px var(--shadow-color), 0 1px 3px var(--shadow-color); transform: translateZ(0); transition: all 0.3s ease; } .message:hover { transform: translateZ(5px); } .chat-container { height: 600px !important; margin-bottom: 10px; } .input-container { height: 70px !important; display: flex; align-items: center; gap: 10px; margin-top: 5px; } .input-textbox { height: 70px !important; border-radius: 8px !important; font-size: 1.1em !important; padding: 10px 15px !important; display: flex !important; align-items: flex-start !important; /* 텍스트 입력 위치를 위로 조정 */ } .input-textbox textarea { padding-top: 5px !important; /* 텍스트 상단 여백 조정 */ } .send-button { height: 70px !important; min-width: 70px !important; font-size: 1.1em !important; } /* 설정 패널 기본 스타일 */ .settings-panel { padding: 20px; margin-top: 20px; } """ # GPU 메모리 관리 함수 수정 def clear_cuda_memory(): if hasattr(torch.cuda, 'empty_cache'): with torch.cuda.device('cuda'): torch.cuda.empty_cache() # 모델 로드 함수 수정 @spaces.GPU def load_model(): try: model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", ) return model except Exception as e: print(f"모델 로드 오류: {str(e)}") raise @spaces.GPU def stream_chat(message: str, history: list, uploaded_file, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float): global model, current_file_context try: if model is None: model = load_model() print(f'message is - {message}') print(f'history is - {history}') # 파일 업로드 처리 file_context = "" if uploaded_file and message == "파일을 분석하고 있습니다...": try: content, file_type = read_uploaded_file(uploaded_file) if content: file_analysis = analyze_file_content(content, file_type) file_context = f"\n\n📄 파일 분석 결과:\n{file_analysis}\n\n파일 내용:\n```\n{content}\n```" current_file_context = file_context # 파일 컨텍스트 저장 message = "업로드된 파일을 분석해주세요." except Exception as e: print(f"파일 분석 오류: {str(e)}") file_context = f"\n\n❌ 파일 분석 중 오류가 발생했습니다: {str(e)}" elif current_file_context: # 저장된 파일 컨텍스트가 있으면 사용 file_context = current_file_context # 메모리 사용량 모니터링 if torch.cuda.is_available(): print(f"CUDA 메모리 사용량: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") # 대화 히스토리가 너무 길면 잘라내기 max_history_length = 10 # 최대 히스토리 길이 설정 if len(history) > max_history_length: history = history[-max_history_length:] # 관련 컨텍스트 찾기 try: relevant_contexts = find_relevant_context(message) wiki_context = "\n\n관련 위키피디아 정보:\n" for ctx in relevant_contexts: wiki_context += f"Q: {ctx['question']}\nA: {ctx['answer']}\n유사도: {ctx['similarity']:.3f}\n\n" except Exception as e: print(f"컨텍스트 검색 오류: {str(e)}") wiki_context = "" # 대화 히스토리 구성 conversation = [] for prompt, answer in history: conversation.extend([ {"role": "user", "content": prompt}, {"role": "assistant", "content": answer} ]) # 최종 프롬프트 구성 final_message = file_context + wiki_context + "\n현재 질문: " + message conversation.append({"role": "user", "content": final_message}) # 토큰 수 제한 input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) max_length = 4096 # 또는 모델의 최대 컨텍스트 길이 if len(input_ids.split()) > max_length: # 컨텍스트가 너무 길면 잘라내기 input_ids = " ".join(input_ids.split()[-max_length:]) inputs = tokenizer(input_ids, return_tensors="pt").to("cuda") # 메모리 사용량 체크 if torch.cuda.is_available(): print(f"입력 텐서 생성 후 CUDA 메모리: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( inputs, streamer=streamer, top_k=top_k, top_p=top_p, repetition_penalty=penalty, max_new_tokens=min(max_new_tokens, 2048), # 최대 토큰 수 제한 do_sample=True, temperature=temperature, eos_token_id=[255001], ) # 생성 시작 전 메모리 정리 clear_cuda_memory() thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield "", history + [[message, buffer]] # 생성 완료 후 메모리 정리 clear_cuda_memory() except Exception as e: error_message = f"오류가 발생했습니다: {str(e)}" print(f"Stream chat 오류: {error_message}") # 오류 발생 시에도 메모리 정리 clear_cuda_memory() yield "", history + [[message, error_message]] def create_demo(): with gr.Blocks(css=CSS) as demo: with gr.Column(elem_classes="markdown-style"): gr.Markdown(""" # 🤖 OnDevice AI RAG #### 📊 RAG: Upload and Analyze Files (TXT, CSV, PDF, Parquet files) Upload your files for data analysis and learning """) chatbot = gr.Chatbot( value=[], height=600, label="GiniGEN AI Assistant", elem_classes="chat-container" ) with gr.Row(elem_classes="input-container"): with gr.Column(scale=1, min_width=70): file_upload = gr.File( type="filepath", elem_classes="file-upload-icon", scale=1, container=True, interactive=True, show_label=False ) with gr.Column(scale=3): msg = gr.Textbox( show_label=False, placeholder="Type your message here... 💭", container=False, elem_classes="input-textbox", scale=1 ) with gr.Column(scale=1, min_width=70): send = gr.Button( "Send", elem_classes="send-button custom-button", scale=1 ) with gr.Column(scale=1, min_width=70): clear = gr.Button( "Clear", elem_classes="clear-button custom-button", scale=1 ) with gr.Accordion("🎮 Advanced Settings", open=False): with gr.Row(): with gr.Column(scale=1): temperature = gr.Slider( minimum=0, maximum=1, step=0.1, value=0.8, label="Creativity Level 🎨" ) max_new_tokens = gr.Slider( minimum=128, maximum=8000, step=1, value=4000, label="Maximum Token Count 📝" ) with gr.Column(scale=1): top_p = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.8, label="Diversity Control 🎯" ) top_k = gr.Slider( minimum=1, maximum=20, step=1, value=20, label="Selection Range 📊" ) penalty = gr.Slider( minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty 🔄" ) gr.Examples( examples=[ ["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"], ["Please analyze this data and provide insights:\nAnnual Revenue (Million)\n2019: 1200\n2020: 980\n2021: 1450\n2022: 2100\n2023: 1890"], ["Please solve this math problem step by step: 'When a circle's area is twice that of its inscribed square, find the relationship between the circle's radius and the square's side length.'"], ["Please analyze this marketing campaign's ROI and suggest improvements:\nTotal Cost: $50,000\nReach: 1M users\nClick Rate: 2.3%\nConversion Rate: 0.8%\nAverage Purchase: $35"], ], inputs=msg ) def clear_conversation(): global current_file_context current_file_context = None return [], None, "Start a new conversation..." # Event bindings msg.submit( stream_chat, inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty], outputs=[msg, chatbot] ) send.click( stream_chat, inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty], outputs=[msg, chatbot] ) file_upload.change( fn=init_msg, outputs=msg, queue=False ).then( fn=stream_chat, inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty], outputs=[msg, chatbot], queue=True ) # Clear button event binding clear.click( fn=clear_conversation, outputs=[chatbot, file_upload, msg], queue=False ) return demo if __name__ == "__main__": demo = create_demo() demo.launch()