DHO: Simple yet Effective Semi-supervised Knowledge Distillation from Vision-Language Models via Dual-Head Optimization

arXiv

This repository contains pretrained checkpoints for DHO (Dual-Head Optimization), a simple yet effective approach for semi-supervised knowledge distillation from Vision-Language Models.

Model Description

DHO introduces a dual-head optimization strategy that enables efficient knowledge transfer from large Vision-Language Models (e.g., CLIP) to smaller student models using minimal labeled data. The method achieves state-of-the-art performance on ImageNet semi-supervised learning benchmarks with only 1% and 10% labeled data.

Paper: Simple yet Effective Semi-supervised Knowledge Distillation from Vision-Language Models via Dual-Head Optimization

Authors: Seongjae Kang, Dong Bok Lee, Hyungjoon Jang, Sung Ju Hwang

Key Features

  • ✨ Dual-head optimization strategy for semi-supervised distillation
  • πŸ† State-of-the-art performance on ImageNet with 1% and 10% labeled data
  • πŸ”„ Efficient transfer from VLMs (e.g., CLIP) to smaller student models
  • 🧩 Simple, scalable, and easy to integrate into existing pipelines

Available Checkpoints

Checkpoint Name Student Model Teacher Model Labeled Data Top-1 Acc. Parameters
vit_b_1.pt ViT-B/16 ViT-H/14 (DFN5B) 1% 81.6% 86M
vit_b_10.pt ViT-B/16 ViT-H/14 (DFN5B) 10% 82.8% 86M
vit_l_1.pt ViT-L/14 ViT-H/14 (DFN5B) 1% 84.6% 304M
vit_l_10.pt ViT-L/14 ViT-H/14 (DFN5B) 10% 85.9% 304M

Usage

Loading a Checkpoint

import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
from huggingface_hub import hf_hub_download

# Define the DHO StudentModel architecture with dual heads
class StudentModel(nn.Module):
    def __init__(self, num_classes=1000, model_name='ViT-B-16'):
        super().__init__()
        # Load CLIP backbone
        clip_model, _ = clip.load(model_name, device='cpu')
        self.backbone = clip_model.float().visual
        
        # Feature dimensions per architecture
        in_features = {
            'RN50': 1024,
            'ViT-B-16': 512,
            'ViT-L-14': 768,
            'ViT-L-14-336px': 768
        }[model_name]
        
        # Dual-head architecture
        self.ce_head = nn.Linear(in_features, num_classes)  # CE branch
        self.kd_head = nn.Linear(in_features, num_classes)  # KD branch
    
    def forward(self, x):
        features = self.backbone(x)
        ce_out = self.ce_head(features)
        kd_out = self.kd_head(F.normalize(features, dim=1)) * 100
        return ce_out, kd_out

# Download and load checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint_path = hf_hub_download(repo_id="erjui/dho", filename="vit_b_10.pt")
checkpoint = torch.load(checkpoint_path, map_location=device)

# Initialize model
model = StudentModel(num_classes=1000, model_name='ViT-B-16').to(device)

# Handle DDP wrapped state_dict
state_dict = checkpoint['model_state_dict']
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)

# Get optimal inference parameters
alpha = checkpoint['alpha']  # Weight for CE head
beta = checkpoint['beta']    # Temperature for KD head
model.eval()

# Inference example
from PIL import Image
import torchvision.transforms as transforms

# CLIP preprocessing
preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                        std=(0.26862954, 0.26130258, 0.27577711))
])

image = preprocess(Image.open("path/to/image.jpg")).unsqueeze(0).to(device)
with torch.no_grad():
    ce_logits, kd_logits = model(image)
    
    # Combine predictions using saved parameters
    probs_ce = F.softmax(ce_logits, dim=1)
    probs_kd = F.softmax(kd_logits / beta, dim=1)
    probs = alpha * probs_ce + (1 - alpha) * probs_kd
    
    predicted_class = probs.argmax(dim=1)
    print(f"Predicted class: {predicted_class.item()}")

Important Notes:

  • DHO checkpoints contain: model_state_dict, epoch, acc, alpha, beta
  • The model has a dual-head architecture (CE head + KD head)
  • Use the saved alpha and beta parameters for optimal inference
  • For ViT-L checkpoints, change model_name='ViT-L-14' and use image size 224 (or 336 for ViT-L-14-336px)

Training Your Own Model

To train your own DHO model, please visit the official GitHub repository for detailed instructions and training scripts.

Example training command:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29500 train_imgnet_semi.py \
    --teacher_model "apple/DFN5B-CLIP-ViT-H-14-378" \
    --student_model "ViT-B-16" \
    --lr 5e-5 \
    --train_epoch 32 \
    --batch_size 256 \
    --percent 10.0 \
    | tee ./logs/imagenet/imgnet_lowshot.log

Model Architecture

The DHO student model consists of:

  • Backbone: CLIP Vision Transformer (ViT-B/16 or ViT-L/14)
  • Two parallel heads:
    • CE Head: Optimized with cross-entropy loss on labeled data
    • KD Head: Optimized with knowledge distillation loss from teacher predictions

During inference, predictions from both heads are combined using learned weighting parameters (alpha, beta).

Performance

ImageNet Semi-supervised Learning

Student Teacher Labeled Data Top-1 Accuracy
ViT-B/16 ViT-H/14 1% 81.6%
ViT-B/16 ViT-H/14 10% 82.8%
ViT-L/14 ViT-H/14 1% 84.6%
ViT-L/14 ViT-H/14 10% 85.9%

These results establish new state-of-the-art benchmarks for semi-supervised learning on ImageNet-1K.

Citation

If you use these models in your research, please cite:

@article{kang2025simple,
  title={Simple yet Effective Semi-supervised Knowledge Distillation from Vision-Language Models via Dual-Head Optimization},
  author={Kang, Seongjae and Lee, Dong Bok and Jang, Hyungjoon and Hwang, Sung Ju},
  journal={arXiv preprint arXiv:2505.07675},
  year={2025}
}

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Acknowledgments

We appreciate the open-source implementations from:

Contact

For questions or issues, please open an issue on the GitHub repository or contact the authors.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train erjui/dho