kaczmarj's picture
Update README.md
247ca43
|
raw
history blame
3.42 kB
metadata
tags:
  - pytorch
library_tag: pytorch
license: bsd-3-clause
datasets:
  - imagenet-1k
pipeline_tag: feature-extraction

Model card for resnet50-truncated.tv_in1k

A truncated ResNet-50 feature extraction model, as used in CLAM.

This model features:

  • ReLU activations
  • single layer 7x7 convolution with pooling
  • 1x1 convolution shortcut downsample

Trained on ImageNet-1k, original torchvision model weight, truncated to exclude layer 4 and the fully connected layer. This uses trained weights distributed via PyTorch.

This model card was adapted from https://huggingface.co/timm/resnet50.tv_in1k.

Model Details

Model Creation

import types
import torch
from torchvision.models import ResNet
from torchvision.models import resnet50

def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)

    return x

model = resnet50(weights=None)
del model.layer4, model.fc

model._forward_impl = types.MethodType(_forward_impl, model)

state_dict = torch.hub.load_state_dict_from_url(
    "https://download.pytorch.org/models/resnet50-19c8e357.pth"
)
# Remove truncated keys.
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("layer4.") and not k.startswith("fc.")}

model.load_state_dict(state_dict, strict=True)
model.eval()

Model Usage

Image Embeddings

from urllib.request import urlopen
from PIL import Image
import torch

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

# See above for how to load the model. Or load a TorchScript version of the model,
# which can be loaded automatically with
# model = torch.jit.load("torchscript_model.bin")
model = model.eval()

transform = transforms.Compose([
    # Depending on the pipeline, this may be 256x256 or a different value.
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225)),
])

with torch.no_grad():
    output = model(transform(img).unsqueeze(0))  # unsqueeze single image into batch of 1
output.shape  # 1x1024

Citation

@article{He2015,
  author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun},
  title = {Deep Residual Learning for Image Recognition},
  journal = {arXiv preprint arXiv:1512.03385},
  year = {2015}
}
@article{lu2021data,
  title={Data-efficient and weakly supervised computational pathology on whole-slide images},
  author={Lu, Ming Y and Williamson, Drew FK and Chen, Tiffany Y and Chen, Richard J and Barbieri, Matteo and Mahmood, Faisal},
  journal={Nature Biomedical Engineering},
  volume={5},
  number={6},
  pages={555--570},
  year={2021},
  publisher={Nature Publishing Group}
}