import numpy as np
import pandas as pd
import datasets
import streamlit as st
from streamlit_cytoscapejs import st_cytoscapejs
import networkx as nx

st.set_page_config(layout='wide')

# parse out gene_ids from URL query args to it's possible to link to this page
query_params = st.query_params
if "gene_ids" in query_params.keys():
    input_gene_ids = query_params["gene_ids"]
else:
    input_gene_ids = "CNAG_04365,CNAG_06468"
    
# use "\n" as the separator so it shows correctly in the text area
input_gene_ids = input_gene_ids.replace(",", "\n")

if "coexp_score_threshold" in query_params.keys():
    coexp_score_threshold = query_params["coexp_score_threshold"]        
else:
    coexp_score_threshold = "0.85"

if "max_per_gene" in query_params.keys():
    max_per_gene = query_params["max_per_gene"]        
else:
    max_per_gene = "25"

st.markdown("""
# CryptoCEN Network
**CryptoCEN** is a co-expression network for *Cryptococcus neoformans* built on 1,524 RNA-seq runs across 34 studies.
A pair of genes are said to be co-expressed when their expression is correlated across different conditions and
is often a marker for genes to be involved in similar processes. 
To Cite:
    O'Meara MJ, Rapala JR, Nichols CB, Alexandre C, Billmyre RB, Steenwyk JL, A Alspaugh JA, O'Meara TR
    CryptoCEN: A Co-Expression Network for Cryptococcus neoformans reveals novel proteins involved in DNA damage repair.
    PLoS Genet 20(2): e1011158. (2024) https://doi.org/10.1371/journal.pgen.1011158
* Code available at https://github.com/maomlab/CalCEN/tree/master/vignettes/CryptoCEN
* Full network and dataset: https://huggingface.co/datasets/maomlab/CryptoCEN
## Plot a network for a set of genes
Put a ``CNAG_#####`` gene_id, one one each row to seed the network
""")


h99_transcript_annotations = datasets.load_dataset(
    path = "maomlab/CryptoCEN",
    data_files = {"h99_transcript_annotations": "h99_transcript_annotations.tsv"})
h99_transcript_annotations = h99_transcript_annotations["h99_transcript_annotations"].to_pandas()

top_coexp_hits = datasets.load_dataset(
    path = "maomlab/CryptoCEN",
    data_files = {"top_coexp_hits": "top_coexp_hits.tsv"})
top_coexp_hits = top_coexp_hits["top_coexp_hits"].to_pandas()


col1, col2, col3 = st.columns(spec = [0.3, 0.2, 0.5])
with col1:
    input_gene_ids = st.text_area(
        label = "Gene IDs",
        value = f"{input_gene_ids}",
        height = 130,
        help = "CNAG Gene ID e.g. CNAG_04365")


with col2:
    coexp_score_threshold = st.text_input(
        label = "Co-expression threshold [0-1]",
        value = f"{coexp_score_threshold}",
        help = "Default: 0.85")

    try:
        coexp_score_threshold = float(coexp_score_threshold)
    except:
        st.error(f"Co-expression threshold should be a number between 0 and 1, instead it is '{coexp_score_threshold}'")
    if coexp_score_threshold < 0 or 1 < coexp_score_threshold:
        st.error(f"Co-expression threshold should be a number between 0 and 1, instead it is '{coexp_score_threshold}'")

    max_per_gene = st.text_input(
        label = "Max per gene",
        value = f"{max_per_gene}",
        help = "Default: 25")

    try:
        max_per_gene = int(max_per_gene)
    except:
        st.error(f"Max per gene should be a number greater than 0, instead it is '{max_per_gene}'")
    if max_per_gene <= 0:
        st.error(f"Max per gene should be a number greater than 0, instead it is '{max_per_gene}'")
    
    
##################################
# Parse and check the user input #
##################################

seed_gene_ids = []
for input_gene_id in input_gene_ids.split("\n"):
    gene_id = input_gene_id.strip()
    if gene_id == "":
        continue
    else:
        seed_gene_ids.append(gene_id)

neighbors = []
for seed_gene_id in seed_gene_ids:
    hits = top_coexp_hits[
            (top_coexp_hits.gene_id_1 == seed_gene_id) & (top_coexp_hits.coexp_score > coexp_score_threshold)]
    if len(hits.index) > max_per_gene:
        hits = hits[0:max_per_gene]
    neighbors.append(hits)

