Deep Linear Networks Learn Hierarchical Structure

Below are notes on Saxe et al.’s “A mathematical theory of semantic development in deep neural networks.” that I initially made for the Discord sessions of March 2026, and which I’ve now (15 Apr 2026) finalized after presenting the paper at the Crick.

In this paper the authors use a two-layer linear network to investigate how brains and machines uncover structure in data. Specifically, they consider a problem of mapping items, modelled on plants and animals, to their features, where features exhibit hierarchical structure, ranging from high-level features like “can move” to low-level features like “has leaves”. Agents in the world encounter items and observe their features. The question is whether and how such agents can organize these observations to reflect the underlying structures that they reflect.

I think this is a great paper for at least four reasons. The first is scientific, and the remaining three are meta-scientific

  1. The scientific contribution: it gives us a concrete framework for understanding semantic development and answering various questions about it.
  2. Simplicity: Rather than studying semantic development by throwing large datasets at fancy deep learning models, the authors study a very simplified version of the problem, which they can analyze in detail to get insight.
  3. Theoretical moves: Despite the simplicity of their model, it still starts off as being hard to analyze. Through a series of assumptions they simplify the model to something that has a closed form solution. Remarkably this solution well describes the behaviour of the model before simplifications.
  4. Picking a good title: It doesn’t matter if you have the best idea in the world, it won’t matter if nobody reads it. They could have called their paper something cryptic “Two-layer linear networks perform sequential SVD”, which is the core of their findings, but that would sound too technical to most people, even if they’d benefit from reading it. Instead, they refer to “deep neural networks”, even though the network is only two layers deep, and is linear. Neverthless, their insights apply to deeper, nonlinear networks, and so by using a good title they can spread knowledge of their results more effectively. For an example of a poor title, see this paper.

The paper has two parts. The first part is about learning. The second part is about what is learned, the SVD of the input output covariance, and how the resulting bases are useful for answering a variety of questions. In this post I’ll focus on the first part.

Problem Setup

The learning problem they study is mapping items such as different plants and animals, to various properties, such as whether they can move, whether they are big, etc.

Semantic development

Deep nonlinear networks solving this task can produce a hierarchical decomposition in time: early in learning the representation of all items are similar. As learning progresses, items become gradually distinguished, first along high-level dimensions, later along more fine-grained distinctions. Below is a panel showing dimensions-reduced representations for one such network:

The representations for the items seem to first split by the animal/plant distinction, then by the e.g. bird/fish distinction, till finally they arrive at their specific instance. This seems qualitatively similar to what happens in human learning, where (presumably) broad categories are learned before narrow distinctions. How does this come about?

A deep linear network

To address this problem while retaining analytic tractability, the authors turn to a deep linear network. It’s deep because it has one hidden layer, and linear because the activations of all units are linear. Depth doesn’t add anything to the network’s input-output transformation, but, as it turns out, creates interesting learning dynamics.

They encode inputs $\xx$ as one-hot vectors indicating specific animals, plants etc. Target outputs are binary vectors $\yy$ indicating various properties of each input. Their predicted output is generated by transforming the inputs to a hidden layer via a first set of weights $\WW_1$, and then transforming the hidden layer activations to predicted outputs $\hat \yy$ using a second set of weights $\WW_2$: $$ \hat \yy = \WW_2 \hh , \quad \hh = \WW_1 \xx.$$

Learning is performed by minimizing the least-squares loss $$L = {1 \over 2}\|\yy – \hat \yy\|_2^2.$$ Item-property pairs $(\xx_i, \yy_i)$ are presented in sequence, and weights are updated down the loss gradient after each such presentations.

Weight updates by gradient descent

To find the gradients, we compute the differential $$ dL = -\bE^T d\WW_2 \WW_1 \xx – \bE^T \WW_2 d\WW_1 \xx,$$ where we’ve defined $$ \bE \triangleq (\yy – \hat \yy) = (\yy – \WW_2 \WW_1 \xx).$$

