File size: 3,724 Bytes
3709d9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch.nn.functional as F

placeholder = 'ACATGCTAAATTAGTTGGCAATTTTTTCTCAGGTAGCTGGGCACAATTTGGTAGTCCAGTTGAACAAAATCCATTAGCTTCTTTTAGCAAGTCCCCTGGTTTGGGCCCTGCCAGTCCCATTAATACCAACCATTTGTCTGGATTGGCTGCAATTCTTTCCCCACAAGCAACAACCTCTACCAAGATTGCACCGATTGGCAAGGACCCTGGAAGGGCTGCAAATCAGATGTTTTCTAACTCTGGATCAACACAAGGAGCAGCTTTTCAGCATTCTATATCCTTTCCTGAGCAAAATGTAAAGGCAAGTCCTAGGCCTATATCTACTTTTGGTGAATCAAGTTCTAGTGCATCAAGTATTGGAACACTGTCCGGTCCTCAATTTCTTTGGGGAAGCCCAACTCCTTACTCTGAGCATTCAAACACTTCTGCCTGGTCTTCATCTTCGGTGGGGCTTCCATTTACATCTAGTGTCCAAAGGCAGGGTTTCCCATATACTAGTAATCACAGTCCTTTTCTTGGCTCCCACTCTCATCATCATGTTGGATCTGCTCCATCTGGCCTTCCGCTTGATAGGCATTTTAGCTACTTCCCTGAGTCACCTGAAGCTTCTCTCATGAGCCCGGTTGCATTTGGGAATTTAAATCACGGTGATGGGAATTTTATGATGAACAACATTAGTGCTCGTGCATCTGTAGGAGCCGGTGTTGGTCTTTCTGGAAATACCCCTGAAATTAGTTCACCCAATTTCAGAATGATGTCTCTGCCTAGGCATGGTTCCTTGTTCCATGGAAATAGTTTGTATTCTGGACCTGGAGCAACTAACATTGAGGGATTAGCTGAACGTGGACGAAGTAGACGACCTGAAAATGGTGGGAACCAAATTGATAGTAAGAAGCTGTACCAGCTTGATCTTGACAAAATCGTCTGTGGTGAAGATACAAGGACTACTTTAATGATTAAAAACATTCCTAACAAGTAAGAATAACTAAACATCTATCCT'
model_names = ['plant-dnabert', 'plant-dnagpt', 'plant-nucleotide-transformer', 'plant-dnagemma',
               'dnabert2', 'nucleotide-transformer-v2-100m', 'agront-1b']
tokenizer_type = "6mer"
model_names = [x + '-' + tokenizer_type if x.startswith("plant") else x for x in model_names]
task_map = {
            "promoter": ["Not promoter", "Core promoter"],
            "conservation": ["Not conserved", "Conserved"],
            "H3K27ac": ["Not H3K27ac", "H3K27ac"],
            "H3K27me3": ["Not H3K27me3", "H3K27me3"],
            "H3K4me3": ["Not H3K4me3", "H3K4me3"],
            "lncRNAs": ["Not lncRNA", "lncRNA"],
            "open_chromatin": ['Not open chromatin', 'Full open chromatin', 'Partial open chromatin'],
            }
task_lists = task_map.keys()

def inference(seq,model,task):
    if not seq:
        gr.Warning("No sequence provided, use the default sequence.")
        seq = placeholder
    # Load model and tokenizer
    model_name = f'zhangtaolab/{model}-{task}'
    model = AutoModelForSequenceClassification.from_pretrained(model_name,ignore_mismatched_sizes=True)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Inference
    inputs = tokenizer(seq, return_tensors='pt', padding=True, truncation=True, max_length=512)
    outputs = model(**inputs)
    probabilities = F.softmax(outputs.logits,dim=-1).tolist()[0]  
    #Map probabilities to labels
    labels = task_map[task]
    result = {labels[i]: probabilities[i] for i in range(len(labels))}
    return result


# Create Gradio interface
with gr.Blocks() as demo:
    gr.HTML(
        """
        <h1 style="text-align: center;">Prediction of sequence conservation in plant with LLMs</h1>
        """
    )
    with gr.Row():
        drop1 = gr.Dropdown(choices=task_lists,
                            label="Selected Task",
                            interactive=False,
                            value='conservation')
        drop2 = gr.Dropdown(choices=model_names,
                            label="Select Model",
                            interactive=True,
                            value=model_names[0])
    seq_input = gr.Textbox(label="Input Sequence", lines=6, placeholder=placeholder)
    with gr.Row():
        predict_btn = gr.Button("Predict",variant="primary")
        clear_btn = gr.Button("Clear")
    output = gr.Label(label="Predict result")
    
    predict_btn.click(inference, inputs=[seq_input,drop2, drop1], outputs=output)
    clear_btn.click(lambda: ("", None), inputs=[], outputs=[seq_input, output])

# Launch Gradio app
demo.launch()