python - RuntimeError: Batching rule not implemented for aten::gru.input - Stack Overflow

admin2025-04-18  5

I'm using torchrl for training my reinforcement learning agent.

I'm trying to call a GRU layer with batched data (to be precise, calling a GAE as an advantage module with the batched data from the replay buffer), and it shows this error:
RuntimeError: Batching rule not implemented for aten::gru.input. We could not generate a fallback.
This error pops up regardless of using 'cpu' or 'cuda:0' as device, anyone knows how to solve it?

Below is the relevant made-up code:

data_collector = SyncDataCollector(
    train_env,
    policy_module,
    total_frames=10240,
    frames_per_batch=64,
    split_trajs=False,
)
(get `input` from data_collector)

# input format:
# TensorDict(
#     fields={
#         action: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.int64, is_shared=True),
#         action_value: Tensor(shape=torch.Size([64, 1, 15]), device=cuda:0, dtype=torch.float32, is_shared=True),
#         chosen_action_value: Tensor(shape=torch.Size([64, 1, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
#         collector: TensorDict(
#             fields={
#                 traj_ids: Tensor(shape=torch.Size([64]), device=cuda:0, dtype=torch.int64, is_shared=True)},
#             batch_size=torch.Size([64]),
#             device=cuda:0,
#             is_shared=True),
#         done: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
#         next: TensorDict(
#             fields={
#                 done: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
#                 observation: Tensor(shape=torch.Size([64, 6, 40, 5]), device=cuda:0, dtype=torch.float32, is_shared=True),
#                 reward: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
#                 step_count: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.int64, is_shared=True),
#                 terminated: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
#             batch_size=torch.Size([64]),
#             device=cuda:0,
#             is_shared=True),
#         observation: Tensor(shape=torch.Size([64, 6, 40, 5]), device=cuda:0, dtype=torch.float32, is_shared=True),
#         step_count: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.int64, is_shared=True),
#         terminated: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
#     batch_size=torch.Size([64]),
#     device=cuda:0,
#     is_shared=True)


rnn = torch.nn.GRU(
    input_size=5,
    hidden_size=32,
    num_layers=1,
    dropout=0,
    batch_first=True,
    bidirectional=True,
)

value_module = ValueOperator(
    module=rnn,
    in_keys=["observation"],
)

advantage_module = GAE(
    gamma=0.99, lmbda=0.95, value_network=value_module, average_gae=True
)


output, _ = advantage_module(input) # <-- error
转载请注明原文地址:http://www.anycun.com/QandA/1744959590a90050.html