[論文理解] Mutual Information Neural Estimation
阿新 • • 發佈:2021-10-02
Mutual Information Neural Estimation
互資訊定義:
\(I(X;Z) = \int_{X \times Z} log\frac{d\mathbb{P}(XZ)}{d\mathbb{P}(X) \otimes \mathbb{P}(Z)}d\mathbb{P}(XZ)\)
CPC文章裡用下面這個公式定義要更加容易理解,都是一樣的:
\[I(x;z) = \sum_{x,z}p(x,z) log \frac{p(x,z)}{p(x)p(z)} \]互資訊越大,表明兩個變數依賴關係越強,互資訊越小,表示兩個隨機變數越獨立。
KL散度的對偶問題:
因此根據KL散度和其對偶問題之間的關係我們可以得到:
利用上式優化互資訊的下界:
\[I(X ; Z) \geq I_{\Theta}(X, Z) \]\[I_{\Theta}(X, Z)=\sup _{\theta \in \Theta} \mathbb{E}_{\mathbb{P}_{X Z}}\left[T_{\theta}\right]-\log \left(\mathbb{E}_{\mathbb{P}_{X} \otimes \mathbb{P}_{Z}}\left[e^{T_{\theta}}\right]\right) \]優化演算法:
一般來說z的分佈用高斯分佈,x和z的分佈(marginal distribution)都好取樣;
對於joint distribution,用一個神經網路來建模,即F(x,z),然後其結果就是joint distribution的取樣了。
代入公式計算即可。
class Mine(nn.Module): def __init__(self, input_size=2, hidden_size=100): super().__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, 1) def forward(self, input): output = F.elu(self.fc1(input)) output = F.elu(self.fc2(output)) output = self.fc3(output) return output def mutual_information(joint, marginal, mine_net): t = mine_net(joint) et = torch.exp(mine_net(marginal)) mi_lb = torch.mean(t) - torch.log(torch.mean(et)) return mi_lb, t, et