File size: 5,702 Bytes
334dcac e4bcc80 334dcac 781e740 1a2c8bd 781e740 334dcac 55f430c 334dcac 55f430c 334dcac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# File name: model.py
import json
import os
import numpy as np
import torch
from starlette.requests import Request
from PIL import Image
import ray
from ray import serve
from clip_retrieval.load_clip import load_clip, get_tokenizer
# from clip_retrieval.clip_client import ClipClient, Modality
@serve.deployment(num_replicas=6, ray_actor_options={"num_cpus": .2, "num_gpus": 0.1})
class CLIPTransform:
def __init__(self):
# os.environ["OMP_NUM_THREADS"] = "20"
# torch.set_num_threads(20)
# Load model
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self._clip_model="ViT-L/14"
self._clip_model_id ="laion5B-L-14"
self.model, self.preprocess = load_clip(self._clip_model, use_jit=True, device=self.device)
self.tokenizer = get_tokenizer(self._clip_model)
print ("using device", self.device)
def text_to_embeddings(self, prompt):
text = self.tokenizer([prompt]).to(self.device)
with torch.no_grad():
prompt_embededdings = self.model.encode_text(text)
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
return(prompt_embededdings)
def image_to_embeddings(self, input_im):
input_im = Image.fromarray(input_im)
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
def preprocessed_image_to_emdeddings(self, prepro):
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
async def __call__(self, http_request: Request) -> str:
form_data = await http_request.form()
embeddings = None
if "text" in form_data:
prompt = (await form_data["text"].read()).decode()
print (type(prompt))
print (str(prompt))
embeddings = self.text_to_embeddings(prompt)
elif "image_url" in form_data:
image_url = (await form_data["image_url"].read()).decode()
# download image from url
import requests
from io import BytesIO
image_bytes = requests.get(image_url).content
input_image = Image.open(BytesIO(image_bytes))
input_image = input_image.convert('RGB')
input_image = np.array(input_image)
embeddings = self.image_to_embeddings(input_image)
elif "preprocessed_image" in form_data:
tensor_bytes = await form_data["preprocessed_image"].read()
shape_bytes = await form_data["shape"].read()
dtype_bytes = await form_data["dtype"].read()
# Convert bytes back to original form
dtype_mapping = {
"torch.float32": torch.float32,
"torch.float64": torch.float64,
"torch.float16": torch.float16,
"torch.uint8": torch.uint8,
"torch.int8": torch.int8,
"torch.int16": torch.int16,
"torch.int32": torch.int32,
"torch.int64": torch.int64,
torch.float32: np.float32,
torch.float64: np.float64,
torch.float16: np.float16,
torch.uint8: np.uint8,
torch.int8: np.int8,
torch.int16: np.int16,
torch.int32: np.int32,
torch.int64: np.int64,
# add more if needed
}
dtype_str = dtype_bytes.decode()
dtype_torch = dtype_mapping[dtype_str]
dtype_numpy = dtype_mapping[dtype_torch]
# shape = np.frombuffer(shape_bytes, dtype=np.int64)
# TODO: fix shape so it is passed nicely
shape = tuple([1, 3, 224, 224])
tensor_numpy = np.frombuffer(tensor_bytes, dtype=dtype_numpy).reshape(shape)
tensor_numpy = np.require(tensor_numpy, requirements='W')
tensor = torch.from_numpy(tensor_numpy)
prepro = tensor.to(self.device)
embeddings = self.preprocessed_image_to_emdeddings(prepro)
else:
print ("Invalid request")
raise Exception("Invalid request")
return embeddings.cpu().numpy().tolist()
request = await http_request.json()
# print(type(request))
# print(str(request))
# switch based if we are using text or image
embeddings = None
if "text" in request:
prompt = request["text"]
embeddings = self.text_to_embeddings(prompt)
elif "image_url" in request:
image_url = request["image_url"]
# download image from url
import requests
from io import BytesIO
image_bytes = requests.get(image_url).content
input_image = Image.open(BytesIO(image_bytes))
input_image = input_image.convert('RGB')
input_image = np.array(input_image)
embeddings = self.image_to_embeddings(input_image)
elif "preprocessed_image" in request:
prepro = request["preprocessed_image"]
# create torch tensor on the device
prepro = torch.tensor(prepro).to(self.device)
embeddings = self.preprocessed_image_to_emdeddings(prepro)
else:
raise Exception("Invalid request")
return embeddings.cpu().numpy().tolist()
deployment_graph = CLIPTransform.bind()
|