Update app.py
Browse files
app.py
CHANGED
@@ -118,6 +118,25 @@ model = timm.create_model(
|
|
118 |
num_classes=9083,
|
119 |
) # type: VisionTransformer
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
safetensors.torch.load_model(model, "JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
|
122 |
model.eval()
|
123 |
|
@@ -134,10 +153,9 @@ def create_tags(image, threshold):
|
|
134 |
tensor = transform(img).unsqueeze(0)
|
135 |
|
136 |
with torch.no_grad():
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
values = probabilities[indices]
|
141 |
|
142 |
temp = []
|
143 |
tag_score = dict()
|
@@ -150,10 +168,10 @@ def create_tags(image, threshold):
|
|
150 |
|
151 |
with gr.Blocks() as demo:
|
152 |
gr.Markdown("""
|
153 |
-
## Joint Tagger Project: PILOT Demo
|
154 |
This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
|
155 |
|
156 |
-
This tagger is the result of joint efforts between members of the RedRocket team.
|
157 |
|
158 |
Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
|
159 |
""")
|
|
|
118 |
num_classes=9083,
|
119 |
) # type: VisionTransformer
|
120 |
|
121 |
+
class GatedHead(torch.nn.Module):
|
122 |
+
def __init__(self,
|
123 |
+
num_features: int,
|
124 |
+
num_classes: int
|
125 |
+
):
|
126 |
+
super().__init__()
|
127 |
+
self.num_classes = num_classes
|
128 |
+
self.linear = torch.nn.Linear(num_features, num_classes * 2)
|
129 |
+
|
130 |
+
self.act = torch.nn.Sigmoid()
|
131 |
+
self.gate = torch.nn.Sigmoid()
|
132 |
+
|
133 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
134 |
+
x = self.linear(x)
|
135 |
+
x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:])
|
136 |
+
return x
|
137 |
+
|
138 |
+
model.head = GatedHead(min(model.head.weight.shape), 9083)
|
139 |
+
|
140 |
safetensors.torch.load_model(model, "JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
|
141 |
model.eval()
|
142 |
|
|
|
153 |
tensor = transform(img).unsqueeze(0)
|
154 |
|
155 |
with torch.no_grad():
|
156 |
+
probits = model(tensor)
|
157 |
+
indices = torch.where(probits > threshold)[0]
|
158 |
+
values = probits[indices]
|
|
|
159 |
|
160 |
temp = []
|
161 |
tag_score = dict()
|
|
|
168 |
|
169 |
with gr.Blocks() as demo:
|
170 |
gr.Markdown("""
|
171 |
+
## Joint Tagger Project: JTP-PILOT² Demo **BETA**
|
172 |
This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
|
173 |
|
174 |
+
This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
|
175 |
|
176 |
Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
|
177 |
""")
|