Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,543 Bytes
0e6708a 20b8a33 0e6708a 20b8a33 0e6708a 20b8a33 bd5e5cb e224f20 0e6708a d551587 0e6708a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
import celldetection as cd
import cv2
import numpy as np
__all__ = ['contours2labels', 'CpnInterface']
def contours2labels(contours, size, overlap=False, max_iter=999):
labels = cd.data.contours2labels(cd.asnumpy(contours), size, initial_depth=3)
if not overlap:
kernel = cv2.getStructuringElement(1, (3, 3))
mask_sm = np.sum(labels > 0, axis=-1)
mask = mask_sm > 1 # all overlaps
if mask.any():
mask_ = mask_sm == 1 # all cores
lbl = np.zeros(labels.shape[:2], dtype='float64')
lbl[mask_] = labels.max(-1)[mask_]
for _ in range(max_iter):
lbl_ = np.copy(lbl)
m = mask & (lbl <= 0)
if not np.any(m):
break
lbl[m] = cv2.dilate(lbl, kernel=kernel)[m]
if np.allclose(lbl_, lbl):
break
else:
lbl = labels.max(-1)
labels = lbl.astype('int')
return labels
class CpnInterface:
def __init__(self, model, device=None, **kwargs):
self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
model = cd.resolve_model(model, **kwargs)
if not isinstance(model, cd.models.LitCpn):
model = cd.models.LitCpn(model)
self.model = model.to(device)
self.model.eval()
self.model.requires_grad_(False)
self.tile_size = 1664
self.overlap = 384
def __call__(
self,
img,
div=255,
reduce_labels=True,
return_labels=True,
return_viewable_contours=True,
):
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
img = img / div
x = cd.data.to_tensor(img, transpose=True, dtype=torch.float32)[None]
with torch.no_grad():
out = cd.asnumpy(self.model(x, crop_size=self.tile_size,
stride=max(64, self.tile_size - self.overlap)))
# if torch.cuda.device_count():
# print(cd.GpuStats())
contours, = out['contours']
boxes, = out['boxes']
scores, = out['scores']
labels = None
if return_labels or return_viewable_contours:
labels = contours2labels(contours, img.shape[:2], overlap=not reduce_labels)
return dict(
contours=contours,
labels=labels,
boxes=boxes,
scores=scores
)
|