File size: 3,841 Bytes
9f895e9
14a4318
1007582
1ca62db
14a4318
9f895e9
14a4318
78b79c7
14a4318
 
 
1007582
 
 
9f895e9
 
 
1007582
78b79c7
 
 
 
1007582
 
 
 
 
78b79c7
 
 
 
 
 
 
 
 
 
1007582
14a4318
1007582
 
 
 
 
 
 
 
 
 
 
14a4318
1007582
1ca62db
1007582
 
 
 
 
14a4318
1007582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2089531
14a4318
 
1007582
14a4318
1007582
 
 
 
14a4318
1007582
14a4318
1007582
 
 
14a4318
 
 
1007582
14a4318
 
 
 
1007582
 
14a4318
1007582
 
78b79c7
 
 
 
 
 
1007582
 
 
14a4318
1007582
 
14a4318
1007582
 
 
 
 
 
 
 
14a4318
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import logging
import os
from enum import Enum
from typing import List, Optional, Union

from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from gradio_client import Client
from pydantic import BaseModel, ConfigDict, Field, constr

from docs import description, tags_metadata

load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

GRADIO_URL = os.getenv("GRADIO_URL", "http://localhost:7860/")
logger.info(f"GRADIO_URL: {GRADIO_URL}")
client = Client(GRADIO_URL)

app = FastAPI(
    title="ACRES RAG API",
    description=description,
    openapi_tags=tags_metadata,
)

origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class StudyVariables(str, Enum):
    ebola_virus = "Ebola Virus"
    vaccine_coverage = "Vaccine coverage"
    genexpert = "GeneXpert"


class PromptType(str, Enum):
    default = "Default"
    highlight = "Highlight"
    evidence_based = "Evidence-based"


class StudyVariableRequest(BaseModel):
    study_variable: Union[StudyVariables, str]
    prompt_type: PromptType
    text: constr(min_length=1, strip_whitespace=True)  # type: ignore

    model_config = ConfigDict(from_attributes=True)


class DownloadCSV(BaseModel):
    text: constr(min_length=1, strip_whitespace=True)  # type: ignore

    model_config = ConfigDict(from_attributes=True)


class Study(BaseModel):
    study_name: constr(min_length=1, strip_whitespace=True)  # type: ignore

    model_config = ConfigDict(from_attributes=True)


class ZoteroCredentials(BaseModel):
    library_id: constr(min_length=1, strip_whitespace=True)  # type: ignore
    api_access_key: constr(min_length=1, strip_whitespace=True)  # type: ignore

    model_config = ConfigDict(from_attributes=True)


@app.post("/process_zotero_library_items", tags=["zotero"])
def process_zotero_library_items(zotero_credentials: ZoteroCredentials):
    result = client.predict(
        zotero_library_id_param=zotero_credentials.library_id,
        zotero_api_access_key=zotero_credentials.api_access_key,
        api_name="/process_zotero_library_items",
    )
    return {"result": result}


@app.post("/get_study_info", tags=["zotero"])
def get_study_info(study: Study):
    result = client.predict(study_name=study.study_name, api_name="/get_study_info")
    # print(result)
    return {"result": result}


@app.post("/study_variables", tags=["zotero"])
def process_study_variables(
    study_request: StudyVariableRequest,
):
    result = client.predict(
        text=study_request.text,  # "study id, study title, study design, study summary",
        study_name=study_request.study_variable,  # "Ebola Virus",
        prompt_type=study_request.prompt_type,  # "Default",
        api_name="/process_multi_input",
    )
    print(type(result))
    return {"result": result[0]}


@app.post("/new_study_choices", tags=["zotero"])
def new_study_choices():
    result = client.predict(api_name="/new_study_choices")
    return {"result": result}


@app.post("/download_csv", tags=["zotero"])
def download_csv(download_request: DownloadCSV):
    result = client.predict(
        markdown_content=download_request.text, api_name="/download_as_csv"
    )
    print(result)

    file_path = result
    if not file_path or not os.path.exists(file_path):
        raise HTTPException(status_code=404, detail="File not found")

    # Use FileResponse to send the file to the client
    return FileResponse(
        file_path,
        media_type="text/csv",  # Specify the correct MIME type for CSV
        filename=os.path.basename(
            file_path
        ),  # Provide a default filename for the download
    )