Spaces:
Running
on
T4
Running
on
T4
mv-lab
commited on
Commit
•
39417b0
1
Parent(s):
616408c
InstructIR x HF
Browse files- .gitignore +4 -0
- app.py +157 -0
- configs/eval5d.yml +40 -0
- images/a0010.jpg +0 -0
- images/frog.png +0 -0
- images/gopro.png +0 -0
- images/gradio_demo_images/bear.png +0 -0
- images/gradio_demo_images/city.jpg +0 -0
- images/gradio_demo_images/frog.png +0 -0
- images/lol_1.png +0 -0
- images/lol_748.png +0 -0
- images/noise50.png +0 -0
- images/rain-020.png +0 -0
- models/instructir.py +134 -0
- models/nafnet.py +201 -0
- models/nafnet_utils.py +146 -0
- requirements_gradio.txt +6 -0
- text/models.py +65 -0
- text/sample_prompts.json +55 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.pt
|
3 |
+
*.gif
|
4 |
+
*.pth
|
app.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
#from gradio_imageslider import ImageSlider
|
11 |
+
|
12 |
+
## local code
|
13 |
+
from models import instructir
|
14 |
+
from text.models import LanguageModel, LMHead
|
15 |
+
|
16 |
+
|
17 |
+
def dict2namespace(config):
|
18 |
+
namespace = argparse.Namespace()
|
19 |
+
for key, value in config.items():
|
20 |
+
if isinstance(value, dict):
|
21 |
+
new_value = dict2namespace(value)
|
22 |
+
else:
|
23 |
+
new_value = value
|
24 |
+
setattr(namespace, key, new_value)
|
25 |
+
return namespace
|
26 |
+
|
27 |
+
|
28 |
+
CONFIG = "configs/eval5d.yml"
|
29 |
+
LM_MODEL = "models/lm_instructir-7d.pt"
|
30 |
+
MODEL_NAME = "models/im_instructir-7d.pt"
|
31 |
+
|
32 |
+
# parse config file
|
33 |
+
with open(os.path.join(CONFIG), "r") as f:
|
34 |
+
config = yaml.safe_load(f)
|
35 |
+
|
36 |
+
cfg = dict2namespace(config)
|
37 |
+
|
38 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
39 |
+
model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks,
|
40 |
+
middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim)
|
41 |
+
model = model.to(device)
|
42 |
+
print ("IMAGE MODEL CKPT:", MODEL_NAME)
|
43 |
+
model.load_state_dict(torch.load(MODEL_NAME, map_location="cpu"), strict=True)
|
44 |
+
|
45 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
46 |
+
LMODEL = cfg.llm.model
|
47 |
+
language_model = LanguageModel(model=LMODEL)
|
48 |
+
lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses)
|
49 |
+
lm_head = lm_head.to(device)
|
50 |
+
|
51 |
+
print("LMHEAD MODEL CKPT:", LM_MODEL)
|
52 |
+
lm_head.load_state_dict(torch.load(LM_MODEL, map_location="cpu"), strict=True)
|
53 |
+
|
54 |
+
|
55 |
+
def load_img (filename, norm=True,):
|
56 |
+
img = np.array(Image.open(filename).convert("RGB"))
|
57 |
+
if norm:
|
58 |
+
img = img / 255.
|
59 |
+
img = img.astype(np.float32)
|
60 |
+
return img
|
61 |
+
|
62 |
+
|
63 |
+
def process_img (image, prompt):
|
64 |
+
img = np.array(image)
|
65 |
+
img = img / 255.
|
66 |
+
img = img.astype(np.float32)
|
67 |
+
y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
|
68 |
+
|
69 |
+
lm_embd = language_model(prompt)
|
70 |
+
lm_embd = lm_embd.to(device)
|
71 |
+
|
72 |
+
with torch.no_grad():
|
73 |
+
text_embd, deg_pred = lm_head (lm_embd)
|
74 |
+
x_hat = model(y, text_embd)
|
75 |
+
|
76 |
+
restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
|
77 |
+
restored_img = np.clip(restored_img, 0. , 1.)
|
78 |
+
|
79 |
+
restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
|
80 |
+
return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img))
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
title = "InstructIR ✏️🖼️ 🤗"
|
85 |
+
description = ''' ## [High-Quality Image Restoration Following Human Instructions](https://github.com/mv-lab/InstructIR)
|
86 |
+
|
87 |
+
[Marcos V. Conde](https://scholar.google.com/citations?user=NtB1kjYAAAAJ&hl=en), [Gregor Geigle](https://scholar.google.com/citations?user=uIlyqRwAAAAJ&hl=en), [Radu Timofte](https://scholar.google.com/citations?user=u3MwH5kAAAAJ&hl=en)
|
88 |
+
|
89 |
+
Computer Vision Lab, University of Wuerzburg | Sony PlayStation, FTG
|
90 |
+
|
91 |
+
### TL;DR: quickstart
|
92 |
+
InstructIR takes as input an image and a human-written instruction for how to improve that image. The neural model performs all-in-one image restoration. InstructIR achieves state-of-the-art results on several restoration tasks including image denoising, deraining, deblurring, dehazing, and (low-light) image enhancement.
|
93 |
+
|
94 |
+
**🚀 You can start with the [demo tutorial](https://github.com/mv-lab/InstructIR/blob/main/demo.ipynb)**
|
95 |
+
|
96 |
+
<details>
|
97 |
+
<summary> <b> Abstract</b> (click me to read)</summary>
|
98 |
+
<p>
|
99 |
+
Image restoration is a fundamental problem that involves recovering a high-quality clean image from its degraded observation. All-In-One image restoration models can effectively restore images from various types and levels of degradation using degradation-specific information as prompts to guide the restoration model. In this work, we present the first approach that uses human-written instructions to guide the image restoration model. Given natural language prompts, our model can recover high-quality images from their degraded counterparts, considering multiple degradation types. Our method, InstructIR, achieves state-of-the-art results on several restoration tasks including image denoising, deraining, deblurring, dehazing, and (low-light) image enhancement. InstructIR improves +1dB over previous all-in-one restoration methods. Moreover, our dataset and results represent a novel benchmark for new research on text-guided image restoration and enhancement.
|
100 |
+
</p>
|
101 |
+
</details>
|
102 |
+
|
103 |
+
> Disclaimer: please remember this is not a product, thus, you will notice some limitations.
|
104 |
+
|
105 |
+
**This demo expects an image with some degradations (blur, noise, rain, low-light, haze) and a prompt requesting what should be done.**
|
106 |
+
Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
|
107 |
+
|
108 |
+
<br>
|
109 |
+
'''
|
110 |
+
# **Demo notebook can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Swin2SR/Perform_image_super_resolution_with_Swin2SR.ipynb).
|
111 |
+
|
112 |
+
article = "<p style='text-align: center'><a href='https://github.com/mv-lab/InstructIR' target='_blank'>High-Quality Image Restoration Following Human Instructions</a></p>"
|
113 |
+
|
114 |
+
examples = [['images/rain-020.png', "I love this photo, could you remove the raindrops? please keep the content intact"],
|
115 |
+
['images/gradio_demo_images/city.jpg', "I took this photo during a foggy day, can you improve it?"],
|
116 |
+
['images/gradio_demo_images/frog.png', "can you remove the tiny dots in the image? it is very unpleasant"],
|
117 |
+
["images/lol_748.png", "my image is too dark, I cannot see anything, can you fix it?"],
|
118 |
+
["images/gopro.png", "I took this photo while I was running, can you stabilize the image? it is too blurry"],
|
119 |
+
["images/a0010.jpg", "please I want this image for my photo album, can you edit it as a photographer"]]
|
120 |
+
|
121 |
+
css = """
|
122 |
+
.image-frame img, .image-container img {
|
123 |
+
width: auto;
|
124 |
+
height: auto;
|
125 |
+
max-width: none;
|
126 |
+
}
|
127 |
+
"""
|
128 |
+
|
129 |
+
demo = gr.Interface(
|
130 |
+
fn=process_img,
|
131 |
+
inputs=[
|
132 |
+
gr.Image(type="pil", label="Input"),
|
133 |
+
gr.Text(label="Prompt")
|
134 |
+
],
|
135 |
+
outputs=[gr.Image(type="pil", label="Ouput")], #ImageSlider(position=0.5, type="pil", label="SideBySide")], #gr.Image(type="pil", label="Ouput"), #
|
136 |
+
title=title,
|
137 |
+
description=description,
|
138 |
+
article=article,
|
139 |
+
examples=examples,
|
140 |
+
css=css,
|
141 |
+
)
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
demo.launch()
|
145 |
+
|
146 |
+
# with gr.Blocks() as demo:
|
147 |
+
# with gr.Row(equal_height=True):
|
148 |
+
# with gr.Column(scale=1):
|
149 |
+
# input = gr.Image(type="pil", label="Input")
|
150 |
+
# with gr.Column(scale=1):
|
151 |
+
# prompt = gr.Text(label="Prompt")
|
152 |
+
# process_btn = gr.Button("Process")
|
153 |
+
# with gr.Row(equal_height=True):
|
154 |
+
# output = gr.Image(type="pil", label="Ouput")
|
155 |
+
# slider = ImageSlider(position=0.5, type="pil", label="SideBySide")
|
156 |
+
# process_btn.click(fn=process_img, inputs=[input, prompt], outputs=[output, slider])
|
157 |
+
# demo.launch(share=True)
|
configs/eval5d.yml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
llm:
|
2 |
+
model: 'TaylorAI/bge-micro-v2' # See Paper Sec. 3.2 and Appendix
|
3 |
+
model_dim: 384
|
4 |
+
embd_dim: 256
|
5 |
+
nclasses: 7 # noise, blur, rain, haze, lol, enhancement, upsampling (Paper Sec. 4.3)
|
6 |
+
weights: False
|
7 |
+
|
8 |
+
model:
|
9 |
+
arch: "instructir"
|
10 |
+
use_text: True
|
11 |
+
in_ch: 3
|
12 |
+
out_ch: 3
|
13 |
+
width : 32
|
14 |
+
enc_blks: [2, 2, 4, 8]
|
15 |
+
middle_blk_num: 4
|
16 |
+
dec_blks: [2, 2, 2, 2]
|
17 |
+
textdim: 256
|
18 |
+
weights: False
|
19 |
+
|
20 |
+
test:
|
21 |
+
batch_size: 1
|
22 |
+
num_workers: 3
|
23 |
+
|
24 |
+
dn_datapath: "data/denoising_testsets/"
|
25 |
+
dn_datasets: ["CBSD68", "urban100", "Kodak24", "McMaster"]
|
26 |
+
dn_sigmas: [15, 25, 50]
|
27 |
+
|
28 |
+
rain_targets: ["data/Rain/rain_test/Rain100L/target/"]
|
29 |
+
rain_inputs: ["data/Rain/rain_test/Rain100L/input/"]
|
30 |
+
|
31 |
+
haze_targets: "data/SOTS-OUT/GT/"
|
32 |
+
haze_inputs : "data/SOTS-OUT/IN/"
|
33 |
+
|
34 |
+
lol_targets: "data/LOL/eval15/high/"
|
35 |
+
lol_inputs : "data/LOL/eval15/low/"
|
36 |
+
|
37 |
+
gopro_targets: "data/gopro_test/GoPro/target/"
|
38 |
+
gopro_inputs: "data/gopro_test/GoPro/input/"
|
39 |
+
|
40 |
+
|
images/a0010.jpg
ADDED
images/frog.png
ADDED
images/gopro.png
ADDED
images/gradio_demo_images/bear.png
ADDED
images/gradio_demo_images/city.jpg
ADDED
images/gradio_demo_images/frog.png
ADDED
images/lol_1.png
ADDED
images/lol_748.png
ADDED
images/noise50.png
ADDED
images/rain-020.png
ADDED
models/instructir.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn import init as init
|
6 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
7 |
+
|
8 |
+
from models.nafnet_utils import Local_Base, LayerNorm2d
|
9 |
+
from models.nafnet import SimpleGate, NAFBlock
|
10 |
+
|
11 |
+
|
12 |
+
class ICB(nn.Module):
|
13 |
+
"""
|
14 |
+
Instruction Condition Block (ICB)
|
15 |
+
Paper Section 3.3
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, feature_dim, text_dim=768):
|
19 |
+
super(ICB, self).__init__()
|
20 |
+
self.fc = nn.Linear(text_dim, feature_dim)
|
21 |
+
self.block = NAFBlock(feature_dim)
|
22 |
+
self.beta = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
|
23 |
+
self.gamma = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
|
24 |
+
|
25 |
+
def forward(self, x, text_embedding):
|
26 |
+
gating_factors = torch.sigmoid(self.fc(text_embedding))
|
27 |
+
gating_factors = gating_factors.unsqueeze(-1).unsqueeze(-1)
|
28 |
+
|
29 |
+
f = x * self.gamma + self.beta # 1) learned feature scaling/modulation
|
30 |
+
f = f * gating_factors # 2) (soft) feature routing based on text
|
31 |
+
f = self.block(f) # 3) block feature enhancement
|
32 |
+
return f + x
|
33 |
+
|
34 |
+
|
35 |
+
class InstructIR(nn.Module):
|
36 |
+
"""
|
37 |
+
InstructIR model using NAFNet (ECCV 2022) as backbone.
|
38 |
+
The model takes as input an RGB image and a text embedding (encoded instruction).
|
39 |
+
Described in Paper Section 3.3
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], txtdim=768):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
46 |
+
bias=True)
|
47 |
+
self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
|
48 |
+
bias=True)
|
49 |
+
|
50 |
+
self.encoders = nn.ModuleList()
|
51 |
+
self.decoders = nn.ModuleList()
|
52 |
+
self.middle_blks = nn.ModuleList()
|
53 |
+
self.ups = nn.ModuleList()
|
54 |
+
self.downs = nn.ModuleList()
|
55 |
+
self.enc_cond = nn.ModuleList()
|
56 |
+
self.dec_cond = nn.ModuleList()
|
57 |
+
|
58 |
+
chan = width
|
59 |
+
for num in enc_blk_nums:
|
60 |
+
self.encoders.append(
|
61 |
+
nn.Sequential(
|
62 |
+
*[NAFBlock(chan) for _ in range(num)]
|
63 |
+
)
|
64 |
+
)
|
65 |
+
|
66 |
+
self.enc_cond.append(ICB(chan, txtdim))
|
67 |
+
|
68 |
+
self.downs.append(
|
69 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
70 |
+
)
|
71 |
+
chan = chan * 2
|
72 |
+
|
73 |
+
self.middle_blks = nn.Sequential(
|
74 |
+
*[NAFBlock(chan) for _ in range(middle_blk_num)]
|
75 |
+
)
|
76 |
+
|
77 |
+
for num in dec_blk_nums:
|
78 |
+
self.ups.append(
|
79 |
+
nn.Sequential(
|
80 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
81 |
+
nn.PixelShuffle(2)
|
82 |
+
)
|
83 |
+
)
|
84 |
+
chan = chan // 2
|
85 |
+
self.decoders.append(
|
86 |
+
nn.Sequential(
|
87 |
+
*[NAFBlock(chan) for _ in range(num)]
|
88 |
+
)
|
89 |
+
)
|
90 |
+
# Add text embedding as modulation
|
91 |
+
self.dec_cond.append(ICB(chan, txtdim))
|
92 |
+
|
93 |
+
self.padder_size = 2 ** len(self.encoders)
|
94 |
+
|
95 |
+
def forward(self, inp, txtembd):
|
96 |
+
B, C, H, W = inp.shape
|
97 |
+
inp = self.check_image_size(inp)
|
98 |
+
|
99 |
+
x = self.intro(inp)
|
100 |
+
encs = []
|
101 |
+
|
102 |
+
for encoder, enc_mod, down in zip(self.encoders, self.enc_cond, self.downs):
|
103 |
+
x = encoder(x)
|
104 |
+
x = enc_mod(x, txtembd)
|
105 |
+
encs.append(x)
|
106 |
+
x = down(x)
|
107 |
+
|
108 |
+
x = self.middle_blks(x)
|
109 |
+
|
110 |
+
for decoder, up, enc_skip, dec_mod in zip(self.decoders, self.ups, encs[::-1], self.dec_cond):
|
111 |
+
x = up(x)
|
112 |
+
x = x + enc_skip
|
113 |
+
x = decoder(x)
|
114 |
+
x = dec_mod(x, txtembd)
|
115 |
+
|
116 |
+
x = self.ending(x)
|
117 |
+
x = x + inp
|
118 |
+
|
119 |
+
return x[:, :, :H, :W]
|
120 |
+
|
121 |
+
def check_image_size(self, x):
|
122 |
+
_, _, h, w = x.size()
|
123 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
124 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
125 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
def create_model(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2], txtdim=768):
|
130 |
+
|
131 |
+
net = InstructIR(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
|
132 |
+
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, txtdim=txtdim)
|
133 |
+
|
134 |
+
return net
|
models/nafnet.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
3 |
+
# ------------------------------------------------------------------------
|
4 |
+
# Source: https://github.com/megvii-research/NAFNet
|
5 |
+
|
6 |
+
'''
|
7 |
+
Simple Baselines for Image Restoration
|
8 |
+
|
9 |
+
@article{chen2022simple,
|
10 |
+
title={Simple Baselines for Image Restoration},
|
11 |
+
author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
|
12 |
+
journal={arXiv preprint arXiv:2204.04676},
|
13 |
+
year={2022}
|
14 |
+
}
|
15 |
+
'''
|
16 |
+
|
17 |
+
import math
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.nn import init as init
|
22 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
23 |
+
from models.nafnet_utils import Local_Base, LayerNorm2d
|
24 |
+
|
25 |
+
|
26 |
+
class SimpleGate(nn.Module):
|
27 |
+
def forward(self, x):
|
28 |
+
x1, x2 = x.chunk(2, dim=1)
|
29 |
+
return x1 * x2
|
30 |
+
|
31 |
+
class NAFBlock(nn.Module):
|
32 |
+
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
|
33 |
+
super().__init__()
|
34 |
+
dw_channel = c * DW_Expand
|
35 |
+
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
36 |
+
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
|
37 |
+
bias=True)
|
38 |
+
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
39 |
+
|
40 |
+
# Simplified Channel Attention
|
41 |
+
self.sca = nn.Sequential(
|
42 |
+
nn.AdaptiveAvgPool2d(1),
|
43 |
+
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
44 |
+
groups=1, bias=True),
|
45 |
+
)
|
46 |
+
|
47 |
+
# SimpleGate
|
48 |
+
self.sg = SimpleGate()
|
49 |
+
|
50 |
+
ffn_channel = FFN_Expand * c
|
51 |
+
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
52 |
+
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
53 |
+
|
54 |
+
self.norm1 = LayerNorm2d(c)
|
55 |
+
self.norm2 = LayerNorm2d(c)
|
56 |
+
|
57 |
+
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
58 |
+
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
59 |
+
|
60 |
+
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
61 |
+
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
62 |
+
|
63 |
+
def forward(self, inp):
|
64 |
+
x = inp
|
65 |
+
|
66 |
+
x = self.norm1(x)
|
67 |
+
|
68 |
+
x = self.conv1(x)
|
69 |
+
x = self.conv2(x)
|
70 |
+
x = self.sg(x)
|
71 |
+
x = x * self.sca(x)
|
72 |
+
x = self.conv3(x)
|
73 |
+
|
74 |
+
x = self.dropout1(x)
|
75 |
+
|
76 |
+
y = inp + x * self.beta
|
77 |
+
|
78 |
+
x = self.conv4(self.norm2(y))
|
79 |
+
x = self.sg(x)
|
80 |
+
x = self.conv5(x)
|
81 |
+
|
82 |
+
x = self.dropout2(x)
|
83 |
+
|
84 |
+
return y + x * self.gamma
|
85 |
+
|
86 |
+
|
87 |
+
class NAFNet(nn.Module):
|
88 |
+
|
89 |
+
def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
93 |
+
bias=True)
|
94 |
+
self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
|
95 |
+
bias=True)
|
96 |
+
|
97 |
+
self.encoders = nn.ModuleList()
|
98 |
+
self.decoders = nn.ModuleList()
|
99 |
+
self.middle_blks = nn.ModuleList()
|
100 |
+
self.ups = nn.ModuleList()
|
101 |
+
self.downs = nn.ModuleList()
|
102 |
+
|
103 |
+
chan = width
|
104 |
+
for num in enc_blk_nums:
|
105 |
+
self.encoders.append(
|
106 |
+
nn.Sequential(
|
107 |
+
*[NAFBlock(chan) for _ in range(num)]
|
108 |
+
)
|
109 |
+
)
|
110 |
+
self.downs.append(
|
111 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
112 |
+
)
|
113 |
+
chan = chan * 2
|
114 |
+
|
115 |
+
self.middle_blks = \
|
116 |
+
nn.Sequential(
|
117 |
+
*[NAFBlock(chan) for _ in range(middle_blk_num)]
|
118 |
+
)
|
119 |
+
|
120 |
+
for num in dec_blk_nums:
|
121 |
+
self.ups.append(
|
122 |
+
nn.Sequential(
|
123 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
124 |
+
nn.PixelShuffle(2)
|
125 |
+
)
|
126 |
+
)
|
127 |
+
chan = chan // 2
|
128 |
+
self.decoders.append(
|
129 |
+
nn.Sequential(
|
130 |
+
*[NAFBlock(chan) for _ in range(num)]
|
131 |
+
)
|
132 |
+
)
|
133 |
+
|
134 |
+
self.padder_size = 2 ** len(self.encoders)
|
135 |
+
|
136 |
+
def forward(self, inp):
|
137 |
+
B, C, H, W = inp.shape
|
138 |
+
inp = self.check_image_size(inp)
|
139 |
+
|
140 |
+
x = self.intro(inp)
|
141 |
+
|
142 |
+
encs = []
|
143 |
+
|
144 |
+
for encoder, down in zip(self.encoders, self.downs):
|
145 |
+
x = encoder(x)
|
146 |
+
encs.append(x)
|
147 |
+
x = down(x)
|
148 |
+
|
149 |
+
x = self.middle_blks(x)
|
150 |
+
|
151 |
+
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
|
152 |
+
x = up(x)
|
153 |
+
x = x + enc_skip
|
154 |
+
x = decoder(x)
|
155 |
+
|
156 |
+
x = self.ending(x)
|
157 |
+
x = x + inp
|
158 |
+
|
159 |
+
return x[:, :, :H, :W]
|
160 |
+
|
161 |
+
def check_image_size(self, x):
|
162 |
+
_, _, h, w = x.size()
|
163 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
164 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
165 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
166 |
+
return x
|
167 |
+
|
168 |
+
class NAFNetLocal(Local_Base, NAFNet):
|
169 |
+
def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
|
170 |
+
Local_Base.__init__(self)
|
171 |
+
NAFNet.__init__(self, *args, **kwargs)
|
172 |
+
|
173 |
+
N, C, H, W = train_size
|
174 |
+
base_size = (int(H * 1.5), int(W * 1.5))
|
175 |
+
|
176 |
+
self.eval()
|
177 |
+
with torch.no_grad():
|
178 |
+
self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
|
179 |
+
|
180 |
+
|
181 |
+
def create_nafnet(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2]):
|
182 |
+
"""
|
183 |
+
Create Nafnet model
|
184 |
+
https://github.com/megvii-research/NAFNet/blob/main/options/test/SIDD/NAFNet-width32.yml
|
185 |
+
"""
|
186 |
+
|
187 |
+
net = NAFNet(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
|
188 |
+
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
|
189 |
+
|
190 |
+
# inp_shape = (3, 256, 256)
|
191 |
+
|
192 |
+
# from ptflops import get_model_complexity_info
|
193 |
+
|
194 |
+
# macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
|
195 |
+
|
196 |
+
# params = float(params[:-3])
|
197 |
+
# macs = float(macs[:-4])
|
198 |
+
|
199 |
+
# print(macs, params)
|
200 |
+
|
201 |
+
return net
|
models/nafnet_utils.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
3 |
+
# ------------------------------------------------------------------------
|
4 |
+
# Source: https://github.com/megvii-research/NAFNet
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import math
|
11 |
+
|
12 |
+
class LayerNormFunction(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def forward(ctx, x, weight, bias, eps):
|
16 |
+
ctx.eps = eps
|
17 |
+
N, C, H, W = x.size()
|
18 |
+
mu = x.mean(1, keepdim=True)
|
19 |
+
var = (x - mu).pow(2).mean(1, keepdim=True)
|
20 |
+
y = (x - mu) / (var + eps).sqrt()
|
21 |
+
ctx.save_for_backward(y, var, weight)
|
22 |
+
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
23 |
+
return y
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def backward(ctx, grad_output):
|
27 |
+
eps = ctx.eps
|
28 |
+
|
29 |
+
N, C, H, W = grad_output.size()
|
30 |
+
y, var, weight = ctx.saved_variables
|
31 |
+
g = grad_output * weight.view(1, C, 1, 1)
|
32 |
+
mean_g = g.mean(dim=1, keepdim=True)
|
33 |
+
|
34 |
+
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
35 |
+
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
36 |
+
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
|
37 |
+
dim=0), None
|
38 |
+
|
39 |
+
class LayerNorm2d(nn.Module):
|
40 |
+
|
41 |
+
def __init__(self, channels, eps=1e-6):
|
42 |
+
super(LayerNorm2d, self).__init__()
|
43 |
+
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
44 |
+
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
45 |
+
self.eps = eps
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
class AvgPool2d(nn.Module):
|
53 |
+
def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
|
54 |
+
super().__init__()
|
55 |
+
self.kernel_size = kernel_size
|
56 |
+
self.base_size = base_size
|
57 |
+
self.auto_pad = auto_pad
|
58 |
+
|
59 |
+
# only used for fast implementation
|
60 |
+
self.fast_imp = fast_imp
|
61 |
+
self.rs = [5, 4, 3, 2, 1]
|
62 |
+
self.max_r1 = self.rs[0]
|
63 |
+
self.max_r2 = self.rs[0]
|
64 |
+
self.train_size = train_size
|
65 |
+
|
66 |
+
def extra_repr(self) -> str:
|
67 |
+
return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
|
68 |
+
self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if self.kernel_size is None and self.base_size:
|
73 |
+
train_size = self.train_size
|
74 |
+
if isinstance(self.base_size, int):
|
75 |
+
self.base_size = (self.base_size, self.base_size)
|
76 |
+
self.kernel_size = list(self.base_size)
|
77 |
+
self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
|
78 |
+
self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
|
79 |
+
|
80 |
+
# only used for fast implementation
|
81 |
+
self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
|
82 |
+
self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
|
83 |
+
|
84 |
+
if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
|
85 |
+
return F.adaptive_avg_pool2d(x, 1)
|
86 |
+
|
87 |
+
if self.fast_imp: # Non-equivalent implementation but faster
|
88 |
+
h, w = x.shape[2:]
|
89 |
+
if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
|
90 |
+
out = F.adaptive_avg_pool2d(x, 1)
|
91 |
+
else:
|
92 |
+
r1 = [r for r in self.rs if h % r == 0][0]
|
93 |
+
r2 = [r for r in self.rs if w % r == 0][0]
|
94 |
+
# reduction_constraint
|
95 |
+
r1 = min(self.max_r1, r1)
|
96 |
+
r2 = min(self.max_r2, r2)
|
97 |
+
s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
|
98 |
+
n, c, h, w = s.shape
|
99 |
+
k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
|
100 |
+
out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
|
101 |
+
out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
|
102 |
+
else:
|
103 |
+
n, c, h, w = x.shape
|
104 |
+
s = x.cumsum(dim=-1).cumsum_(dim=-2)
|
105 |
+
s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
|
106 |
+
k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
|
107 |
+
s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
|
108 |
+
out = s4 + s1 - s2 - s3
|
109 |
+
out = out / (k1 * k2)
|
110 |
+
|
111 |
+
if self.auto_pad:
|
112 |
+
n, c, h, w = x.shape
|
113 |
+
_h, _w = out.shape[2:]
|
114 |
+
# print(x.shape, self.kernel_size)
|
115 |
+
pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
|
116 |
+
out = torch.nn.functional.pad(out, pad2d, mode='replicate')
|
117 |
+
|
118 |
+
return out
|
119 |
+
|
120 |
+
def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
|
121 |
+
for n, m in model.named_children():
|
122 |
+
if len(list(m.children())) > 0:
|
123 |
+
## compound module, go inside it
|
124 |
+
replace_layers(m, base_size, train_size, fast_imp, **kwargs)
|
125 |
+
|
126 |
+
if isinstance(m, nn.AdaptiveAvgPool2d):
|
127 |
+
pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
|
128 |
+
assert m.output_size == 1
|
129 |
+
setattr(model, n, pool)
|
130 |
+
|
131 |
+
|
132 |
+
'''
|
133 |
+
ref.
|
134 |
+
@article{chu2021tlsc,
|
135 |
+
title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
|
136 |
+
author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
|
137 |
+
journal={arXiv preprint arXiv:2112.04491},
|
138 |
+
year={2021}
|
139 |
+
}
|
140 |
+
'''
|
141 |
+
class Local_Base():
|
142 |
+
def convert(self, *args, train_size, **kwargs):
|
143 |
+
replace_layers(self, *args, train_size=train_size, **kwargs)
|
144 |
+
imgs = torch.rand(train_size)
|
145 |
+
with torch.no_grad():
|
146 |
+
self.forward(imgs)
|
requirements_gradio.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
Pillow>=6.2.2
|
4 |
+
sentence-transformers==2.3.0
|
5 |
+
gradio==4.16.0
|
6 |
+
#gradio_imageslider==0.0.18
|
text/models.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Models that use mean pooling
|
8 |
+
POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"}
|
9 |
+
|
10 |
+
#Mean Pooling - Take attention mask into account for correct averaging
|
11 |
+
def mean_pooling(model_output, attention_mask):
|
12 |
+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
13 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
14 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
15 |
+
|
16 |
+
|
17 |
+
class LanguageModel(nn.Module):
|
18 |
+
def __init__(self, model='distilbert-base-uncased'):
|
19 |
+
super(LanguageModel, self).__init__()
|
20 |
+
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
22 |
+
self.model = AutoModel.from_pretrained(model)
|
23 |
+
self.model_name = model
|
24 |
+
# Remove the CLIP vision tower
|
25 |
+
if "clip" in self.model_name:
|
26 |
+
self.model.vision_model = None
|
27 |
+
# Freeze the pre-trained parameters (very important)
|
28 |
+
for param in self.model.parameters():
|
29 |
+
param.requires_grad = False
|
30 |
+
|
31 |
+
# Make sure to set evaluation mode (also important)
|
32 |
+
self.model.eval()
|
33 |
+
|
34 |
+
def forward(self, text_batch):
|
35 |
+
inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt")
|
36 |
+
with torch.no_grad(): # Ensure no gradients are computed for this forward pass
|
37 |
+
|
38 |
+
if "clip" in self.model_name:
|
39 |
+
sentence_embedding = self.model.get_text_features(**inputs)
|
40 |
+
return sentence_embedding
|
41 |
+
|
42 |
+
outputs = self.model(**inputs)
|
43 |
+
|
44 |
+
if any(model in self.model_name for model in POOL_MODELS):
|
45 |
+
sentence_embeddings = mean_pooling(outputs, inputs['attention_mask'])
|
46 |
+
# Normalize embeddings
|
47 |
+
sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
|
48 |
+
else:
|
49 |
+
sentence_embedding = outputs.last_hidden_state[:, 0, :]
|
50 |
+
return sentence_embedding
|
51 |
+
|
52 |
+
|
53 |
+
class LMHead(nn.Module):
|
54 |
+
def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4):
|
55 |
+
super(LMHead, self).__init__()
|
56 |
+
|
57 |
+
self.fc1 = nn.Linear(embedding_dim, hidden_dim)
|
58 |
+
#self.gelu = nn.GELU()
|
59 |
+
self.fc2 = nn.Linear(hidden_dim, num_classes)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
embd = self.fc1(x)
|
63 |
+
embd = F.normalize(embd, p=2, dim=1)
|
64 |
+
deg_pred = self.fc2(embd)
|
65 |
+
return embd, deg_pred
|
text/sample_prompts.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"denoising": [
|
3 |
+
"Help me reduce the fuzziness in this image.",
|
4 |
+
"I need this image denoised ASAP.",
|
5 |
+
"Clean up this noisy image, it's an eyesore.",
|
6 |
+
"Can you clean the dots from my image?",
|
7 |
+
"Help me with my picture, it's full of tiny spots.",
|
8 |
+
"Clean up this image, it's all grainy."
|
9 |
+
],
|
10 |
+
"deblurring": [
|
11 |
+
"Please, clean up this blurry photo.",
|
12 |
+
"My picture's not sharp, fix it.",
|
13 |
+
"Deblur my picture, it's too fuzzy.",
|
14 |
+
"Help, my photo is too blurry.",
|
15 |
+
"Please, make my image less smudgy."
|
16 |
+
],
|
17 |
+
"dehazing": [
|
18 |
+
"Please, fix the haziness in my image.",
|
19 |
+
"I need to remove the haziness from this image.",
|
20 |
+
"Get rid of the fog in my image.",
|
21 |
+
"Fix my photo, it's too misty.",
|
22 |
+
"Help me, my photo is all hazy."
|
23 |
+
],
|
24 |
+
"deraining": [
|
25 |
+
"I want to eliminate the water from this image.",
|
26 |
+
"Clear the rain from my picture.",
|
27 |
+
"I need to clear the rain from this image.",
|
28 |
+
"Can you get rid of the raindrops in my picture?"
|
29 |
+
],
|
30 |
+
"sr": [
|
31 |
+
"I need to enhance the size and quality of this image.",
|
32 |
+
"My photo is lacking size and clarity; can you improve it?",
|
33 |
+
"I'd appreciate it if you could upscale this photo.",
|
34 |
+
"My picture is too little, enlarge it."
|
35 |
+
],
|
36 |
+
"ambiguous": [
|
37 |
+
"Please, clear up the mess on this image.",
|
38 |
+
"I want this image to look good.",
|
39 |
+
"make it pop",
|
40 |
+
"Fix my photo, it's all messed up."
|
41 |
+
],
|
42 |
+
"lol": [
|
43 |
+
"I took this photo during night, enhance it",
|
44 |
+
"The photo is too dark, improve exposure",
|
45 |
+
"my image has poor lighting conditions, can you fix it?",
|
46 |
+
"Can you make the image brighter?"
|
47 |
+
],
|
48 |
+
"enhancement": [
|
49 |
+
"make my image look like DSLR",
|
50 |
+
"improve the colors of my image",
|
51 |
+
"enhance the colors of the image",
|
52 |
+
"Can you edit this to look like an award-winning photo?",
|
53 |
+
"I want the picture to be retouched for a professional portfolio."
|
54 |
+
]
|
55 |
+
}
|