Spaces:
Runtime error
Runtime error
import glob | |
import itertools | |
from pathlib import Path | |
from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable | |
import string | |
import numpy as np | |
import torch | |
from scipy.spatial.distance import squareform, pdist, cdist | |
from Bio import SeqIO | |
#import biotite.structure as bs | |
#from biotite.structure.io.pdbx import PDBxFile, get_structure | |
#from biotite.database import rcsb | |
from tqdm import tqdm | |
import pandas as pd | |
# This is an efficient way to delete lowercase characters and insertion characters from a string | |
deletekeys = dict.fromkeys(string.ascii_lowercase) | |
deletekeys["."] = None | |
deletekeys["*"] = None | |
translation = str.maketrans(deletekeys) | |
def read_sequence(filename: str) -> Tuple[str, str]: | |
""" Reads the first (reference) sequences from a fasta or MSA file.""" | |
record = next(SeqIO.parse(filename, "fasta")) | |
return record.description, str(record.seq) | |
def remove_insertions(sequence: str) -> str: | |
""" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """ | |
return sequence.translate(translation) | |
def read_msa(filename: str) -> List[Tuple[str, str]]: | |
""" Reads the sequences from an MSA file, automatically removes insertions.""" | |
return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")] | |
def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]: | |
""" | |
Select sequences from the MSA to maximize the hamming distance | |
Alternatively, can use hhfilter | |
""" | |
assert mode in ("max", "min") | |
if len(msa) <= num_seqs: | |
return msa | |
array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8) | |
optfunc = np.argmax if mode == "max" else np.argmin | |
all_indices = np.arange(len(msa)) | |
indices = [0] | |
pairwise_distances = np.zeros((0, len(msa))) | |
for _ in range(num_seqs - 1): | |
dist = cdist(array[indices[-1:]], array, "hamming") | |
pairwise_distances = np.concatenate([pairwise_distances, dist]) | |
shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0) | |
shifted_index = optfunc(shifted_distance) | |
index = np.delete(all_indices, indices)[shifted_index] | |
indices.append(index) | |
indices = sorted(indices) | |
return [msa[idx] for idx in indices] |