Update README.md
Browse files
README.md
CHANGED
@@ -38,4 +38,133 @@ output ="""
|
|
38 |
- ์๋ด์์ด ์นด๋ ๋ฒํธ์ ์์ก ํ์ธ ํ ์ถ๊ฐ ์ด์ฉ ํํ ์๋ด
|
39 |
- ๊ณ ๊ฐ์ด ์ฌํ ํ ์ธ, ๋ง์ผ๋ฆฌ์ง, ํธํ
ํ ์ธ ๋ฑ ๋ค์ํ ํํ์ ๊ด์ฌ ํํ
|
40 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
```
|
|
|
38 |
- ์๋ด์์ด ์นด๋ ๋ฒํธ์ ์์ก ํ์ธ ํ ์ถ๊ฐ ์ด์ฉ ํํ ์๋ด
|
39 |
- ๊ณ ๊ฐ์ด ์ฌํ ํ ์ธ, ๋ง์ผ๋ฆฌ์ง, ํธํ
ํ ์ธ ๋ฑ ๋ค์ํ ํํ์ ๊ด์ฌ ํํ
|
40 |
"""
|
41 |
+
```
|
42 |
+
|
43 |
+
|
44 |
+
ํด๋น ๋ชจ๋ธ์ ํ์ฉํ๊ธฐ ์ํด์ ๋ค์๊ณผ ๊ฐ์ class ํ์
|
45 |
+
```
|
46 |
+
class LongformerSelfAttentionForBart(nn.Module):
|
47 |
+
def __init__(self, config, layer_id):
|
48 |
+
super().__init__()
|
49 |
+
self.embed_dim = config.d_model
|
50 |
+
self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
|
51 |
+
self.output = nn.Linear(self.embed_dim, self.embed_dim)
|
52 |
+
|
53 |
+
def forward(
|
54 |
+
self,
|
55 |
+
hidden_states: torch.Tensor,
|
56 |
+
key_value_states: Optional[torch.Tensor] = None,
|
57 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
58 |
+
attention_mask: Optional[torch.Tensor] = None,
|
59 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
60 |
+
output_attentions: bool = False,
|
61 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
62 |
+
|
63 |
+
is_cross_attention = key_value_states is not None
|
64 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
65 |
+
|
66 |
+
# bs x seq_len x seq_len -> bs x seq_len ์ผ๋ก ๋ณ๊ฒฝ
|
67 |
+
attention_mask = attention_mask.squeeze(dim=1)
|
68 |
+
attention_mask = attention_mask[:,0]
|
69 |
+
|
70 |
+
is_index_masked = attention_mask < 0
|
71 |
+
is_index_global_attn = attention_mask > 0
|
72 |
+
is_global_attn = is_index_global_attn.flatten().any().item()
|
73 |
+
|
74 |
+
outputs = self.longformer_self_attn(
|
75 |
+
hidden_states,
|
76 |
+
attention_mask=attention_mask,
|
77 |
+
layer_head_mask=None,
|
78 |
+
is_index_masked=is_index_masked,
|
79 |
+
is_index_global_attn=is_index_global_attn,
|
80 |
+
is_global_attn=is_global_attn,
|
81 |
+
output_attentions=output_attentions,
|
82 |
+
)
|
83 |
+
|
84 |
+
attn_output = self.output(outputs[0])
|
85 |
+
|
86 |
+
return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None, None)
|
87 |
+
```
|
88 |
+
|
89 |
+
```
|
90 |
+
class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration):
|
91 |
+
def __init__(self, config):
|
92 |
+
super().__init__(config)
|
93 |
+
|
94 |
+
if config.attention_mode == 'n2':
|
95 |
+
pass # do nothing, use BertSelfAttention instead
|
96 |
+
else:
|
97 |
+
|
98 |
+
self.model.encoder.embed_positions = BartLearnedPositionalEmbedding(
|
99 |
+
config.max_encoder_position_embeddings,
|
100 |
+
config.d_model)
|
101 |
+
|
102 |
+
self.model.decoder.embed_positions = BartLearnedPositionalEmbedding(
|
103 |
+
config.max_decoder_position_embeddings,
|
104 |
+
config.d_model)
|
105 |
+
|
106 |
+
for i, layer in enumerate(self.model.encoder.layers):
|
107 |
+
layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)
|
108 |
+
```
|
109 |
+
|
110 |
+
```
|
111 |
+
class LongformerEncoderDecoderConfig(BartConfig):
|
112 |
+
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
|
113 |
+
autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
|
114 |
+
gradient_checkpointing: bool = False, **kwargs):
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
attention_window: list of attention window sizes of length = number of layers.
|
118 |
+
window size = number of attention locations on each side.
|
119 |
+
For an affective window size of 512, use `attention_window=[256]*num_layers`
|
120 |
+
which is 256 on each side.
|
121 |
+
attention_dilation: list of attention dilation of length = number of layers.
|
122 |
+
attention dilation of `1` means no dilation.
|
123 |
+
autoregressive: do autoregressive attention or have attention of both sides
|
124 |
+
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
|
125 |
+
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
|
126 |
+
"""
|
127 |
+
super().__init__(**kwargs)
|
128 |
+
self.attention_window = attention_window
|
129 |
+
self.attention_dilation = attention_dilation
|
130 |
+
self.autoregressive = autoregressive
|
131 |
+
self.attention_mode = attention_mode
|
132 |
+
self.gradient_checkpointing = gradient_checkpointing
|
133 |
+
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
|
134 |
+
```
|
135 |
+
๋ชจ๋ธ ์ค๋ธ์ ํธ ๋ก๋ ํ
|
136 |
+
weightํ์ผ์ ๋ณ๋๋ก ๋ค์ด๋ฐ์์ load_state_dict๋ก ์จ์ดํธ๋ฅผ ๋ถ๋ฌ์ผ ํฉ๋๋ค.
|
137 |
+
```
|
138 |
+
tokenizer = AutoTokenizer.from_pretrained("cocoirun/longforemr-kobart-summary-v1")
|
139 |
+
model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained("cocoirun/longforemr-kobart-summary-v1")
|
140 |
+
device = torch.device('cuda')
|
141 |
+
model.load_state_dict(torch.load("summary weight.ckpt"))
|
142 |
+
model.to(device)
|
143 |
+
```
|
144 |
+
|
145 |
+
๋ชจ๋ธ ์์ฝ ํจ์
|
146 |
+
```
|
147 |
+
def summarize(text, max_len):
|
148 |
+
max_seq_len = 4096
|
149 |
+
context_tokens = ['<s>'] + tokenizer.tokenize(text) + ['</s>']
|
150 |
+
input_ids = tokenizer.convert_tokens_to_ids(context_tokens)
|
151 |
+
|
152 |
+
if len(input_ids) < max_seq_len:
|
153 |
+
while len(input_ids) < max_seq_len:
|
154 |
+
input_ids += [tokenizer.pad_token_id]
|
155 |
+
|
156 |
+
else:
|
157 |
+
input_ids = input_ids[:max_seq_len - 1] + [
|
158 |
+
tokenizer.eos_token_id]
|
159 |
+
|
160 |
+
res_ids = model.generate(torch.tensor([input_ids]).to(device),
|
161 |
+
max_length=max_len,
|
162 |
+
num_beams=5,
|
163 |
+
no_repeat_ngram_size = 3,
|
164 |
+
eos_token_id=tokenizer.eos_token_id,
|
165 |
+
bad_words_ids=[[tokenizer.unk_token_id]])
|
166 |
+
|
167 |
+
res = tokenizer.batch_decode(res_ids.tolist(), skip_special_tokens=True)[0]
|
168 |
+
res = res.replace("\n\n","\n")
|
169 |
+
return res
|
170 |
```
|