# Mechanistic Interpretability
The core problem that must be solved is the **curse of dimensionality**. We need to be able to take these weird, intricate, high dimensional objects (neural networks) and break them down into bits that can be understood semi independently.
Note that this paper [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html) focus's entirely on attention only transformers (and drops the MLPs). The model is made up of several attention layers. Each layer is made up of several heads. The heads can be thought of as working in parallel. Each head is an interesting object to try and understand on it's own.
**Composition** is the entire point of deep learning. Deeper models have depth, and depth means that you can have composition. Fundamentally, the way to think about neural networks is that they are incredibly simple functions of matrix multiplication, with a tiny veneer of nonlinearity on top. As you compose these simple but not quite linear functions, you go from "yeah what I'm representing is basically a linear map" to "what I am representing is actually an interesting series composing a bunch of things". The main example of interesting composition that we will cover is **induction heads**.
Transformers are weirdly *linear*. One consequence of this is that there are lots of ways to represent equivalent concepts. For instance, say we want to take the product of $3 \times 4 \times 5$, this can be done in equivalent ways:
$3 \times (4 \times 5) = (3 \times 4) \times 5$
If you let ML people chose the best way to represent things, they think about "how would I represent this is code so that it is really computationally efficient?" (this makes sense since we are spending millions of dollars on these models). However, there is no reason to think that the optimal thing for computational efficiency is also the optimal thing for human understanding. And a lot of what this paper talks about is that there is indeed a better way of thinking about it from a human interpretable point of view.
### Transformer Overview
The fundamental point of a transformer is to be a sequence prediction machine. You have a sequence, you want a model that can take in sequences of varying length, and then do a bunch of computation in parallel on each individual element of your sequence, but also be able to move information between sequence positions in an intelligent way. The main models we will focus on are GPT-2 style models - these are **autoregressive** models.
Our high level architecture is as follows:

