--- tags: - pytorch library_tag: pytorch license: bsd-3-clause datasets: - imagenet-1k pipeline_tag: feature-extraction --- # Model card for resnet50-truncated.tv_in1k A truncated ResNet-50 feature extraction model, as used in [CLAM](https://www.nature.com/articles/s41551-020-00682-w). This model features: * ReLU activations * single layer 7x7 convolution with pooling * 1x1 convolution shortcut downsample Trained on ImageNet-1k, original torchvision model weight, truncated to exclude layer 4 and the fully connected layer. This uses trained weights distributed via PyTorch. This model card was adapted from https://huggingface.co/timm/resnet50.tv_in1k. ## Model Details - **Model Type:** Feature backbone - **Model Stats:** - Params (M): 8.5 - Image size: 224 x 224 - **Papers:** - Deep Residual Learning for Image Recognition: https://arxiv.org/abs/1512.03385 - Data-efficient and weakly supervised computational pathology on whole-slide images: https://www.nature.com/articles/s41551-020-00682-w - **Original:** https://github.com/pytorch/vision ## Model Creation ```python import types import torch from torchvision.models import ResNet from torchvision.models import resnet50 def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.avgpool(x) x = x.view(x.size(0), -1) return x model = resnet50(weights=None) del model.layer4, model.fc model._forward_impl = types.MethodType(_forward_impl, model) state_dict = torch.hub.load_state_dict_from_url( "https://download.pytorch.org/models/resnet50-19c8e357.pth" ) # Remove truncated keys. state_dict = {k: v for k, v in state_dict.items() if not k.startswith("layer4.") and not k.startswith("fc.")} model.load_state_dict(state_dict, strict=True) model.eval() ``` ## Model Usage ### Image Embeddings ```python from urllib.request import urlopen from PIL import Image import torch img = Image.open(urlopen( 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' )) # See above for how to load the model. Or load a TorchScript version of the model, # which can be loaded automatically with # model = torch.jit.load("torchscript_model.bin") model = model.eval() transform = transforms.Compose([ # Depending on the pipeline, this may be 256x256 or a different value. transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) with torch.no_grad(): output = model(transform(img).unsqueeze(0)) # unsqueeze single image into batch of 1 output.shape # 1x1024 ``` ## Citation ```bibtex @article{He2015, author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun}, title = {Deep Residual Learning for Image Recognition}, journal = {arXiv preprint arXiv:1512.03385}, year = {2015} } @article{lu2021data, title={Data-efficient and weakly supervised computational pathology on whole-slide images}, author={Lu, Ming Y and Williamson, Drew FK and Chen, Tiffany Y and Chen, Richard J and Barbieri, Matteo and Mahmood, Faisal}, journal={Nature Biomedical Engineering}, volume={5}, number={6}, pages={555--570}, year={2021}, publisher={Nature Publishing Group} } ```