Edit model card

Overview

This is the the Pythia-160m model developed by EleutherAI fine-tuned using Cell2Sentence on full scRNA-seq cells. Cell2Sentence is a novel method for adapting large language models to single-cell transcriptomics. We transform single-cell RNA sequencing data into sequences of gene names ordered by expression level, termed "cell sentences". For more details, we refer to the paper linked below. This model was trained on the immune tissue dataset from Domínguez et al. using 8 A100 40GB GPUs for approximately 20 hours on the following tasks:

  1. conditional cell generation
  2. unconditional cell generation
  3. cell type prediction

Cell2Sentence Links:

GitHub: https://github.com/vandijklab/cell2sentence-ft
Paper: https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3

Pythia Links:

GitHub: https://github.com/EleutherAI/pythia
Paper: https://arxiv.org/abs/2304.01373
Hugging Face: https://huggingface.co/EleutherAI/pythia-160m

Evaluation

This model was evaluated on KNN classification and Gromov-Wasserstein (GW) distance. The label for a generated cell is the corresponding cell type used in its corresponding prompt for generation. Ground truth cells were sampled with replacement from a held out test dataset. The generated cells are converted to expression vectors using the method described in the paper. For complete details on the experiments, we refer to the paper.

Model k=3 NN (↑) k=5 NN (↑) k=10 NN (↑) k=25 NN (↑) GW (↓)
scGEN 0.2376 0.2330 0.2377 0.2335 315.9505
scVI 0.2436 0.2400 0.2425 0.2348 302.1285
scDiffusion 0.2335 0.2288 0.2368 0.2306 72.0208
scGPT 0.1838 0.1788 0.1811 0.1882 2989.8066
C2S (Pythia-160m) 0.2588 0.2565 0.2746 0.2715 54.3040

Sample Code

We provide an example of how to use the model to conditionally generate a cell equipped with a post-processing function to remove duplicate and invalid genes. In order to generate full cells, the max_length generation parameter should be changed to 9200. However, we recommend using an A100 GPU for inference speed and memory capacity if full cell generation is required. Unconditional cell generation and cell type prediction prompts are included as well, but we do not include an example cell sentence to format the prompt. We refer to the paper and GitHub repository for instructions on how to transform expression vectors into cell sentences.

import json
import re
from collections import Counter
from typing import List

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


def post_process_generated_cell_sentences(
    cell_sentence: str, 
    gene_dictionary: List
):
    """
    Post-processing function for generated cell sentences. 
    Invalid genes are removed and ranks of duplicated genes are averaged.

    Arguments:
        cell_sentence:              generated cell sentence string
        gene_dictionary:            list of gene vocabulary (all uppercase)

    Returns:
        post_processed_sentence:    generated cell sentence after post processing steps
    """
    generated_gene_names = cell_sentence.split(" ")
    generated_gene_names = [generated_gene.upper() for generated_gene in generated_gene_names]

    #--- Remove nonsense genes ---#
    generated_gene_names = [gene_name for gene_name in generated_gene_names if gene_name in gene_dictionary]

    #--- Average ranks ---#
    gene_name_to_occurrences = Counter(generated_gene_names)  # get mapping of gene name --> number of occurrences
    post_processed_sentence = generated_gene_names.copy()  # copy of generated gene list

    for gene_name in gene_name_to_occurrences:
        if gene_name_to_occurrences[gene_name] > 1 and gene_name != replace_nonsense_string:
            # Find positions of all occurrences of duplicated generated gene in list
            # Note: using post_processed_sentence here; since duplicates are being removed, list will be
            #   getting shorter. Getting indices in original list will no longer be accurate positions
            occurrence_positions = [idx for idx, elem in enumerate(post_processed_sentence) if elem == gene_name]
            average_position = int(sum(occurrence_positions) / len(occurrence_positions))

            # Remove occurrences
            post_processed_sentence = [elem for elem in post_processed_sentence if elem != gene_name]

            # Reinsert gene_name at average position
            post_processed_sentence.insert(average_position, gene_name)
    
    return post_processed_sentence

genes_path = "pbmc_vocab.json"

with open(vocab_path, "r") as f:
    gene_dictionary = json.load(f)

model_name = "vandijklab/pythia-160m-c2s"

model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16, 
        attn_implementation="flash_attention_2"
        ).to(torch.device("cuda"))
tokenizer = AutoTokenizer.from_pretrained(model_name)

cell_type = "T Cell"
ccg = f"Enumerate the genes in a {cell_type} cell with nonzero expression, from highest to lowest."

# Prompts for other forms a generation.
# ucg = "Display a cell's genes by expression level, in descending order."
# cellsentence = "CELL_SENTENCE"
# ctp = "Identify the cell type most likely associated with these highly expressed genes listed in descending order. "
#  + cellsentence +
#  "Name the cell type connected to these genes, ranked from highest to lowest expression."

tokens = tokenizer(ccg, return_tensors='pt')
input_ids = tokens['input_ids'].to(torch.device("cuda"))
attention_mask = tokens['attention_mask'].to(torch.device("cuda"))

with torch.no_grad():
    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        do_sample=True,
        max_length=1024,
        top_k=50,
        top_p=0.95,
    )

output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
cell_sentence = "".join(re.split(r"\?|\.|:", output_text)[1:]).strip()
processed_genes = post_process_generated_cell_sentences(cell_sentence, gene_dictionary)
Downloads last month
567
Safetensors
Model size
162M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train vandijklab/pythia-160m-c2s

Collection including vandijklab/pythia-160m-c2s