zhizhi-aiservice / main.py
hujameson's picture
zhizhi-aiservice:0.1.2
e6b354c verified
raw
history blame
9.01 kB
#pip install fastapi ###for fastapi
#pip install uvicorn ###for server. to run the api serice from terminal: uvicorn main:app --reload
#pip install gunicorn ###gunicorn --bind 0.0.0.0:8000 -k uvicorn.workers.UvicornWorker main:app
#pip install python-multipart ###for UploadFile
#pip install pillow ###for PIL
#pip install transformers ###for transformers
#pip install torch ###for torch
#pip install sentencepiece ###for AutoTokenizer
#pip install -U cos-python-sdk-v5 ###腾讯云对象存储SDK(COS-SDK)
#pip install -q -U google-generativeai
# from typing import Optional
# from fastapi import FastAPI, Header
# #from transformers import pipeline, EfficientNetImageProcessor, EfficientNetForImageClassification, AutoTokenizer, AutoModelForSeq2SeqLM
# import torch
# from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification, pipeline
# from models import ItemInHistory, ItemUploaded, ServiceLoginInfo
# from openai import OpenAI
# import sys, os, logging
# import json, requests
from fastapi import FastAPI
import sys, logging, os
from models import Item2AI, AI2Item
import google.generativeai as genai
import urllib.parse as urlparse
from qcloud_cos import CosConfig, CosS3Client
from PIL import Image
# init app and logging
app = FastAPI()
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
# load google gemini models
genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
gemini_pro = genai.GenerativeModel('gemini-pro')
logging.info("google gemini-pro model loaded successfully.")
gemini_pro_vision = genai.GenerativeModel('gemini-pro-vision')
logging.info("google gemini-pro-vision model loaded successfully.")
# init tencent cos
cos_secret_id = os.environ['COS_SECRET_ID']
cos_secret_key = os.environ['COS_SECRET_KEY']
cos_region = 'ap-shanghai'
cos_bucket = '7072-prod-3g52ms9o7a81f23c-1324125412'
token = None
scheme = 'https'
config = CosConfig(Region=cos_region, SecretId=cos_secret_id, SecretKey=cos_secret_key, Token=token, Scheme=scheme)
client = CosS3Client(config)
logging.info(f"tencent cos init succeeded.")
# Route to create an item
@app.post("/firstturn/")
async def ai_first_turn(item: Item2AI):
logging.info("ai_first_turn...")
logging.info("item:", item)
# response = gemini_pro.generate_content("What is the meaning of life?")
# logging.info(response.text)
url = urlparse.urlparse(item.item_fileurl)
key = url[2][1::]
bucket = url[1].split('.')[1]
contentfile = key.split('/')[1]
historyid = contentfile.split('.')[0]
response = client.get_object(
Bucket = bucket,
Key = key
)
response['Body'].get_stream_to_file(contentfile)
ai2item = AI2Item(
upload_id = item.upload_id,
union_id = item.union_id,
item_fileurl = item.item_fileurl,
item_mediatype = item.item_mediatype,
upload_datetime = item.upload_datetime,
ai_feedback = ""
)
if item.item_mediatype == "image":
# Opening the image using PIL
img = Image.open(contentfile)
logging.info(f"image file {contentfile} is opened.")
response = gemini_pro_vision.generate_content(["Describe this picture in Chinese with plenty of details.", img])
ai2item.ai_feedback = response.text
else:
ai2item.ai_feedback = "不是image类型,暂不能识别"
logging.info(ai2item)
return(ai2item)
# try:
# ai_model_bc_preprocessor = EfficientNetImageProcessor.from_pretrained("./birds-classifier-efficientnetb2")
# ai_model_bc_model = EfficientNetForImageClassification.from_pretrained("./birds-classifier-efficientnetb2")
# logging.info(f"local model dennisjooo/Birds-Classifier-EfficientNetB2 loaded.")
# except Exception as e:
# logging.error(e)
# try:
# openai_client = OpenAI(
# api_key=os.environ.get("OPENAI_API_KEY"),
# )
# # prompt = """你是一个鸟类学家,用中文回答关于鸟类的问题。你的回答需要满足以下要求:
# # 1. 你的回答必须是中文
# # 2. 回答限制在100个字以内"""
# # conv = Conversation(open_client, prompt, 3)
# logging.info(f"openai chat model loaded.")
# except Exception as e:
# logging.error(e)
# try:
# ai_model_bc_pipe= pipeline("image-classification", model="dennisjooo/Birds-Classifier-EfficientNetB2")
# logging.info(f"remote model dennisjooo/Birds-Classifier-EfficientNetB2 loaded.")
# except Exception as e:
# print(e)
#try:
# ai_model_ez_preprocessor = AutoTokenizer.from_pretrained("./opus-mt-en-zh")
# ai_model_ez_model = AutoModelForSeq2SeqLM.from_pretrained("./opus-mt-en-zh")
# print(f"local model Helsinki-NLP/opus-mt-en-zh loaded.")
#except Exception as e:
# print(e)
#try:
# ai_model_ez_pipe= pipeline(task="translation_en_to_zh", model="Helsinki-NLP/opus-mt-en-zh", device=0)
# print(f"remote model Helsinki-NLP/opus-mt-en-zh loaded.")
#except Exception as e:
# print(e)
# def bird_classifier(image_file: str) -> str:
# # Opening the image using PIL
# img = Image.open(image_file)
# logging.info(f"image file {image_file} is opened.")
# result:str = ""
# try:
# inputs = ai_model_bc_preprocessor(img, return_tensors="pt")
# # Running the inference
# with torch.no_grad():
# logits = ai_model_bc_model(**inputs).logits
# # Getting the predicted label
# predicted_label = logits.argmax(-1).item()
# result = ai_model_bc_model.config.id2label[predicted_label]
# logging.info(f"{ai_model_bc_model.config.id2label[predicted_label]}:{ai_model_bc_pipe(img)[0]['label']}")
# except Exception as e:
# logging.error(e)
# logging.info(result)
# return result
# def text_en_zh(text_en: str) -> str:
# text_zh = ""
# if ai_model_ez_status is MODEL_STATUS.LOCAL:
# input = ai_model_ez_preprocessor(text_en)
# translated = ai_model_ez_model.generate(**ai_model_ez_preprocessor(text_en, return_tensors="pt", padding=True))
# for t in translated:
# text_zh += ai_model_ez_preprocessor.decode(t, skip_special_tokens=True)
# elif ai_model_ez_status is MODEL_STATUS.REMOTE:
# text_zh = ai_model_ez_pipe(text_en)
# return text_zh
# # Route to list all items uploaded by a specific user by unionid
# # @app.get("/items/{user_unionid}")
# # def list_items(user_unionid: str) -> dict[str, list[ItemInHistory]]:
# # logging.info("list_items")
# # logging.info(user_unionid)
# # items: list[ItemInHistory] = []
# # response = client.list_objects(
# # Bucket=cos_bucket,
# # Prefix=f'{user_unionid}/history/'
# # )
# # logging.info(response['Contents'])
# # for obj in response['Contents']:
# # key:str = obj['Key']
# # response = client.get_object(
# # Bucket = cos_bucket,
# # Key = key
# # )
# # localfile = key.split('/')[2]
# # response['Body'].get_stream_to_file(localfile)
# # item = itemFromJsonFile(localfile)
# # items.append(item)
# return {"items": items}
# # Route to list all items uploaded by a specific user by unionid from header
# @app.get("/items/")
# def list_items_byheader(x_wx_openid: Optional[str]=Header(None)) -> dict[str, list[ItemInHistory]]:
# logging.info("list_items_byheader")
# logging.info(x_wx_openid)
# items: list[ItemInHistory] = []
# response = client.list_objects(
# Bucket=cos_bucket,
# Prefix=f'{x_wx_openid}/history/'
# )
# logging.info(response['Contents'])
# for obj in response['Contents']:
# key:str = obj['Key']
# response = client.get_object(
# Bucket = cos_bucket,
# Key = key
# )
# localfile = key.split('/')[2]
# response['Body'].get_stream_to_file(localfile)
# item = itemFromJsonFile(localfile)
# items.append(item)
# return {"items": items}
# def itemFromJsonFile(jsonfile: str) -> ItemInHistory:
# f = open(jsonfile, 'r')
# content = f.read()
# a = json.loads(content)
# f.close()
# return ItemInHistory(history_id = a['history_id'],union_id = a['union_id'],
# item_fileurl = a['item_fileurl'],item_mediatype = a["item_mediatype"],
# upload_datetime = a["upload_datetime"],ai_feedback = a['ai_feedback'])
# def itemToJsonFile(item: ItemInHistory):
# history_json = {
# "history_id": item.history_id,
# "union_id": item.union_id,
# "item_fileurl": item.item_fileurl,
# "item_mediatype": item.item_mediatype,
# "upload_datetime": item.upload_datetime,
# "ai_feedback": item.ai_feedback
# }
# b = json.dumps(history_json)
# historyfile = f'{item.history_id}.json'
# f = open(historyfile, 'w')
# f.write(b)
# f.close()
# return historyfile