word_graph_viz / app.py
gigant's picture
Update app.py
94ef5cb
raw
history blame
3.13 kB
import networkx as nx
import matplotlib.pyplot as plt
import jraph
import jax.numpy as jnp
from datasets import load_dataset
import spacy
import en_core_web_sm
dataset = load_dataset("gigant/tib_transcripts")
nlp = en_core_web_sm.load()
def dependency_parser(sentences):
return [nlp(sentence) for sentence in sentences]
def construct_dependency_graph(docs):
"""
docs is a list of outputs of the SpaCy dependency parser
"""
graphs = []
for doc in docs:
nodes = [token.text for token in doc]
senders = []
receivers = []
for token in doc:
for child in token.children:
senders.append(token.i)
receivers.append(child.i)
graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers})
return graphs
def to_jraph(graph):
nodes = graph["nodes"]
s = graph["senders"]
r = graph["receivers"]
# Define a three node graph, each node has an integer as its feature.
node_features = jnp.array([0]*len(nodes))
# We will construct a graph for which there is a directed edge between each node
# and its successor. We define this with `senders` (source nodes) and `receivers`
# (destination nodes).
senders = jnp.array(s)
receivers = jnp.array(r)
# We then save the number of nodes and the number of edges.
# This information is used to make running GNNs over multiple graphs
# in a GraphsTuple possible.
n_node = jnp.array([len(nodes)])
n_edge = jnp.array([len(s)])
return jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
edges=None, n_node=n_node, n_edge=n_edge, globals=None)
def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph:
nodes, edges, receivers, senders, _, _, _ = jraph_graph
nx_graph = nx.DiGraph()
if nodes is None:
for n in range(jraph_graph.n_node[0]):
nx_graph.add_node(n)
else:
for n in range(jraph_graph.n_node[0]):
nx_graph.add_node(n, node_feature=nodes[n])
if edges is None:
for e in range(jraph_graph.n_edge[0]):
nx_graph.add_edge(int(senders[e]), int(receivers[e]))
else:
for e in range(jraph_graph.n_edge[0]):
nx_graph.add_edge(
int(senders[e]), int(receivers[e]), edge_feature=edges[e])
return nx_graph
def plot_graph_sentence(sentence):
docs = dependency_parser([sentence])
graphs = construct_dependency_graph(docs)
g = to_jraph(graphs[0])
nx_graph = convert_jraph_to_networkx_graph(g)
pos = nx.spring_layout(nx_graph)
plot = plt.figure(figsize=(6, 6))
nx.draw(nx_graph, pos=pos, labels={i: e for i,e in enumerate(graphs[0]["nodes"])}, with_labels = True,
node_size=500, font_color='black', node_color="yellow")
return plot
def get_list_sentences(id):
return gr.update(choices = dataset["train"][id]["transcript"].split("."))
with gr.Blocks() as demo:
id = gr.Slider(maximum=len(dataset["train"]) - 1)
sentence = gr.Dropdown(choices = dataset["train"][0]["transcript"].split("."), interactive = True)
plot = gr.Plot()
id.change(get_list_sentences, id, sentence)
sentence.change(plot_graph_sentence, sentence, plot)
demo.launch()