Access the embeddings of MEGAGENE-1?

#2
by iLOVE2D - opened

Hi, thanks for your interesting work. I just wondered if there is any approach we can obtain the dna sequence emebddings. I only see the generation step. Thanks.

METAGENE · Metagenomic Foundation Model org
edited 5 days ago

Yeah of course! Below is a sample script to extract the representations of our model. We are also releasing our evaluation script for the Gene-MTEB experiments later this week. Stay tuned :-)

import numpy as np
from transformers import AutoModel, AutoTokenizer
from transformers.trainer_utils import set_seed
import torch

model_name_or_path = "metagene-ai/METAGENE-1"
seed = 42
set_seed(seed)

batch_size = 32

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModel.from_pretrained(
    model_name_or_path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="cuda" if torch.cuda.is_available() else "auto")

model.eval()

sentences = [
    "ACTG",
    "CCCTAGC"
]

embeddings = []

for i in range(0, len(sentences), batch_size):
    batch = sentences[i:i + batch_size]

    inputs = tokenizer(
        batch,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(model.device)

    # Remove `token_type_ids` if it exists
    if "token_type_ids" in inputs:
        del inputs["token_type_ids"]

    with torch.no_grad():
        outputs = model(**inputs)
        batch_embeddings = outputs.last_hidden_state.mean(dim=1)
    embeddings.extend(batch_embeddings.cpu().to(torch.float32).numpy())

embeddings = np.array(embeddings)
oliu-io changed discussion status to closed

Thank you!

I suggest add more description for called function, e.g., set_seed is not defined here. If this step is used for inference, do we need to set random seed? Thanks.

METAGENE · Metagenomic Foundation Model org

Apologies. I don't think there is any randomness in our code, but we decided to include set_seed(which is the standard implementation in HF: from transformers.trainer_utils import set_seed) here as a standard practice when we potentially evaluate on a shuffled dataset of samples. I've added the library imports in the original post above, and lmk if there's any other question.

oliu-io changed discussion status to open

Sign up or log in to comment