Search code examples
pythoncloningstable-baselines

Behavioural cloning (Imitation learning) for SB3-contrib RecurrentPPO


I'm working on a LSTM RecurrentPPO that's need a behavioural cloning implementation.

The Imitation library provided with Stable Baselines 3 (see here : https://imitation.readthedocs.io/en/latest/) does not seem made for SB3-contrib's RecurrentPPO.

I found this method that could be adapted for RecurrentPPO : https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pretraining.ipynb

I guess this part of code have to be modified in order to consider lstm_states and episode_starts but I don't know how to implement it.

def pretrain_agent(
    student,
    batch_size=64,
    epochs=1000,
    scheduler_gamma=0.7,
    learning_rate=1.0,
    log_interval=100,
    no_cuda=True,
    seed=1,
    test_batch_size=64,
):
    use_cuda = not no_cuda and th.cuda.is_available()
    th.manual_seed(seed)
    device = th.device("cuda" if use_cuda else "cpu")
    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

    if isinstance(env.action_space, gym.spaces.Box):
        criterion = nn.MSELoss()
    else:
        criterion = nn.CrossEntropyLoss()

    # Extract initial policy
    model = student.policy.to(device)

    def train(model, device, train_loader, optimizer):
        model.train()

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            if isinstance(env.action_space, gym.spaces.Box):
                # A2C/PPO policy outputs actions, values, log_prob
                # SAC/TD3 policy outputs actions only
                if isinstance(student, (A2C, PPO)):
                    action, _, _ = model(data)
                else:
                    # SAC/TD3:
                    action = model(data)
                action_prediction = action.double()
            else:
                # Retrieve the logits for A2C/PPO when using discrete actions
                dist = model.get_distribution(data)
                action_prediction = dist.distribution.logits
                target = target.long()

            loss = criterion(action_prediction, target)
            loss.backward()
            optimizer.step()
            if batch_idx % log_interval == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
                )

    def test(model, device, test_loader):
        model.eval()
        test_loss = 0
        with th.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)

                if isinstance(env.action_space, gym.spaces.Box):
                    # A2C/PPO policy outputs actions, values, log_prob
                    # SAC/TD3 policy outputs actions only
                    if isinstance(student, (A2C, PPO)):
                        action, _, _ = model(data)
                    else:
                        # SAC/TD3:
                        action = model(data)
                    action_prediction = action.double()
                else:
                    # Retrieve the logits for A2C/PPO when using discrete actions
                    dist = model.get_distribution(data)
                    action_prediction = dist.distribution.logits
                    target = target.long()

                test_loss = criterion(action_prediction, target)
        test_loss /= len(test_loader.dataset)
        print(f"Test set: Average loss: {test_loss:.4f}")

    # Here, we use PyTorch `DataLoader` to our load previously created `ExpertDataset` for training
    # and testing
    train_loader = th.utils.data.DataLoader(
        dataset=train_expert_dataset, batch_size=batch_size, shuffle=True, **kwargs
    )
    test_loader = th.utils.data.DataLoader(
        dataset=test_expert_dataset,
        batch_size=test_batch_size,
        shuffle=True,
        **kwargs,
    )

    # Define an Optimizer and a learning rate schedule.
    optimizer = optim.Adadelta(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=1, gamma=scheduler_gamma)

    # Now we are finally ready to train the policy model.
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer)
        test(model, device, test_loader)
        scheduler.step()

    # Implant the trained policy network back into the RL student agent
    a2c_student.policy = model

Does anyone have a solution?


Solution

  • Just stumbled upon this problem as well.

    Traceback (most recent call last):
      File "my_imitate.py", line 49, in <module>
        bc_trainer.train(n_epochs=1)
      File "python3.8/site-packages/imitation/algorithms/bc.py", line 470, in train
        training_metrics = self.loss_calculator(self.policy, obs, acts)
      File "python3.8/site-packages/imitation/algorithms/bc.py", line 119, in __call__
        _, log_prob, entropy = policy.evaluate_actions(obs, acts)
    TypeError: evaluate_actions() missing 2 required positional arguments: 'lstm_states' and 'episode_starts'
    

    The problem is obviously that evaluate_actions in RecurrentActorCriticPolicy has a different signature for evaluate_actions which needs the lstm_states and episode_starts as well.

    My first thought was that this means that also during rollout collection this information needs to be stored (which I thought, it would, but it does not). And the solution would be to store the missing infos during rollout collection and handle them during BC if they are there and compatible with the policy at hand.

    But actually it is unclear what the expert state is, when the expert policy from the rollout collection is not recurrent itself (but e.g. a near-optimal search algorithm). Thus for recurrent policies the BC algorithm should train using whole trajectories from begin to end and passing the lstm_step in between timesteps.