KoichiYasuoka commited on
Commit
b55b5d0
·
1 Parent(s): d4321fc

algorithm improved

Browse files
Files changed (1) hide show
  1. ud.py +11 -27
ud.py CHANGED
@@ -16,6 +16,8 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
16
  def postprocess(self,model_outputs,**kwargs):
17
  if "logits" not in model_outputs:
18
  return self.postprocess(model_outputs[0],**kwargs)
 
 
19
  m=model_outputs["logits"][0].numpy()
20
  e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
21
  z=e/e.sum(axis=-1,keepdims=True)
@@ -56,28 +58,10 @@ class UniversalDependenciesPipeline(BellmanFordTokenClassificationPipeline):
56
  self.right_arc[v]=0
57
  def postprocess(self,model_outputs,**kwargs):
58
  import torch
 
59
  if "logits" not in model_outputs:
60
  return self.postprocess(model_outputs[0],**kwargs)
61
- m=model_outputs["logits"][0].numpy()
62
- e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
63
- z=e/e.sum(axis=-1,keepdims=True)
64
- for i in range(m.shape[0]-1,0,-1):
65
- m[i-1]+=numpy.max(m[i]+self.transition,axis=1)
66
- k=[numpy.argmax(m[0]+self.transition[0])]
67
- for i in range(1,m.shape[0]):
68
- k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
69
- w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
70
- for i,t in reversed(list(enumerate(w))):
71
- p=t.pop("entity")
72
- if p.startswith("I-"):
73
- w[i-1]["score"]=min(w[i-1]["score"],t["score"])
74
- w[i-1]["end"]=w.pop(i)["end"]
75
- elif p.startswith("B-"):
76
- t["entity_group"]=p[2:]
77
- else:
78
- t["entity_group"]=p
79
- for t in w:
80
- t["text"]=model_outputs["sentence"][t["start"]:t["end"]]
81
  off=[(t["start"],t["end"]) for t in w]
82
  for i,(s,e) in reversed(list(enumerate(off))):
83
  if s<e:
@@ -132,11 +116,11 @@ class UniversalDependenciesPipeline(BellmanFordTokenClassificationPipeline):
132
  e[i,i+j]=m[k]+self.right_arc
133
  k+=1
134
  k+=1
135
- m,p=numpy.nanmax(e,axis=2),numpy.nanargmax(e,axis=2)
136
  h=self.chu_liu_edmonds(m)
137
  z=[i for i,j in enumerate(h) if i==j]
138
  if len(z)>1:
139
- k,h=z[numpy.nanargmax(m[z,z])],numpy.nanmin(m)-numpy.nanmax(m)
140
  m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
141
  h=self.chu_liu_edmonds(m)
142
  q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
@@ -146,7 +130,7 @@ class UniversalDependenciesPipeline(BellmanFordTokenClassificationPipeline):
146
  u+="\t".join([str(i+1),t[s:e],"_",q[i][0],"_","_" if len(q[i])<3 else "|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),"root" if q[i][-1]=="root" else q[i][-1][2:],"_","_" if i+1<len(off) and e<off[i+1][0] else "SpaceAfter=No"])+"\n"
147
  return u+"\n"
148
  def chu_liu_edmonds(self,matrix):
149
- h=numpy.nanargmax(matrix,axis=0)
150
  x=[-1 if i==j else j for i,j in enumerate(h)]
151
  for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
152
  y=[]
@@ -157,10 +141,10 @@ class UniversalDependenciesPipeline(BellmanFordTokenClassificationPipeline):
157
  if max(x)<0:
158
  return h
159
  y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
160
- z=matrix-numpy.nanmax(matrix,axis=0)
161
- m=numpy.block([[z[x,:][:,x],numpy.nanmax(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.nanmax(z[y,:][:,x],axis=0),numpy.nanmax(z[y,y])]])
162
- k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.nanargmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
163
  h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
164
- i=y[numpy.nanargmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
165
  h[i]=x[k[-1]] if k[-1]<len(x) else i
166
  return h
 
16
  def postprocess(self,model_outputs,**kwargs):
17
  if "logits" not in model_outputs:
18
  return self.postprocess(model_outputs[0],**kwargs)
19
+ return self.bellman_ford_token_classification(model_outputs,**kwargs)
20
+ def bellman_ford_token_classification(self,model_outputs,**kwargs):
21
  m=model_outputs["logits"][0].numpy()
22
  e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
23
  z=e/e.sum(axis=-1,keepdims=True)
 
58
  self.right_arc[v]=0
59
  def postprocess(self,model_outputs,**kwargs):
60
  import torch
61
+ kwargs["aggregation_strategy"]="simple"
62
  if "logits" not in model_outputs:
63
  return self.postprocess(model_outputs[0],**kwargs)
64
+ w=self.bellman_ford_token_classification(model_outputs,**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  off=[(t["start"],t["end"]) for t in w]
66
  for i,(s,e) in reversed(list(enumerate(off))):
67
  if s<e:
 
116
  e[i,i+j]=m[k]+self.right_arc
117
  k+=1
118
  k+=1
119
+ m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
120
  h=self.chu_liu_edmonds(m)
121
  z=[i for i,j in enumerate(h) if i==j]
122
  if len(z)>1:
123
+ k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
124
  m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
125
  h=self.chu_liu_edmonds(m)
126
  q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
 
130
  u+="\t".join([str(i+1),t[s:e],"_",q[i][0],"_","_" if len(q[i])<3 else "|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),"root" if q[i][-1]=="root" else q[i][-1][2:],"_","_" if i+1<len(off) and e<off[i+1][0] else "SpaceAfter=No"])+"\n"
131
  return u+"\n"
132
  def chu_liu_edmonds(self,matrix):
133
+ h=numpy.argmax(matrix,axis=0)
134
  x=[-1 if i==j else j for i,j in enumerate(h)]
135
  for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
136
  y=[]
 
141
  if max(x)<0:
142
  return h
143
  y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
144
+ z=matrix-numpy.max(matrix,axis=0)
145
+ m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
146
+ k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
147
  h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
148
+ i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
149
  h[i]=x[k[-1]] if k[-1]<len(x) else i
150
  return h