espejelomar commited on
Commit
e149a99
1 Parent(s): 48fb736

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +76 -76
utils.py CHANGED
@@ -1,96 +1,96 @@
1
  import torch
2
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
3
- from datasets import load_dataset
4
- from PIL import Image
5
  import numpy as np
6
- import paddlehub as hub
7
- import random
8
- from PIL import ImageDraw,ImageFont
9
 
10
- import streamlit as st
11
 
12
- @st.experimental_singleton
13
- def load_bg_model():
14
- bg_model = hub.Module(name='U2NetP', directory='assets/models/')
15
- return bg_model
16
 
17
 
18
- bg_model = load_bg_model()
19
- def remove_bg(img):
20
- result = bg_model.Segmentation(
21
- images=[np.array(img)[:,:,::-1]],
22
- paths=None,
23
- batch_size=1,
24
- input_size=320,
25
- output_dir=None,
26
- visualization=False)
27
- output = result[0]
28
- mask=Image.fromarray(output['mask'])
29
- front=Image.fromarray(output['front'][:,:,::-1]).convert("RGBA")
30
- front.putalpha(mask)
31
- return front
32
 
33
- meme_template=Image.open("./assets/pigeon_meme.jpg").convert("RGBA")
34
- def make_meme(pigeon,text="Is this a pigeon?",show_text=True,remove_background=True):
35
 
36
- meme=meme_template.copy()
37
- approx_butterfly_center=(850,30)
38
 
39
- if remove_background:
40
- pigeon=remove_bg(pigeon)
41
 
42
- else:
43
- pigeon=Image.fromarray(pigeon).convert("RGBA")
44
 
45
- random_rotate=random.randint(-30,30)
46
- random_size=random.randint(150,200)
47
- pigeon=pigeon.resize((random_size,random_size)).rotate(random_rotate,expand=True)
48
 
49
- meme.alpha_composite(pigeon, approx_butterfly_center)
50
 
51
- #ref: https://blog.lipsumarium.com/caption-memes-in-python/
52
- def drawTextWithOutline(text, x, y):
53
- draw.text((x-2, y-2), text,(0,0,0),font=font)
54
- draw.text((x+2, y-2), text,(0,0,0),font=font)
55
- draw.text((x+2, y+2), text,(0,0,0),font=font)
56
- draw.text((x-2, y+2), text,(0,0,0),font=font)
57
- draw.text((x, y), text, (255,255,255), font=font)
58
 
59
- if show_text:
60
- draw = ImageDraw.Draw(meme)
61
- font_size=52
62
- font = ImageFont.truetype("assets/impact.ttf", font_size)
63
- w, h = draw.textsize(text, font) # measure the size the text will take
64
- drawTextWithOutline(text, meme.width/2 - w/2, meme.height - font_size*2)
65
- meme = meme.convert("RGB")
66
- return meme
67
 
68
- def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
69
- dataset=load_dataset(dataset_name)
70
- dataset=dataset.sort("sim_score")
71
- return dataset["train"]
72
 
73
- from transformers import BeitFeatureExtractor, BeitForImageClassification
74
- emb_feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
75
- emb_model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
76
- def embed(images):
77
- inputs = emb_feature_extractor(images=images, return_tensors="pt")
78
- outputs = emb_model(**inputs,output_hidden_states= True)
79
- last_hidden=outputs.hidden_states[-1]
80
- pooler=emb_model.base_model.pooler
81
- final_emb=pooler(last_hidden).detach().numpy()
82
- return final_emb
83
 
84
- def build_index():
85
- dataset=get_train_data()
86
- ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
87
- ds_with_embeddings.add_faiss_index(column='beit_embeddings')
88
- ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss')
89
 
90
- def get_dataset():
91
- dataset=get_train_data()
92
- dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
93
- return dataset
94
 
95
  def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512',model_version=None):
96
  gan = LightweightGAN.from_pretrained(model_name,version=model_version)
@@ -103,5 +103,5 @@ def generate(gan,batch_size=1):
103
  ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
104
  return ims
105
 
106
- def interpolate():
107
- pass
 
1
  import torch
