LegNet - Cell Type Specific Models

LegNet model with weights trained on different cell types.

Available Cell Types:

  • hepg2 - HepG2 cell line
  • k562 - K562 cell line
  • wtc11 - WTC11 cell line

Usage:

from model_loader import load_cell_type_model

# Load model for HepG2
model = load_cell_type_model("hepg2")

# Load model for K562
model = load_cell_type_model("k562")

If you want to download weights

def get_device():
    """Automatically detects available device"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")
        
# Load Pre-Trained Model Weights for Human Legnet
def download_and_load_model(cell_type="k562", repo_id="Ni-os/Human_Legnet", device=None):
    # Download main config
    config_path = hf_hub_download(
        repo_id=repo_id,
        filename="config.json"
    )
    
    # Load config
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    # Create model
    model = LegNet(
        in_ch=config["in_ch"],
        stem_ch=config["stem_ch"],
        stem_ks=config["stem_ks"], 
        ef_ks=config["ef_ks"],
        ef_block_sizes=config["ef_block_sizes"],
        pool_sizes=config["pool_sizes"],
        resize_factor=config["resize_factor"],
        activation=torch.nn.SiLU
    ).to(device)
    
    # Determine which weight file to download
    weight_files = {
        "hepg2": "weights/hepg2_best_model_test1_val2.safetensors",
        "k562": "weights/k562_best_model_test1_val2.safetensors", 
        "wtc11": "weights/wtc11_best_model_test1_val2.safetensors"
    }
    
    # Download weights
    weights_path = hf_hub_download(
        repo_id=repo_id,
        filename=weight_files[cell_type.lower()]
    )
    
    # Load weights into model
    state_dict = load_file(weights_path)
    model.load_state_dict(state_dict)
    model.eval()
    print(f"โœ… Model for {cell_type} loaded!")
    return model

device = get_device()
    
print("Loading pre-trained model weights for Human Legnet")
model_human_legnet = download_and_load_model("hepg2", device = device)
Downloads last month
178
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support