Spaces:
Running
on
Zero
Running
on
Zero
drscotthawley
commited on
Commit
•
3f93e88
1
Parent(s):
b98fe4a
mods to get ZeroGPU working
Browse files
app.py
CHANGED
@@ -42,15 +42,16 @@ def infer_mask_from_init_img(img, mask_with='grey'):
|
|
42 |
"note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
|
43 |
assert mask_with in ['blue','white','grey']
|
44 |
"given an image with mask areas marked, extract the mask itself"
|
|
|
45 |
if not torch.is_tensor(img):
|
46 |
img = ToTensor()(img)
|
47 |
-
print("img.shape: ", img.shape)
|
48 |
# shape of mask should be img shape without the channel dimension
|
49 |
if len(img.shape) == 3:
|
50 |
mask = torch.zeros(img.shape[-2:])
|
51 |
elif len(img.shape) == 2:
|
52 |
mask = torch.zeros(img.shape)
|
53 |
-
print("mask.shape: ", mask.shape)
|
54 |
if mask_with == 'white':
|
55 |
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
|
56 |
elif mask_with == 'blue':
|
|
|
42 |
"note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
|
43 |
assert mask_with in ['blue','white','grey']
|
44 |
"given an image with mask areas marked, extract the mask itself"
|
45 |
+
print("\n in infer_mask_from_init_img: ")
|
46 |
if not torch.is_tensor(img):
|
47 |
img = ToTensor()(img)
|
48 |
+
print(" img.shape: ", img.shape)
|
49 |
# shape of mask should be img shape without the channel dimension
|
50 |
if len(img.shape) == 3:
|
51 |
mask = torch.zeros(img.shape[-2:])
|
52 |
elif len(img.shape) == 2:
|
53 |
mask = torch.zeros(img.shape)
|
54 |
+
print(" mask.shape: ", mask.shape)
|
55 |
if mask_with == 'white':
|
56 |
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
|
57 |
elif mask_with == 'blue':
|
sample.py
CHANGED
@@ -522,7 +522,7 @@ def get_init_image_and_mask(args, device):
|
|
522 |
init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
|
523 |
return init_image.to(device), init_mask.to(device)
|
524 |
|
525 |
-
|
526 |
def main():
|
527 |
global init_image, init_mask
|
528 |
p = argparse.ArgumentParser(description=__doc__,
|
|
|
522 |
init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
|
523 |
return init_image.to(device), init_mask.to(device)
|
524 |
|
525 |
+
#@spaces.GPU # generates an error
|
526 |
def main():
|
527 |
global init_image, init_mask
|
528 |
p = argparse.ArgumentParser(description=__doc__,
|