2
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
3
+ # from datasets import load_dataset
4
+ # from PIL import Image
5
  import numpy as np
6
+ # import paddlehub as hub
7
+ # import random
8
+ # from PIL import ImageDraw,ImageFont
9
 
10
+ # import streamlit as st
11
 
12
+ # @st.experimental_singleton
13
+ # def load_bg_model():
14
+ # bg_model = hub.Module(name='U2NetP', directory='assets/models/')
15
+ # return bg_model
16
 
17
 
18
+ # bg_model = load_bg_model()
19
+ # def remove_bg(img):
20
+ # result = bg_model.Segmentation(
21
+ # images=[np.array(img)[:,:,::-1]],
22
+ # paths=None,
23
+ # batch_size=1,
24
+ # input_size=320,
25
+ # output_dir=None,
26
+ # visualization=False)
27
+ # output = result[0]
28
+ # mask=Image.fromarray(output['mask'])
29
+ # front=Image.fromarray(output['front'][:,:,::-1]).convert("RGBA")
30
+ # front.putalpha(mask)
31
+ # return front
32
 
33
+ # meme_template=Image.open("./assets/pigeon_meme.jpg").convert("RGBA")
34
+ # def make_meme(pigeon,text="Is this a pigeon?",show_text=True,remove_background=True):
35
 
36
+ # meme=meme_template.copy()
37
+ # approx_butterfly_center=(850,30)
38
 
39
+ # if remove_background:
40
+ # pigeon=remove_bg(pigeon)
41
 
42
+ # else:
43
+ # pigeon=Image.fromarray(pigeon).convert("RGBA")
44
 
45
+ # random_rotate=random.randint(-30,30)
46
+ # random_size=random.randint(150,200)
47
+ # pigeon=pigeon.resize((random_size,random_size)).rotate(random_rotate,expand=True)
48
 
49
+ # meme.alpha_composite(pigeon, approx_butterfly_center)
50
 
51
+ # #ref: https://blog.lipsumarium.com/caption-memes-in-python/
52
+ # def drawTextWithOutline(text, x, y):
53
+ # draw.text((x-2, y-2), text,(0,0,0),font=font)
54
+ # draw.text((x+2, y-2), text,(0,0,0),font=font)
55
+ # draw.text((x+2, y+2), text,(0,0,0),font=font)
56
+ # draw.text((x-2, y+2), text,(0,0,0),font=font)
57
+ # draw.text((x, y), text, (255,255,255), font=font)
58
 
59
+ # if show_text:
60
+ # draw = ImageDraw.Draw(meme)
61
+ # font_size=52
62
+ # font = ImageFont.truetype("assets/impact.ttf", font_size)
63
+ # w, h = draw.textsize(text, font) # measure the size the text will take
64
+ # drawTextWithOutline(text, meme.width/2 - w/2, meme.height - font_size*2)
65
+ # meme = meme.convert("RGB")
66
+ # return meme
67
 
68
+ # def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
69
+ # dataset=load_dataset(dataset_name)
70
+ # dataset=dataset.sort("sim_score")
71
+ # return dataset["train"]
72
 
73
+ # from transformers import BeitFeatureExtractor, BeitForImageClassification
74
+ # emb_feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
75
+ # emb_model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
76
+ # def embed(images):
77
+ # inputs = emb_feature_extractor(images=images, return_tensors="pt")
78
+ # outputs = emb_model(**inputs,output_hidden_states= True)
79
+ # last_hidden=outputs.hidden_states[-1]
80
+ # pooler=emb_model.base_model.pooler
81
+ # final_emb=pooler(last_hidden).detach().numpy()
82
+ # return final_emb
83
 
84
+ # def build_index():
85
+ # dataset=get_train_data()
86
+ # ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
87
+ # ds_with_embeddings.add_faiss_index(column='beit_embeddings')
88
+ # ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss')
89
 
90
+ # def get_dataset():
91
+ # dataset=get_train_data()
92
+ # dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
93
+ # return dataset
94
 
95
  def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512',model_version=None):
96
  gan = LightweightGAN.from_pretrained(model_name,version=model_version)
 
103
  ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
104
  return ims
105
 
106
+ # def interpolate():
107
+ # pass