I'm using torchrl for training my reinforcement learning agent.
I'm trying to call a GRU layer with batched data (to be precise, calling a GAE as an advantage module with the batched data from the replay buffer), and it shows this error:
RuntimeError: Batching rule not implemented for aten::gru.input. We could not generate a fallback.
This error pops up regardless of using 'cpu' or 'cuda:0' as device, anyone knows how to solve it?
Below is the relevant made-up code:
data_collector = SyncDataCollector(
train_env,
policy_module,
total_frames=10240,
frames_per_batch=64,
split_trajs=False,
)
(get `input` from data_collector)
# input format:
# TensorDict(
# fields={
# action: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.int64, is_shared=True),
# action_value: Tensor(shape=torch.Size([64, 1, 15]), device=cuda:0, dtype=torch.float32, is_shared=True),
# chosen_action_value: Tensor(shape=torch.Size([64, 1, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
# collector: TensorDict(
# fields={
# traj_ids: Tensor(shape=torch.Size([64]), device=cuda:0, dtype=torch.int64, is_shared=True)},
# batch_size=torch.Size([64]),
# device=cuda:0,
# is_shared=True),
# done: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
# next: TensorDict(
# fields={
# done: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
# observation: Tensor(shape=torch.Size([64, 6, 40, 5]), device=cuda:0, dtype=torch.float32, is_shared=True),
# reward: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
# step_count: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.int64, is_shared=True),
# terminated: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
# batch_size=torch.Size([64]),
# device=cuda:0,
# is_shared=True),
# observation: Tensor(shape=torch.Size([64, 6, 40, 5]), device=cuda:0, dtype=torch.float32, is_shared=True),
# step_count: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.int64, is_shared=True),
# terminated: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
# batch_size=torch.Size([64]),
# device=cuda:0,
# is_shared=True)
rnn = torch.nn.GRU(
input_size=5,
hidden_size=32,
num_layers=1,
dropout=0,
batch_first=True,
bidirectional=True,
)
value_module = ValueOperator(
module=rnn,
in_keys=["observation"],
)
advantage_module = GAE(
gamma=0.99, lmbda=0.95, value_network=value_module, average_gae=True
)
output, _ = advantage_module(input) # <-- error