|
--- |
|
tags: |
|
- pytorch |
|
library_tag: pytorch |
|
license: bsd-3-clause |
|
datasets: |
|
- imagenet-1k |
|
pipeline_tag: feature-extraction |
|
--- |
|
|
|
# Model card for resnet50-truncated.tv_in1k |
|
|
|
A truncated ResNet-50 feature extraction model, as used in [CLAM](https://www.nature.com/articles/s41551-020-00682-w). |
|
|
|
This model features: |
|
* ReLU activations |
|
* single layer 7x7 convolution with pooling |
|
* 1x1 convolution shortcut downsample |
|
|
|
Trained on ImageNet-1k, original torchvision model weight, truncated to exclude layer 4 |
|
and the fully connected layer. This uses trained weights distributed via PyTorch. |
|
|
|
This model card was adapted from https://huggingface.co/timm/resnet50.tv_in1k. |
|
|
|
## Model Details |
|
- **Model Type:** Feature backbone |
|
- **Model Stats:** |
|
- Params (M): 8.5 |
|
- Image size: 224 x 224 |
|
- **Papers:** |
|
- Deep Residual Learning for Image Recognition: https://arxiv.org/abs/1512.03385 |
|
- Data-efficient and weakly supervised computational pathology on whole-slide images: https://www.nature.com/articles/s41551-020-00682-w |
|
- **Original:** https://github.com/pytorch/vision |
|
|
|
## Model Creation |
|
|
|
```python |
|
import types |
|
import torch |
|
from torchvision.models import ResNet |
|
from torchvision.models import resnet50 |
|
|
|
def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
|
|
x = self.avgpool(x) |
|
x = x.view(x.size(0), -1) |
|
|
|
return x |
|
|
|
model = resnet50(weights=None) |
|
del model.layer4, model.fc |
|
|
|
model._forward_impl = types.MethodType(_forward_impl, model) |
|
|
|
state_dict = torch.hub.load_state_dict_from_url( |
|
"https://download.pytorch.org/models/resnet50-19c8e357.pth" |
|
) |
|
# Remove truncated keys. |
|
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("layer4.") and not k.startswith("fc.")} |
|
|
|
model.load_state_dict(state_dict, strict=True) |
|
model.eval() |
|
``` |
|
|
|
## Model Usage |
|
|
|
### Image Embeddings |
|
```python |
|
from urllib.request import urlopen |
|
from PIL import Image |
|
import torch |
|
|
|
img = Image.open(urlopen( |
|
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' |
|
)) |
|
|
|
# See above for how to load the model. Or load a TorchScript version of the model, |
|
# which can be loaded automatically with |
|
# model = torch.jit.load("torchscript_model.bin") |
|
model = model.eval() |
|
|
|
transform = transforms.Compose([ |
|
# Depending on the pipeline, this may be 256x256 or a different value. |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=(0.485, 0.456, 0.406), |
|
std=(0.229, 0.224, 0.225)), |
|
]) |
|
|
|
with torch.no_grad(): |
|
output = model(transform(img).unsqueeze(0)) # unsqueeze single image into batch of 1 |
|
output.shape # 1x1024 |
|
``` |
|
|
|
## Citation |
|
```bibtex |
|
@article{He2015, |
|
author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun}, |
|
title = {Deep Residual Learning for Image Recognition}, |
|
journal = {arXiv preprint arXiv:1512.03385}, |
|
year = {2015} |
|
} |
|
@article{lu2021data, |
|
title={Data-efficient and weakly supervised computational pathology on whole-slide images}, |
|
author={Lu, Ming Y and Williamson, Drew FK and Chen, Tiffany Y and Chen, Richard J and Barbieri, Matteo and Mahmood, Faisal}, |
|
journal={Nature Biomedical Engineering}, |
|
volume={5}, |
|
number={6}, |
|
pages={555--570}, |
|
year={2021}, |
|
publisher={Nature Publishing Group} |
|
} |
|
``` |