CNN
Multi-head Attention
import torch
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class MultiHeadSelfAttention(nn.Module):
def __init__(self, hidden_dim, num_heads):
super(MultiHeadSelfAttention, self).__init__()
assert hidden_dim % num_heads == 0, "Hidden dimension must be divisible by number of heads"
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.query_matrix = nn.Linear(hidden_dim, hidden_dim)
self.key_matrix = nn.Linear(hidden_dim, hidden_dim)
self.value_matrix = nn.Linear(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(0.1)
self.fc_out = nn.Linear(hidden_dim, hidden_dim)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, x):
"""
x: [batch_size, seq_len, hidden_dim]
"""
batch_size, seq_len, _ = x.shape
# Transform the inputs to (batch_size, seq_len, num_heads, head_dim)
Q = self.query_matrix(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
K = self.key_matrix(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
V = self.value_matrix(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
# Transpose for matrix multiplication: (batch_size, num_heads, seq_len, head_dim)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Calculate attention scores and apply scaling
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# Apply softmax to get attention weights and dropout
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Multiply the attention weights with V
output = torch.matmul(attn_weights, V) # (batch_size, num_heads, seq_len, head_dim)
# Concatenate heads and put through final linear layer
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) # (batch_size, seq_len, hidden_dim)
output = self.fc_out(output)
return output
Self Attention
import torch
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class SelfAttention(nn.Module):
def __init__(self, hidden_dim):
super(SelfAttention, self).__init__()
self.qurey_matrix = nn.Linear(hidden_dim, hidden_dim)
self.key_matrix = nn.Linear(hidden_dim, hidden_dim)
self.value_matrix = nn.Linear(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(0.1)
self.scale = torch.sqrt(torch.FloatTensor([hidden_dim])).to(device)
def forward(self, x):
"""
x: [batch_size, seq_len, hidden_dim]
"""
batch_size, seq_len, hidden_dim = x.shape
Q = self.qurey_matrix(x) # [batch_size, seq_len, hidden_dim]
K = self.key_matrix(x) # [batch_size, seq_len, hidden_dim]
V = self.value_matrix(x) # [batch_size, seq_len, hidden_dim]
scores = torch.matmul(Q, K.transpose(1, 2))
scores = scores / self.scale
attn_weights = torch.softmax(scores, dim=-1) # [batch_size, seq_len, seq_len]
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, V) # [batch_size, seq_len, hidden_dim]
return output
d_model = 256
num_heads = 8
attention = SelfAttention(d_model)
attention = attention.to(device)
x = torch.randn(32, 100, d_model) # [batch_size, seq_len, hidden_dim]
x = x.to(device)
output = attention(x)
print(x.shape)
print(output.shape)
Transformer