KoichiYasuoka
commited on
Commit
·
b55b5d0
1
Parent(s):
d4321fc
algorithm improved
Browse files
ud.py
CHANGED
@@ -16,6 +16,8 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
|
|
16 |
def postprocess(self,model_outputs,**kwargs):
|
17 |
if "logits" not in model_outputs:
|
18 |
return self.postprocess(model_outputs[0],**kwargs)
|
|
|
|
|
19 |
m=model_outputs["logits"][0].numpy()
|
20 |
e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
|
21 |
z=e/e.sum(axis=-1,keepdims=True)
|
@@ -56,28 +58,10 @@ class UniversalDependenciesPipeline(BellmanFordTokenClassificationPipeline):
|
|
56 |
self.right_arc[v]=0
|
57 |
def postprocess(self,model_outputs,**kwargs):
|
58 |
import torch
|
|
|
59 |
if "logits" not in model_outputs:
|
60 |
return self.postprocess(model_outputs[0],**kwargs)
|
61 |
-
|
62 |
-
e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
|
63 |
-
z=e/e.sum(axis=-1,keepdims=True)
|
64 |
-
for i in range(m.shape[0]-1,0,-1):
|
65 |
-
m[i-1]+=numpy.max(m[i]+self.transition,axis=1)
|
66 |
-
k=[numpy.argmax(m[0]+self.transition[0])]
|
67 |
-
for i in range(1,m.shape[0]):
|
68 |
-
k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
|
69 |
-
w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
|
70 |
-
for i,t in reversed(list(enumerate(w))):
|
71 |
-
p=t.pop("entity")
|
72 |
-
if p.startswith("I-"):
|
73 |
-
w[i-1]["score"]=min(w[i-1]["score"],t["score"])
|
74 |
-
w[i-1]["end"]=w.pop(i)["end"]
|
75 |
-
elif p.startswith("B-"):
|
76 |
-
t["entity_group"]=p[2:]
|
77 |
-
else:
|
78 |
-
t["entity_group"]=p
|
79 |
-
for t in w:
|
80 |
-
t["text"]=model_outputs["sentence"][t["start"]:t["end"]]
|
81 |
off=[(t["start"],t["end"]) for t in w]
|
82 |
for i,(s,e) in reversed(list(enumerate(off))):
|
83 |
if s<e:
|
@@ -132,11 +116,11 @@ class UniversalDependenciesPipeline(BellmanFordTokenClassificationPipeline):
|
|
132 |
e[i,i+j]=m[k]+self.right_arc
|
133 |
k+=1
|
134 |
k+=1
|
135 |
-
m,p=numpy.
|
136 |
h=self.chu_liu_edmonds(m)
|
137 |
z=[i for i,j in enumerate(h) if i==j]
|
138 |
if len(z)>1:
|
139 |
-
k,h=z[numpy.
|
140 |
m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
|
141 |
h=self.chu_liu_edmonds(m)
|
142 |
q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
|
@@ -146,7 +130,7 @@ class UniversalDependenciesPipeline(BellmanFordTokenClassificationPipeline):
|
|
146 |
u+="\t".join([str(i+1),t[s:e],"_",q[i][0],"_","_" if len(q[i])<3 else "|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),"root" if q[i][-1]=="root" else q[i][-1][2:],"_","_" if i+1<len(off) and e<off[i+1][0] else "SpaceAfter=No"])+"\n"
|
147 |
return u+"\n"
|
148 |
def chu_liu_edmonds(self,matrix):
|
149 |
-
h=numpy.
|
150 |
x=[-1 if i==j else j for i,j in enumerate(h)]
|
151 |
for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
|
152 |
y=[]
|
@@ -157,10 +141,10 @@ class UniversalDependenciesPipeline(BellmanFordTokenClassificationPipeline):
|
|
157 |
if max(x)<0:
|
158 |
return h
|
159 |
y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
|
160 |
-
z=matrix-numpy.
|
161 |
-
m=numpy.block([[z[x,:][:,x],numpy.
|
162 |
-
k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.
|
163 |
h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
|
164 |
-
i=y[numpy.
|
165 |
h[i]=x[k[-1]] if k[-1]<len(x) else i
|
166 |
return h
|
|
|
16 |
def postprocess(self,model_outputs,**kwargs):
|
17 |
if "logits" not in model_outputs:
|
18 |
return self.postprocess(model_outputs[0],**kwargs)
|
19 |
+
return self.bellman_ford_token_classification(model_outputs,**kwargs)
|
20 |
+
def bellman_ford_token_classification(self,model_outputs,**kwargs):
|
21 |
m=model_outputs["logits"][0].numpy()
|
22 |
e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
|
23 |
z=e/e.sum(axis=-1,keepdims=True)
|
|
|
58 |
self.right_arc[v]=0
|
59 |
def postprocess(self,model_outputs,**kwargs):
|
60 |
import torch
|
61 |
+
kwargs["aggregation_strategy"]="simple"
|
62 |
if "logits" not in model_outputs:
|
63 |
return self.postprocess(model_outputs[0],**kwargs)
|
64 |
+
w=self.bellman_ford_token_classification(model_outputs,**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
off=[(t["start"],t["end"]) for t in w]
|
66 |
for i,(s,e) in reversed(list(enumerate(off))):
|
67 |
if s<e:
|
|
|
116 |
e[i,i+j]=m[k]+self.right_arc
|
117 |
k+=1
|
118 |
k+=1
|
119 |
+
m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
|
120 |
h=self.chu_liu_edmonds(m)
|
121 |
z=[i for i,j in enumerate(h) if i==j]
|
122 |
if len(z)>1:
|
123 |
+
k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
|
124 |
m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
|
125 |
h=self.chu_liu_edmonds(m)
|
126 |
q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
|
|
|
130 |
u+="\t".join([str(i+1),t[s:e],"_",q[i][0],"_","_" if len(q[i])<3 else "|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),"root" if q[i][-1]=="root" else q[i][-1][2:],"_","_" if i+1<len(off) and e<off[i+1][0] else "SpaceAfter=No"])+"\n"
|
131 |
return u+"\n"
|
132 |
def chu_liu_edmonds(self,matrix):
|
133 |
+
h=numpy.argmax(matrix,axis=0)
|
134 |
x=[-1 if i==j else j for i,j in enumerate(h)]
|
135 |
for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
|
136 |
y=[]
|
|
|
141 |
if max(x)<0:
|
142 |
return h
|
143 |
y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
|
144 |
+
z=matrix-numpy.max(matrix,axis=0)
|
145 |
+
m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
|
146 |
+
k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
|
147 |
h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
|
148 |
+
i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
|
149 |
h[i]=x[k[-1]] if k[-1]<len(x) else i
|
150 |
return h
|