kaczmarj's picture
Update README.md
247ca43
---
tags:
- pytorch
library_tag: pytorch
license: bsd-3-clause
datasets:
- imagenet-1k
pipeline_tag: feature-extraction
---
# Model card for resnet50-truncated.tv_in1k
A truncated ResNet-50 feature extraction model, as used in [CLAM](https://www.nature.com/articles/s41551-020-00682-w).
This model features:
* ReLU activations
* single layer 7x7 convolution with pooling
* 1x1 convolution shortcut downsample
Trained on ImageNet-1k, original torchvision model weight, truncated to exclude layer 4
and the fully connected layer. This uses trained weights distributed via PyTorch.
This model card was adapted from https://huggingface.co/timm/resnet50.tv_in1k.
## Model Details
- **Model Type:** Feature backbone
- **Model Stats:**
- Params (M): 8.5
- Image size: 224 x 224
- **Papers:**
- Deep Residual Learning for Image Recognition: https://arxiv.org/abs/1512.03385
- Data-efficient and weakly supervised computational pathology on whole-slide images: https://www.nature.com/articles/s41551-020-00682-w
- **Original:** https://github.com/pytorch/vision
## Model Creation
```python
import types
import torch
from torchvision.models import ResNet
from torchvision.models import resnet50
def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
model = resnet50(weights=None)
del model.layer4, model.fc
model._forward_impl = types.MethodType(_forward_impl, model)
state_dict = torch.hub.load_state_dict_from_url(
"https://download.pytorch.org/models/resnet50-19c8e357.pth"
)
# Remove truncated keys.
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("layer4.") and not k.startswith("fc.")}
model.load_state_dict(state_dict, strict=True)
model.eval()
```
## Model Usage
### Image Embeddings
```python
from urllib.request import urlopen
from PIL import Image
import torch
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
# See above for how to load the model. Or load a TorchScript version of the model,
# which can be loaded automatically with
# model = torch.jit.load("torchscript_model.bin")
model = model.eval()
transform = transforms.Compose([
# Depending on the pipeline, this may be 256x256 or a different value.
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)),
])
with torch.no_grad():
output = model(transform(img).unsqueeze(0)) # unsqueeze single image into batch of 1
output.shape # 1x1024
```
## Citation
```bibtex
@article{He2015,
author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun},
title = {Deep Residual Learning for Image Recognition},
journal = {arXiv preprint arXiv:1512.03385},
year = {2015}
}
@article{lu2021data,
title={Data-efficient and weakly supervised computational pathology on whole-slide images},
author={Lu, Ming Y and Williamson, Drew FK and Chen, Tiffany Y and Chen, Richard J and Barbieri, Matteo and Mahmood, Faisal},
journal={Nature Biomedical Engineering},
volume={5},
number={6},
pages={555--570},
year={2021},
publisher={Nature Publishing Group}
}
```