|
import requests |
|
import torch |
|
from PIL import Image |
|
import hashlib |
|
import tempfile |
|
import unittest |
|
from io import BytesIO |
|
from pathlib import Path |
|
from unittest.mock import patch |
|
|
|
from urllib3 import HTTPResponse |
|
from urllib3._collections import HTTPHeaderDict |
|
|
|
import open_clip |
|
from open_clip.pretrained import download_pretrained_from_url |
|
|
|
|
|
class DownloadPretrainedTests(unittest.TestCase): |
|
|
|
def create_response(self, data, status_code=200, content_type='application/octet-stream'): |
|
fp = BytesIO(data) |
|
headers = HTTPHeaderDict({ |
|
'Content-Type': content_type, |
|
'Content-Length': str(len(data)) |
|
}) |
|
raw = HTTPResponse(fp, preload_content=False, headers=headers, status=status_code) |
|
return raw |
|
|
|
@patch('open_clip.pretrained.urllib') |
|
def test_download_pretrained_from_url_from_openaipublic(self, urllib): |
|
file_contents = b'pretrained model weights' |
|
expected_hash = hashlib.sha256(file_contents).hexdigest() |
|
urllib.request.urlopen.return_value = self.create_response(file_contents) |
|
with tempfile.TemporaryDirectory() as root: |
|
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' |
|
download_pretrained_from_url(url, root) |
|
urllib.request.urlopen.assert_called_once() |
|
|
|
@patch('open_clip.pretrained.urllib') |
|
def test_download_pretrained_from_url_from_openaipublic_corrupted(self, urllib): |
|
file_contents = b'pretrained model weights' |
|
expected_hash = hashlib.sha256(file_contents).hexdigest() |
|
urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') |
|
with tempfile.TemporaryDirectory() as root: |
|
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' |
|
with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): |
|
download_pretrained_from_url(url, root) |
|
urllib.request.urlopen.assert_called_once() |
|
|
|
@patch('open_clip.pretrained.urllib') |
|
def test_download_pretrained_from_url_from_openaipublic_valid_cache(self, urllib): |
|
file_contents = b'pretrained model weights' |
|
expected_hash = hashlib.sha256(file_contents).hexdigest() |
|
urllib.request.urlopen.return_value = self.create_response(file_contents) |
|
with tempfile.TemporaryDirectory() as root: |
|
local_file = Path(root) / 'RN50.pt' |
|
local_file.write_bytes(file_contents) |
|
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' |
|
download_pretrained_from_url(url, root) |
|
urllib.request.urlopen.assert_not_called() |
|
|
|
@patch('open_clip.pretrained.urllib') |
|
def test_download_pretrained_from_url_from_openaipublic_corrupted_cache(self, urllib): |
|
file_contents = b'pretrained model weights' |
|
expected_hash = hashlib.sha256(file_contents).hexdigest() |
|
urllib.request.urlopen.return_value = self.create_response(file_contents) |
|
with tempfile.TemporaryDirectory() as root: |
|
local_file = Path(root) / 'RN50.pt' |
|
local_file.write_bytes(b'corrupted pretrained model') |
|
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' |
|
download_pretrained_from_url(url, root) |
|
urllib.request.urlopen.assert_called_once() |
|
|
|
@patch('open_clip.pretrained.urllib') |
|
def test_download_pretrained_from_url_from_mlfoundations(self, urllib): |
|
file_contents = b'pretrained model weights' |
|
expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] |
|
urllib.request.urlopen.return_value = self.create_response(file_contents) |
|
with tempfile.TemporaryDirectory() as root: |
|
url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' |
|
download_pretrained_from_url(url, root) |
|
urllib.request.urlopen.assert_called_once() |
|
|
|
@patch('open_clip.pretrained.urllib') |
|
def test_download_pretrained_from_url_from_mlfoundations_corrupted(self, urllib): |
|
file_contents = b'pretrained model weights' |
|
expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] |
|
urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') |
|
with tempfile.TemporaryDirectory() as root: |
|
url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' |
|
with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): |
|
download_pretrained_from_url(url, root) |
|
urllib.request.urlopen.assert_called_once() |
|
|
|
@patch('open_clip.pretrained.urllib') |
|
def test_download_pretrained_from_hfh(self, urllib): |
|
model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:hf-internal-testing/tiny-open-clip-model') |
|
tokenizer = open_clip.get_tokenizer('hf-hub:hf-internal-testing/tiny-open-clip-model') |
|
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png" |
|
image = preprocess(Image.open(requests.get(img_url, stream=True).raw)).unsqueeze(0) |
|
text = tokenizer(["a diagram", "a dog", "a cat"]) |
|
|
|
with torch.no_grad(): |
|
image_features = model.encode_image(image) |
|
text_features = model.encode_text(text) |
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) |
|
|
|
self.assertTrue(torch.allclose(text_probs, torch.tensor([[0.0597, 0.6349, 0.3053]]), 1e-3)) |
|
|