Defining a custom example model in DMGP
In this notebook, we are looking at how to define a custom model in DMGP. As an example, we consider a 2-layer sparse DGP model for a regression task. The model consists of two layers, each with a level-3 sparse grid design.
[1]:
import torch
import torch.nn as nn
from dmgp.layers.linear import LinearFlipout
from dmgp.layers.activation import TMK
Defining a 2-layer DMGP Model
In the next cell, we define our simple exact DMGP for regression. The model consists of two layers, each with level-3 sparse grid design.
[2]:
from dmgp.utils.sparse_design.design_class import HyperbolicCrossDesign
from dmgp.kernels.laplace_kernel import LaplaceProductKernel
# Define a 2-layer DMGP model for regression
class DMGP_regression(nn.Module):
def __init__(self, input_dim, output_dim, design_class, kernel):
super(DMGP_regression, self).__init__()
# 1st layer of DGP: input:[n, input_dim] size tensor, output:[n, 8] size tensor
self.tmk1 = TMK(
in_features=input_dim,
n_level=3,
kernel=kernel,
design_class=design_class,
)
self.fc1 = LinearFlipout(self.tmk1.out_features, 7)
# 2nd layer of DGP: input:[n, 8] size tensor, output:[n, output_dim] size tensor
self.tmk2 = TMK(
in_features=7,
n_level=3,
kernel=kernel,
design_class=design_class,
)
self.fc2 = LinearFlipout(in_features=self.tmk2.out_features, out_features=output_dim)
def forward(self, x):
kl_sum = 0
x = self.tmk1(x)
x, kl = self.fc1(x)
kl_sum += kl
x = self.tmk2(x)
x, kl = self.fc2(x)
kl_sum += kl
return torch.squeeze(x), kl_sum
model = DMGP_regression(
input_dim=1,
output_dim=1,
design_class=HyperbolicCrossDesign,
kernel=LaplaceProductKernel(1.),
)
Viewing model hyperparameters
Let’s take a look at the model parameters. By “parameters”, here I mean explicitly objects of type torch.nn.Parameter that will have gradients filled in by autograd. To access these, we use model.state_dict() which returns a dictionary of the model’s parameters.
[3]:
model_params = model.state_dict()
Counting the number of parameters
We can count the total number of parameters in the model by counting the number of parameters in each layer.
[4]:
def parameter_count(model):
all_parameters = list(model.parameters())
layer_parameters=[len(i) for i in all_parameters]
print(layer_parameters)
print("Total number of parameters in the network: ", sum(layer_parameters))
return sum(layer_parameters)
parameter_count(model)
[7, 7, 7, 7, 1, 1, 1, 1]
Total number of parameters in the network: 32
[4]:
32
Viewing model architecture
We can also print the model architecture by simply printing the model object.
[5]:
print(model)
DMGP_regression(
(tmk1): TMK(
(kernel): LaplaceProductKernel()
)
(fc1): LinearFlipout()
(tmk2): TMK(
(kernel): LaplaceProductKernel()
)
(fc2): LinearFlipout()
)
Saving Model State
The state dictionary above represents all trainable parameters in the model. We can save this state dictionary to a file and load it back later.
[6]:
torch.save(model.state_dict(), 'model_state.pth')
Loading Model State
Next, we load this state in to a new model and demonstrate that the parameters were updated correctly.
[7]:
state_dict = torch.load('model_state.pth')
model = DMGP_regression(input_dim=1,
output_dim=1,
design_class=HyperbolicCrossDesign,
kernel=LaplaceProductKernel(1.))
model.load_state_dict(state_dict)
[7]:
<All keys matched successfully>
[8]:
model.state_dict()
[8]:
OrderedDict([('tmk1.design_points',
tensor([[ 0.0000],
[-1.0000],
[ 1.0000],
[-1.5000],
[-0.5000],
[ 0.5000],
[ 1.5000]])),
('tmk1.chol_inv',
tensor(indices=tensor([[0, 0, 1, 0, 2, 1, 3, 1, 4, 0, 0, 5, 2, 2, 6, 0, 0, 0, 0],
[0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 0, 0, 0, 0]]),
values=tensor([ 1.0000, -0.3956, 1.0754, -0.3956, 1.0754, -0.7629,
1.2578, -0.6523, 1.4710, -0.6523, -0.6523, 1.4710,
-0.6523, -0.7629, 1.2578, 0.0000, 0.0000, 0.0000,
0.0000]),
size=(7, 7), nnz=19, layout=torch.sparse_coo)),
('fc1.mu_weight',
tensor([[ 0.0345, -0.0331, 0.1433, 0.0244, 0.0466, 0.0648, -0.1209],
[-0.0860, -0.0045, 0.2755, 0.1140, 0.2405, 0.0235, 0.0019],
[ 0.1102, -0.1545, 0.0109, 0.0156, -0.0386, 0.0741, -0.0063],
[ 0.0080, -0.0239, -0.0153, 0.1926, 0.0542, -0.0050, -0.0751],
[-0.0851, -0.0198, -0.0946, -0.0516, -0.1105, -0.0636, 0.2105],
[ 0.0762, -0.1456, -0.0844, 0.0429, -0.1963, 0.0546, 0.0852],
[-0.2232, 0.0288, -0.0359, 0.0231, 0.0172, -0.0894, 0.0932]])),
('fc1.rho_weight',
tensor([[-3.0058, -3.0219, -3.0735, -3.1648, -3.1597, -3.1056, -2.9991],
[-3.1995, -2.9587, -3.0155, -2.8990, -3.0019, -2.8497, -2.9690],
[-3.0400, -3.0067, -2.9067, -3.1761, -3.0889, -2.9560, -2.8502],
[-2.9557, -3.0494, -2.9368, -2.8174, -3.0653, -2.9333, -3.1594],
[-2.9583, -3.1364, -3.0424, -2.9543, -2.9114, -3.1319, -3.1132],
[-3.0926, -3.0476, -3.2349, -2.9633, -2.8772, -3.1382, -3.0902],
[-2.9531, -3.0845, -2.8782, -2.8496, -2.8998, -2.9290, -3.0826]])),
('fc1.mu_bias',
tensor([-0.0513, -0.0315, -0.1563, 0.0572, 0.0033, -0.0065, 0.0829])),
('fc1.rho_bias',
tensor([-2.9364, -3.0666, -3.0992, -3.0973, -2.9117, -3.0521, -3.0974])),
('tmk2.design_points',
tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -1.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.0000],
...,
[ 0.7500, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 1.2500, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 1.7500, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]])),
('tmk2.chol_inv',
tensor(indices=tensor([[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 1, 15, ..., 0, 0, 0]]),
values=tensor([ 1.0000e+00, -3.9563e-01, -2.1458e-06, ...,
0.0000e+00, 0.0000e+00, 0.0000e+00]),
size=(799, 799), nnz=15435, layout=torch.sparse_coo)),
('fc2.mu_weight',
tensor([[-9.4429e-02, -7.8046e-02, 1.2859e-02, -8.8056e-02, -6.6972e-02,
1.2411e-02, -3.2759e-02, -3.4576e-02, 2.6424e-02, 7.4473e-02,
5.1041e-02, 2.1938e-03, -5.4987e-02, 4.0834e-02, 5.7914e-02,
-4.2282e-02, 2.1575e-03, 1.1776e-01, -6.2707e-03, 2.4497e-01,
1.3857e-01, -3.7841e-02, -3.7947e-02, 1.5835e-02, 3.5271e-02,
-9.5948e-02, 1.5443e-01, 4.3158e-02, -2.6382e-02, -1.3318e-01,
9.8079e-02, 6.3483e-02, -4.1358e-03, 1.6909e-01, -6.2706e-02,
4.7923e-02, -5.1022e-02, -1.1936e-01, 3.3957e-02, 5.6257e-02,
-1.1588e-01, -1.0752e-01, -4.2216e-02, -4.3752e-02, 4.7362e-02,
2.6536e-02, -6.9206e-02, 2.5816e-01, -1.5933e-01, -2.0594e-01,
-2.7107e-02, -5.2035e-02, 5.3444e-02, -8.2445e-02, -1.8551e-02,
9.3028e-03, -4.8265e-03, 1.8404e-01, 3.3445e-02, -1.2936e-01,
-6.3829e-03, -3.6115e-02, -1.7655e-01, 9.0745e-03, 7.0674e-02,
-2.7392e-03, 3.7175e-02, 1.1430e-01, -1.3430e-02, -3.4107e-02,
1.4393e-01, 1.5431e-01, 1.3068e-01, -1.0216e-02, -3.0086e-02,
-2.7918e-02, 7.3348e-03, -1.5069e-01, -1.2261e-02, 1.1872e-01,
1.3326e-01, 5.0157e-02, -1.8507e-01, 2.3393e-02, -1.4417e-01,
-4.7938e-02, 1.9846e-03, -7.1360e-03, -1.1070e-01, 5.0470e-03,
1.2413e-01, 6.9701e-02, 3.8212e-02, -1.1223e-02, -3.8455e-02,
2.3074e-02, 1.6638e-01, -1.8503e-01, 1.0554e-02, -4.7552e-02,
1.1916e-01, -1.4256e-01, -6.8961e-02, -4.1993e-02, -2.9059e-02,
-5.8580e-02, 5.4103e-02, -1.7790e-02, 8.3428e-02, 2.1139e-02,
1.8140e-01, 1.8463e-01, -1.0903e-01, -3.0303e-02, 3.4089e-02,
1.2216e-01, 2.6792e-02, -1.9784e-01, -1.6897e-01, -2.4044e-01,
1.6775e-01, -2.6324e-02, -2.6981e-01, -4.5328e-02, 1.2873e-01,
3.6020e-02, 1.4180e-01, -2.0575e-02, 7.8597e-02, 3.6656e-02,
-1.4003e-02, 9.9597e-02, -6.2969e-03, 1.8014e-01, -5.6130e-02,
-1.7058e-01, -5.3616e-02, -5.3653e-02, -7.9732e-02, -8.9190e-04,
-5.9352e-03, -7.2967e-02, -1.6586e-01, -2.5340e-02, 7.0967e-02,
1.0545e-01, 8.7816e-02, -1.0111e-01, -1.7147e-02, 1.5197e-01,
8.0631e-02, 9.6178e-03, 6.1733e-02, 1.6722e-01, -2.8026e-02,
2.8775e-02, -3.8213e-02, -1.0358e-01, -1.4963e-01, -4.0325e-02,
9.3771e-02, -2.6921e-02, 5.0502e-02, 1.5701e-01, -6.1409e-02,
-4.0770e-02, -8.2082e-02, -3.0024e-02, -2.2377e-02, -9.5304e-02,
2.0790e-02, -9.6612e-02, -2.1830e-01, -2.3801e-03, 1.3033e-01,
-6.4815e-02, -1.4418e-01, -5.4029e-02, -7.9638e-02, -1.4334e-01,
1.4651e-02, -1.1681e-01, -5.2683e-02, 9.2108e-02, 2.1576e-01,
5.8409e-02, 1.3206e-01, -3.4568e-02, -6.8758e-02, 4.0455e-02,
-6.2726e-02, -1.1169e-01, -9.6546e-02, 6.8018e-02, -3.9184e-02,
6.8218e-02, -1.1212e-02, -8.4942e-02, -2.3065e-01, 1.1252e-01,
2.6671e-01, 1.3829e-02, 4.1824e-02, 1.5090e-01, 5.1433e-02,
8.2920e-02, 8.5281e-03, 1.5213e-02, 1.6357e-02, 6.6204e-02,
4.8588e-02, -1.6610e-02, -5.7956e-02, 1.3209e-01, -3.1444e-02,
-3.8161e-02, 6.8461e-02, 9.9941e-02, 5.7587e-02, -2.2967e-02,
1.1567e-01, -1.5147e-01, 5.4065e-02, -1.5081e-01, -2.6044e-03,
2.9335e-02, 1.0480e-01, -1.7155e-01, -7.3994e-02, -7.3775e-02,
2.8377e-02, 6.7576e-02, 2.4395e-02, 1.7674e-01, 1.1761e-01,
-1.3512e-01, -7.3963e-03, 1.0032e-01, -7.4382e-02, 1.5761e-03,
-7.2153e-02, 3.3959e-02, -6.9279e-02, 2.5722e-01, -3.7791e-02,
2.2146e-01, -1.5793e-01, 7.6427e-03, -7.6188e-02, 3.5461e-02,
1.0163e-01, 1.3135e-01, 1.2594e-01, 4.4995e-01, 2.2407e-03,
5.8003e-02, -2.8487e-02, 4.7213e-02, 1.2835e-02, 1.1065e-02,
-7.8124e-02, -3.8556e-02, 1.4246e-01, -1.2569e-01, -2.7167e-01,
-8.9190e-02, -7.7483e-02, 6.2573e-02, -3.5969e-02, -1.4137e-01,
4.2762e-03, 1.0273e-01, -8.6150e-02, -1.5294e-03, 8.1303e-02,
-4.0402e-02, 1.3192e-01, 7.5185e-02, -6.0120e-02, 7.6040e-03,
1.0638e-01, 5.9695e-02, 5.0386e-02, 5.6979e-02, -9.4388e-03,
8.9457e-02, -1.3666e-01, 5.3046e-04, -6.2615e-02, -1.7981e-01,
8.7845e-02, -1.4710e-01, 9.0020e-02, -9.6443e-02, -1.3178e-01,
7.9424e-02, -9.3706e-02, 5.6118e-02, -7.4922e-02, -6.2394e-02,
3.3124e-03, -9.1033e-02, 6.5171e-02, -1.8169e-01, -8.7617e-02,
-1.1573e-03, -4.0484e-02, -3.5440e-02, -4.1749e-03, -1.6652e-01,
6.5167e-02, -4.1553e-02, -6.2776e-02, -1.1885e-01, 8.1885e-02,
-6.5344e-02, -8.0583e-02, -9.4331e-02, -9.1452e-02, -9.8784e-02,
-6.7502e-02, 1.5669e-01, -3.3251e-02, 1.1141e-01, -1.3680e-01,
1.1193e-01, 3.1364e-02, -1.2166e-01, -8.4834e-02, 1.3484e-01,
-2.1929e-02, -1.4246e-01, -1.4735e-01, -1.8576e-02, -9.7115e-02,
7.5656e-02, 2.2129e-01, -1.5467e-01, -8.1069e-02, 1.9180e-03,
8.0702e-03, -3.2955e-02, 9.7027e-02, -3.1433e-02, -4.1769e-03,
-6.9799e-02, -1.0125e-01, 2.7872e-03, 1.0488e-01, -2.6389e-01,
-2.0264e-01, 8.0821e-02, -7.3888e-02, -7.4624e-03, -1.0062e-01,
-1.4872e-01, -1.0991e-02, 7.8438e-02, 3.4813e-02, 9.8469e-02,
3.1754e-02, -2.3251e-01, 5.1904e-02, -8.2635e-02, -2.1549e-01,
-6.0398e-02, -2.7530e-02, -2.7527e-02, 1.1731e-01, 1.8717e-02,
-1.6044e-01, 6.5581e-02, -3.6864e-02, -5.3139e-02, 1.0745e-02,
6.8184e-02, 5.7974e-02, 8.9485e-02, -6.8882e-02, 5.0115e-02,
3.6873e-02, 1.0328e-01, 2.7528e-03, 1.0575e-02, 1.0594e-01,
8.3048e-02, -6.5197e-02, 4.8077e-02, 1.2236e-01, 7.9487e-02,
-7.5863e-02, -1.4945e-02, -6.1988e-02, -1.6195e-02, 6.8115e-03,
1.5798e-01, 1.3103e-01, -5.0274e-03, 7.5803e-02, -4.2832e-02,
-1.0818e-02, 1.3495e-01, 4.5693e-02, -5.1622e-02, -1.4847e-01,
-1.5049e-01, -3.9425e-02, 1.3974e-01, 3.9423e-02, -6.0817e-02,
1.2107e-01, 2.1204e-02, -9.0793e-02, -4.4810e-02, -6.3235e-02,
-2.0422e-01, -5.7302e-02, -6.2192e-02, -1.1961e-01, -2.0862e-02,
1.1194e-01, 3.7626e-02, -3.9478e-02, 6.9883e-02, -7.7217e-02,
1.2580e-03, 6.4918e-02, -1.9830e-01, -1.6213e-02, 1.5035e-01,
1.3347e-01, 3.4958e-02, -6.8531e-02, 1.6133e-01, 1.0903e-01,
-1.7875e-01, -9.1240e-02, 1.6896e-01, -7.9950e-02, 1.2757e-01,
-1.7118e-01, -1.4887e-01, -9.5463e-02, -2.1121e-01, -6.7617e-02,
7.2990e-02, 3.6344e-02, 1.6013e-01, 6.6087e-02, -8.9617e-02,
1.4792e-01, 9.4042e-02, -8.6990e-02, -1.0462e-01, 1.1650e-01,
-1.3754e-01, -2.3416e-01, 1.4205e-02, 5.1774e-03, 5.2973e-02,
2.2291e-02, 1.0215e-01, 6.0150e-02, 4.0155e-02, -1.7372e-01,
-5.2399e-02, -5.1387e-02, -1.2185e-01, -2.7640e-02, -1.4897e-02,
6.7853e-02, 1.5557e-02, 1.0297e-01, -1.0305e-01, 1.5846e-03,
9.8716e-02, 1.4407e-01, -7.1671e-03, 7.3873e-02, -8.9970e-02,
8.3439e-03, 4.6997e-02, 1.5637e-01, -4.9752e-02, 3.3278e-02,
-2.0959e-01, 1.0780e-01, 1.0590e-01, 2.6830e-02, -2.2623e-01,
2.1198e-02, 2.8747e-02, 8.5375e-02, -5.4920e-02, -5.0830e-02,
1.1250e-01, 6.6987e-02, -4.0003e-02, -8.0875e-02, 2.9926e-02,
5.1245e-03, -8.7718e-02, 2.5203e-01, 1.6361e-02, 1.1680e-01,
1.0005e-01, -1.5820e-01, 1.3098e-01, 6.5367e-02, -4.0919e-02,
6.0854e-03, 6.2880e-02, -2.0403e-02, -9.5495e-02, 1.2390e-02,
1.1202e-01, -4.4130e-02, -1.3241e-01, -4.4972e-03, -1.5942e-03,
-8.2943e-02, 3.9465e-02, 6.7188e-02, 9.2201e-03, -1.7799e-01,
-7.2726e-02, -1.2609e-01, -4.0010e-02, 5.4917e-02, 8.6904e-02,
4.9283e-02, 6.8435e-02, -1.7060e-02, 1.7555e-01, -7.8016e-02,
-1.0467e-01, 2.3970e-02, -4.9032e-03, -3.4889e-03, -3.6150e-02,
4.0716e-02, 5.9861e-02, 7.8703e-02, -8.8551e-02, -9.0751e-02,
-5.3700e-02, 2.5523e-01, -4.3281e-02, -5.3891e-02, 7.4523e-02,
2.5913e-02, 5.6908e-02, -2.0300e-01, -2.2136e-01, -1.7140e-02,
6.8844e-02, 1.5783e-02, -1.2978e-01, 1.3677e-01, 7.1011e-02,
4.5775e-02, 1.0420e-01, -2.3742e-01, 6.8321e-02, 1.4909e-01,
-3.8553e-02, -1.5508e-02, -1.8003e-02, -1.2939e-01, 6.0976e-02,
-1.3076e-01, -1.7003e-02, 1.4587e-01, 1.0770e-01, -5.3172e-02,
1.5763e-01, 5.1856e-02, 1.2848e-01, 1.5859e-01, -4.1042e-02,
1.3093e-01, -6.5869e-03, 1.6295e-01, -1.1253e-01, 8.4330e-02,
-1.3044e-01, -6.4644e-02, 5.5736e-02, 3.8473e-02, -1.4547e-01,
1.8274e-02, -8.0717e-02, 1.3257e-01, 4.4037e-02, -7.7392e-02,
-1.1963e-01, -8.5312e-02, -6.1887e-02, 1.4149e-01, 8.4646e-03,
3.7163e-02, -2.6160e-02, -1.1007e-01, -4.6349e-02, 2.8228e-02,
1.4670e-01, 5.3132e-02, -4.2938e-02, -4.6612e-02, 1.3667e-02,
7.1889e-03, 3.2215e-02, -1.9972e-01, 2.8313e-02, 2.5599e-02,
8.3035e-02, -1.0410e-01, 3.7239e-02, -6.1465e-02, 1.7514e-01,
-7.9751e-02, -2.5842e-01, 3.3786e-02, -1.6385e-02, -3.5068e-02,
1.3458e-01, 1.9884e-02, 9.9673e-02, -6.9933e-02, 6.2514e-02,
-6.6644e-02, 1.0848e-01, -1.5673e-01, 8.0039e-02, -7.6737e-02,
-2.0104e-02, -2.0236e-02, -5.0319e-02, -1.6769e-02, -6.4063e-02,
6.7056e-03, -1.7998e-01, 4.7727e-02, 3.9858e-02, -1.4198e-01,
5.6424e-02, 6.1276e-02, 4.4410e-02, 1.8184e-02, -4.5429e-02,
-2.1573e-01, 1.0196e-01, 1.7774e-01, 1.9351e-01, 7.1946e-02,
-5.0811e-02, 3.9965e-02, 6.6085e-02, 7.7655e-02, 3.5857e-02,
1.1504e-01, -1.6815e-01, -5.7878e-03, -1.4887e-01, -1.1657e-02,
6.0857e-03, 1.1087e-02, 3.7716e-02, -3.9674e-03, -6.1191e-02,
1.0470e-01, 6.4773e-02, 8.1343e-02, -8.9914e-02, 3.1231e-03,
5.6842e-02, -1.0967e-01, -2.5089e-02, -1.1995e-01, -8.8970e-03,
9.5199e-02, 4.4099e-02, 6.5321e-02, -2.2275e-02, -4.9469e-02,
-7.8414e-02, 2.4580e-02, 1.2900e-01, 1.5357e-01, 1.4761e-01,
7.7674e-02, -6.8827e-02, 4.5120e-02, 6.8950e-02, -5.2183e-02,
-2.8262e-02, -1.3671e-01, 1.5802e-01, 1.3085e-01, 1.6027e-01,
-9.7038e-02, -5.3655e-02, -4.9763e-02, 1.1105e-01, 1.6940e-01,
-1.7430e-02, 9.9733e-02, 6.6155e-02, -1.3356e-01, -3.7169e-02,
5.8091e-02, -3.6726e-02, -6.3138e-02, 5.1737e-02, 1.5815e-01,
5.2862e-02, -1.0068e-01, -6.0800e-02, -6.5441e-02, -6.5780e-03,
2.7200e-03, -5.0979e-02, 7.5990e-02, 3.6195e-03, -3.9769e-03,
1.6375e-01, 8.9654e-02, 1.5271e-04, 1.0345e-01, 1.1280e-02,
1.1795e-01, 9.8938e-02, 8.4720e-02, 1.7462e-01, 7.1916e-03,
1.1038e-02, 1.0439e-01, -4.2888e-02, -9.5676e-02, -2.9342e-02,
5.4187e-02, 1.7264e-02, 4.0791e-02, 1.2400e-01, 5.6593e-03,
9.3506e-02, -2.5770e-02, 1.8773e-01, 6.0283e-03, 1.7846e-02,
8.9632e-03, 6.0583e-02, 5.0935e-02, -3.2352e-02, 3.0245e-03,
-7.6975e-02, -4.3842e-02, -1.2988e-02, -1.5275e-01, -6.9040e-02,
-5.5662e-02, -2.0024e-01, 4.8018e-02, -2.3397e-02, 1.1357e-01,
1.4834e-01, -1.0157e-01, -1.1493e-01, -8.5344e-02, -9.5561e-02,
-4.9137e-02, 3.1638e-03, -1.5516e-01, 1.2394e-01, 1.7985e-01,
-3.1265e-02, -8.7387e-02, 1.4266e-01, 2.0653e-02, -1.3971e-03,
-7.2030e-02, 3.9646e-02, 9.7855e-03, -1.4678e-01, 7.1699e-02,
-7.0665e-02, -7.3419e-02, -3.9464e-04, -7.3373e-02, -1.1660e-02,
1.4766e-01, 3.3607e-02, 1.7193e-01, -1.0179e-01, -3.8493e-02,
7.5983e-02, 1.9195e-01, 4.6902e-02, -1.2967e-01]])),
('fc2.rho_weight',
tensor([[-2.9829, -3.0047, -2.9570, -3.0360, -2.8879, -3.1766, -2.9494, -3.0672,
-2.8386, -2.7699, -3.0606, -3.0070, -2.9377, -3.0110, -2.8399, -2.9882,
-3.0877, -3.0715, -2.9137, -2.9577, -2.7921, -2.9758, -3.0343, -3.0136,
-3.0193, -2.9281, -2.9983, -3.1204, -2.9799, -3.0695, -3.1678, -3.0130,
-3.0980, -2.8355, -2.9479, -3.0524, -2.8583, -2.8712, -2.9014, -2.9215,
-3.0726, -2.9735, -2.9525, -2.9023, -2.9228, -3.0825, -3.0013, -2.7888,
-2.9306, -3.0439, -2.9888, -3.0645, -3.2737, -3.1240, -3.0454, -3.0283,
-2.9835, -3.0823, -3.2054, -3.1669, -3.0419, -3.0074, -3.0234, -2.8896,
-3.0153, -2.8804, -2.8997, -3.1020, -3.0918, -2.8640, -2.9834, -3.1010,
-3.0409, -2.9695, -2.9525, -2.9359, -2.8352, -2.9794, -2.8859, -3.0675,
-2.8513, -3.0101, -3.0585, -3.0495, -3.0055, -2.8001, -2.9709, -2.9890,
-2.8921, -3.0771, -2.8579, -3.1639, -2.9965, -2.8490, -3.0690, -3.0394,
-3.1169, -2.8636, -2.8388, -3.0033, -2.9176, -2.9659, -2.8882, -2.9336,
-3.0109, -2.8911, -2.9826, -3.0069, -2.7588, -3.1698, -3.1057, -3.0451,
-2.8142, -3.0044, -3.0414, -3.1612, -2.9123, -2.9808, -3.0021, -2.9918,
-2.9860, -3.1009, -3.0660, -2.9780, -2.8015, -3.0815, -2.9802, -2.7907,
-3.0164, -3.1469, -2.9931, -2.9630, -3.0333, -3.0583, -3.1679, -3.0082,
-3.0418, -2.8212, -3.0690, -3.0878, -3.0754, -2.9217, -2.8843, -2.8342,
-2.8914, -2.9593, -2.9435, -2.8802, -3.1271, -3.0106, -2.8565, -2.9416,
-2.9377, -3.0922, -2.9455, -3.0195, -2.9499, -3.1423, -3.1442, -3.0080,
-2.9155, -3.0334, -2.9534, -3.0818, -3.0217, -3.0052, -2.9024, -2.8618,
-2.8283, -3.0610, -3.0181, -2.8918, -3.0782, -2.8492, -2.8782, -2.9875,
-3.0840, -3.1547, -2.9203, -2.9514, -2.8977, -3.0491, -2.9121, -3.0861,
-2.8774, -3.0128, -2.9025, -3.1787, -3.0647, -2.7906, -2.9883, -2.9587,
-3.0410, -2.8601, -3.1737, -3.0479, -3.0819, -3.0477, -2.9868, -2.9024,
-3.0275, -3.0417, -3.0774, -2.9819, -2.9000, -2.9556, -2.9307, -2.9814,
-3.0289, -3.0943, -2.9042, -3.0491, -3.0925, -2.9762, -3.0900, -2.8698,
-3.0360, -3.1329, -2.8815, -2.9002, -2.9418, -2.9736, -2.8787, -2.8967,
-3.0662, -2.9857, -2.9476, -2.9390, -3.0513, -3.0437, -3.1175, -3.0301,
-3.1483, -3.1527, -2.7861, -3.1371, -3.0505, -2.9435, -3.1437, -2.7309,
-3.0298, -3.0350, -2.9070, -2.9550, -3.0971, -3.1178, -3.0738, -2.9796,
-3.0355, -3.0059, -2.9847, -3.0803, -2.9114, -3.1475, -2.9690, -2.9070,
-3.0039, -2.9606, -2.8244, -2.9586, -2.9302, -3.1236, -2.9911, -3.2581,
-2.8627, -2.8215, -3.0391, -3.0943, -3.1396, -3.0491, -2.9677, -2.9401,
-3.0572, -2.8709, -2.9922, -3.0391, -2.9880, -3.0226, -3.0617, -3.0712,
-2.8863, -3.0599, -2.9237, -2.9731, -2.9047, -3.0713, -3.0532, -2.9893,
-3.1647, -2.8586, -3.0425, -2.9910, -3.3008, -3.0244, -3.0073, -3.0995,
-2.9264, -3.1765, -2.9276, -2.9274, -2.9101, -2.9702, -3.0714, -3.0964,
-3.0688, -3.1275, -2.9110, -3.0602, -3.1203, -2.9726, -3.0462, -2.9663,
-2.9732, -3.0669, -3.0632, -2.8376, -3.0559, -2.8104, -3.0296, -2.9035,
-3.1053, -2.8826, -2.9600, -2.9353, -2.9093, -2.9355, -3.1468, -2.8620,
-3.0436, -3.1621, -2.9782, -2.9230, -3.0041, -2.9086, -2.9366, -3.1147,
-3.0928, -2.9381, -3.1314, -2.9102, -3.1127, -3.0069, -2.9949, -3.2408,
-2.9820, -3.1620, -3.1759, -3.0649, -3.1534, -2.9196, -2.9862, -3.1150,
-2.9503, -3.0369, -2.9861, -3.0078, -3.1044, -2.9105, -2.9347, -3.0000,
-2.9461, -2.9267, -2.9945, -2.9911, -2.9804, -3.0625, -2.9553, -2.9028,
-2.9493, -3.0278, -3.0286, -3.0061, -2.9527, -2.8346, -2.8534, -2.8998,
-3.0094, -3.0435, -2.9483, -2.9080, -2.9436, -3.0654, -3.0258, -3.0415,
-2.9915, -2.9641, -3.0494, -2.9894, -2.9352, -2.8864, -2.8459, -2.9853,
-2.8688, -2.8498, -2.9809, -3.0047, -2.8838, -2.7494, -3.0872, -2.9672,
-3.1121, -2.9999, -3.0556, -2.9690, -2.9214, -3.0549, -3.0038, -3.1795,
-3.0137, -3.0741, -3.1380, -3.1081, -2.9288, -2.9251, -3.0362, -2.8812,
-2.9888, -3.2129, -2.7653, -2.9546, -3.0623, -2.8575, -2.7941, -3.0405,
-2.9822, -3.0462, -2.9562, -3.1998, -2.9719, -2.8572, -2.9658, -2.7420,
-3.1541, -3.0511, -3.0243, -2.8883, -2.9224, -3.1891, -2.8251, -2.9129,
-3.0507, -3.0758, -3.0238, -3.2220, -2.9039, -2.9713, -2.9436, -2.9332,
-2.9481, -2.9874, -2.9242, -2.8805, -3.1984, -3.0944, -2.9968, -3.0850,
-2.8526, -2.9497, -3.0334, -3.0341, -2.8833, -3.1369, -3.1160, -2.8509,
-2.8965, -2.8406, -2.8983, -3.0071, -2.9437, -3.1443, -2.9567, -3.0426,
-2.9580, -2.9034, -2.7715, -2.9652, -2.9363, -3.0531, -2.9687, -2.9758,
-3.0533, -3.0157, -3.0676, -3.0508, -2.9865, -2.9132, -3.0986, -2.9482,
-2.9567, -2.8824, -2.9814, -3.0409, -3.1294, -2.9416, -2.9321, -3.0957,
-3.0022, -2.9672, -2.8444, -2.9847, -3.1457, -2.9132, -2.9788, -2.8792,
-3.0897, -3.0332, -3.1515, -3.1739, -3.1642, -2.9052, -3.0404, -3.0191,
-2.9527, -2.9734, -2.9422, -3.1456, -2.9272, -3.0248, -2.9922, -2.9872,
-2.9983, -2.9984, -2.9674, -3.0935, -2.8651, -2.8782, -2.9344, -2.8170,
-2.8574, -2.8531, -2.9813, -2.8502, -3.0552, -3.1115, -3.0922, -2.9271,
-3.0387, -2.7450, -3.1254, -3.1245, -3.0719, -3.0280, -2.9168, -3.0306,
-2.9118, -2.8322, -3.0969, -2.9511, -3.1041, -3.1169, -3.1567, -3.1042,
-2.8920, -3.1966, -3.0437, -2.8167, -2.8032, -2.7280, -3.0565, -2.9344,
-2.8481, -2.8699, -3.0712, -2.9349, -3.0281, -3.0541, -3.0288, -3.0218,
-3.0103, -3.0491, -2.9265, -2.9017, -3.0379, -3.1720, -2.8867, -2.9611,
-3.0134, -2.9166, -3.0186, -2.9407, -2.9264, -3.0563, -2.9653, -2.9049,
-2.8952, -3.0098, -3.0592, -2.9435, -3.0893, -2.9763, -2.9282, -2.7825,
-3.0228, -2.9142, -2.9780, -3.1433, -3.0289, -3.1464, -2.9495, -2.9127,
-2.8501, -2.9135, -2.9769, -2.9334, -3.0989, -2.9429, -3.0510, -2.9790,
-2.9718, -3.0443, -2.9025, -3.1442, -2.9868, -3.0490, -2.9984, -3.0900,
-3.1665, -2.7649, -2.9574, -3.1464, -3.0246, -2.9073, -2.8108, -3.0097,
-2.9368, -3.0094, -2.8170, -2.9479, -2.9568, -2.9459, -2.9817, -3.0570,
-2.9915, -2.9170, -3.0845, -3.1740, -2.9748, -2.9488, -3.1068, -3.0409,
-2.9297, -2.9540, -2.9643, -3.0240, -3.0377, -3.2697, -3.0383, -2.9242,
-3.2493, -3.0242, -3.0675, -2.9293, -3.0722, -2.9633, -3.0721, -2.9683,
-3.0910, -2.9263, -3.0295, -3.0362, -2.9124, -2.8997, -3.0242, -3.1577,
-3.1334, -2.9843, -3.0584, -3.0196, -3.0536, -2.8790, -2.9818, -2.9272,
-3.1058, -2.9366, -2.9965, -3.0424, -3.0794, -3.0221, -3.0302, -2.8847,
-2.9373, -3.0013, -2.9063, -2.9960, -3.0191, -2.9152, -3.1098, -2.9930,
-3.0546, -3.0811, -3.0429, -2.9600, -2.7779, -3.2812, -3.0303, -2.9594,
-3.0089, -3.0883, -2.8901, -2.9762, -2.8075, -3.0997, -3.2077, -2.8847,
-3.0763, -3.0311, -2.9621, -2.9729, -3.0469, -3.0700, -2.8390, -2.8529,
-3.2463, -3.1426, -2.9858, -3.0481, -2.9259, -3.0200, -2.8609, -3.0338,
-3.0395, -3.1633, -3.0922, -3.1221, -2.9088, -2.9813, -2.9826, -2.9588,
-2.9859, -2.9231, -2.9448, -2.9169, -2.8424, -3.1484, -3.0659, -2.7960,
-3.0139, -2.9052, -3.0881, -3.0359, -3.0615, -3.0882, -2.8556, -3.0198,
-2.8864, -2.8869, -3.0947, -3.0226, -3.0377, -3.1023, -2.9630, -3.1662,
-2.8621, -2.8843, -3.1055, -2.9409, -3.1010, -2.9974, -2.9903, -2.9824,
-3.0130, -3.1127, -2.9906, -3.1545, -3.2000, -3.0201, -3.2173, -3.0799,
-3.0153, -3.0565, -3.1110, -3.1201, -2.9716, -2.9103, -3.2370, -3.0344,
-2.8454, -2.9012, -2.8696, -3.0083, -3.1281, -3.0945, -2.8451, -3.0138,
-2.9574, -3.0385, -2.9091, -2.8403, -3.1696, -2.8779, -3.0542, -3.0023,
-3.0290, -3.0893, -2.7972, -2.9449, -3.0427, -3.0695, -2.9408]])),
('fc2.mu_bias', tensor([0.1011])),
('fc2.rho_bias', tensor([-3.1962]))])
See our documentation for more information on how to use the library.