Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify | |
from goldenretriever import GoldenRetriever | |
app = Flask(__name__) | |
# Initialize the GoldenRetriever model | |
retriever_interventions = GoldenRetriever( | |
question_encoder="models/interventions/question_encoder", | |
document_index="models/interventions/document_index/", | |
device="cpu" | |
) | |
# Initialize the GoldenRetriever model | |
retriever_outcomes = GoldenRetriever( | |
question_encoder="models/outcomes/question_encoder", | |
document_index="models/outcomes/document_index/", | |
device="cpu" | |
) | |
def retrieve_documents(retriever, text): | |
pred_docs = retriever.retrieve(text, k=5, batch_size=1, progress_bar=False)[0] | |
return [doc.document.text for doc in pred_docs] | |
def retrieve_intervention(): | |
data = request.get_json() | |
text = data.get('text', '') | |
if text: | |
result = retrieve_documents(retriever_interventions, text) | |
return jsonify(result), 200 | |
else: | |
return jsonify({'error': 'No text provided'}), 400 | |
def retrieve_outcomes(): | |
data = request.get_json() | |
text = data.get('text', '') | |
if text: | |
result = retrieve_documents(retriever_outcomes, text) | |
return jsonify(result), 200 | |
else: | |
return jsonify({'error': 'No text provided'}), 400 | |
if __name__ == '__main__': | |
app.run(debug=True, host='https://carlosmalaga-retriever-api.hf.space') | |