v2 / modules /discourse_analysis.py
AIdeaText's picture
Update modules/discourse_analysis.py
2f61f89 verified
raw
history blame contribute delete
No virus
2.66 kB
import streamlit as st
import spacy
import networkx as nx
import matplotlib.pyplot as plt
from collections import defaultdict
from .semantic_analysis import visualize_semantic_relations, create_semantic_graph, POS_COLORS, POS_TRANSLATIONS
##################################################################################################################
def compare_semantic_analysis(text1, text2, nlp, lang):
doc1 = nlp(text1)
doc2 = nlp(text2)
G1, pos_counts1 = create_semantic_graph(doc1, lang)
G2, pos_counts2 = create_semantic_graph(doc2, lang)
# Create two separate figures with a smaller size
fig1, ax1 = plt.subplots(figsize=(18, 13))
fig2, ax2 = plt.subplots(figsize=(18, 13))
# Draw the first graph
pos1 = nx.spring_layout(G1, k=0.7, iterations=50)
nx.draw(G1, pos1, ax=ax1, node_color=[POS_COLORS.get(G1.nodes[node]['pos'], '#CCCCCC') for node in G1.nodes()],
with_labels=True, node_size=4000, font_size=10, font_weight='bold',
arrows=True, arrowsize=20, width=2, edge_color='gray')
nx.draw_networkx_edge_labels(G1, pos1, edge_labels=nx.get_edge_attributes(G1, 'label'), font_size=8, ax=ax1)
# Draw the second graph
pos2 = nx.spring_layout(G2, k=0.7, iterations=50)
nx.draw(G2, pos2, ax=ax2, node_color=[POS_COLORS.get(G2.nodes[node]['pos'], '#CCCCCC') for node in G2.nodes()],
with_labels=True, node_size=4000, font_size=10, font_weight='bold',
arrows=True, arrowsize=20, width=2, edge_color='gray')
nx.draw_networkx_edge_labels(G2, pos2, edge_labels=nx.get_edge_attributes(G2, 'label'), font_size=8, ax=ax2)
ax1.set_title("Documento 1: Relaciones Semánticas Relevantes", fontsize=14, fontweight='bold')
ax2.set_title("Documento 2: Relaciones Semánticas Relevantes", fontsize=14, fontweight='bold')
ax1.axis('off')
ax2.axis('off')
# Add legends
legend_elements = [plt.Rectangle((0,0),1,1,fc=POS_COLORS.get(pos, '#CCCCCC'), edgecolor='none',
label=f"{POS_TRANSLATIONS[lang].get(pos, pos)}")
for pos in ['NOUN', 'VERB']]
ax1.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0, 1), fontsize=8)
ax2.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0, 1), fontsize=8)
plt.tight_layout()
return fig1, fig2
##################################################################################################################
def perform_discourse_analysis(text1, text2, nlp, lang):
graph1, graph2 = compare_semantic_analysis(text1, text2, nlp, lang)
return graph1, graph2