Defining a custom example model in DMGP

In this notebook, we are looking at how to define a custom model in DMGP. As an example, we consider a 2-layer sparse DGP model for a regression task. The model consists of two layers, each with a level-3 sparse grid design.

[1]:
import torch
import torch.nn as nn
from dmgp.layers.linear import LinearFlipout
from dmgp.layers.activation import TMK

Defining a 2-layer DMGP Model

In the next cell, we define our simple exact DMGP for regression. The model consists of two layers, each with level-3 sparse grid design.

[2]:
from dmgp.utils.sparse_design.design_class import HyperbolicCrossDesign
from dmgp.kernels.laplace_kernel import LaplaceProductKernel

# Define a 2-layer DMGP model for regression
class DMGP_regression(nn.Module):
    def __init__(self, input_dim, output_dim, design_class, kernel):
        super(DMGP_regression, self).__init__()

        # 1st layer of DGP: input:[n, input_dim] size tensor, output:[n, 8] size tensor
        self.tmk1 = TMK(
            in_features=input_dim,
            n_level=3,
            kernel=kernel,
            design_class=design_class,
        )
        self.fc1 = LinearFlipout(self.tmk1.out_features, 7)

        # 2nd layer of DGP: input:[n, 8] size tensor, output:[n, output_dim] size tensor
        self.tmk2 = TMK(
            in_features=7,
            n_level=3,
            kernel=kernel,
            design_class=design_class,
        )
        self.fc2 = LinearFlipout(in_features=self.tmk2.out_features, out_features=output_dim)

    def forward(self, x):
        kl_sum = 0
        x = self.tmk1(x)
        x, kl = self.fc1(x)
        kl_sum += kl
        x = self.tmk2(x)
        x, kl = self.fc2(x)
        kl_sum += kl
        return torch.squeeze(x), kl_sum

model = DMGP_regression(
    input_dim=1,
    output_dim=1,
    design_class=HyperbolicCrossDesign,
    kernel=LaplaceProductKernel(1.),
)

Viewing model hyperparameters

Let’s take a look at the model parameters. By “parameters”, here I mean explicitly objects of type torch.nn.Parameter that will have gradients filled in by autograd. To access these, we use model.state_dict() which returns a dictionary of the model’s parameters.

[3]:
model_params = model.state_dict()

Counting the number of parameters

We can count the total number of parameters in the model by counting the number of parameters in each layer.

[4]:
def parameter_count(model):
    all_parameters = list(model.parameters())
    layer_parameters=[len(i) for i in all_parameters]
    print(layer_parameters)
    print("Total number of parameters in the network: ", sum(layer_parameters))
    return sum(layer_parameters)

parameter_count(model)
[7, 7, 7, 7, 1, 1, 1, 1]
Total number of parameters in the network:  32
[4]:
32

Viewing model architecture

We can also print the model architecture by simply printing the model object.

[5]:
print(model)
DMGP_regression(
  (tmk1): TMK(
    (kernel): LaplaceProductKernel()
  )
  (fc1): LinearFlipout()
  (tmk2): TMK(
    (kernel): LaplaceProductKernel()
  )
  (fc2): LinearFlipout()
)

Saving Model State

The state dictionary above represents all trainable parameters in the model. We can save this state dictionary to a file and load it back later.

[6]:
torch.save(model.state_dict(), 'model_state.pth')

Loading Model State

Next, we load this state in to a new model and demonstrate that the parameters were updated correctly.

[7]:
state_dict = torch.load('model_state.pth')
model = DMGP_regression(input_dim=1,
                        output_dim=1,
                        design_class=HyperbolicCrossDesign,
                        kernel=LaplaceProductKernel(1.))
