We frequently observe decorrelation in projection neuron responses. This has often been linked to either redundancy reduction, or pattern separation. Can we make an explicit link to inference?
A simple case to consider is $\ell_2$ regularized MAP inference, where $$ \log p(x|y) = L(x,y) = {1 \over 2\sigma^2} \|y – A x\|_2^2 + {\gamma \over 2} \|x\|_2^2.$$
We can find the solution directly by taking the gradient with respect to $x$,
$$ \nabla_x L = -{1\over \sigma^2} A^T(y – A x) + \gamma x.$$
Setting this to zero, we get $$ x = {1 \over \gamma \sigma^2}A^T ( y – A x).$$ Rearranging to isolate $x$, we get $$ (\gamma \sigma^2 I + A^T A)x = A^T y.$$
Computing the SVD of $A$ as $USV^T$, we can decompse the identity as $$ I = V V^T + V^\perp (V^\perp)^T.$$ Substituting into the condition on $x$ we get
$$ (\gamma \sigma^2 V^\perp (V^\perp)^T + \gamma \sigma^2 V V^T + V S^2 V^T) x = V S U^T y.$$
The right-hand side is in the span of $V$ so the left-hand side must be too, which implies $(V^\perp)^T x = 0$. This makes sense because the only part of $x$ that interacts with the likelihood is the projection on $V$ (through $A$), and the regularize sends the remaining component to zero.
We’re left with $$ V (\gamma \sigma^2 + S^2) V^T x = VS U^T y.$$ We then arrive at a result for the projection of $x$ on $V$:
$$ V^T x = {S \over \gamma \sigma^2 + S^2} U^T y,$$ from which we get $x$ by left multiplying by $V$,
$$ x = V {S \over \gamma \sigma^2 + S^2} U^T y.$$
It’s interesting to observe what happens if $y = A x^*.$ Our inferred value is then
\begin{align} x &= V {S \over \gamma \sigma^2 + S^2} U^T A x^*\\ &= V {S \over \gamma \sigma^2 + S^2} U^T U S V^T x^*\\ &= V {S^2 \over \gamma \sigma^2 + S^2}V^T x^*. \end{align}
This says, first, that $x$ is a projection of $x^*$ onto the span of $V$. Secondly, it says that dimensions where $S^2 \gg \sigma^2 \gamma$ are preserved, whereas those that are small are shrunk by a factor of $\gamma \sigma^2/S^2.$
A second population
We can achieve the same solution using two populations,
\begin{align} \dot \lambda &= -\sigma^2 \lambda + y – A x\\
\dot x &= -\gamma x + A^T \lambda \end{align}
To see this, notice that at convergence, we get $$\lambda = {1 \over \sigma^2}(y – A x).$$ Substituting this into $\gamma x = A^T \lambda$ then gives the same condition on $x$ we derived above, $$ x = {1 \over \gamma \sigma^2} A^T (y – A x).$$
Correlations are typically measured in the “projection neuron” population $\lambda$, so it’s more useful to express $x$ in terms of $\lambda$ and get
$$ \lambda = {1 \over \sigma^2} (y – \gamma^{-1} A A^T \lambda).$$
Rearranging gives
$$ ( \gamma \sigma^2 I + AA^T)\lambda = \gamma y.$$
Using the SVD of $A$. we get
$$ U (\gamma \sigma^2 + S^2 ) U^T \lambda = \gamma y,$$ so $$ \lambda = U {\gamma \over \gamma \sigma^2 + S^2} U^T y.$$
If we take $y = A x$, then
\begin{align} \lambda &= U {\gamma \over \gamma \sigma^2 + S^2} U^T U S V^T x = U {\gamma S \over \gamma \sigma^2 + S^2} V^T x. \end{align}
If $\mathbb E x_i x_j = \delta_{ij},$ then the channel correlations are
$$ \mathbb E \lambda \lambda^T = U \left[{\gamma S \over \gamma \sigma^2 + S^2}\right]^2 U^T.$$
This means that the more energy a dimension carries in the input, the more it’s shrunk in the output.
Another way to view this is through an input-output “transfer function”, where we divide the eigenvalues of corresponding dimensions and get, with some abuse of notation
$$ U^T {\mathbb E \lambda \lambda ^T \over \mathbb E y y^T} U = {\gamma^2 \over (\gamma \sigma^2 + S^2)^2 }.$$
So dimensions with energy $\ll \gamma \sigma^2$ are all shrunk by the same amount, while those much larger than this are shrunk by about $S^4$.
Equalization
Is there a way to equalize the output channel correlations, while still performing inference? That is, to get $ \mathbb E \lambda \lambda^T \propto I$?
Examining the expression for the channel correlations, we see that the factor of $S$ in the numerator comes from the inputs. The problematic term is that $S^2$ in the denominator. If we could replace it with $S$, then, at least for large $S$, the ratio would be $O(1)$.
The $S^2$ comes from the $AA^T$ in the steady-state expression for $\lambda$. The $A$ comes from $y – A x$, while the $A^T$ comes from $x \propto A^T \lambda$. What if we tried to modify this latter component?
Since $AA^T = (U S V^T) (V S U^T)$, what about replacing $A^T$ with $V U^T$? That is, we update our $x$ dynamics to $$ \dot x’ = -\gamma x + V U^T \lambda,$$ where we’ve used the tick to indicate the modified dynamics.
A descent direction?
Is $\dot x’$ a descent direction? If so, it must have positive projection onto $\dot x$, which was proportional to the negative gradient. We have
\begin{align} \dot x^T \dot x’ &= (-\gamma x + A^T \lambda)^T (-\gamma x + V U^T \lambda)\\ &= \gamma^2 x^2 + \lambda^T U S U^T \lambda – \gamma x^T (A^T + V U^T) \lambda\\ &= \gamma^2 x^2 + \lambda^T U S U^T \lambda – \gamma x^T V( S + 1 ) U^T \lambda \end{align}
To determine whether this expression is positive, we first lower bound the first two terms as
$$ \gamma^2 x^2 + \lambda^T U S U^T \lambda \ge \gamma^2 x^2 + \mu \lambda^2, \quad \mu = \min(S).$$
We can write the lower bound as $$ (\gamma x – \sqrt{\mu} \lambda)^2 + 2 \gamma \sqrt{\mu} x \lambda.$$
We also have $$\gamma x^T V( S + 1 ) U^T \lambda \le \gamma (M+1)x \lambda, \quad M = \max(S).$$
So we can bound the projection as
\begin{align} \dot x^T \dot x’ &\ge (\gamma x – \sqrt{\mu} \lambda)^2 + 2 \gamma \sqrt{\mu} x \lambda – \gamma (M+1)x \lambda \\
&= (\gamma x – \sqrt{\mu} \lambda)^2 + (2 \sqrt{\mu} – (M+1)) \gamma x \lambda.
\end{align}
To be guaranteed that this is not negative, we need the second term to be non-negative, so
$$ M+1 \le 2 \sqrt{\mu}.$$
Let’s say $M = k \mu.$ Then we have
$$ k \mu + 1 \le 2 \sqrt{\mu} \implies k^2 \mu^2 + 2 k\mu + 1 \le 2 \mu \implies k^2 \mu^2 + 2(k-2)\mu + 1 \le 0.$$
So at the very least we need $k \le 2$ otherwise all the terms here are positive. Since $M \ge \mu$, we have $k \in [1,2]$. And even then, since the parabola above is pointing up, the condition will only hold for a limited range of $\mu$’s.
From this derivation it seems like we’d almost never have descent directions. But empirically, if I sample $x$ and $\lambda$ randomly and compute the projection, it’s always positive.
Part of the reason the bound is pessimistic may be because in the first part I assume that $U^T \lambda$ falls along the smallest dimension, and in the second part, I assume it falls along the largest. So let’s see if we can just minimize the part that depends on $\lambda$ directly. Returning to our expression for the overlap,
$$ \dot x ^T \dot x’ = \gamma^2 x^2 + f(U^T \lambda), \quad f(\lambda) = \lambda^T S \lambda – \gamma x^T V( S + 1 ) \lambda.$$
$f(\lambda)$ is concave up, and its gradient is $$\nabla f = 2 S \lambda – (S + 1) \gamma V^T x .$$ Setting this to zero and solving for $\lambda$ we get $$ \lambda_\text{min} = \gamma {S + 1 \over 2 S} V^T x.$$
Then \begin{align} f(\lambda_\text{min}) &= \gamma^2 x^T V { (S+1) S (S + 1) \over 4S^2} V^T x – \gamma^2 x^T V(S+1){ S + 1 \over 2 S} V^T x\\
&= \gamma^2 x^T V { (S+1)^2 \over 4S} V^T x – \gamma^2 x^T V{ (S + 1)^2 \over 2 S} V^T x\\ &= – \gamma^2 x^T V { (S+1)^2 \over 4S} V^T x. \end{align}
Now $$- { (S+1)^2 \over 4S} \le -1,$$ so $$ f(\lambda_\text{min}) \le -\gamma^2 x^2.$$
Substituting this into our expression for the projection, we see that, at least in the worst case, the direction $\dot x’$ is not a descent direction,
$$ \dot x^T \dot x’ \le \gamma^2 x^2 – \gamma^2 x^2 = 0.$$
Steady-state solution
Where do the modified dynamics lead anyway? At steady-state,
$$ \gamma x = VU^T \lambda = VU^T (y – Ax)/\sigma^2,$$ so
$$ \gamma \sigma^2 V^T x = U^T(y – U S V^T x).$$ If we switch to $\tilde x = V^T x$ and $\tilde y = U^T y$, then $$ \gamma \sigma^2 \tilde x = \tilde y – S \tilde x,$$ so finally,
$$ \tilde x = {\tilde y \over \gamma \sigma^2 + S}.$$
Compare this to the desired value, which is
$$ \tilde x = {\tilde y S \over \gamma \sigma^2 + S^2},$$
and we see that for large $S$, the two give similar behaviour.
Changing the prior
It’s rather unsatisfactory to make the dynamics and solution suboptimal as above, just to achieve decorrelation for the sake of it. What if we changed the prior, so that the optimal dynamics produced a decorrelated representation in the projection neurons at steady-state?
If we go back to the objective we have
$$ L(x,\lambda) = -{\sigma^2 \over 2} \lambda^2 + \lambda^T (y – Ax) + \phi(x).$$
The $x$ dynamics are $\dot x = -\nabla \phi(x) + A^T \lambda$, so at steady state we have
$$ \nabla \phi(x) = A^T \lambda.$$
We can extract $x$ by inverting the gradient of the prior,
$$ x = (\nabla \phi)^{-1}(A^T \lambda).$$
Plugging this into our steady-state equation for $\lambda$, we get
$$ \sigma^2 \lambda + A (\nabla \phi)^{-1} [A^T \lambda] = y.$$
To equalize, we want the second summand on the right hand side to come out as $U S U^T \lambda$ (usually it’s $U S^2 U^T\lambda$). In other words, we need
$$ U S V^T (\nabla \phi)^{-1} [V S U^T \lambda] = U S U^T \lambda.$$
This then says that $$ (\nabla \phi)^{-1} [V S U^T \lambda] = V U^T \lambda.$$ Applying the gradient to both sides,
$$ V S U^T \lambda = \nabla \phi (V U^T \lambda).$$
Letting $V U^T \lambda = z$, we get that $$\nabla \phi(z) = V S V^T z,$$ so that $$\phi(z) = {1 \over 2} z^T V S V^T z,$$ dropping constants in $z$.
Indeed, we then have at steady state
$$ V S V^T x = A^T \lambda \implies x = V S^{-1} V^T A^T \lambda = V S^{-1} S U^T \lambda = V U^T \lambda,$$ where we’ve glossed over that $VSV^T$ is not invertible if $A$ is a wide matrix. Notice also that this gives the same steady state solution as our $\dot x’$ dynamics (up to a factor of $\gamma$.)
Plugging this into the equation for the projection neurons, we get $$ \sigma^2 \lambda = y – A V U^T \lambda = y – U S U^T \lambda \implies (\sigma^2 I + U S U^T) \lambda = y,$$ as desired.
So what this seems to says is that if we perform inference under a prior that’s matched to the statistics of the data(maybe? see below), then the projection neuron activities will be (approximately channel) decorrelated.
And note that although the gradient flow dynamics are $$ \dot x = -V S V^T x + A^T \lambda,$$ which seem to require lots of interaction between the corresponding units, we achieve the same steady-state using
$$ \dot x = -x + V U^T \lambda.$$
The next question is how th $VU^T$ weights can be learned – but perhaps this can be done through channel decorrelation itself?
It’s also not clear up there what I mean by `statistics matched to the data.’
The prior imposes some statistics. And note that these aren’t even $A^T A$, which would be $V S^2 V^T$, not $V S V^T.$ So at best we can say that decorrelation corresponds to a particular prior on the causes.
We can interpret the prior as penalizing causes whose projected dimensions correspond to large values of $S$, effectively imposing a kind of gain control, When the causes then go through $A$ they’re amplified along those dimensions, and the net effect is an equalization.
So in effect, we get the equalization if our prior cancels out the effect of the likelihood. A trivial result, perhaps.
$$\blacksquare$$
Leave a Reply