The gradients are then $${ \partial L \over \partial \WW_1} = -\WW_2^T \bE \xx^T = -\WW_2^T(\yy- \WW_2 \WW_1 \xx) \xx^T,$$ and $${\partial L \over \partial \WW_2} = -\bE \xx^T \WW_1^T = -(\yy – \WW_2 \WW_1 \xx) \xx^T \WW_1^T.$$

The weight updates are a stepsize $\lambda$ down these gradients: \begin{align*} \Delta \WW_1 &= -\lambda {\partial L \over \partial \WW_1} = -\lambda \WW_2^T(\yy- \WW_2 \WW_1 \xx) \xx^T \\ \Delta \WW_2 &= -\lambda {\partial L \over \partial \WW_2} = -\lambda (\yy- \WW_2 \WW_1 \xx) \xx^T \WW_1^T \end{align*}

From discrete to continuous time

To convert these discrete, per-item, updates into continuous time dynamics, we can take the changes above to occur over a time step $\Delta t$, and consider the net change after having viewed all $P$ items in the dataset.

For the first set of weights, this would look like \begin{align*} \Delta \WW_1 &= \Delta t \lambda \sum_{i=1}^P \WW_2^T(\yy_i- \WW_2 \WW_1 \xx_i) \xx_i^T\\ &= \Delta t \lambda \WW_2^T \left(\left(\sum_{i=1}^P \yy_i \xx_i^T\right)- \WW_2 \WW_1 \left(\sum_{i=1}^P \xx_i \xx_i^T\right)\right) \\ &= \Delta t \lambda P \WW_2^T (\Sigma_{yx} – \WW_2 \WW_1 \Sigma_x), \end{align*} where we’ve defined the covariances $$ {1 \over P} \sum_{i=1}^P \yy_i \xx_i^T = \Sigma_{yx}, \quad {1 \over P}\sum_{i=1}^P \xx_i \xx_i^T = \Sigma_x.$$

Then $$\lim_{\Delta t \to 0} {1 \over P \lambda}{ \Delta \WW_1 \over \Delta t} = \tau {d \WW_1 \over dt} = \WW_2^T (\Sigma_{yx} – \WW_2 \WW_1 \Sigma_x),$$ where we’ve defined $\tau \triangleq {1/P\lambda}$.

Similarly, for the second set of weights, we get $$ \tau {d \WW_2 \over dt} = (\Sigma_{yx} – \WW_2 \WW_1 \Sigma_x) \WW_1^T.$$

Weight dynamics

At this point we have the following analyze: \begin{align*} \tau {d \WW_1 \over dt} &= \WW_2^T (\Sigma_{yx} – \WW_2 \WW_1 \Sigma_x) \WW_1^T\\ \tau {d \WW_2 \over dt} &= (\Sigma_{yx} – \WW_2 \WW_1 \Sigma_x) \WW_1^T. \end{align*} So, despite our simple, linear model of a two layer network, the learning dynamics are coupled, and non-linear, since $\WW_2$ and $\WW_1$ multiply in the expressions above.

Instead of using advanced math to perform a complicated analysis of the equations above, the authors apply George Polya’s advice that when one encounters a problem one can’t solve, one should try solving a simpler problem that can be solved. They will now apply a series of approximations and assumptions, that I’ll call moves, to convert the system above into one that has a closed-form solution.

First Move: Decorrelated Inputs

The dynamics above depend on both the input covariance $\Sigma_x$, and the input-output covariance $\Sigma_{xy}$.

To gain insight into the weight dynamics, we can first simplify the problem by assuming that the input dimension are uncorrelated, so $ \Sigma_x = \II$.

Second Move: Switching coordinates

We can then switch to coordinates that reflect the input and output statistics, by applying SVD to the input-output covariance, $$\Sigma_{yx} = \UU \SS \VV^T,$$

