Spaces:
Runtime error
Runtime error
File size: 2,364 Bytes
683c41b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import ctranslate2
from transformers import AutoTokenizer
import torch
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import argparse
import time
model_name = "BAAI/bge-base-en-v1.5"
model_save_path = "bge_model_ctranslate2"
# model_path = "bge_model_ctranslate2_base"
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if device == "cuda":
translator = ctranslate2.Encoder(
model_save_path, device=device, compute_type="float16"
) # or "cuda" for GPU
else:
translator = ctranslate2.Encoder(model_save_path, device=device)
def generate_embeddings(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs["input_ids"].tolist()[0]
output = translator.forward_batch([input_ids])
pooler_output = output.pooler_output
if device == "cuda":
embeddings = (
torch.as_tensor(pooler_output, device=device).detach().cpu().tolist()[0]
)
else:
pooler_output = np.array(pooler_output)
embeddings = torch.as_tensor(pooler_output, device=device).detach().tolist()[0]
return embeddings
app = FastAPI()
class EmbeddingRequest(BaseModel):
input: str
model: str
class EmbeddingResponse(BaseModel):
object: str = "list"
data: list
model: str
usage: dict
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def embeddings(request: EmbeddingRequest):
input_text = request.input
if not input_text:
raise HTTPException(status_code=400, detail="No input text provided")
# Generate embeddings
embeddings = generate_embeddings(input_text)
# Construct the response in OpenAI format
response = {
"object": "list",
"data": [{"object": "embedding", "embedding": embeddings, "index": 0}],
"model": request.model,
"usage": {
"prompt_tokens": len(input_text.split()),
"total_tokens": len(input_text.split()),
},
}
return response
@app.get("/ping")
async def ping():
return {"status": "pong"}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=5001)
args = parser.parse_args()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=args.port) |