Spaces:
Sleeping
Sleeping
File size: 3,841 Bytes
9f895e9 14a4318 1007582 1ca62db 14a4318 9f895e9 14a4318 78b79c7 14a4318 1007582 9f895e9 1007582 78b79c7 1007582 78b79c7 1007582 14a4318 1007582 14a4318 1007582 1ca62db 1007582 14a4318 1007582 2089531 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 78b79c7 1007582 14a4318 1007582 14a4318 1007582 14a4318 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import logging
import os
from enum import Enum
from typing import List, Optional, Union
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from gradio_client import Client
from pydantic import BaseModel, ConfigDict, Field, constr
from docs import description, tags_metadata
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
GRADIO_URL = os.getenv("GRADIO_URL", "http://localhost:7860/")
logger.info(f"GRADIO_URL: {GRADIO_URL}")
client = Client(GRADIO_URL)
app = FastAPI(
title="ACRES RAG API",
description=description,
openapi_tags=tags_metadata,
)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class StudyVariables(str, Enum):
ebola_virus = "Ebola Virus"
vaccine_coverage = "Vaccine coverage"
genexpert = "GeneXpert"
class PromptType(str, Enum):
default = "Default"
highlight = "Highlight"
evidence_based = "Evidence-based"
class StudyVariableRequest(BaseModel):
study_variable: Union[StudyVariables, str]
prompt_type: PromptType
text: constr(min_length=1, strip_whitespace=True) # type: ignore
model_config = ConfigDict(from_attributes=True)
class DownloadCSV(BaseModel):
text: constr(min_length=1, strip_whitespace=True) # type: ignore
model_config = ConfigDict(from_attributes=True)
class Study(BaseModel):
study_name: constr(min_length=1, strip_whitespace=True) # type: ignore
model_config = ConfigDict(from_attributes=True)
class ZoteroCredentials(BaseModel):
library_id: constr(min_length=1, strip_whitespace=True) # type: ignore
api_access_key: constr(min_length=1, strip_whitespace=True) # type: ignore
model_config = ConfigDict(from_attributes=True)
@app.post("/process_zotero_library_items", tags=["zotero"])
def process_zotero_library_items(zotero_credentials: ZoteroCredentials):
result = client.predict(
zotero_library_id_param=zotero_credentials.library_id,
zotero_api_access_key=zotero_credentials.api_access_key,
api_name="/process_zotero_library_items",
)
return {"result": result}
@app.post("/get_study_info", tags=["zotero"])
def get_study_info(study: Study):
result = client.predict(study_name=study.study_name, api_name="/get_study_info")
# print(result)
return {"result": result}
@app.post("/study_variables", tags=["zotero"])
def process_study_variables(
study_request: StudyVariableRequest,
):
result = client.predict(
text=study_request.text, # "study id, study title, study design, study summary",
study_name=study_request.study_variable, # "Ebola Virus",
prompt_type=study_request.prompt_type, # "Default",
api_name="/process_multi_input",
)
print(type(result))
return {"result": result[0]}
@app.post("/new_study_choices", tags=["zotero"])
def new_study_choices():
result = client.predict(api_name="/new_study_choices")
return {"result": result}
@app.post("/download_csv", tags=["zotero"])
def download_csv(download_request: DownloadCSV):
result = client.predict(
markdown_content=download_request.text, api_name="/download_as_csv"
)
print(result)
file_path = result
if not file_path or not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="File not found")
# Use FileResponse to send the file to the client
return FileResponse(
file_path,
media_type="text/csv", # Specify the correct MIME type for CSV
filename=os.path.basename(
file_path
), # Provide a default filename for the download
)
|