Search code examples

Policy network of PPO in Rllib

I want to set "actor_hiddens" a.k.a the hidden layers of the policy network of PPO in Rllib, and be able to set their weights. Is this possible? If yes please tell me how? I know how to do it for DDPG in Rllib, but the problem with PPO is that I can't find the policy network. Thanks.


  • You can always create your own/custom policy network then you have full control over the layers and also the initialization of the weights.

    If you want to use the default model you have the following params to adapt it to your needs:

    MODEL_DEFAULTS: ModelConfigDict = {
        # === Built-in options ===
        # FullyConnectedNetwork (tf and torch):|
        # These are used if no custom model is specified and the input space is 1D.
        # Number of hidden layers to be used.
        "fcnet_hiddens": [256, 256],
        # Activation function descriptor.
        # Supported values are: "tanh", "relu", "swish" (or "silu"),
        # "linear" (or None).
        "fcnet_activation": "tanh",
        # VisionNetwork (tf and torch):|
        # These are used if no custom model is specified and the input space is 2D.
        # Filter config: List of [out_channels, kernel, stride] for each filter.
        # Example:
        # Use None for making RLlib try to find a default filter setup given the
        # observation space.
        "conv_filters": None,
        # Activation function descriptor.
        # Supported values are: "tanh", "relu", "swish" (or "silu"),
        # "linear" (or None).
        "conv_activation": "relu",
        # Some default models support a final FC stack of n Dense layers with given
        # activation:
        # - Complex observation spaces: Image components are fed through
        #   VisionNets, flat Boxes are left as-is, Discrete are one-hot'd, then
        #   everything is concated and pushed through this final FC stack.
        # - VisionNets (CNNs), e.g. after the CNN stack, there may be
        #   additional Dense layers.
        # - FullyConnectedNetworks will have this additional FCStack as well
        # (that's why it's empty by default).
        "post_fcnet_hiddens": [],
        "post_fcnet_activation": "relu",
        # For DiagGaussian action distributions, make the second half of the model
        # outputs floating bias variables instead of state-dependent. This only
        # has an effect is using the default fully connected net.
        "free_log_std": False,
        # Whether to skip the final linear layer used to resize the hidden layer
        # outputs to size `num_outputs`. If True, then the last hidden layer
        # should already match num_outputs.
        "no_final_linear": False,
        # Whether layers should be shared for the value function.
        "vf_share_layers": True,
        # == LSTM ==
        # Whether to wrap the model with an LSTM.
        "use_lstm": False,
        # Max seq len for training the LSTM, defaults to 20.
        "max_seq_len": 20,
        # Size of the LSTM cell.
        "lstm_cell_size": 256,
        # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
        "lstm_use_prev_action": False,
        # Whether to feed r_{t-1} to LSTM.
        "lstm_use_prev_reward": False,
        # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
        "_time_major": False,
        # == Attention Nets (experimental: torch-version is untested) ==
        # Whether to use a GTrXL ("Gru transformer XL"; attention net) as the
        # wrapper Model around the default Model.
        "use_attention": False,
        # The number of transformer units within GTrXL.
        # A transformer unit in GTrXL consists of a) MultiHeadAttention module and
        # b) a position-wise MLP.
        "attention_num_transformer_units": 1,
        # The input and output size of each transformer unit.
        "attention_dim": 64,
        # The number of attention heads within the MultiHeadAttention units.
        "attention_num_heads": 1,
        # The dim of a single head (within the MultiHeadAttention units).
        "attention_head_dim": 32,
        # The memory sizes for inference and training.
        "attention_memory_inference": 50,
        "attention_memory_training": 50,
        # The output dim of the position-wise MLP.
        "attention_position_wise_mlp_dim": 32,
        # The initial bias values for the 2 GRU gates within a transformer unit.
        "attention_init_gru_gate_bias": 2.0,
        # Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete).
        "attention_use_n_prev_actions": 0,
        # Whether to feed r_{t-n:t-1} to GTrXL.
        "attention_use_n_prev_rewards": 0,
        # == Atari ==
        # Which framestacking size to use for Atari envs.
        # "auto": Use a value of 4, but only if the env is an Atari env.
        # > 1: Use the trajectory view API in the default VisionNets to request the
        #      last n observations (single, grayscaled 84x84 image frames) as
        #      inputs. The time axis in the so provided observation tensors
        #      will come right after the batch axis (channels first format),
        #      e.g. BxTx84x84, where T=num_framestacks.
        # 0 or 1: No framestacking used.
        # Use the deprecated `framestack=True`, to disable the above behavor and to
        # enable legacy stacking behavior (w/o trajectory view API) instead.
        "num_framestacks": "auto",
        # Final resized frame dimension
        "dim": 84,
        # (deprecated) Converts ATARI frame to 1 Channel Grayscale image
        "grayscale": False,
        # (deprecated) Changes frame to range from [-1, 1] if true
        "zero_mean": True,
        # === Options for custom models ===
        # Name of a custom model to use
        "custom_model": None,
        # Extra options to pass to the custom classes. These will be available to
        # the Model's constructor in the model_config field. Also, they will be
        # attempted to be passed as **kwargs to ModelV2 models. For an example,
        # see rllib/models/[tf|torch]/
        "custom_model_config": {},
        # Name of a custom action distribution to use.
        "custom_action_dist": None,
        # Custom preprocessors are deprecated. Please use a wrapper class around
        # your environment instead to preprocess observations.
        "custom_preprocessor": None,
        # Deprecated keys:
        # Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
        "lstm_use_prev_action_reward": DEPRECATED_VALUE,
        # Use `num_framestacks` (int) instead.
        "framestack": True,
