Estimating MI by noise contrastive method (InfoNCE)¶
Setup¶
The following cell imports the necessary packages:
import torch
import torch.nn as nn
import os
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import argparse
from torch.autograd import Variable
import itertools
from tqdm import tqdm
from data.mix_gaussian import MixedGaussian
from data.gaussian import Gaussian
from model.utils import *
from torch.utils.tensorboard import SummaryWriter
import datetime
Use GPU if GPU is available.
cuda = True if torch.cuda.is_available() else False
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
torch.set_default_tensor_type(FloatTensor)
name = './results/InfoNCE' # filename
chkpt_name = name+'.pt' # checkpoint
writer = SummaryWriter('./results/log/InfoNCE')
Set random seed for reproducibility.
SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
Specify the hyper-parameters:
# arguments
parser = argparse.ArgumentParser()
parser.add_argument("--rho", type=float, default=0.9, help="coefficient of Gaussian")
parser.add_argument("--d", type=int, default=6, help="dimension of X & Y")
parser.add_argument("--sample_size", type=int, default=400, help="sample size")
parser.add_argument("--gamma", type=float, default=0, help="clipping parameter")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_iters", type=int, default=60000, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=100, help="size of the batches")
parser.add_argument("--lr", type=float, default=1e-4, help="adam: learning rate")
parser.add_argument("--hidden_dim", type=int, default=100, help="Hidden dimension")
parser.add_argument("--c_0_1_ratio", type=float, default=1, help="Ratio of samples with label 0 and samples with label 1 ")
parser.add_argument("--ma_rate", type=float, default=0.1, help="move average rate")
parser.add_argument("--ma_ef", type=float, default=1, help="move average ef")
opt, unknown = parser.parse_known_args()
Specify the distribution to generate the data samples:
# Two choices here: 'Gaussian' and 'MixedGaussian'
density = 'Gaussian'
Model¶
Define the function generate_data
for generating Gaussian or Mixed Gaussian distributions.
When generating Gaussian distrition, \({X}\) and \({Y}\) are distributed as
\begin{align} p_{XY} = \prod_{i=1}^d \mathcal{N} \left( \mathbf{0}, \begin{bmatrix}1 & \rho \ \rho & 1\ \end{bmatrix} \right) \end{align}
When generating Mixed Gaussian distributions, \({X}\) and \({Y}\) are distributed as
\begin{align} p_{XY} = \frac{1}{2} \prod_{i=1}^d \mathcal{N} \left( \mathbf{0}, \begin{bmatrix}1 & \rho \ \rho & 1\ \end{bmatrix} \right)
\frac{1}{2} \prod_{i=1}^d \mathcal{N} \left( \mathbf{0}, \begin{bmatrix}1 & {-} \rho \ {-} \rho & 1\ \end{bmatrix} \right) \end{align}
def generate_data(distribution='Gaussian', rho=0.9):
# rho is the covariance for generating distributions
# mu1 and mu2 are means for generating Mixed Gaussian distribution
mu1 = 0
mu2 = 0
# mg is an object of Class Gaussian or MixedGaussian
if distribution=='Gaussian':
mg = Gaussian(sample_size=opt.sample_size,rho=rho)
else:
mg = MixedGaussian(sample_size=opt.sample_size, mean1=mu1, mean2=mu2,rho1=rho, rho2=-rho)
# Calculate the ground truth MI between X and Y for (X, Y) from mg
mi = mg.ground_truth * opt.d
# Create X, Y for storing generated samples
X = np.zeros((opt.sample_size,opt.d))
Y = np.zeros((opt.sample_size,opt.d))
# Generate samples of random variable X,Y and XY
for j in range(opt.d):
# In each iteration, mg.data will generate samples of two dimensions where one dimension for X and another for Y respectively
data = mg.data
X[:,j] = data[:,0]
Y[:,j] = data[:,1]
X = torch.Tensor(X)
Y = torch.Tensor(Y)
XY = torch.cat((X, Y), dim=1)
return XY, X, Y, mi
Define the neural network.
class Net(nn.Module):
# Inner class that defines the neural network architecture
def __init__(self, input_size=2, hidden_size=100, sigma=0.02):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, 1)
nn.init.normal_(self.fc1.weight, std=sigma)
nn.init.constant_(self.fc1.bias, 0)
nn.init.normal_(self.fc2.weight, std=sigma)
nn.init.constant_(self.fc2.bias, 0)
nn.init.normal_(self.fc3.weight, std=sigma)
nn.init.constant_(self.fc3.bias, 0)
def forward(self, input):
output = F.elu(self.fc1(input))
output = F.elu(self.fc2(output))
output = self.fc3(output)
return output
The function _resample
below is for resampling the given data samples for training the neural network.
def _resample(data, batch_size, replace=False):
# Resample the given data sample.
index = np.random.choice(
range(data.shape[0]), size=batch_size, replace=replace)
batch = data[index]
return batch
The function infonce_loss
below returns the empirical estimate of the following infoNCE lower bound on the mutual information \(I(X\wedge Y)\):
\begin{align} I_{\text{infoNCE}}(f) := E\left[\log \frac{e^{f(X, Y)}}{E \left[\left. e^{f( X, Y’)} \right| X \right]} \right] \leq I(X\wedge Y) \end{align}
where \(f: \mathcal{X}\times \mathcal{Y}\to \mathbb{R}\) is a function to be implemented by a neural network, and \(p_{Y'|X,Y}(y'|x,y)=p_{Y}(y)\), i.e., \(Y'\) is identically distributed as \(Y\) and independent of \((X,Y)\).
def infonce_loss(net, x_samples, y_samples):
sample_size = x_samples.shape[0]
x_tile = x_samples.unsqueeze(0).repeat((sample_size, 1, 1))
y_tile = y_samples.unsqueeze(1).repeat((1, sample_size, 1))
T0 = net(torch.cat([x_samples,y_samples], dim = -1))
T1 = net(torch.cat([x_tile, y_tile], dim = -1))
lower_bound = T0.mean() - (T1.logsumexp(dim = 1).mean() - np.log(sample_size))
return lower_bound
The neural network discriminator
below implements \(f\):
discriminator = Net(input_size=opt.d*2, hidden_size=100)
# move NN model to GPU if GPU is available
if cuda:
discriminator.cuda()
# Adam optimizer
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Load previous results.
load_available = True # set to False to prevent loading previous results
if load_available and os.path.exists(chkpt_name):
checkpoint = torch.load(
chkpt_name, map_location='cuda' if torch.cuda.is_available() else 'cpu')
mi_list = checkpoint['mi_list']
model_state = checkpoint['model_state']
discriminator.load_state_dict(model_state)
print('Previous results loaded.')
else:
mi_list = [] # storing the mi estimation of each iteration
Training¶
Since maximizing the infoNCE bound over \(f\) attains the mutual information
\begin{align} I(X\wedge Y) = \sup_{f} I_{\text{infoNCE}}(f), \end{align}
we train the neural network to minimize the loss \(-I_{\text{infoNCE}}(f)\). The optimal solution \(f\) satisfies
\begin{align} f(x,y)=\log \frac{p_{Y|X}(y|x)}{p_{Y}(y)} + c \quad \forall x\in \mathcal{X}, y \in \mathcal{Y} \end{align}
for any constant \(c\).
The following cell train the neural network using the data samples.
XY, X, Y, Ground_truth = generate_data(distribution=density, rho=opt.rho)
continue_train = False # set to True to continue to train
if continue_train:
for i in range(opt.n_iters):
data_joint = _resample(XY, batch_size=opt.batch_size)
train_x = data_joint[:,0:opt.d]
train_y = data_joint[:,opt.d:]
optimizer_D.zero_grad()
loss = - infonce_loss(discriminator, train_x, train_y) # negative infonce_bound as the loss
loss.backward()
optimizer_D.step()
with torch.no_grad():
mi_est = infonce_loss(discriminator, X, Y)
mi_list.append(mi_est.item())
writer.add_scalar('mi_list', mi_est.item(), i)
writer.add_scalar('loss', loss, i)
if i%500==0:
print("Iternation: %d, loss: %f, mi_est: %f"%(i, loss.item(), mi_est))
writer.add_graph(discriminator, (XY,))
Iternation: 0, loss: 0.000010, mi_est: -0.000011
Iternation: 500, loss: -1.986413, mi_est: 2.064315
Iternation: 1000, loss: -3.324629, mi_est: 3.781751
Iternation: 1500, loss: -3.816203, mi_est: 4.240057
Iternation: 2000, loss: -3.713367, mi_est: 4.403550
Iternation: 2500, loss: -3.796084, mi_est: 4.479726
Iternation: 3000, loss: -3.887900, mi_est: 4.518673
Iternation: 3500, loss: -4.057630, mi_est: 4.549497
Iternation: 4000, loss: -3.867954, mi_est: 4.571889
Iternation: 4500, loss: -3.878994, mi_est: 4.592578
Iternation: 5000, loss: -3.913843, mi_est: 4.617288
Iternation: 5500, loss: -4.089233, mi_est: 4.643414
Iternation: 6000, loss: -3.991745, mi_est: 4.663235
Iternation: 6500, loss: -3.934570, mi_est: 4.686573
Iternation: 7000, loss: -4.032845, mi_est: 4.712495
Iternation: 7500, loss: -4.063352, mi_est: 4.731988
Iternation: 8000, loss: -4.008489, mi_est: 4.754710
Iternation: 8500, loss: -4.079526, mi_est: 4.782144
Iternation: 9000, loss: -4.210361, mi_est: 4.795047
Iternation: 9500, loss: -4.119889, mi_est: 4.821354
Iternation: 10000, loss: -4.167799, mi_est: 4.841077
Iternation: 10500, loss: -4.207998, mi_est: 4.863544
Iternation: 11000, loss: -4.259162, mi_est: 4.880834
Iternation: 11500, loss: -4.085646, mi_est: 4.909222
Iternation: 12000, loss: -4.237001, mi_est: 4.927190
Iternation: 12500, loss: -4.238161, mi_est: 4.941082
Iternation: 13000, loss: -4.208549, mi_est: 4.969678
Iternation: 13500, loss: -4.337718, mi_est: 4.987714
Iternation: 14000, loss: -4.229445, mi_est: 5.016221
Iternation: 14500, loss: -4.249481, mi_est: 5.035546
Iternation: 15000, loss: -4.330183, mi_est: 5.052544
Iternation: 15500, loss: -4.326031, mi_est: 5.071255
Iternation: 16000, loss: -4.261763, mi_est: 5.099937
Iternation: 16500, loss: -4.250395, mi_est: 5.124137
Iternation: 17000, loss: -4.325122, mi_est: 5.137771
Iternation: 17500, loss: -4.382256, mi_est: 5.164788
Iternation: 18000, loss: -4.259338, mi_est: 5.176157
Iternation: 18500, loss: -4.387579, mi_est: 5.207092
Iternation: 19000, loss: -4.284403, mi_est: 5.226067
Iternation: 19500, loss: -4.365234, mi_est: 5.249267
Iternation: 20000, loss: -4.372822, mi_est: 5.271327
Iternation: 20500, loss: -4.368681, mi_est: 5.277494
Iternation: 21000, loss: -4.467575, mi_est: 5.301701
Iternation: 21500, loss: -4.394046, mi_est: 5.332717
Iternation: 22000, loss: -4.323278, mi_est: 5.349476
Iternation: 22500, loss: -4.371675, mi_est: 5.367954
Iternation: 23000, loss: -4.457518, mi_est: 5.389325
Iternation: 23500, loss: -4.489443, mi_est: 5.410488
Iternation: 24000, loss: -4.405626, mi_est: 5.424552
Iternation: 24500, loss: -4.381571, mi_est: 5.449566
Iternation: 25000, loss: -4.434167, mi_est: 5.469605
Iternation: 25500, loss: -4.520988, mi_est: 5.486578
Iternation: 26000, loss: -4.437832, mi_est: 5.503848
Iternation: 26500, loss: -4.452613, mi_est: 5.523114
Iternation: 27000, loss: -4.480072, mi_est: 5.542090
Iternation: 27500, loss: -4.533477, mi_est: 5.559884
Iternation: 28000, loss: -4.506382, mi_est: 5.581320
Iternation: 28500, loss: -4.476612, mi_est: 5.595648
Iternation: 29000, loss: -4.548304, mi_est: 5.610945
Iternation: 29500, loss: -4.542811, mi_est: 5.632243
Iternation: 30000, loss: -4.514511, mi_est: 5.645619
Iternation: 30500, loss: -4.521682, mi_est: 5.657323
Iternation: 31000, loss: -4.531125, mi_est: 5.681959
Iternation: 31500, loss: -4.500861, mi_est: 5.696266
Iternation: 32000, loss: -4.536700, mi_est: 5.711816
Iternation: 32500, loss: -4.556186, mi_est: 5.723273
Iternation: 33000, loss: -4.522396, mi_est: 5.737668
Iternation: 33500, loss: -4.533624, mi_est: 5.754815
Iternation: 34000, loss: -4.552873, mi_est: 5.767499
Iternation: 34500, loss: -4.548962, mi_est: 5.778892
Iternation: 35000, loss: -4.574831, mi_est: 5.795615
Iternation: 35500, loss: -4.564256, mi_est: 5.801952
Iternation: 36000, loss: -4.561363, mi_est: 5.816156
Iternation: 36500, loss: -4.543841, mi_est: 5.828671
Iternation: 37000, loss: -4.563581, mi_est: 5.835651
Iternation: 37500, loss: -4.556992, mi_est: 5.846558
Iternation: 38000, loss: -4.559249, mi_est: 5.858206
Iternation: 38500, loss: -4.577271, mi_est: 5.867131
Iternation: 39000, loss: -4.577236, mi_est: 5.874681
Iternation: 39500, loss: -4.589894, mi_est: 5.884233
Iternation: 40000, loss: -4.569760, mi_est: 5.892098
Iternation: 40500, loss: -4.589166, mi_est: 5.898573
Iternation: 41000, loss: -4.597407, mi_est: 5.903266
Iternation: 41500, loss: -4.582931, mi_est: 5.909054
Iternation: 42000, loss: -4.581110, mi_est: 5.918843
Iternation: 42500, loss: -4.590831, mi_est: 5.923814
Iternation: 43000, loss: -4.581953, mi_est: 5.928774
Iternation: 43500, loss: -4.586583, mi_est: 5.933340
Iternation: 44000, loss: -4.589789, mi_est: 5.939531
Iternation: 44500, loss: -4.592797, mi_est: 5.944098
Iternation: 45000, loss: -4.594497, mi_est: 5.945866
Iternation: 45500, loss: -4.598111, mi_est: 5.952806
Iternation: 46000, loss: -4.598542, mi_est: 5.955683
Iternation: 46500, loss: -4.592979, mi_est: 5.958319
Iternation: 47000, loss: -4.597659, mi_est: 5.962470
Iternation: 47500, loss: -4.594491, mi_est: 5.964867
Iternation: 48000, loss: -4.598873, mi_est: 5.967290
Iternation: 48500, loss: -4.601477, mi_est: 5.969827
Iternation: 49000, loss: -4.601346, mi_est: 5.972457
Iternation: 49500, loss: -4.601625, mi_est: 5.974484
Iternation: 50000, loss: -4.600704, mi_est: 5.975776
Iternation: 50500, loss: -4.601782, mi_est: 5.977545
Iternation: 51000, loss: -4.599659, mi_est: 5.979087
Iternation: 51500, loss: -4.602352, mi_est: 5.980538
Iternation: 52000, loss: -4.602472, mi_est: 5.981858
Iternation: 52500, loss: -4.604167, mi_est: 5.982779
Iternation: 53000, loss: -4.604033, mi_est: 5.983732
Iternation: 53500, loss: -4.603842, mi_est: 5.984695
Iternation: 54000, loss: -4.603607, mi_est: 5.985449
Iternation: 54500, loss: -4.603514, mi_est: 5.986165
Iternation: 55000, loss: -4.604076, mi_est: 5.986738
Iternation: 55500, loss: -4.604426, mi_est: 5.986921
Iternation: 56000, loss: -4.603978, mi_est: 5.987811
Iternation: 56500, loss: -4.604378, mi_est: 5.988275
Iternation: 57000, loss: -4.604061, mi_est: 5.988607
Iternation: 57500, loss: -4.604918, mi_est: 5.989032
Iternation: 58000, loss: -4.604554, mi_est: 5.989225
Iternation: 58500, loss: -4.604691, mi_est: 5.989665
Iternation: 59000, loss: -4.604871, mi_est: 5.989891
Iternation: 59500, loss: -4.604836, mi_est: 5.990000
MI estimates¶
Use moving average to smooth the estimated MI in each training iteration.
ma_rate = 0.01 # moving average rate
mi_copy = mi_list.copy()
for k in range(1,len(mi_list)):
mi_copy[k] = (1-ma_rate) * mi_copy[k-1] + ma_rate * mi_copy[k]
Plot the MI estimation curve against the training iteration, together with the ground truth.
plt.plot(mi_copy, label='MI estimate')
plt.axhline(Ground_truth,label='ground truth',linestyle='--',color='red')
for t in range(len(mi_copy)):
if (mi_copy[t]>.8*Ground_truth):
plt.axvline(t,label='80% reached',linestyle=':',color='green')
break
plt.xlabel('number of iterations')
plt.ylabel('MI estimation')
plt.legend()
<matplotlib.legend.Legend at 0x7ff35c1ab990>
Save the model.
overwrite = True # set to True to overwrite previously stored results
if overwrite or not os.path.exists(chkpt_name):
model_state = discriminator.state_dict()
torch.save({
'mi_list': mi_list,
'mi_copy': mi_copy,
'model_state': model_state
}, chkpt_name)
writer.close()
print('Current results saved.')
Current results saved.