File size: 1,385 Bytes
99e744f
 
32e6acc
99e744f
2ac4210
55a77e3
99e744f
dad185d
99e744f
 
00742e9
2ac4210
052ff21
 
99e744f
 
 
 
 
 
 
16ecc46
2ac4210
 
 
16ecc46
2ac4210
dad185d
2ac4210
 
99e744f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from transformers import LEDForConditionalGeneration, LEDTokenizer
from langchain_openai import OpenAI
# from huggingface_hub import login
from dotenv import load_dotenv
from logging import getLogger
# import streamlit as st
import torch

load_dotenv()
hf_token = os.environ.get("HF_TOKEN")
# # hf_token = st.secrets["HF_TOKEN"]
# login(token=hf_token)
logger = getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"

def get_local_model(model_name_or_path:str)->pipeline:

    #print(f"Model is running on {device}")

    tokenizer = AutoTokenizer.from_pretrained( 
        model_name_or_path,
        token = hf_token
    )
    model = AutoModelForSeq2SeqLM.from_pretrained( 
        model_name_or_path,
        torch_dtype=torch.float32,
        token = hf_token
    )
    pipe = pipeline(
        task = 'summarization',
        model=model,
        tokenizer=tokenizer,
        device = device,
    )

    logger.info(f"Summarization pipeline created and loaded to {device}")
   
    return pipe

def get_endpoint(api_key:str):

    llm = OpenAI(openai_api_key=api_key)
    return llm

def get_model(model_type,model_name_or_path,api_key = None):
    if model_type == "openai":
        return get_endpoint(api_key)
    else: 
        return get_local_model(model_name_or_path)