Module AssetAllocator.algorithms.TD3.critic
Expand source code
import torch
import torch.nn as nn
import torch.optim as optim
class Critic(nn.Module):
"""This is the critic network for the TD3 Agent.
Original paper can be found at https://arxiv.org/abs/1802.09477
This implementation was adapted from https://github.com/saashanair/rl-series/tree/master/td3
"""
def __init__(self, state_dim, action_dim, hidden_dim, lr = 0.1):
"""Initializes the TD3 Critic Network
Args:
state_dim (int): State space dimension
action_dim (int): Action space dimension
hidden_dim (int): Size of hidden layer
lr (float, optional): Learning rate. Defaults to 0.1.
"""
super(Critic, self).__init__()
self.linear_relu_stack = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
self.optimizer = optim.Adam(self.linear_relu_stack.parameters(),
lr = lr)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience = 2)
def forward(self, state, action):
"""Forward pass
Args:
state (array_like): Current environment state
action (array_like): Current agent's action
Returns:
out: State-Action Values
"""
x = torch.cat([state, action], dim = 1)
out = self.linear_relu_stack(x)
return out
Classes
class Critic (state_dim, action_dim, hidden_dim, lr=0.1)
-
This is the critic network for the TD3 Agent.
Original paper can be found at https://arxiv.org/abs/1802.09477
This implementation was adapted from https://github.com/saashanair/rl-series/tree/master/td3
Initializes the TD3 Critic Network
Args
state_dim
:int
- State space dimension
action_dim
:int
- Action space dimension
hidden_dim
:int
- Size of hidden layer
lr
:float
, optional- Learning rate. Defaults to 0.1.
Expand source code
class Critic(nn.Module): """This is the critic network for the TD3 Agent. Original paper can be found at https://arxiv.org/abs/1802.09477 This implementation was adapted from https://github.com/saashanair/rl-series/tree/master/td3 """ def __init__(self, state_dim, action_dim, hidden_dim, lr = 0.1): """Initializes the TD3 Critic Network Args: state_dim (int): State space dimension action_dim (int): Action space dimension hidden_dim (int): Size of hidden layer lr (float, optional): Learning rate. Defaults to 0.1. """ super(Critic, self).__init__() self.linear_relu_stack = nn.Sequential( nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), ) self.optimizer = optim.Adam(self.linear_relu_stack.parameters(), lr = lr) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience = 2) def forward(self, state, action): """Forward pass Args: state (array_like): Current environment state action (array_like): Current agent's action Returns: out: State-Action Values """ x = torch.cat([state, action], dim = 1) out = self.linear_relu_stack(x) return out
Ancestors
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def forward(self, state, action) ‑> Callable[..., Any]
-
Forward pass
Args
state
:array_like
- Current environment state
action
:array_like
- Current agent's action
Returns
out
- State-Action Values
Expand source code
def forward(self, state, action): """Forward pass Args: state (array_like): Current environment state action (array_like): Current agent's action Returns: out: State-Action Values """ x = torch.cat([state, action], dim = 1) out = self.linear_relu_stack(x) return out