# Copyright (c) 2024 Wenyuan Zhao, Haoyuan Chen
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Linear Layers with reparameterization and flipout to perform
# variational inference in Bayesian neural networks. Reparameterization layers
# enables Monte Carlo approximation of the distribution over 'kernel' and 'bias'.
#
# Kullback-Leibler divergence between the surrogate posterior and prior is computed
# and returned along with the tensors of outputs after linear opertaion, which is
# required to compute Evidence Lower Bound (ELBO).
#
# @authors: Wenyuan Zhao. Some code snippets borrowed from: Intel Labs Bayeisan-Torch.
#
# ===============================================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform
from .base_variational_layer import _BaseVariationalLayer
from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver
from torch.quantization.qconfig import QConfig
[docs]
class LinearReparameterization(_BaseVariationalLayer):
r"""
Implements Linear layer with reparameterization trick. Inherits from dmgp.layers._BaseVariationalLayer
:param in_features: Size of each input sample.
:type in_features: int
:param out_features: Size of each output sample.
:type out_features: int
:param prior_mean: Mean of the prior arbitrary distribution to be used on the complexity cost. (Default: `0`.)
:type prior_mean: float, optional
:param prior_variance: Variance of the prior arbitrary distribution to be used on the complexity cost. (Default: `1.0`.)
:type prior_variance: float, optional
:param posterior_mu_init: Initialized trainable mu parameter representing mean of the approximate posterior. (Default: `0`.)
:type posterior_mu_init: float, optional
:param posterior_rho_init: Initialized trainable rho parameter representing the sigma of the approximate posterior through softplus function. (Default: `-3.0`.)
:type posterior_rho_init: float, optional
:param bias: If set to False, the layer will not learn an additive bias. (Default: `True`.)
:type bias: bool, optional
"""
def __init__(self,
in_features,
out_features,
prior_mean=0,
prior_variance=1,
posterior_mu_init=0,
posterior_rho_init=-3.0,
bias=True):
super(LinearReparameterization, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.prior_mean = prior_mean
self.prior_variance = prior_variance
self.posterior_mu_init = posterior_mu_init, # mean of weight
self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
self.bias = bias
self.mu_weight = Parameter(torch.Tensor(out_features, in_features))
self.rho_weight = Parameter(torch.Tensor(out_features, in_features))
self.register_buffer('eps_weight',
torch.Tensor(out_features, in_features),
persistent=False)
self.register_buffer('prior_weight_mu',
torch.Tensor(out_features, in_features),
persistent=False)
self.register_buffer('prior_weight_sigma',
torch.Tensor(out_features, in_features),
persistent=False)
if bias:
self.mu_bias = Parameter(torch.Tensor(out_features))
self.rho_bias = Parameter(torch.Tensor(out_features))
self.register_buffer(
'eps_bias',
torch.Tensor(out_features),
persistent=False)
self.register_buffer(
'prior_bias_mu',
torch.Tensor(out_features),
persistent=False)
self.register_buffer('prior_bias_sigma',
torch.Tensor(out_features),
persistent=False)
else:
self.register_buffer('prior_bias_mu', None, persistent=False)
self.register_buffer('prior_bias_sigma', None, persistent=False)
self.register_parameter('mu_bias', None)
self.register_parameter('rho_bias', None)
self.register_buffer('eps_bias', None, persistent=False)
self.init_parameters()
self.quant_prepare = False
def prepare(self):
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric),
activation=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))) for _
in range(5)])
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8),
activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
self.dequant = torch.quantization.DeQuantStub()
self.quant_prepare = True
def init_parameters(self):
self.prior_weight_mu.fill_(self.prior_mean)
self.prior_weight_sigma.fill_(self.prior_variance)
self.mu_weight.data.normal_(mean=self.posterior_mu_init[0], std=0.1)
self.rho_weight.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
if self.mu_bias is not None:
self.prior_bias_mu.fill_(self.prior_mean)
self.prior_bias_sigma.fill_(self.prior_variance)
self.mu_bias.data.normal_(mean=self.posterior_mu_init[0], std=0.1)
self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
def kl_loss(self):
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
kl = self.kl_div(
self.mu_weight,
sigma_weight,
self.prior_weight_mu,
self.prior_weight_sigma)
if self.mu_bias is not None:
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
kl += self.kl_div(self.mu_bias, sigma_bias,
self.prior_bias_mu, self.prior_bias_sigma)
return kl
[docs]
def forward(self, x, return_kl=True):
r"""
Forward the bayesian Linear layer.
:param x: Training data of shape :math:`(n,d)`.
:type x: torch.Tensor.float
:param return_kl: Return KL-divergence. Default: `True`.
:type return_kl: bool, optional
:return: The output and KL-divergence.
"""
if self.dnn_to_bnn_flag:
return_kl = False
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
eps_weight = self.eps_weight.data.normal_()
tmp_result = sigma_weight * eps_weight
weight = self.mu_weight + tmp_result
if return_kl:
kl_weight = self.kl_div(self.mu_weight, sigma_weight,
self.prior_weight_mu, self.prior_weight_sigma)
bias = None
if self.mu_bias is not None:
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
bias = self.mu_bias + (sigma_bias * self.eps_bias.data.normal_())
if return_kl:
kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
self.prior_bias_sigma)
out = F.linear(x, weight, bias)
if self.quant_prepare:
# quint8 quantstrat
x = self.quint_quant[0](x) # input
out = self.quint_quant[1](out) # output
# qint8 quantstrat
sigma_weight = self.qint_quant[0](sigma_weight) # weight
mu_weight = self.qint_quant[1](self.mu_weight) # weight
eps_weight = self.qint_quant[2](eps_weight) # random variable
tmp_result = self.qint_quant[3](tmp_result) # multiply activation
weight = self.qint_quant[4](weight) # add activation
if return_kl:
if self.mu_bias is not None:
kl = kl_weight + kl_bias
else:
kl = kl_weight
return out, kl
return out
[docs]
class LinearFlipout(_BaseVariationalLayer):
r"""
Implements Linear layer with Flipout reparameterization trick. Ref: https://arxiv.org/abs/1803.04386. Inherits from dmgp.layers._BaseVariationalLayer.
:param in_features: Size of each input sample.
:type in_features: int
:param out_features: Size of each output sample.
:type out_features: int
:param prior_mean: Mean of the prior arbitrary distribution to be used on the complexity cost. (Default: `0`.)
:type prior_mean: float, optional
:param prior_variance: Variance of the prior arbitrary distribution to be used on the complexity cost. (Default: `1.0`.)
:type prior_variance: float, optional
:param posterior_mu_init: Initialized trainable mu parameter representing mean of the approximate posterior. (Default: `0`.)
:type posterior_mu_init: float, optional
:param posterior_rho_init: Initialized trainable rho parameter representing the sigma of the approximate posterior through softplus function. (Default: `-3.0`.)
:type posterior_rho_init: float, optional
:param bias: If set to False, the layer will not learn an additive bias. (Default: `True`.)
:type bias: bool, optional
"""
def __init__(self,
in_features,
out_features,
prior_mean=0,
prior_variance=1,
posterior_mu_init=0,
posterior_rho_init=-3.0,
bias=True):
super(LinearFlipout, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.prior_mean = prior_mean
self.prior_variance = prior_variance
self.posterior_mu_init = posterior_mu_init
self.posterior_rho_init = posterior_rho_init
self.mu_weight = nn.Parameter(torch.Tensor(out_features, in_features))
self.rho_weight = nn.Parameter(torch.Tensor(out_features, in_features))
self.register_buffer('eps_weight',
torch.Tensor(out_features, in_features),
persistent=False)
self.register_buffer('prior_weight_mu',
torch.Tensor(out_features, in_features),
persistent=False)
self.register_buffer('prior_weight_sigma',
torch.Tensor(out_features, in_features),
persistent=False)
if bias:
self.mu_bias = nn.Parameter(torch.Tensor(out_features))
self.rho_bias = nn.Parameter(torch.Tensor(out_features))
self.register_buffer('prior_bias_mu', torch.Tensor(out_features), persistent=False)
self.register_buffer('prior_bias_sigma',
torch.Tensor(out_features),
persistent=False)
self.register_buffer('eps_bias', torch.Tensor(out_features), persistent=False)
else:
self.register_buffer('prior_bias_mu', None, persistent=False)
self.register_buffer('prior_bias_sigma', None, persistent=False)
self.register_parameter('mu_bias', None)
self.register_parameter('rho_bias', None)
self.register_buffer('eps_bias', None, persistent=False)
self.init_parameters()
self.quant_prepare = False
def prepare(self):
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric),
activation=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))) for _
in range(4)])
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8),
activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
self.dequant = torch.quantization.DeQuantStub()
self.quant_prepare = True
def init_parameters(self):
# init prior mu
self.prior_weight_mu.fill_(self.prior_mean)
self.prior_weight_sigma.fill_(self.prior_variance)
# init weight and base perturbation weights
self.mu_weight.data.normal_(mean=self.posterior_mu_init, std=0.1)
self.rho_weight.data.normal_(mean=self.posterior_rho_init, std=0.1)
if self.mu_bias is not None:
self.prior_bias_mu.fill_(self.prior_mean)
self.prior_bias_sigma.fill_(self.prior_variance)
self.mu_bias.data.normal_(mean=self.posterior_mu_init, std=0.1)
self.rho_bias.data.normal_(mean=self.posterior_rho_init, std=0.1)
def kl_loss(self):
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
if self.mu_bias is not None:
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
return kl
[docs]
def forward(self, x, return_kl=True):
r"""
Forward the bayesian Linear layer.
:param x: Training data of shape :math:`(n,d)`.
:type x: torch.Tensor.float
:param return_kl: Return KL-divergence. Default: `True`.
:type return_kl: bool, optional
:return: The output and KL-divergence.
"""
if self.dnn_to_bnn_flag:
return_kl = False
# sampling delta_W
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
eps_weight = self.eps_weight.data.normal_()
delta_weight = sigma_weight * eps_weight
# delta_weight = (sigma_weight * self.eps_weight.data.normal_())
# get kl divergence
if return_kl:
kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)
bias = None
if self.mu_bias is not None:
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
bias = (sigma_bias * self.eps_bias.data.normal_())
if return_kl:
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
self.prior_bias_sigma)
# linear outputs
outputs = F.linear(x, self.mu_weight, self.mu_bias)
sign_input = x.clone().uniform_(-1, 1).sign()
sign_output = outputs.clone().uniform_(-1, 1).sign()
x_tmp = x * sign_input
perturbed_outputs_tmp = F.linear(x_tmp, delta_weight, bias)
perturbed_outputs = perturbed_outputs_tmp * sign_output
out = outputs + perturbed_outputs
if self.quant_prepare:
# quint8 quantstub
x = self.quint_quant[0](x) # input
outputs = self.quint_quant[1](outputs) # output
sign_input = self.quint_quant[2](sign_input)
sign_output = self.quint_quant[3](sign_output)
x_tmp = self.quint_quant[4](x_tmp)
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
out = self.quint_quant[7](out) # output
# qint8 quantstub
sigma_weight = self.qint_quant[0](sigma_weight) # weight
mu_weight = self.qint_quant[1](self.mu_weight) # weight
eps_weight = self.qint_quant[2](eps_weight) # random variable
delta_weight = self.qint_quant[3](delta_weight) # multiply activation
# returning outputs + perturbations
if return_kl:
return out, kl
return out