Spaces:
Sleeping
Sleeping
import math | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from utils import DEVICE | |
class RMSNorm(torch.nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-6): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(dim)) | |
def _norm(self, x): | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
def forward(self, x): | |
output = self._norm(x.float()).type_as(x) | |
return output * self.weight | |
class Attention(nn.Module): | |
""" | |
Multi-head Self-Attention with RoPE | |
""" | |
def __init__(self, num_heads, head_size, num_embed, dropout): | |
super().__init__() | |
self.num_heads = num_heads | |
self.head_size = head_size | |
self.wq = nn.Linear(num_embed, num_heads * head_size, bias = False) | |
self.wk = nn.Linear(num_embed, num_heads * head_size, bias = False) | |
self.wv = nn.Linear(num_embed, num_heads * head_size, bias = False) | |
self.wo = nn.Linear(num_heads * head_size, num_embed, bias = False) | |
inv_freq = 1 / (500000 ** (torch.arange(0, head_size, 2)[: (head_size // 2)].float() / head_size)) | |
self.register_buffer('inv_freq', inv_freq) | |
self.dropout = nn.Dropout(dropout) | |
def reshape_for_broadcast(self, freq_cis, x): | |
ndim = x.ndim | |
shape = [1] * (ndim - 2) + list(freq_cis.shape) | |
return freq_cis.view(*shape) | |
def apply_rope(self, x, position, freq): | |
t = torch.arange(position, device=freq.device, dtype=torch.float32) | |
freq = torch.outer(t, freq) | |
freq_cis = torch.polar(torch.ones_like(freq), freq) | |
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) | |
freq_cis = self.reshape_for_broadcast(freq_cis, x) | |
x_out = torch.view_as_real(x_ * freq_cis).flatten(3) | |
return x_out.type_as(x) | |
def forward(self, x): | |
B, T, C = x.shape | |
mask = torch.triu(torch.full((T, T), float("-inf"), device=x.device), diagonal=1) | |
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | |
xq = xq.view(B, T, self.num_heads, self.head_size) | |
xk = xk.view(B, T, self.num_heads, self.head_size) | |
xv = xv.view(B, T, self.num_heads, self.head_size) | |
xq = xq.transpose(1, 2) | |
xk = xk.transpose(1, 2) | |
xv = xv.transpose(1, 2) | |
xq = self.apply_rope(xq, T, self.inv_freq) | |
xk = self.apply_rope(xk, T, self.inv_freq) | |
attn_weights = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_size) | |
attn_weights += mask | |
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq) | |
output = torch.matmul(attn_weights, xv) | |
output = output.transpose(1, 2).contiguous().view(B, T, C) | |
return self.dropout(self.wo(output)) | |
class MLP(nn.Module): | |
""" | |
Implementation of a Multi-Layer Perceptron (MLP) sub-module. | |
This module is a simple feed-forward network with two hidden layers | |
used in various Transformer components like the Mixture of Experts layer. | |
""" | |
def __init__(self, num_embed, dropout): | |
""" | |
Constructor for the MLP. | |
Args: | |
num_embed (int): The number of embedding dimensions. | |
""" | |
super().__init__() | |
hidden = int(4 * num_embed * 2 / 3) | |
# Define linear layers for the MLP | |
self.w1 = nn.Linear(num_embed, hidden, bias=False) | |
self.w2 = nn.Linear(hidden, num_embed, bias=False) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
""" | |
Forward pass of the MLP. | |
Args: | |
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, num_embed). | |
Returns: | |
torch.Tensor: Output tensor after passing through the MLP (shape: batch_size, seq_len, num_embed). | |
""" | |
return self.dropout(self.w2(F.silu(self.w1(x)))) | |
class TransformerBlock(nn.Module): | |
""" | |
This calss will group together MultiHead Attention and | |
MLP, so that we can copy it in Transformer | |
""" | |
def __init__(self, num_heads, head_size, num_embed, dropout): | |
super().__init__() | |
self.mha = Attention( | |
num_heads=num_heads, | |
head_size=head_size, | |
num_embed=num_embed, | |
dropout=dropout | |
) | |
self.mlp = MLP(num_embed = num_embed, dropout = dropout) | |
# add the layer normalization | |
self.norm1 = RMSNorm(num_embed) | |
self.norm2 = RMSNorm(num_embed) | |
def forward(self, x): | |
""" | |
Decodes the input sequence. | |
Args: | |
x (torch.Tensor): A tensor of shape (batch_size, sequence_length, embedding_dim). | |
memory (torch.Tensor): A tensor of shape (batch_size, memory_length, embedding_dim). | |
Returns: | |
torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim). | |
""" | |
#print(x.shape) | |
x = x + self.mha(self.norm1(x)) | |
x = x + self.mlp(self.norm2(x)) | |
return x | |
class Transformer(nn.Module): | |
def __init__(self, **kwargs): | |
super().__init__() | |
# a simple lookup table that stores embeddings of a fixed dictionary and size | |
# each token directly reads off the logits for the next token from a lookup table | |
# see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html | |
self.model_type = 'Prome' | |
self.vocab_size = kwargs.get("vocab_size", 100) | |
self.num_embed = kwargs.get("num_embed", 32) | |
self.block_size = kwargs.get("block_size", 8) | |
self.num_heads = kwargs.get("num_heads", 4) | |
self.head_size = kwargs.get("head_size", 128) | |
self.num_layers = kwargs.get("num_layers", 4) | |
self.dropout = kwargs.get("dropout", 0.2) | |
self.max_seq_len = kwargs.get("max_sqe_len", 1024) | |
# each token reads the logits for the next token from a lookup table | |
self.token_embedding_table = nn.Embedding(self.vocab_size, self.num_embed) | |
# each position from 0 to block_size-1 will get its embedding | |
#self.position_embedding_table = nn.Embedding(self.max_seq_len, self.num_embed) | |
self.decoder = nn.Sequential( | |
*[ | |
TransformerBlock( | |
num_heads=self.num_heads, | |
head_size=self.head_size, | |
num_embed=self.num_embed, | |
dropout=self.dropout, | |
) | |
for _ in range(self.num_layers) | |
] | |
) | |
self.lm_head = nn.Linear(self.num_embed, self.vocab_size) | |
def forward(self, idx, targets=None): | |
B, T = idx.shape | |
# idx and targets are (B,T) tensor of integers | |
# the token_emb is (B, T, C), C = NUM_EMBED | |
x = self.token_embedding_table(idx) | |
# (T, C) | |
#posit_emb = self.position_embedding_table(torch.arange(T, device=DEVICE)) | |
#x = token_emb + posit_emb | |
x = self.decoder(x) | |
# (B, T, vocab_size) | |
logits = self.lm_head(x) | |
# Compute the loss | |
if targets != None: | |
# cross_entropy accepts inputs in a (batch_size, num_classes) | |
# so we need to reformat our logits dimensions to | |
# (batch_size * time, dim_vocabulary), time = block_size | |
#logits = logits.to(dtype=torch.float32) | |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) | |
else: | |
loss = None | |
return logits, loss | |
def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 0.6, top_p: float = 0.9): | |
for _ in range(max_new_tokens): | |
idx_crop = idx[:, -self.max_seq_len:] | |
logits, loss = self.forward(idx_crop) | |
logits = logits[:, -1, :] | |
if temperature > 0: | |
probs = F.softmax(logits / temperature, dim=-1) | |
idx_next = self.sample_top_p(probs, top_p) | |
else: | |
probs = F.softmax(logits, dim=-1) | |
idx_next = torch.multinomial(probs, num_samples=1) | |
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) | |
return idx | |
def sample_top_p(self, probs: torch.Tensor, top_p: float) -> torch.Tensor: | |
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) | |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
# Create a mask for top-p filtering | |
top_p_mask = cumulative_probs <= top_p | |
top_p_mask[..., 1:] = top_p_mask[..., :-1].clone() | |
top_p_mask[..., 0] = 1 | |
filtered_probs = sorted_probs * top_p_mask | |
filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True) # Normalize filtered probabilities | |
next_token = torch.multinomial(filtered_probs, num_samples=1) | |
return torch.gather(sorted_indices, -1, next_token) |