NTK reparametrization


I’ve been learning about the neural tangent kernel and some things confused me.

In the NTK paper, the network layers have the form \(\frac{1}{\sqrt{n_A}}A\) where \(n_A\) is the number of neurons in \(A\). Why the square root?

So I worked it out.

Setup

Input: \(x: \mathbb{^*R_{\lim}}\), where \(\mathbb{^*R_{\lim}}\) means a limited (hyper)real number. Limited just means it’s not infinitely big.

Output: \(y: \mathbb{^*R_{\lim}}\).

The inputs and outputs are scalar, but this is just for simple exposition. The analysis works for vector input/outputs.

Layers: \(A, B\) of shapes \(H \times 1\) and \(1 \times H\) where \(H\) is hyperfinite (infinitely big natural number). Every entry in the matrices is limited.

Activation function: \(r = x \mapsto \max(x,0)\), the relu.

\(\text{Net}(x) := B (r (Ax))\). I left out biases here, but the analysis doesn’t depend on them anyway.

Diagram:

    x: (1,)
    |
    V
+---+
|   |
| A |
|(H,1)
|   |
|   |
|   |
|   |
+---+
    |
    | ReLU
    V
+-----------------+
|    B: (1,H)     |
|                 |
+-----------------+
    |
    V
    y: (1,)

Now what?

Brute force. Let’s run \(x\) through the net. First is multiplying by \(A\) to get \(Ax\). Name it \(v\).

\(v := Ax\) has shape \((H,)\) (using NumPy notation), a vector of unlimited length. This is the first sign that something may be up: each of its entries are limited (because \(x\) is of limited length with limited entries), but their norm (\(\sqrt{v_1^2 + v_2^2 \dots v_H^2}\)) may not be since there are an unlimited number of entries.

Example: Let \(v\) be all 1s. Then its norm is \(\sqrt{1+1 \dots 1}\) \(H\)-many times, which equals \(\sqrt{H}\).

We want our network to have limited (aka assignable) values, so what to do? We can renormalize by inserting a constant in front of the matrix \(A\) so its output is limited. But which constant?

\(\sqrt{H}\).

Proof

\(\vert v \vert\) is less than or equal to its maximum element (by absolute value) repeated \(H\)-many times. Let \(M = \max_{i=0 \dots H} \vert v_i \vert\). This exists because \(H\) is hyperfinite, and so has a maximum element because finite sets do.

In math: \(\vert v \vert \leq \vert (M, M, \dots, M) \vert\). The magnitude of the right hand side is \(\sqrt{H M^2} = \sqrt{H} \cdot M\). The factor of \(M\) is irrelevant since we only care about the output being limited, and dividing out the \(\sqrt{H}\) will do it. Any value of a lower order would still diverge, and a higher order (like \(H^{0.6}\) or something) would make the output infinitesimal. Only \(\sqrt{H}\) is of the correct order.

By similar reasoning, for each layer, we divide by the square root of the number of elements in the layer matrix. The activation functions don’t matter too much.

Another way

This cannot be directly implemented but is mathematically equivalent. We stop worrying about the latent variables having unlimited norm and just let them be unlimited since algebraically, they act the same as any other number.

In that case, let’s consider \(\hat{y} = B(r(Ax))\) and assume all the entries are positive so the relu is trivial and can be ignored, leaving \(\hat{y} = B A x\).

The norm \(\vert B A x \vert\) is bounded above by \(\vert B \vert \vert A \vert \vert x \vert\). AKA \(\vert B A x \vert \leq \vert B \vert \vert A \vert \vert x \vert\), which (using the reasoning above) is of order \(\sqrt{H} \sqrt{H} \cdot 1\) where the \(1\) is the order of \(x\) (because it’s limited), which doesn’t really matter. This suggests a different but equivalent renormalization strategy: divide by the product of the square roots of the number of elements in each layer \(\prod_l{\sqrt{n_l}}\), but not at each step. Only once, at the very end.

In practice, we would want to make each step stay limited (so that it could even be put on a computer for one thing), but for theoretical analysis, leaving out normalizing constants until the very end is nice.

Related Posts

Derivative AT a Discontinuity

Just because 2 things are dual, doesn't mean they're just opposites

Boolean Algebra, Arithmetic POV

discontinuous linear functions

Continuous vs Bounded

Minimal Surfaces

November 2, 2023

Kate from Vancouver, please email me

ChatGPT Session: Emotions, Etymology, Hyperfiniteness

Some ChatGPT Sessions