Search code examples
pythonmachine-learningpytorchreinforcement-learning

Slow performance of PyTorch Categorical


I have been using a PPO (Proximal Policy Optimisation) architecture for training my agent in a custom simulator. My simulator has become quite fast as it is written in Rust. The speed of my inner loop is therefore bottlenecked by some functions that are inside the PPO agent.

When I profiled the function with pyinstrument it showed that most of the time is spent on initialising the Categorical class and calculating the log probabilities.

I hope someone can help and if there is a faster way to do this using PyTorch.

    def act(self, state):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach()

    def evaluate(self, state, action):
        """Evaluates the action given the state."""
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)

        return action_logprobs, state_values, dist_entropy

Pyinstrument showing the speed of the program.

I have seen some other techniques to do this, but it was not very clear to me if they would inprove the speed.


Solution

  • I ran into same problem a while back and implemented my custom Categorical class by copying from pytorch source code

    It is similar to original code but removes unnecessary functionality. Does not require initializing class every time, instead initialize once and just use set_probs() or set_probs_() for setting new probability values. Also, it works only with probability values as input (not logits) but we can manually apply softmax on logits anyway.

    import torch
    from torch.distributions.utils import probs_to_logits
    class Categorical:
        def __init__(self, probs_shape): 
            # NOTE: probs_shape is supposed to be 
            #       the shape of probs that will be 
            #       produced by policy network
            if len(probs_shape) < 1: 
                raise ValueError("`probs_shape` must be at least 1.")
            self.probs_dim = len(probs_shape) 
            self.probs_shape = probs_shape
            self._num_events = probs_shape[-1]
            self._batch_shape = probs_shape[:-1] if self.probs_dim > 1 else torch.Size()
            self._event_shape=torch.Size()
    
        def set_probs_(self, probs):
            self.probs = probs
            self.logits = probs_to_logits(self.probs)
    
        def set_probs(self, probs):
            self.probs = probs / probs.sum(-1, keepdim=True) 
            self.logits = probs_to_logits(self.probs)
    
        def sample(self, sample_shape=torch.Size()):
            if not isinstance(sample_shape, torch.Size):
                sample_shape = torch.Size(sample_shape)
            probs_2d = self.probs.reshape(-1, self._num_events)
            samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
            return samples_2d.reshape(sample_shape + self._batch_shape + self._event_shape)
    
        def log_prob(self, value):
            value = value.long().unsqueeze(-1)
            value, log_pmf = torch.broadcast_tensors(value, self.logits)
            value = value[..., :1]
            return log_pmf.gather(-1, value).squeeze(-1)
    
        def entropy(self):
            min_real = torch.finfo(self.logits.dtype).min
            logits = torch.clamp(self.logits, min=min_real)
            p_log_p = logits * self.probs
            return -p_log_p.sum(-1)
    
    
    

    Checking execution time:

    import time
    import torch as tt
    import torch.distributions as td
    

    First check inbuilt torch.distributions.Categorical

    start=time.perf_counter()
    for _ in range(50000):
        probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
        ct = td.Categorical(probs=probs)
        entropy = ct.entropy()
        action = ct.sample()
        log_prob = ct.log_prob(action)
        entropy, action, log_prob
    end=time.perf_counter()
    print(end - start)
    

    output:

    """
    10.024958199996036
    """
    

    Now check custom Categorical

    start=time.perf_counter()
    ct = Categorical((3,4,2)) #<--- initialize class beforehand
    for _ in range(50000):
        probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
        ct.set_probs(probs)
        entropy = ct.entropy()
        action = ct.sample()
        log_prob = ct.log_prob(action)
        entropy, action, log_prob
    end=time.perf_counter()
    print(end - start)
    

    output:

    """
    4.565093299999717
    """
    

    The execution time dropped by a little more than half. It can be further reduced if we use set_probs_() instead of set_probs(). There is a subtle difference in set_probs() and set_probs_() which skips the line probs / probs.sum(-1, keepdim=True) which is supposed to remove floating points errors. However, it might not be always necessary.

    start=time.perf_counter()
    ct = Categorical((3,4,2)) #<--- initialize class beforehand
    for _ in range(50000):
        probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
        ct.set_probs_(probs)
        entropy = ct.entropy()
        action = ct.sample()
        log_prob = ct.log_prob(action)
        entropy, action, log_prob
    end=time.perf_counter()
    print(end - start)
    

    output:

    """
    3.9343119999975897
    """
    

    You can check source code for pytorch distributions module on your machine some where at ..\Lib\site-packages\torch\distributions