neighbors = pd.concat(neighbors)
    
neighbor_gene_ids = list(set(neighbors.gene_id_2))
gene_ids = seed_gene_ids + neighbor_gene_ids
gene_types = ['seed'] * len(seed_gene_ids) + ['neighbor'] * len(neighbor_gene_ids)

cnag_ids = []
gene_products = []
descriptions = []
    
for gene_id in gene_ids:
    try:
        cnag_id = h99_transcript_annotations.loc[h99_transcript_annotations["gene_id"] == gene_id]["cnag_id"].values[0]
        gene_product = h99_transcript_annotations.loc[h99_transcript_annotations["gene_id"] == gene_id]["gene_product"].values[0]
        description = h99_transcript_annotations.loc[h99_transcript_annotations["gene_id"] == gene_id]["description"].values[0]
    except:
        st.error(f"Unable to locate cnag_id for Gene ID: '{gene_id}', it should be of the form 'cnag_#####'")
        cnag_id = None
        gene_product = None
        description = None

    cnag_ids.append(cnag_id)
    gene_products.append(gene_product)
    descriptions.append(description)

node_info = pd.DataFrame({
    "gene_index": range(len(gene_ids)),
    "gene_id" : gene_ids,
    "gene_type" : gene_types,
    "cnag_id": cnag_ids,
    "gene_product": gene_products,
    "description": description})

neighbors = neighbors.merge(
    right = node_info,
    left_on = "gene_id_1",
    right_on = "gene_id")

neighbors = neighbors.merge(
    right = node_info,
    left_on = "gene_id_2",
    right_on = "gene_id",
    suffixes = ("_a", "_b"))


################################
# Use NetworkX to layout graph #
################################
# note I think CytoscapeJS can layout graphs
# but I'm unsure how to do it through the streamlit-cytoscapejs interface :(

st.write(neighbors)


G = nx.Graph()
for i in range(len(neighbors.index)):
    edge = neighbors.iloc[i]
    G.add_edge(
        edge["gene_index_a"],
        edge["gene_index_b"],
        weight = edge["coexp_score"])
layout = nx.spring_layout(G)




node_color_lut = {
    "seed" : "#4866F0",    # blue
    "neighbor" : "#F0C547" # gold
}

elements = []
singleton_index = 0
for i in range(len(node_info.index)):
    node = node_info.iloc[i]

    if node["gene_index"] in layout.keys():
        layout_x = layout[node["gene_index"]][0] * 600 + 1500/2
        layout_y = layout[node["gene_index"]][1] * 600 + 1500/2
    else:
        layout_x = (singleton_index % 8) * 150 + 100
        layout_y = np.floor(singleton_index / 8) * 50 + 30
        singleton_index += 1

    
    elements.append({
        "data": {
            "id": node["gene_id"],
            "label": node["gene_product"] if node["gene_product"] is not None else node["gene_id"],
            "color": node_color_lut[node["gene_type"]]},
        "position": {
            "x" : layout_x,
            "y" : layout_y}})

for i in range(len(neighbors.index)):
    edge = neighbors.iloc[i]
    elements.append({
        "data" : {
            "source" : edge["gene_id_1"],
            "target" : edge["gene_id_2"],
            "width" :
                20 if edge["coexp_score"] > 0.99 else
                15 if edge["coexp_score"] > 0.96 else
                10 if edge["coexp_score"] > 0.94 else
                8  if edge["coexp_score"] > 0.89 else
                5}})

with col3:
    st.text('') # help alignment with input box
    st.download_button(
        label = "Download as as TSV",
        data = neighbors.to_csv(sep ='\t').encode('utf-8'),
        file_name = f"CryptoCEN_network.tsv",
        mime = "text/csv")

##########################################################

stylesheet = [
    {"selector": "node", "style": {
        "width": 140,
        "height": 30,
        "shape": "rectangle",
        "label" : "data(label)",
        "labelFontSize": 100,
        'background-color': 'data(color)',
        "text-halign": "center",
        "text-valign": "center",
    }},
    {"selector": "edge", "style": {
        "width": "data(width)"
    }}
]

st.title("ToxoCEN Network")
clicked_elements = st_cytoscapejs(
    elements = elements,
    stylesheet = stylesheet,
    width = 1000,
    height= 1000,
    key = "1")