Module AssetAllocator.algorithms.TRPO.trpo

Expand source code
import numpy as np

import torch
from torch.autograd import Variable
from .utils import *


def getSurrogateloss(model,states,actions,advantages,logProbabilityOld):
    log_prob = model.getLogProbabilityDensity(states,Variable(actions))
    action_loss = -advantages.squeeze() * torch.exp(log_prob - Variable(logProbabilityOld))
    return action_loss.mean()

def FisherVectorProduct(v , model, states, actions,logProbabilityOld,damping):
    kl = model.meanKlDivergence(states, actions,logProbabilityOld)

    grads = torch.autograd.grad(kl, model.parameters()
                    ,retain_graph=True, create_graph=True)
    flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

    kl_v = (flat_grad_kl * v).sum()
    grads = torch.autograd.grad(kl_v, model.parameters())
    flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1)
                                    for grad in grads]).data

    return flat_grad_grad_kl + v * damping

def trpo_step(model, states, actions, advantages, max_kl, damping):

    fixed_log_prob = model.getLogProbabilityDensity(Variable(states),actions).detach()
    get_loss = lambda x: getSurrogateloss(x,
                                    states,
                                    actions,
                                    advantages,
                                    fixed_log_prob)
    loss = get_loss(model)

    grads = torch.autograd.grad(loss, model.parameters())
    loss_grad = torch.cat([grad.view(-1) for grad in grads])

    Fvp = lambda v: FisherVectorProduct(v,
                                        model,
                                        states,
                                        actions,
                                        fixed_log_prob,
                                        damping)

    stepdir = conjugate_gradients(Fvp, -loss_grad, 10)


    shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)
    lm = torch.sqrt(shs / max_kl)
    fullstep = stepdir / lm[0]
    neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)
    prev_params = get_flat_params_from(model)
    success, new_params = linesearch(model, get_loss, prev_params, fullstep,
                                     neggdotstepdir / lm[0])
    set_flat_params_to(model, new_params)

    return loss

def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10):
    x = torch.zeros(b.size())
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)
    for i in range(nsteps):
        _Avp = Avp(p)
        alpha = rdotr / torch.dot(p, _Avp)
        x += alpha * p
        r -= alpha * _Avp
        new_rdotr = torch.dot(r, r)
        betta = new_rdotr / rdotr
        p = r + betta * p
        rdotr = new_rdotr
        t= i
        if rdotr < residual_tol:
            break
    return x

def linesearch(model,
               f,
               x,
               fullstep,
               expected_improve_rate,
               max_backtracks=10,
               accept_ratio=.1):
    fval = f(model).data
    for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
        xnew = x + stepfrac * fullstep
        set_flat_params_to(model, xnew)
        newfval = f(model).data
        actual_improve = fval - newfval
        expected_improve = expected_improve_rate * stepfrac
        ratio = actual_improve / expected_improve

        if ratio.item() > accept_ratio and actual_improve.item() > 0:
            return True, xnew
    return False, x

Functions

def FisherVectorProduct(v, model, states, actions, logProbabilityOld, damping)
Expand source code
def FisherVectorProduct(v , model, states, actions,logProbabilityOld,damping):
    kl = model.meanKlDivergence(states, actions,logProbabilityOld)

    grads = torch.autograd.grad(kl, model.parameters()
                    ,retain_graph=True, create_graph=True)
    flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

    kl_v = (flat_grad_kl * v).sum()
    grads = torch.autograd.grad(kl_v, model.parameters())
    flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1)
                                    for grad in grads]).data

    return flat_grad_grad_kl + v * damping
def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10)
Expand source code
def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10):
    x = torch.zeros(b.size())
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)
    for i in range(nsteps):
        _Avp = Avp(p)
        alpha = rdotr / torch.dot(p, _Avp)
        x += alpha * p
        r -= alpha * _Avp
        new_rdotr = torch.dot(r, r)
        betta = new_rdotr / rdotr
        p = r + betta * p
        rdotr = new_rdotr
        t= i
        if rdotr < residual_tol:
            break
    return x
def getSurrogateloss(model, states, actions, advantages, logProbabilityOld)
Expand source code
def getSurrogateloss(model,states,actions,advantages,logProbabilityOld):
    log_prob = model.getLogProbabilityDensity(states,Variable(actions))
    action_loss = -advantages.squeeze() * torch.exp(log_prob - Variable(logProbabilityOld))
    return action_loss.mean()
def linesearch(model, f, x, fullstep, expected_improve_rate, max_backtracks=10, accept_ratio=0.1)
Expand source code
def linesearch(model,
               f,
               x,
               fullstep,
               expected_improve_rate,
               max_backtracks=10,
               accept_ratio=.1):
    fval = f(model).data
    for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
        xnew = x + stepfrac * fullstep
        set_flat_params_to(model, xnew)
        newfval = f(model).data
        actual_improve = fval - newfval
        expected_improve = expected_improve_rate * stepfrac
        ratio = actual_improve / expected_improve

        if ratio.item() > accept_ratio and actual_improve.item() > 0:
            return True, xnew
    return False, x
def trpo_step(model, states, actions, advantages, max_kl, damping)
Expand source code
def trpo_step(model, states, actions, advantages, max_kl, damping):

    fixed_log_prob = model.getLogProbabilityDensity(Variable(states),actions).detach()
    get_loss = lambda x: getSurrogateloss(x,
                                    states,
                                    actions,
                                    advantages,
                                    fixed_log_prob)
    loss = get_loss(model)

    grads = torch.autograd.grad(loss, model.parameters())
    loss_grad = torch.cat([grad.view(-1) for grad in grads])

    Fvp = lambda v: FisherVectorProduct(v,
                                        model,
                                        states,
                                        actions,
                                        fixed_log_prob,
                                        damping)

    stepdir = conjugate_gradients(Fvp, -loss_grad, 10)


    shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)
    lm = torch.sqrt(shs / max_kl)
    fullstep = stepdir / lm[0]
    neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)
    prev_params = get_flat_params_from(model)
    success, new_params = linesearch(model, get_loss, prev_params, fullstep,
                                     neggdotstepdir / lm[0])
    set_flat_params_to(model, new_params)

    return loss