|
import os |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline |
|
from transformers import LEDForConditionalGeneration, LEDTokenizer |
|
from langchain_openai import OpenAI |
|
|
|
from dotenv import load_dotenv |
|
from logging import getLogger |
|
|
|
import torch |
|
|
|
load_dotenv() |
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
logger = getLogger(__name__) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def get_local_model(model_name_or_path:str)->pipeline: |
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name_or_path, |
|
token = hf_token |
|
) |
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name_or_path, |
|
torch_dtype=torch.float32, |
|
token = hf_token |
|
) |
|
pipe = pipeline( |
|
task = 'summarization', |
|
model=model, |
|
tokenizer=tokenizer, |
|
device = device, |
|
) |
|
|
|
logger.info(f"Summarization pipeline created and loaded to {device}") |
|
|
|
return pipe |
|
|
|
def get_endpoint(api_key:str): |
|
|
|
llm = OpenAI(openai_api_key=api_key) |
|
return llm |
|
|
|
def get_model(model_type,model_name_or_path,api_key = None): |
|
if model_type == "openai": |
|
return get_endpoint(api_key) |
|
else: |
|
return get_local_model(model_name_or_path) |
|
|