A few key points are:
* Each token maps to an embedding vector of dimension $d_m$, where $m$ stands for model
* We also want to have something allows us to represent which position a token is in a sequence. One way to do this is via another learned matrix, $W_{pos}$, which would be of size $d_m$ x context length.
* We then have an attention layer. A key thing to keep in your head (which the diagram does not do a good job of showing) is that $x_0$ is actually a *tensor* of an embedding sized vector, but of every position in the sequence. The attention layer is meant to move information between positions and do some processing along the way.
* Then we have the MLP layer. The MLP layer is meant to do a bunch of processing on each token *in parallel* but *not moving information between positions*.
* The residual stream is the accumulated sum of every output so far (i.e. the sum of the output of the attention layer, and then the output of the MLP layer is added to that)
* At the end of the model we want to convert from a vector of size $d_m$ to tokens. We do this via an unembedding matrix.
A weird thing about GPT-2 style transformers it that they output a $position \times d_{vocab}$ tensor, so every position in the sequence has a vector of logits across the entire vocabulary. This is odd because the thing we are trying to do is predict the next token. In order to predict the next token the only thing you care about is the final position. Why do we calculate something for everything? This is an artifact for how the model was trained.
### The Residual Stream is a Big Deal
Transformers can be thought about as having a central object: a big, wide **residual stream** that carries forward all of the information of the network. Each layer is reading in from the residual stream, applying some edits, and then putting it back.
An interesting thing about transformers - we will only ever read and write from the residual stream via *linear operations*. We write to the residual stream via *addition* and we apply a linear map to read out from it. This is very important because it means we can think of the residual stream as being the output of the sum of every layer. This means that you can decompose the input to any layer into a sum of a bunch of terms that correspond to different bits of the network.
The way to think about this: the model is trying to perform a bunch of computation. This will likely involve information flowing from the input to output via some layers. E.g. the token is read via some head, the head moves it to another position in the sequence, and then that gets pushed up through another head, which gets process through the MLP layer, which gets mapped via the unembed to the output logits. One of the reasons that the residual stream is such an important object is that it means that rather than every path needing to go through every single layer of the network, the model can choose which layers it wants to go through (if it doesn't want to go through a layer it can just go through the residual connection). In practice it seems to be the case that most of the computation that a model is doing goes through a couple layers, rather than through everything. This gives the model the ability to chose what paths it wants to send information down. This means that we can expect a lot of the model behavior to be kind of **localized** - i.e. a bunch of behavior will look like some paths matter and most paths don't matter. In practice this seems to be the case.
Another implication: The model is using the residual stream to achieve composition. For instance, say a head in layer one outputs a vector to the residual stream. Then a head in the second layer may read that in and do something with it. In this way you can think of the output of the first head as an encoded message that is then read in by the second head. But the important thing is that for any pair of composing bits in the model, they are completely free to chose their encoding. There is not reason that the encoding between head 3 in layer 2 and head 5 in layer 6 should be the same as the encoding between head 4 in layer 1 and head 7 in layer 4. This means that we should expect the residual stream to be pretty difficult to interpret. In pract ice this is true. The way to deal with this is to say: "understanding the residual stream is too hard. Instead, let's try and understand which paths through the residual stream really matter. And then let's try and decompose a path into bits between parts of the model that we expect to be interpretable".
Again, remember that the residual stream is this sum of all of these encoded messages between parts of the network, making a terrible mess to interact with and try to understand!
Another way to think about this is via the notion of a **[privileged basis](https://harrisonpim.com/blog/privileged-vs-non-privileged-bases-in-machine-learning)** (another good video [here](https://www.youtube.com/watch?v=-oKuDRFHW_Y)). The residual stream does not have a privileged basis, making it very hard to interpret. Now what does a privileged basis mean? Fundamentally, if you have a vector space, you are going to need a basis to understand what is going on inside of it. I.e. you need some way to decompose vectors into coefficients of a bunch of fixed coordinate axes (vectors). There are a bunch of techniques for taking in an arbitrary set of vectors and finding a basis that might be sensible for those (e.g. [Principle Component Analysis](Principle%20Component%20Analysis)). But it would be really nice if we could just take a model and know what the right basis is. Put another way, by privileged basis, we just mean that we can predict a priori, without looking at the weights or activations, which basis vectors might be meaningful. Just to note: vector spaces are just geometric objects.
From another podcast ([Neel Nanda–Mechanistic Interpretability, Superposition, Grokking - YouTube](https://youtu.be/cVBGjhN4-1g?t=3925)) - generally people draw residual connections like:

But Nanda (and anthropic) prefer to draw the residual stream as the central channel, with MLPs branching off. This is because if you look at the norm of the vectors in the residual stream, they are much larger than those in the mlp. So it is more appropriate to draw models with a big central residual stream with tiny skips to the side for each layer (that are updated a small amount). So the residual stream is this big, shared bandwidth that a model is pursuing between layers. Each layer is reading from and writing to as an incremental update.

### Virtual Weights
We can think of **reading** from and **writing** to the residual stream. Yet, reading and writing feel like inverse/complementary operations. However, they are very different in this case!
* **Reading** occurs when going from big (the dimension of the residual stream) to small (the dimension of an attention head). So reading effectively is extracting certain dimensions and information that it cares about (via a **projection**). So **reading = projection**. Note that because we are going from big to small, most directions of the residual stream are just going to be thrown away. The model is choosing to focus on a few meaningful directions.
* **Writing** occurs when going from small (the dimension of the attention head) to big (the dimension of the residual stream). So writing effectively is an **embed**. So, **writing = embed**. Here we are choosing a set of directions in the residual stream and writing our information to those directions. By choosing some directions, future things in the model can chose to look at those directions.

A consequence of this is that we can multiply together the reading and writing matrices to get the concept of "virtual weights". This can give us a rough approximation of what is going on (more on this later).
### Subspaces and Residual Stream Bandwidth
Just see more here [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html#residual-comms). The key point here is that trying to interpret the residual stream is just really hard.
### Attention Heads
An attention head has a key component, the **Attention Pattern**. For each position the head is going to learn a probability distribution of that position and previous positions. This is the attention pattern. Then, the head outputs a kind of weighted sum of some information it has chosen from the residual stream, weighted by the attention probability. It will add these all up and that is going to be the pattern on this token.
The output of an attention layer is the sum of a bunch of independent attention heads.
### Attention Heads as Information Movement
Attention patterns represent information movement. Say we have a sentence: "The cat sat on the mat. The rat sat on the ----- ". A thing that the transformer might want to do is compress the information from the first sentence "the cat sat on the mat" to just a single vector that has information like "kids story", "about animals", "rhyming". It may want to use this in the next sentence. But it might be kind of wasteful if it had to attend to every token in the previous sentence to get all of the relevant info. A thing that it might do is that the "full stop" (the period .) at the end of the first sentence, it copies all of the information and then stores it there. In the future if it wants previous information then it can just attend to the full stop. But at this point if the model attends to the full stop token, it doesn't mean that it cares about the full stop! It just means that it was attending to information that was stored on the full stop token.
### Intuition: Distinguish between *Parameters* and
When you start staring at transformer algebra, it is really helpful to keep in mind the distinction between **parameters** (weights stored in the model, learned by the model during training and updated by the optimizer. These are just a fixed part of its functional form) and **activations** (things that are purely calculated on a particular input, are a function of that input, and will vanish when you stop running the model on that input and run it on other inputs).
The attention pattern is a low rank factorization of a residual stream x residual stream matrix.
---
Date: 20230725
Links to:
Tags:
References:
* [A Walkthrough of A Mathematical Framework for Transformer Circuits - YouTube](https://www.youtube.com/watch?v=KV5gbOmHbjU)
* [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)