import tree_sitter
from tree_sitter import Language, Parser

Language.build_library("./build/my-languages.so", ['./tree-sitter-glsl'])
GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl')
parser = Parser()
parser.set_language(GLSL_LANGUAGE)

def replace_function(old_func_node, new_func_node):
    """
    replaces the old function node with the new function node
    """
    tree = give_tree(old_func_node)
    old_func_start, old_func_end = node_str_idx(old_func_node)
    # new_func_start, new_func_end = node_str_idx(new_func_node)
    new_code = tree.text[:old_func_start].decode() + new_func_node.text.decode() + tree.text[old_func_end:].decode()
    return new_code

def get_root(node):
    """
    returns the root node the tree of the given node (recursively)
    """
    if node.parent is None:
        return node
    else:
        return get_root(node.parent)

def node_str_idx(node):
    """
    returns the character index of start and end of a node
    """
    whole_text = get_root(node).text.decode()
    # start_idx = line_chr2char(whole_text, node.start_point[0], node.start_point[1])
    # end_idx = line_chr2char(whole_text, node.end_point[0], node.end_point[1])
    start_idx = node.start_byte #actual numbers?
    end_idx = node.end_byte
    return start_idx, end_idx

def give_tree(func_node):
    """
    return the tree where this function node is in
    """
    return parser.parse(func_node.parent.text) #really no better way?

def parse_functions(in_code):
    """
    returns all functions in the code as their actual nodes.
    includes any comment made directly after the function definition or diretly after #copilot trigger
    """
    tree = parser.parse(bytes(in_code, encoding="utf-8"))
    funcs = [n for n in tree.root_node.children if n.type == "function_definition"]

    return funcs


def get_docstrings(func_node):
    """
    returns the docstring of a function node
    """
    docstring = ""
    for node in func_node.children:
        if node.type == "comment": #comment in like the declarator
            docstring += node.text.decode()
        elif node.type == "compound_statement": #body below here
            for body_node in node.children:
                if body_node.type == "comment" or body_node.type == "{":
                    docstring += " " * body_node.start_point[1] #add in indentation
                    docstring += body_node.text.decode() + "\n" 
                else:
                    return docstring
    return docstring

def full_func_head(func_node) -> str:
    """
    returns function head including docstrings before any real body code
    """
    cursor = func_node.child_by_field_name("body").walk()
    cursor.goto_first_child()
    while cursor.node.type == "comment" or cursor.node.type == "{":
        last_char = cursor.node.end_byte
        cursor.goto_next_sibling()
    end = cursor.node.start_point
    # return "\n".join(func_node.text.decode().split("\n")[:(end[0]-func_node.start_point[0])])[:-(last_char)-1]
    return func_node.text[:(last_char - func_node.start_byte)].decode()

def grab_before_comments(func_node):
    """
    returns the comments that happen just before a function node
    """
    precomment = ""
    last_comment_line = 0
    start_byte = func_node.start_byte
    for node in func_node.parent.children: #could you optimize where to iterated from? directon?
        if node.start_point[0] != last_comment_line + 1:
            precomment = ""
        if node.type == "comment":
            if precomment == "":
                start_byte = node.start_byte
            precomment += node.text.decode() + "\n"
            last_comment_line = node.start_point[0]
        elif node == func_node:
            if precomment == "":
                start_byte = node.start_byte
            return precomment, start_byte
    return precomment, start_byte

def has_docstrings(func_node):
    """
    returns whether a function node has a docstring 
    """
    return get_docstrings(func_node).strip() != "{" or grab_before_comments(func_node)[0] != ""


def line_chr2char(text, line_idx, chr_idx):
    """
    ## just use strat_byte and end_byte instead!
    returns the character index at the given line and character index.
    """
    lines = text.split("\n")
    char_idx = 0
    for i in range(line_idx):
        try:
            char_idx += len(lines[i]) + 1
        except IndexError as e:
            raise IndexError(f"{i=} of {line_idx=} does not exist in {text=}") from e
    char_idx += chr_idx
    return char_idx