import torch
import torch.nn as nn
from math import sqrt
class CalculateAttention(nn.Module):
def __init__(self):
super().__init__()
def forward(self, Q, K, V):
attention = torch.matmul(Q, torch.transpose(K, -1, -2))
attention = torch.softmax(attention / sqrt(Q.size(-1)), dim=-1)
attention = torch.matmul(attention,V)
return attention
class Multi_CrossAttention(nn.Module):
"""
"""
def __init__(self, hidden_size, all_head_size, head_num):
super().__init__()
self.hidden_size = hidden_size
self.all_head_size = all_head_size
self.num_heads = head_num
self.h_size = all_head_size // head_num
assert all_head_size % head_num == 0
# W_q, W_k, W_v (hidden_size, all_head_size)
self.linear_q = nn.Linear(hidden_size, all_head_size, bias=False)
self.linear_k = nn.Linear(1024, all_head_size, bias=False)
self.linear_v = nn.Linear(1024, all_head_size, bias=False)
self.linear_output = nn.Linear(all_head_size, hidden_size)
# normalization
self.norm = sqrt(all_head_size)
def print(self):
print(self.hidden_size, self.all_head_size)
print(self.linear_k, self.linear_q, self.linear_v)
def forward(self, x, y):
"""
"""
batch_size = x.size(0)
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
# q_s: [batch_size, num_heads, seq_length, h_size]
print(f"x device is {x.device}")
print(f"self.linear_q device is {self.linear_q.weight.device}")
q_s = self.linear_q(x).view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)
print("1")
# k_s: [batch_size, num_heads, seq_length, h_size]
k_s = self.linear_k(y).view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)
# v_s: [batch_size, num_heads, seq_length, h_size]
v_s = self.linear_v(y).view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)
attention = CalculateAttention()(q_s, k_s, v_s)
# attention : [batch_size , seq_length , num_heads * h_size]
attention = attention.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.h_size)
# output : [batch_size , seq_length , hidden_size]
output = self.linear_output(attention)
print(output.shape)
return output
The above is the attention module
prototype_data = prototype_data.to(self.device)
cross_fearures = OrderedDict()
for key in features.keys():
B, D, W, H = features[key].shape
flatten_features = features[key].reshape(B, D, -1).to(self.device)
print(f"flatten_features device is {flatten_features.device}")
print(f"prototype_data device is {prototype_data.device}")
cross_attention = Multi_CrossAttention(flatten_features.shape[2], W ** 2, 8)
cross_output = cross_attention(flatten_features, prototype_data)
cross_output = cross_output.reshape(cross_output.shape[0], cross_output.shape[1], W, -1)
cross_fearures[key] = cross_output
features = cross_fearures
The above is where I inserted the attention module When I run the program, I get the following error
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)
I have found out that the weight of this part(as following) is still on the CPU.
self.linear_q = nn.Linear(hidden_size, all_head_size, bias=False)
self.linear_k = nn.Linear(1024, all_head_size, bias=False)
self.linear_v = nn.Linear(1024, all_head_size, bias=False)
self.linear_output = nn.Linear(all_head_size, hidden_size)
But I have sent the entire model to the gpu, I don't know why this happens, can anyone help me, thanks
I don't know now why the above happens, can anyone help me?