python - i am pluing a cross_attention module in faster rcnn, - Stack Overflow

admin2025-03-18  7

import torch
import torch.nn as nn
from math import sqrt

class CalculateAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V):
        attention = torch.matmul(Q, torch.transpose(K, -1, -2))
        attention = torch.softmax(attention / sqrt(Q.size(-1)), dim=-1)
        attention = torch.matmul(attention,V)
        return attention

class Multi_CrossAttention(nn.Module):
    """
    
    """
    def __init__(self, hidden_size, all_head_size, head_num):
        super().__init__()
        self.hidden_size = hidden_size        
        self.all_head_size = all_head_size    
        self.num_heads = head_num             
        self.h_size = all_head_size // head_num

        assert all_head_size % head_num == 0

        #  W_q, W_k, W_v (hidden_size, all_head_size)
        self.linear_q = nn.Linear(hidden_size, all_head_size, bias=False)
        self.linear_k = nn.Linear(1024, all_head_size, bias=False)
        self.linear_v = nn.Linear(1024, all_head_size, bias=False)
        self.linear_output = nn.Linear(all_head_size, hidden_size)

        #  normalization
        self.norm = sqrt(all_head_size)

    def print(self):
        print(self.hidden_size, self.all_head_size)
        print(self.linear_k, self.linear_q, self.linear_v)

    def forward(self, x, y):
      
        """
        
        """
        batch_size = x.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        
        # q_s: [batch_size, num_heads, seq_length, h_size]
        print(f"x device is {x.device}")
        print(f"self.linear_q device is {self.linear_q.weight.device}")

        q_s = self.linear_q(x).view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)
        print("1")

        # k_s: [batch_size, num_heads, seq_length, h_size]
        k_s = self.linear_k(y).view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)

        # v_s: [batch_size, num_heads, seq_length, h_size]
        v_s = self.linear_v(y).view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)

        attention = CalculateAttention()(q_s, k_s, v_s)
        # attention : [batch_size , seq_length , num_heads * h_size]
        attention = attention.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.h_size)

        # output : [batch_size , seq_length , hidden_size]
        output = self.linear_output(attention)
        print(output.shape)

        return output
        

The above is the attention module

 prototype_data = prototype_data.to(self.device)
        cross_fearures = OrderedDict()
        for key in features.keys():
            B, D, W, H = features[key].shape
            flatten_features = features[key].reshape(B, D, -1).to(self.device)
            print(f"flatten_features device is {flatten_features.device}")
            print(f"prototype_data device is {prototype_data.device}")
            
            cross_attention = Multi_CrossAttention(flatten_features.shape[2],  W ** 2, 8)
         
            cross_output = cross_attention(flatten_features, prototype_data)
            cross_output = cross_output.reshape(cross_output.shape[0], cross_output.shape[1], W, -1)
            
            cross_fearures[key] = cross_output
        features = cross_fearures

The above is where I inserted the attention module When I run the program, I get the following error

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

I have found out that the weight of this part(as following) is still on the CPU.

self.linear_q = nn.Linear(hidden_size, all_head_size, bias=False)
        self.linear_k = nn.Linear(1024, all_head_size, bias=False)
        self.linear_v = nn.Linear(1024, all_head_size, bias=False)
        self.linear_output = nn.Linear(all_head_size, hidden_size)

But I have sent the entire model to the gpu, I don't know why this happens, can anyone help me, thanks

I don't know now why the above happens, can anyone help me?

转载请注明原文地址:http://www.anycun.com/QandA/1742266498a56650.html