The rows of $\VV^T$ form an orthonormal basis for the input, and those of $\UU$ do the same for the output. Since the weights combine as $\WW_2 \WW_1$ to transform the inputs to the outputs, we can reparameterize the weights in terms of the $\UU$ and $\VV$ and an arbitrary orthogonal matrix $\RR$ as $$ \WW_1 = \RR \ol{\WW}_1 \VV^T, \quad \WW_2 = \UU \ol{\WW}_2 \RR^{T}.$$

We can express the dynamics in terms of this parameterization as \begin{align*} {d \WW_1 \over dt} &= \RR {d \ol{\WW_1} \over dt} \VV^T \\ \implies {d \ol \WW_1 \over dt} &= \RR^T {d \WW_1 \over dt} \VV\\ &= \RR^T\WW_2^T (\Sigma_{yx} – \WW_2 \WW_1 ) \VV \\ &= \RR^T \RR \ol{\WW}^T_2 \UU^T(\UU \SS \VV^T – \UU \ol{\WW}_2 \RR^T \RR \ol{\WW}_1\VV^T)\VV \\ &= \ol{\WW}_2 ^T( \SS – \ol{\WW}_2 \ol{\WW}_1). \end{align*}

Similarly, $$ {d \ol{\WW}_2 \over dt} = (\SS – \ol{\WW}_2 \ol{\WW}_1)\ol{\WW}_1^T.$$

Third Move: Diagonal Assumption

To further simplify the analysis, they now assume that $\ol{\WW}_1$ and $\ol{\WW}_2$ are diagonal, with elements $c_\alpha$ and $d_\alpha$ respectively. This makes the dynamics above decouple into \begin{align*} {d c_\alpha \over dt} &= (s_\alpha – c_\alpha d_\alpha) d_\alpha \\ {d d_\alpha \over dt} &= (s_\alpha – c_\alpha d_\alpha) c_\alpha \end{align*}

The dynamics of the net transformation $\WW_2 \WW_1$ in the transformed coordinates are then captured by the dynamics of the product $\ol{\WW}_2 \ol{\WW}_1$, which, under our diagonal assumption, reduce to the dynamics of $a_\alpha = c_\alpha d_\alpha$: $$ {d a_\alpha \over dt} = {d c_\alpha \over dt} d_\alpha + c_\alpha {d d_\alpha \over dt} = (s_\alpha – c_\alpha d_\alpha)( c_\alpha^2 + d_\alpha^2).$$

Fourth Move: Balanced Regime

They then simplify further and study the balanced regime where $c_\alpha = d_\alpha$, which yields $$ {d a_\alpha \over dt} = 2 a_\alpha (s_\alpha – a_\alpha) .$$

Sigmoidal dynamics

Solving this ODE for the final value of $a_\alpha$ and assuming an integration time constant $\tau$ (we used 1 above for clarity) gives the sigmoidal dynamics in the main text,

Thus, if one accepts all the assumptions we made along the way, we find that the modes of the data are learned in a stepwise manner, with the more dominant modes being learned first and faster. This is shown in the panel below:

What makes a mode dominant is the number of data points it relates to. For example, the first mode is a basic “does-it-grow” mode, which all the items do. The second mode is the animal-vs-plant distinction, and so on.

Explaining the hierarchical decomposition

We can now consider how the hidden layer representations evolve over learning, to see if we can recover the phenomenon observed in the deep nonlinear networks we started this post with.

Given an input $\xx$, the corresponding representation is $$ \hh(t) = \WW_1(t) \xx = \RR \diag{\cc(t)}\VV^T \xx,$$ where $\RR$ was our arbitrary orthogonal matrix. Factoring out this term (since a fixed rotation applied to all items doesn’t change the representation), and assuming the balanced regime where $a_i(t) = c_i(t)^2$, we get $$ \hh(t) = \diag\left(\sqrt{\aa(t)}\right) \VV^T \xx.$$

Since the item representations are one-hot, $\VV^T \xx$ picks out the column of $\VV^T$ corresponding to a particular item:

$$ \blacksquare$$


Posted

in

,

by

Tags:

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *