Skip to content

RistoAle97/centered-kernel-alignment

Repository files navigation

🤖 CKA PyTorch 🤖

CKA (Centered Kernel Alignment) in PyTorch.

Python Pytorch


✒️ About

Note

Centered Kernel Alignment (CKA) [1] is a similarity index between representations of features in neural networks, based on the Hilbert-Schmidt Independence Criterion (HSIC) [2]. Given a set of examples, CKA compares the representations of examples passed through the layers that we want to compare.

Given two matrices $\boldsymbol{X} \in \mathbb{R}^{n\times s_1}$ and $\boldsymbol{Y} \in \mathbb{R}^{n\times s_2}$ representing the output of two layers, we can define two auxiliary $n \times n$ Gram matrices $\boldsymbol{K}=\boldsymbol{XX^T}$ and $\boldsymbol{L}=\boldsymbol{YY^T}$ and compute the dot-product similarity between them

$$\langle vec(\boldsymbol{XX^T}), vec(\boldsymbol{YY^T})\rangle = tr(\boldsymbol{XX^T YY^T}) = \lVert \boldsymbol{Y^T X} \rVert_F^2.$$

Then, the $HSIC$ on $K$ and $L$ is defined as

$$HSIC_0(\boldsymbol{K}, \boldsymbol{L}) = \frac{tr(\boldsymbol{KHLH})}{(n - 1)^2},$$

where $\boldsymbol{H} = \boldsymbol{I_n} - \frac{1}{n}\boldsymbol{J_n}$ is the centering matrix and $\boldsymbol{J_n}$ is an $n \times n$ matrix filled with ones. Finally, to obtain the CKA value we only need to normalize $HSIC_0$

$$CKA(\boldsymbol{K}, \boldsymbol{L}) = \frac{HSIC(\boldsymbol{K}, \boldsymbol{L})}{\sqrt{HSIC(\boldsymbol{K}, \boldsymbol{K}) HSIC(\boldsymbol{L}, \boldsymbol{L})}}.$$

Note

However, naive computation of linear CKA (i.e.: the previous equation) requires maintaining the activations across the entire dataset in memory, which is challenging for wide and deep networks [3].

Therefore, we need to define the unbiased estimator of HSIC so that the value of CKA is independent of the batch size

$$HSIC_1(\boldsymbol{K}, \boldsymbol{L})=\frac{1}{n(n-3)}\left( tr(\boldsymbol{\tilde{K}}, \boldsymbol{\tilde{L}}) + \frac{\boldsymbol{1^T\tilde{K}11^T\tilde{L}1}}{(n-1)(n-2)} - \frac{2}{n-2}\boldsymbol{1^T\tilde{K}\tilde{L}1}\right),$$

where $\boldsymbol{\tilde{K}}$ and $\boldsymbol{\tilde{L}}$ are obtained by setting the diagonal entries of $\boldsymbol{K}$ and $\boldsymbol{L}$ to zero. Finally, we can compute the minibatch version of CKA by averaging $HSIC_1$ scores over $k$ minibatches

$$CKA_{minibatch}=\frac{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(\boldsymbol{K_i}, \boldsymbol{L_i})}{\sqrt{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(\boldsymbol{K_i}, \boldsymbol{K_i})}\sqrt{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(\boldsymbol{L_i}, \boldsymbol{L_i})}},$$

with $\boldsymbol{K_i}=\boldsymbol{X_iX_i^T}$ and $\boldsymbol{L_i}=\boldsymbol{Y_iY_i^T}$, where $\boldsymbol{X_i} \in \mathbb{R}^{m \times p_1}$ and $\boldsymbol{Y_i} \in \mathbb{R}^{m \times p_2}$ are now matrices containing activations of the $i^{th}$ minibatch of $m$ examples sampled without replacement [3].


📦 Installation

This project requires python >= 3.10. All the necessary packages can be installed with

pip install -r requirements.txt

Take a look at the examples directory to understand how to compute CKA in different scenarios.


🖼️ Plots

Model compared with itself Different models compared
Model compared with itself Model comparison

📚 Bibliography

[1] Kornblith, Simon, et al. "Similarity of neural network representations revisited." International Conference on Machine Learning. PMLR, 2019.

[2] Wang, Tinghua, Xiaolu Dai, and Yuze Liu. "Learning with Hilbert–Schmidt independence criterion: A review and new perspectives." Knowledge-based systems 234 (2021): 107567.

[3] Nguyen, Thao, Maithra Raghu, and Simon Kornblith. "Do wide and deep networks learn the same things? uncovering how neural network representations vary with width and depth." arXiv preprint arXiv:2010.15327 (2020).

This project is also based on the following repositories:


📝 License

This project is MIT licensed.

About

CKA (Centered Kernel Alignment) implemented in PyTorch

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages