The following is a blog post I wrote on the Polymathic-AI website regarding a recent paper on mechanistic interpretability of Transformer models that we put on the arXiv.
At Polymathic-AI, part of our mission is to develop foundation models that help with scientific exploration and discovery. But it’s not enough to build these models, we also want to understand them! What algorithms do these networks actually learn under the hood? By uncovering them, we might discover improvements to our foundation models or even new insights about the scientific domains they represent.
To understand how Transformers solve complex problems, it helps to start with simpler tasks. In a recent paper, we do exactly this. We introduce a new toy problem designed to help us understand how Transformers can count in a context-dependent way—a core capability for scientific and quantitative reasoning. We call this task contextual counting. Contextual counting asks the model to count tokens in different regions of a sequence. As such, it idealizes scenarios where precise localization and subsequent computation are critical, such as counting specific neuro-receptors within a neuron in biological research. While seemingly simple, this task is surprisingly hard for state-of-the-art LLMs.
The Contextual Counting Task
In this task, the input is a sequence composed of zeros, ones, and square bracket delimiters: {0, 1, [, ]}. Each sample sequence contains ones and zeros with several regions marked by the delimiters. The task is to count the number of ones within each delimited region. For example, given the sequence:Theoretical Insights
We provide some theoretical insights into the problem, showing that a Transformer with one causal encoding layer and one decoding layer can solve the contextual counting task for arbitrary sequence lengths and numbers of regions.Contextual Position (CP)
Contextual position refers to positional information in a sequence that is meaningful only within the context of the problem. For the contextual counting task, this means knowing the region number for each token. For example, with three regions, the input and contextual positions might look like:Key Propositions
- Proposition 1: If the regional contextual position information is available in the latent representation of the tokens at some layer of a Transformer, the contextual counting task can be solved with a single additional layer.
- Proposition 2: A causal Transformer with a single layer and no position encoding (NoPE) can infer the regional contextual position.
Challenges for Non-Causal Transformers
For non-causal (bidirectional) Transformers, the task is more complicated:- Proposition 3: A non-causal Transformer with no position code and a permutation-invariant output head cannot solve the contextual counting task.
- Proposition 4: To emulate a causal attention profile, a non-causal attention layer with Absolute Position code would need an embedding space at least as large as the sequence length.
Experimental Results
The theoretical results above imply that exact solutions exist but do not clarify whether or not such solutions can indeed be found when the model is trained via SGD. We therefore trained various Transformer architectures on this task. Inspired by the theoretical arguments, we use an encoder-decoder architecture, with one layer and one head for each. A typical output of the network is shown in the following image where the model outputs the probability distribution over the number of ones in each region.1. Causal Transformers significantly outperform non-causal ones.
2. NoPE is best but harder to train than RoPE.
We also see that the very best model is trained with NoPE but RoPE is much more consistent in training.3. In the best performing models, the encoder captures the regional contextual position information.
As described above, the regional contextual position is an important piece of information for this task. Looking at the projection of the 1-token embeddings in the different regions, we see that this information is accurately captured.4. In the best performing models, the decoder attends only to the 1-tokens in the relevant region.
We can verify explicitly that the inferred regional contextual position in the encoder is used in the decoder cross-attention module such that the attention profile is focused on the 1-tokens of the relevant region (in the below figure, the third region).6. Out-of-distribution generalization is directly linked to which tokens are used as bias terms.
The figure below shows the behavior of three different type of solutions when generalizing to sequences of different lengths and inputs with different number of regions. Even though all three attain the same performance on the in-distribution data, their out-of-distribution performance is very different. Why is this the case?7. The network generates its output by balancing two learned shapes.
In some of our experiments, we chose to remove the MLP and self-attention layers from the decoder block. That is, the decoder is just a cross-attention layer. This configuration is less expressive but has the advantage that the output of the model is a linear combination of the value vectors derived from the embeddings of the encoder. In a previous case we saw that the decoder only attended to the 1-tokens of the relevant region and the beginning-of-sequence token. The figure below shows the value vectors of these two tokens.- Even though the model has access to the number n through its attention profile, it still does not construct a probability distribution that is sharply peaked at n. As we see in the above figure, as n gets large, this probability distribution gets wider. This, we believe is partly the side-effect of this specific solution where two curves are being balanced against each other. But it is partly a general problem that as the number of tokens that are attended to gets large, we need higher accuracy to be able to infer n exactly. This is because the information about n is coded non-linearly after the attention layer. In this case, if we assume that the model attends to BoS and 1-tokens equally the output becomes: