Spaces:
Sleeping
Sleeping
File size: 6,767 Bytes
969d94d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import numpy as np
import torch
import torch.nn as nn
class Attention(nn.Module):
"""Applies attention mechanism on the `context` using the `query`.
Args:
dimensions (int): Dimensionality of the query and context.
attention_type (str, optional): How to compute the attention score:
* dot: :math:`score(H_j,q) = H_j^T q`
* general: :math:`score(H_j, q) = H_j^T W_a q`
Example:
>>> attention = Attention(256)
>>> query = torch.randn(32, 50, 256)
>>> context = torch.randn(32, 1, 256)
>>> output, weights = attention(query, context)
>>> output.size()
torch.Size([32, 50, 256])
>>> weights.size()
torch.Size([32, 50, 1])
"""
def __init__(self, dimensions):
super(Attention, self).__init__()
self.dimensions = dimensions
self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False)
self.softmax = nn.Softmax(dim=1)
self.tanh = nn.Tanh()
def forward(self, query, context, attention_mask):
"""
Args:
query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of
queries to query the context.
context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data
overwhich to apply the attention mechanism.
output length: length of utterance
query length: length of each token (1)
Returns:
:class:`tuple` with `output` and `weights`:
* **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]):
Tensor containing the attended features.
* **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]):
Tensor containing attention weights.
"""
# query = self.linear_query(query)
batch_size, output_len, hidden_size = query.size()
# query_len = context.size(1)
# (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) ->
# (batch_size, output_len, query_len)
attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous())
# Compute weights across every context sequence
# attention_scores = attention_scores.view(batch_size * output_len, query_len)
if attention_mask is not None:
# Create attention mask, apply attention mask before softmax
attention_mask = torch.unsqueeze(attention_mask, 2)
# attention_mask = attention_mask.view(batch_size * output_len, query_len)
attention_scores.masked_fill_(attention_mask == 0, -np.inf)
# attention_scores = torch.squeeze(attention_scores,1)
attention_weights = self.softmax(attention_scores)
# attention_weights = attention_weights.view(batch_size, output_len, query_len)
# (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) ->
# (batch_size, output_len, dimensions)
mix = torch.bmm(attention_weights, context)
# from IPython import embed; embed()
# concat -> (batch_size * output_len, 2*dimensions)
combined = torch.cat((mix, query), dim=2)
# combined = combined.view(batch_size * output_len, 2 * self.dimensions)
# Apply linear_out on every 2nd dimension of concat
# output -> (batch_size, output_len, dimensions)
# output = self.linear_out(combined).view(batch_size, output_len, self.dimensions)
output = self.linear_out(combined)
output = self.tanh(output)
# output = combined
return output, attention_weights
class IntentClassifier(nn.Module):
def __init__(self, input_dim, num_intent_labels, dropout_rate=0.0):
super(IntentClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_intent_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
class SlotClassifier(nn.Module):
def __init__(
self,
input_dim,
num_intent_labels,
num_slot_labels,
use_intent_context_concat=False,
use_intent_context_attn=False,
max_seq_len=50,
attention_embedding_size=200,
dropout_rate=0.0,
):
super(SlotClassifier, self).__init__()
self.use_intent_context_attn = use_intent_context_attn
self.use_intent_context_concat = use_intent_context_concat
self.max_seq_len = max_seq_len
self.num_intent_labels = num_intent_labels
self.num_slot_labels = num_slot_labels
self.attention_embedding_size = attention_embedding_size
output_dim = self.attention_embedding_size # base model
if self.use_intent_context_concat:
output_dim = self.attention_embedding_size
self.linear_out = nn.Linear(2 * attention_embedding_size, attention_embedding_size)
elif self.use_intent_context_attn:
output_dim = self.attention_embedding_size
self.attention = Attention(attention_embedding_size)
self.linear_slot = nn.Linear(input_dim, self.attention_embedding_size, bias=False)
if self.use_intent_context_attn or self.use_intent_context_concat:
# project intent vector and slot vector to have the same dimensions
self.linear_intent_context = nn.Linear(self.num_intent_labels, self.attention_embedding_size, bias=False)
self.softmax = nn.Softmax(dim=-1) # softmax layer for intent logits
# self.linear_out = nn.Linear(2 * intent_embedding_size, intent_embedding_size)
# output
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(output_dim, num_slot_labels)
def forward(self, x, intent_context, attention_mask):
x = self.linear_slot(x)
if self.use_intent_context_concat:
intent_context = self.softmax(intent_context)
intent_context = self.linear_intent_context(intent_context)
intent_context = torch.unsqueeze(intent_context, 1)
intent_context = intent_context.expand(-1, self.max_seq_len, -1)
x = torch.cat((x, intent_context), dim=2)
x = self.linear_out(x)
elif self.use_intent_context_attn:
intent_context = self.softmax(intent_context)
intent_context = self.linear_intent_context(intent_context)
intent_context = torch.unsqueeze(intent_context, 1) # 1: query length (each token)
output, weights = self.attention(x, intent_context, attention_mask)
x = output
x = self.dropout(x)
return self.linear(x)
|