model.load_state_dict(state_dict)
[7]:
<All keys matched successfully>
[8]:
model.state_dict()
[8]:
OrderedDict([('tmk1.design_points',
              tensor([[ 0.0000],
                      [-1.0000],
                      [ 1.0000],
                      [-1.5000],
                      [-0.5000],
                      [ 0.5000],
                      [ 1.5000]])),
             ('tmk1.chol_inv',
              tensor(indices=tensor([[0, 0, 1, 0, 2, 1, 3, 1, 4, 0, 0, 5, 2, 2, 6, 0, 0, 0, 0],
                                     [0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 0, 0, 0, 0]]),
                     values=tensor([ 1.0000, -0.3956,  1.0754, -0.3956,  1.0754, -0.7629,
                                     1.2578, -0.6523,  1.4710, -0.6523, -0.6523,  1.4710,
                                    -0.6523, -0.7629,  1.2578,  0.0000,  0.0000,  0.0000,
                                     0.0000]),
                     size=(7, 7), nnz=19, layout=torch.sparse_coo)),
             ('fc1.mu_weight',
              tensor([[ 0.0345, -0.0331,  0.1433,  0.0244,  0.0466,  0.0648, -0.1209],
                      [-0.0860, -0.0045,  0.2755,  0.1140,  0.2405,  0.0235,  0.0019],
                      [ 0.1102, -0.1545,  0.0109,  0.0156, -0.0386,  0.0741, -0.0063],
                      [ 0.0080, -0.0239, -0.0153,  0.1926,  0.0542, -0.0050, -0.0751],
                      [-0.0851, -0.0198, -0.0946, -0.0516, -0.1105, -0.0636,  0.2105],
                      [ 0.0762, -0.1456, -0.0844,  0.0429, -0.1963,  0.0546,  0.0852],
                      [-0.2232,  0.0288, -0.0359,  0.0231,  0.0172, -0.0894,  0.0932]])),
             ('fc1.rho_weight',
              tensor([[-3.0058, -3.0219, -3.0735, -3.1648, -3.1597, -3.1056, -2.9991],
                      [-3.1995, -2.9587, -3.0155, -2.8990, -3.0019, -2.8497, -2.9690],
                      [-3.0400, -3.0067, -2.9067, -3.1761, -3.0889, -2.9560, -2.8502],
                      [-2.9557, -3.0494, -2.9368, -2.8174, -3.0653, -2.9333, -3.1594],
                      [-2.9583, -3.1364, -3.0424, -2.9543, -2.9114, -3.1319, -3.1132],
                      [-3.0926, -3.0476, -3.2349, -2.9633, -2.8772, -3.1382, -3.0902],
                      [-2.9531, -3.0845, -2.8782, -2.8496, -2.8998, -2.9290, -3.0826]])),
             ('fc1.mu_bias',
              tensor([-0.0513, -0.0315, -0.1563,  0.0572,  0.0033, -0.0065,  0.0829])),
             ('fc1.rho_bias',
              tensor([-2.9364, -3.0666, -3.0992, -3.0973, -2.9117, -3.0521, -3.0974])),
             ('tmk2.design_points',
              tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, -1.0000],
                      [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  1.0000],
                      ...,
                      [ 0.7500,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [ 1.2500,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [ 1.7500,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])),
             ('tmk2.chol_inv',
              tensor(indices=tensor([[ 0,  0,  0,  ...,  0,  0,  0],
                                     [ 0,  1, 15,  ...,  0,  0,  0]]),
                     values=tensor([ 1.0000e+00, -3.9563e-01, -2.1458e-06,  ...,
                                     0.0000e+00,  0.0000e+00,  0.0000e+00]),
                     size=(799, 799), nnz=15435, layout=torch.sparse_coo)),
             ('fc2.mu_weight',
              tensor([[-9.4429e-02, -7.8046e-02,  1.2859e-02, -8.8056e-02, -6.6972e-02,
                        1.2411e-02, -3.2759e-02, -3.4576e-02,  2.6424e-02,  7.4473e-02,
                        5.1041e-02,  2.1938e-03, -5.4987e-02,  4.0834e-02,  5.7914e-02,
                       -4.2282e-02,  2.1575e-03,  1.1776e-01, -6.2707e-03,  2.4497e-01,
                        1.3857e-01, -3.7841e-02, -3.7947e-02,  1.5835e-02,  3.5271e-02,
                       -9.5948e-02,  1.5443e-01,  4.3158e-02, -2.6382e-02, -1.3318e-01,
                        9.8079e-02,  6.3483e-02, -4.1358e-03,  1.6909e-01, -6.2706e-02,
                        4.7923e-02, -5.1022e-02, -1.1936e-01,  3.3957e-02,  5.6257e-02,
                       -1.1588e-01, -1.0752e-01, -4.2216e-02, -4.3752e-02,  4.7362e-02,
                        2.6536e-02, -6.9206e-02,  2.5816e-01, -1.5933e-01, -2.0594e-01,
                       -2.7107e-02, -5.2035e-02,  5.3444e-02, -8.2445e-02, -1.8551e-02,
                        9.3028e-03, -4.8265e-03,  1.8404e-01,  3.3445e-02, -1.2936e-01,
                       -6.3829e-03, -3.6115e-02, -1.7655e-01,  9.0745e-03,  7.0674e-02,
                       -2.7392e-03,  3.7175e-02,  1.1430e-01, -1.3430e-02, -3.4107e-02,
                        1.4393e-01,  1.5431e-01,  1.3068e-01, -1.0216e-02, -3.0086e-02,
                       -2.7918e-02,  7.3348e-03, -1.5069e-01, -1.2261e-02,  1.1872e-01,
                        1.3326e-01,  5.0157e-02, -1.8507e-01,  2.3393e-02, -1.4417e-01,
                       -4.7938e-02,  1.9846e-03, -7.1360e-03, -1.1070e-01,  5.0470e-03,
                        1.2413e-01,  6.9701e-02,  3.8212e-02, -1.1223e-02, -3.8455e-02,
                        2.3074e-02,  1.6638e-01, -1.8503e-01,  1.0554e-02, -4.7552e-02,
                        1.1916e-01, -1.4256e-01, -6.8961e-02, -4.1993e-02, -2.9059e-02,
                       -5.8580e-02,  5.4103e-02, -1.7790e-02,  8.3428e-02,  2.1139e-02,
                        1.8140e-01,  1.8463e-01, -1.0903e-01, -3.0303e-02,  3.4089e-02,
                        1.2216e-01,  2.6792e-02, -1.9784e-01, -1.6897e-01, -2.4044e-01,
                        1.6775e-01, -2.6324e-02, -2.6981e-01, -4.5328e-02,  1.2873e-01,
                        3.6020e-02,  1.4180e-01, -2.0575e-02,  7.8597e-02,  3.6656e-02,
                       -1.4003e-02,  9.9597e-02, -6.2969e-03,  1.8014e-01, -5.6130e-02,
                       -1.7058e-01, -5.3616e-02, -5.3653e-02, -7.9732e-02, -8.9190e-04,
                       -5.9352e-03, -7.2967e-02, -1.6586e-01, -2.5340e-02,  7.0967e-02,
                        1.0545e-01,  8.7816e-02, -1.0111e-01, -1.7147e-02,  1.5197e-01,
                        8.0631e-02,  9.6178e-03,  6.1733e-02,  1.6722e-01, -2.8026e-02,
                        2.8775e-02, -3.8213e-02, -1.0358e-01, -1.4963e-01, -4.0325e-02,
                        9.3771e-02, -2.6921e-02,  5.0502e-02,  1.5701e-01, -6.1409e-02,
                       -4.0770e-02, -8.2082e-02, -3.0024e-02, -2.2377e-02, -9.5304e-02,
                        2.0790e-02, -9.6612e-02, -2.1830e-01, -2.3801e-03,  1.3033e-01,
                       -6.4815e-02, -1.4418e-01, -5.4029e-02, -7.9638e-02, -1.4334e-01,
                        1.4651e-02, -1.1681e-01, -5.2683e-02,  9.2108e-02,  2.1576e-01,
                        5.8409e-02,  1.3206e-01, -3.4568e-02, -6.8758e-02,  4.0455e-02,
                       -6.2726e-02, -1.1169e-01, -9.6546e-02,  6.8018e-02, -3.9184e-02,
                        6.8218e-02, -1.1212e-02, -8.4942e-02, -2.3065e-01,  1.1252e-01,
                        2.6671e-01,  1.3829e-02,  4.1824e-02,  1.5090e-01,  5.1433e-02,
                        8.2920e-02,  8.5281e-03,  1.5213e-02,  1.6357e-02,  6.6204e-02,
                        4.8588e-02, -1.6610e-02, -5.7956e-02,  1.3209e-01, -3.1444e-02,
                       -3.8161e-02,  6.8461e-02,  9.9941e-02,  5.7587e-02, -2.2967e-02,
                        1.1567e-01, -1.5147e-01,  5.4065e-02, -1.5081e-01, -2.6044e-03,
                        2.9335e-02,  1.0480e-01, -1.7155e-01, -7.3994e-02, -7.3775e-02,
                        2.8377e-02,  6.7576e-02,  2.4395e-02,  1.7674e-01,  1.1761e-01,
                       -1.3512e-01, -7.3963e-03,  1.0032e-01, -7.4382e-02,  1.5761e-03,
                       -7.2153e-02,  3.3959e-02, -6.9279e-02,  2.5722e-01, -3.7791e-02,
                        2.2146e-01, -1.5793e-01,  7.6427e-03, -7.6188e-02,  3.5461e-02,
                        1.0163e-01,  1.3135e-01,  1.2594e-01,  4.4995e-01,  2.2407e-03,
                        5.8003e-02, -2.8487e-02,  4.7213e-02,  1.2835e-02,  1.1065e-02,
                       -7.8124e-02, -3.8556e-02,  1.4246e-01, -1.2569e-01, -2.7167e-01,
                       -8.9190e-02, -7.7483e-02,  6.2573e-02, -3.5969e-02, -1.4137e-01,
                        4.2762e-03,  1.0273e-01, -8.6150e-02, -1.5294e-03,  8.1303e-02,
                       -4.0402e-02,  1.3192e-01,  7.5185e-02, -6.0120e-02,  7.6040e-03,
                        1.0638e-01,  5.9695e-02,  5.0386e-02,  5.6979e-02, -9.4388e-03,
                        8.9457e-02, -1.3666e-01,  5.3046e-04, -6.2615e-02, -1.7981e-01,
                        8.7845e-02, -1.4710e-01,  9.0020e-02, -9.6443e-02, -1.3178e-01,
                        7.9424e-02, -9.3706e-02,  5.6118e-02, -7.4922e-02, -6.2394e-02,
                        3.3124e-03, -9.1033e-02,  6.5171e-02, -1.8169e-01, -8.7617e-02,
                       -1.1573e-03, -4.0484e-02, -3.5440e-02, -4.1749e-03, -1.6652e-01,
                        6.5167e-02, -4.1553e-02, -6.2776e-02, -1.1885e-01,  8.1885e-02,
                       -6.5344e-02, -8.0583e-02, -9.4331e-02, -9.1452e-02, -9.8784e-02,
                       -6.7502e-02,  1.5669e-01, -3.3251e-02,  1.1141e-01, -1.3680e-01,
                        1.1193e-01,  3.1364e-02, -1.2166e-01, -8.4834e-02,  1.3484e-01,
                       -2.1929e-02, -1.4246e-01, -1.4735e-01, -1.8576e-02, -9.7115e-02,
                        7.5656e-02,  2.2129e-01, -1.5467e-01, -8.1069e-02,  1.9180e-03,
                        8.0702e-03, -3.2955e-02,  9.7027e-02, -3.1433e-02, -4.1769e-03,
                       -6.9799e-02, -1.0125e-01,  2.7872e-03,  1.0488e-01, -2.6389e-01,
                       -2.0264e-01,  8.0821e-02, -7.3888e-02, -7.4624e-03, -1.0062e-01,
                       -1.4872e-01, -1.0991e-02,  7.8438e-02,  3.4813e-02,  9.8469e-02,
                        3.1754e-02, -2.3251e-01,  5.1904e-02, -8.2635e-02, -2.1549e-01,
                       -6.0398e-02, -2.7530e-02, -2.7527e-02,  1.1731e-01,  1.8717e-02,
                       -1.6044e-01,  6.5581e-02, -3.6864e-02, -5.3139e-02,  1.0745e-02,
                        6.8184e-02,  5.7974e-02,  8.9485e-02, -6.8882e-02,  5.0115e-02,
                        3.6873e-02,  1.0328e-01,  2.7528e-03,  1.0575e-02,  1.0594e-01,
                        8.3048e-02, -6.5197e-02,  4.8077e-02,  1.2236e-01,  7.9487e-02,
                       -7.5863e-02, -1.4945e-02, -6.1988e-02, -1.6195e-02,  6.8115e-03,
                        1.5798e-01,  1.3103e-01, -5.0274e-03,  7.5803e-02, -4.2832e-02,
                       -1.0818e-02,  1.3495e-01,  4.5693e-02, -5.1622e-02, -1.4847e-01,
                       -1.5049e-01, -3.9425e-02,  1.3974e-01,  3.9423e-02, -6.0817e-02,
                        1.2107e-01,  2.1204e-02, -9.0793e-02, -4.4810e-02, -6.3235e-02,
                       -2.0422e-01, -5.7302e-02, -6.2192e-02, -1.1961e-01, -2.0862e-02,
                        1.1194e-01,  3.7626e-02, -3.9478e-02,  6.9883e-02, -7.7217e-02,
                        1.2580e-03,  6.4918e-02, -1.9830e-01, -1.6213e-02,  1.5035e-01,
                        1.3347e-01,  3.4958e-02, -6.8531e-02,  1.6133e-01,  1.0903e-01,
                       -1.7875e-01, -9.1240e-02,  1.6896e-01, -7.9950e-02,  1.2757e-01,
                       -1.7118e-01, -1.4887e-01, -9.5463e-02, -2.1121e-01, -6.7617e-02,
                        7.2990e-02,  3.6344e-02,  1.6013e-01,  6.6087e-02, -8.9617e-02,
                        1.4792e-01,  9.4042e-02, -8.6990e-02, -1.0462e-01,  1.1650e-01,
                       -1.3754e-01, -2.3416e-01,  1.4205e-02,  5.1774e-03,  5.2973e-02,
                        2.2291e-02,  1.0215e-01,  6.0150e-02,  4.0155e-02, -1.7372e-01,
                       -5.2399e-02, -5.1387e-02, -1.2185e-01, -2.7640e-02, -1.4897e-02,
                        6.7853e-02,  1.5557e-02,  1.0297e-01, -1.0305e-01,  1.5846e-03,
                        9.8716e-02,  1.4407e-01, -7.1671e-03,  7.3873e-02, -8.9970e-02,
                        8.3439e-03,  4.6997e-02,  1.5637e-01, -4.9752e-02,  3.3278e-02,
                       -2.0959e-01,  1.0780e-01,  1.0590e-01,  2.6830e-02, -2.2623e-01,
                        2.1198e-02,  2.8747e-02,  8.5375e-02, -5.4920e-02, -5.0830e-02,
                        1.1250e-01,  6.6987e-02, -4.0003e-02, -8.0875e-02,  2.9926e-02,
                        5.1245e-03, -8.7718e-02,  2.5203e-01,  1.6361e-02,  1.1680e-01,
                        1.0005e-01, -1.5820e-01,  1.3098e-01,  6.5367e-02, -4.0919e-02,
                        6.0854e-03,  6.2880e-02, -2.0403e-02, -9.5495e-02,  1.2390e-02,
                        1.1202e-01, -4.4130e-02, -1.3241e-01, -4.4972e-03, -1.5942e-03,
                       -8.2943e-02,  3.9465e-02,  6.7188e-02,  9.2201e-03, -1.7799e-01,
                       -7.2726e-02, -1.2609e-01, -4.0010e-02,  5.4917e-02,  8.6904e-02,
                        4.9283e-02,  6.8435e-02, -1.7060e-02,  1.7555e-01, -7.8016e-02,
                       -1.0467e-01,  2.3970e-02, -4.9032e-03, -3.4889e-03, -3.6150e-02,
                        4.0716e-02,  5.9861e-02,  7.8703e-02, -8.8551e-02, -9.0751e-02,
                       -5.3700e-02,  2.5523e-01, -4.3281e-02, -5.3891e-02,  7.4523e-02,
                        2.5913e-02,  5.6908e-02, -2.0300e-01, -2.2136e-01, -1.7140e-02,
                        6.8844e-02,  1.5783e-02, -1.2978e-01,  1.3677e-01,  7.1011e-02,
                        4.5775e-02,  1.0420e-01, -2.3742e-01,  6.8321e-02,  1.4909e-01,
                       -3.8553e-02, -1.5508e-02, -1.8003e-02, -1.2939e-01,  6.0976e-02,
                       -1.3076e-01, -1.7003e-02,  1.4587e-01,  1.0770e-01, -5.3172e-02,
                        1.5763e-01,  5.1856e-02,  1.2848e-01,  1.5859e-01, -4.1042e-02,
                        1.3093e-01, -6.5869e-03,  1.6295e-01, -1.1253e-01,  8.4330e-02,
                       -1.3044e-01, -6.4644e-02,  5.5736e-02,  3.8473e-02, -1.4547e-01,
                        1.8274e-02, -8.0717e-02,  1.3257e-01,  4.4037e-02, -7.7392e-02,
                       -1.1963e-01, -8.5312e-02, -6.1887e-02,  1.4149e-01,  8.4646e-03,
                        3.7163e-02, -2.6160e-02, -1.1007e-01, -4.6349e-02,  2.8228e-02,
                        1.4670e-01,  5.3132e-02, -4.2938e-02, -4.6612e-02,  1.3667e-02,
                        7.1889e-03,  3.2215e-02, -1.9972e-01,  2.8313e-02,  2.5599e-02,
                        8.3035e-02, -1.0410e-01,  3.7239e-02, -6.1465e-02,  1.7514e-01,
                       -7.9751e-02, -2.5842e-01,  3.3786e-02, -1.6385e-02, -3.5068e-02,
                        1.3458e-01,  1.9884e-02,  9.9673e-02, -6.9933e-02,  6.2514e-02,
                       -6.6644e-02,  1.0848e-01, -1.5673e-01,  8.0039e-02, -7.6737e-02,
                       -2.0104e-02, -2.0236e-02, -5.0319e-02, -1.6769e-02, -6.4063e-02,
                        6.7056e-03, -1.7998e-01,  4.7727e-02,  3.9858e-02, -1.4198e-01,
                        5.6424e-02,  6.1276e-02,  4.4410e-02,  1.8184e-02, -4.5429e-02,
                       -2.1573e-01,  1.0196e-01,  1.7774e-01,  1.9351e-01,  7.1946e-02,
                       -5.0811e-02,  3.9965e-02,  6.6085e-02,  7.7655e-02,  3.5857e-02,
                        1.1504e-01, -1.6815e-01, -5.7878e-03, -1.4887e-01, -1.1657e-02,
                        6.0857e-03,  1.1087e-02,  3.7716e-02, -3.9674e-03, -6.1191e-02,
                        1.0470e-01,  6.4773e-02,  8.1343e-02, -8.9914e-02,  3.1231e-03,
                        5.6842e-02, -1.0967e-01, -2.5089e-02, -1.1995e-01, -8.8970e-03,
                        9.5199e-02,  4.4099e-02,  6.5321e-02, -2.2275e-02, -4.9469e-02,
                       -7.8414e-02,  2.4580e-02,  1.2900e-01,  1.5357e-01,  1.4761e-01,
                        7.7674e-02, -6.8827e-02,  4.5120e-02,  6.8950e-02, -5.2183e-02,
                       -2.8262e-02, -1.3671e-01,  1.5802e-01,  1.3085e-01,  1.6027e-01,
                       -9.7038e-02, -5.3655e-02, -4.9763e-02,  1.1105e-01,  1.6940e-01,
                       -1.7430e-02,  9.9733e-02,  6.6155e-02, -1.3356e-01, -3.7169e-02,
                        5.8091e-02, -3.6726e-02, -6.3138e-02,  5.1737e-02,  1.5815e-01,
                        5.2862e-02, -1.0068e-01, -6.0800e-02, -6.5441e-02, -6.5780e-03,
                        2.7200e-03, -5.0979e-02,  7.5990e-02,  3.6195e-03, -3.9769e-03,
                        1.6375e-01,  8.9654e-02,  1.5271e-04,  1.0345e-01,  1.1280e-02,
                        1.1795e-01,  9.8938e-02,  8.4720e-02,  1.7462e-01,  7.1916e-03,
                        1.1038e-02,  1.0439e-01, -4.2888e-02, -9.5676e-02, -2.9342e-02,
                        5.4187e-02,  1.7264e-02,  4.0791e-02,  1.2400e-01,  5.6593e-03,
                        9.3506e-02, -2.5770e-02,  1.8773e-01,  6.0283e-03,  1.7846e-02,
                        8.9632e-03,  6.0583e-02,  5.0935e-02, -3.2352e-02,  3.0245e-03,
                       -7.6975e-02, -4.3842e-02, -1.2988e-02, -1.5275e-01, -6.9040e-02,
                       -5.5662e-02, -2.0024e-01,  4.8018e-02, -2.3397e-02,  1.1357e-01,
                        1.4834e-01, -1.0157e-01, -1.1493e-01, -8.5344e-02, -9.5561e-02,
                       -4.9137e-02,  3.1638e-03, -1.5516e-01,  1.2394e-01,  1.7985e-01,
                       -3.1265e-02, -8.7387e-02,  1.4266e-01,  2.0653e-02, -1.3971e-03,
                       -7.2030e-02,  3.9646e-02,  9.7855e-03, -1.4678e-01,  7.1699e-02,
                       -7.0665e-02, -7.3419e-02, -3.9464e-04, -7.3373e-02, -1.1660e-02,
                        1.4766e-01,  3.3607e-02,  1.7193e-01, -1.0179e-01, -3.8493e-02,
                        7.5983e-02,  1.9195e-01,  4.6902e-02, -1.2967e-01]])),
             ('fc2.rho_weight',
              tensor([[-2.9829, -3.0047, -2.9570, -3.0360, -2.8879, -3.1766, -2.9494, -3.0672,
                       -2.8386, -2.7699, -3.0606, -3.0070, -2.9377, -3.0110, -2.8399, -2.9882,
                       -3.0877, -3.0715, -2.9137, -2.9577, -2.7921, -2.9758, -3.0343, -3.0136,
                       -3.0193, -2.9281, -2.9983, -3.1204, -2.9799, -3.0695, -3.1678, -3.0130,
                       -3.0980, -2.8355, -2.9479, -3.0524, -2.8583, -2.8712, -2.9014, -2.9215,
                       -3.0726, -2.9735, -2.9525, -2.9023, -2.9228, -3.0825, -3.0013, -2.7888,
                       -2.9306, -3.0439, -2.9888, -3.0645, -3.2737, -3.1240, -3.0454, -3.0283,
                       -2.9835, -3.0823, -3.2054, -3.1669, -3.0419, -3.0074, -3.0234, -2.8896,
                       -3.0153, -2.8804, -2.8997, -3.1020, -3.0918, -2.8640, -2.9834, -3.1010,
                       -3.0409, -2.9695, -2.9525, -2.9359, -2.8352, -2.9794, -2.8859, -3.0675,
                       -2.8513, -3.0101, -3.0585, -3.0495, -3.0055, -2.8001, -2.9709, -2.9890,
                       -2.8921, -3.0771, -2.8579, -3.1639, -2.9965, -2.8490, -3.0690, -3.0394,
                       -3.1169, -2.8636, -2.8388, -3.0033, -2.9176, -2.9659, -2.8882, -2.9336,
                       -3.0109, -2.8911, -2.9826, -3.0069, -2.7588, -3.1698, -3.1057, -3.0451,
                       -2.8142, -3.0044, -3.0414, -3.1612, -2.9123, -2.9808, -3.0021, -2.9918,
                       -2.9860, -3.1009, -3.0660, -2.9780, -2.8015, -3.0815, -2.9802, -2.7907,
                       -3.0164, -3.1469, -2.9931, -2.9630, -3.0333, -3.0583, -3.1679, -3.0082,
                       -3.0418, -2.8212, -3.0690, -3.0878, -3.0754, -2.9217, -2.8843, -2.8342,
                       -2.8914, -2.9593, -2.9435, -2.8802, -3.1271, -3.0106, -2.8565, -2.9416,
                       -2.9377, -3.0922, -2.9455, -3.0195, -2.9499, -3.1423, -3.1442, -3.0080,
                       -2.9155, -3.0334, -2.9534, -3.0818, -3.0217, -3.0052, -2.9024, -2.8618,
                       -2.8283, -3.0610, -3.0181, -2.8918, -3.0782, -2.8492, -2.8782, -2.9875,
                       -3.0840, -3.1547, -2.9203, -2.9514, -2.8977, -3.0491, -2.9121, -3.0861,
                       -2.8774, -3.0128, -2.9025, -3.1787, -3.0647, -2.7906, -2.9883, -2.9587,
                       -3.0410, -2.8601, -3.1737, -3.0479, -3.0819, -3.0477, -2.9868, -2.9024,
                       -3.0275, -3.0417, -3.0774, -2.9819, -2.9000, -2.9556, -2.9307, -2.9814,
                       -3.0289, -3.0943, -2.9042, -3.0491, -3.0925, -2.9762, -3.0900, -2.8698,
                       -3.0360, -3.1329, -2.8815, -2.9002, -2.9418, -2.9736, -2.8787, -2.8967,
                       -3.0662, -2.9857, -2.9476, -2.9390, -3.0513, -3.0437, -3.1175, -3.0301,
                       -3.1483, -3.1527, -2.7861, -3.1371, -3.0505, -2.9435, -3.1437, -2.7309,
                       -3.0298, -3.0350, -2.9070, -2.9550, -3.0971, -3.1178, -3.0738, -2.9796,
                       -3.0355, -3.0059, -2.9847, -3.0803, -2.9114, -3.1475, -2.9690, -2.9070,
                       -3.0039, -2.9606, -2.8244, -2.9586, -2.9302, -3.1236, -2.9911, -3.2581,
                       -2.8627, -2.8215, -3.0391, -3.0943, -3.1396, -3.0491, -2.9677, -2.9401,
                       -3.0572, -2.8709, -2.9922, -3.0391, -2.9880, -3.0226, -3.0617, -3.0712,
                       -2.8863, -3.0599, -2.9237, -2.9731, -2.9047, -3.0713, -3.0532, -2.9893,
                       -3.1647, -2.8586, -3.0425, -2.9910, -3.3008, -3.0244, -3.0073, -3.0995,
                       -2.9264, -3.1765, -2.9276, -2.9274, -2.9101, -2.9702, -3.0714, -3.0964,
                       -3.0688, -3.1275, -2.9110, -3.0602, -3.1203, -2.9726, -3.0462, -2.9663,
                       -2.9732, -3.0669, -3.0632, -2.8376, -3.0559, -2.8104, -3.0296, -2.9035,
                       -3.1053, -2.8826, -2.9600, -2.9353, -2.9093, -2.9355, -3.1468, -2.8620,
                       -3.0436, -3.1621, -2.9782, -2.9230, -3.0041, -2.9086, -2.9366, -3.1147,
                       -3.0928, -2.9381, -3.1314, -2.9102, -3.1127, -3.0069, -2.9949, -3.2408,
                       -2.9820, -3.1620, -3.1759, -3.0649, -3.1534, -2.9196, -2.9862, -3.1150,
                       -2.9503, -3.0369, -2.9861, -3.0078, -3.1044, -2.9105, -2.9347, -3.0000,
                       -2.9461, -2.9267, -2.9945, -2.9911, -2.9804, -3.0625, -2.9553, -2.9028,
                       -2.9493, -3.0278, -3.0286, -3.0061, -2.9527, -2.8346, -2.8534, -2.8998,
                       -3.0094, -3.0435, -2.9483, -2.9080, -2.9436, -3.0654, -3.0258, -3.0415,
                       -2.9915, -2.9641, -3.0494, -2.9894, -2.9352, -2.8864, -2.8459, -2.9853,
                       -2.8688, -2.8498, -2.9809, -3.0047, -2.8838, -2.7494, -3.0872, -2.9672,
                       -3.1121, -2.9999, -3.0556, -2.9690, -2.9214, -3.0549, -3.0038, -3.1795,
                       -3.0137, -3.0741, -3.1380, -3.1081, -2.9288, -2.9251, -3.0362, -2.8812,
                       -2.9888, -3.2129, -2.7653, -2.9546, -3.0623, -2.8575, -2.7941, -3.0405,
                       -2.9822, -3.0462, -2.9562, -3.1998, -2.9719, -2.8572, -2.9658, -2.7420,
                       -3.1541, -3.0511, -3.0243, -2.8883, -2.9224, -3.1891, -2.8251, -2.9129,
                       -3.0507, -3.0758, -3.0238, -3.2220, -2.9039, -2.9713, -2.9436, -2.9332,
                       -2.9481, -2.9874, -2.9242, -2.8805, -3.1984, -3.0944, -2.9968, -3.0850,
                       -2.8526, -2.9497, -3.0334, -3.0341, -2.8833, -3.1369, -3.1160, -2.8509,
                       -2.8965, -2.8406, -2.8983, -3.0071, -2.9437, -3.1443, -2.9567, -3.0426,
                       -2.9580, -2.9034, -2.7715, -2.9652, -2.9363, -3.0531, -2.9687, -2.9758,
                       -3.0533, -3.0157, -3.0676, -3.0508, -2.9865, -2.9132, -3.0986, -2.9482,
                       -2.9567, -2.8824, -2.9814, -3.0409, -3.1294, -2.9416, -2.9321, -3.0957,
                       -3.0022, -2.9672, -2.8444, -2.9847, -3.1457, -2.9132, -2.9788, -2.8792,
                       -3.0897, -3.0332, -3.1515, -3.1739, -3.1642, -2.9052, -3.0404, -3.0191,
                       -2.9527, -2.9734, -2.9422, -3.1456, -2.9272, -3.0248, -2.9922, -2.9872,
                       -2.9983, -2.9984, -2.9674, -3.0935, -2.8651, -2.8782, -2.9344, -2.8170,
                       -2.8574, -2.8531, -2.9813, -2.8502, -3.0552, -3.1115, -3.0922, -2.9271,
                       -3.0387, -2.7450, -3.1254, -3.1245, -3.0719, -3.0280, -2.9168, -3.0306,
                       -2.9118, -2.8322, -3.0969, -2.9511, -3.1041, -3.1169, -3.1567, -3.1042,
                       -2.8920, -3.1966, -3.0437, -2.8167, -2.8032, -2.7280, -3.0565, -2.9344,
                       -2.8481, -2.8699, -3.0712, -2.9349, -3.0281, -3.0541, -3.0288, -3.0218,
                       -3.0103, -3.0491, -2.9265, -2.9017, -3.0379, -3.1720, -2.8867, -2.9611,
                       -3.0134, -2.9166, -3.0186, -2.9407, -2.9264, -3.0563, -2.9653, -2.9049,
                       -2.8952, -3.0098, -3.0592, -2.9435, -3.0893, -2.9763, -2.9282, -2.7825,
                       -3.0228, -2.9142, -2.9780, -3.1433, -3.0289, -3.1464, -2.9495, -2.9127,
                       -2.8501, -2.9135, -2.9769, -2.9334, -3.0989, -2.9429, -3.0510, -2.9790,
                       -2.9718, -3.0443, -2.9025, -3.1442, -2.9868, -3.0490, -2.9984, -3.0900,
                       -3.1665, -2.7649, -2.9574, -3.1464, -3.0246, -2.9073, -2.8108, -3.0097,
                       -2.9368, -3.0094, -2.8170, -2.9479, -2.9568, -2.9459, -2.9817, -3.0570,
                       -2.9915, -2.9170, -3.0845, -3.1740, -2.9748, -2.9488, -3.1068, -3.0409,
                       -2.9297, -2.9540, -2.9643, -3.0240, -3.0377, -3.2697, -3.0383, -2.9242,
                       -3.2493, -3.0242, -3.0675, -2.9293, -3.0722, -2.9633, -3.0721, -2.9683,
                       -3.0910, -2.9263, -3.0295, -3.0362, -2.9124, -2.8997, -3.0242, -3.1577,
                       -3.1334, -2.9843, -3.0584, -3.0196, -3.0536, -2.8790, -2.9818, -2.9272,
                       -3.1058, -2.9366, -2.9965, -3.0424, -3.0794, -3.0221, -3.0302, -2.8847,
                       -2.9373, -3.0013, -2.9063, -2.9960, -3.0191, -2.9152, -3.1098, -2.9930,
                       -3.0546, -3.0811, -3.0429, -2.9600, -2.7779, -3.2812, -3.0303, -2.9594,
                       -3.0089, -3.0883, -2.8901, -2.9762, -2.8075, -3.0997, -3.2077, -2.8847,
                       -3.0763, -3.0311, -2.9621, -2.9729, -3.0469, -3.0700, -2.8390, -2.8529,
                       -3.2463, -3.1426, -2.9858, -3.0481, -2.9259, -3.0200, -2.8609, -3.0338,
                       -3.0395, -3.1633, -3.0922, -3.1221, -2.9088, -2.9813, -2.9826, -2.9588,
                       -2.9859, -2.9231, -2.9448, -2.9169, -2.8424, -3.1484, -3.0659, -2.7960,
                       -3.0139, -2.9052, -3.0881, -3.0359, -3.0615, -3.0882, -2.8556, -3.0198,
                       -2.8864, -2.8869, -3.0947, -3.0226, -3.0377, -3.1023, -2.9630, -3.1662,
                       -2.8621, -2.8843, -3.1055, -2.9409, -3.1010, -2.9974, -2.9903, -2.9824,
                       -3.0130, -3.1127, -2.9906, -3.1545, -3.2000, -3.0201, -3.2173, -3.0799,
                       -3.0153, -3.0565, -3.1110, -3.1201, -2.9716, -2.9103, -3.2370, -3.0344,
                       -2.8454, -2.9012, -2.8696, -3.0083, -3.1281, -3.0945, -2.8451, -3.0138,
                       -2.9574, -3.0385, -2.9091, -2.8403, -3.1696, -2.8779, -3.0542, -3.0023,
                       -3.0290, -3.0893, -2.7972, -2.9449, -3.0427, -3.0695, -2.9408]])),
             ('fc2.mu_bias', tensor([0.1011])),
             ('fc2.rho_bias', tensor([-3.1962]))])

See our documentation for more information on how to use the library.