retriever-api / app.py
CarlosMalaga's picture
Update app.py
2d301d0 verified
raw
history blame
1.48 kB
from flask import Flask, request, jsonify
from goldenretriever import GoldenRetriever
app = Flask(__name__)
# Initialize the GoldenRetriever model
retriever_interventions = GoldenRetriever(
question_encoder="models/interventions/question_encoder",
document_index="models/interventions/document_index/",
device="cpu"
)
# Initialize the GoldenRetriever model
retriever_outcomes = GoldenRetriever(
question_encoder="models/outcomes/question_encoder",
document_index="models/outcomes/document_index/",
device="cpu"
)
def retrieve_documents(retriever, text):
pred_docs = retriever.retrieve(text, k=5, batch_size=1, progress_bar=False)[0]
return [doc.document.text for doc in pred_docs]
@app.route('/retrieve-intervention', methods=['POST'])
def retrieve_intervention():
data = request.get_json()
text = data.get('text', '')
if text:
result = retrieve_documents(retriever_interventions, text)
return jsonify(result), 200
else:
return jsonify({'error': 'No text provided'}), 400
@app.route('/retrieve-outcomes', methods=['POST'])
def retrieve_outcomes():
data = request.get_json()
text = data.get('text', '')
if text:
result = retrieve_documents(retriever_outcomes, text)
return jsonify(result), 200
else:
return jsonify({'error': 'No text provided'}), 400
if __name__ == '__main__':
app.run(debug=True, host='https://carlosmalaga-retriever-api.hf.space')