Back to Podcasts
TWIML AI Podcast

Recurrence and Attention for Long-Context Transformers with Jacob Buckman - #750

TWIML AI Podcast

Tuesday, October 7, 202557m
Recurrence and Attention for Long-Context Transformers with Jacob Buckman - #750

Recurrence and Attention for Long-Context Transformers with Jacob Buckman - #750

TWIML AI Podcast

0:0057:23

What You'll Learn

  • Context size is a key bottleneck in scaling AI models, even as models can handle millions of tokens
  • Not all contexts are equally useful - it's about the model's ability to effectively use the additional context
  • Manifest AI is working on 'retention' architectures as an alternative to attention for long-context processing
  • The goal is to develop models that can better synthesize and leverage large amounts of input information

AI Summary

The podcast discusses the importance of long context in AI models and the challenges of scaling context size. The guest, Jacob Buckman, co-founder of Manifest AI, explains how context is a crucial axis of scaling that can significantly improve model performance, but current approaches like attention have limitations. Manifest AI is focused on developing new architectures, like the 'retention' mechanism, to better handle long-context inputs and synthesize large amounts of information.

Key Points

  • 1Context size is a key bottleneck in scaling AI models, even as models can handle millions of tokens
  • 2Not all contexts are equally useful - it's about the model's ability to effectively use the additional context
  • 3Manifest AI is working on 'retention' architectures as an alternative to attention for long-context processing
  • 4The goal is to develop models that can better synthesize and leverage large amounts of input information

Topics Discussed

#Long-context transformers#Attention mechanisms#Retention architectures#In-context learning#Scaling AI models

Frequently Asked Questions

What is "Recurrence and Attention for Long-Context Transformers with Jacob Buckman - #750 " about?

The podcast discusses the importance of long context in AI models and the challenges of scaling context size. The guest, Jacob Buckman, co-founder of Manifest AI, explains how context is a crucial axis of scaling that can significantly improve model performance, but current approaches like attention have limitations. Manifest AI is focused on developing new architectures, like the 'retention' mechanism, to better handle long-context inputs and synthesize large amounts of information.

What topics are discussed in this episode?

This episode covers the following topics: Long-context transformers, Attention mechanisms, Retention architectures, In-context learning, Scaling AI models.

What is key insight #1 from this episode?

Context size is a key bottleneck in scaling AI models, even as models can handle millions of tokens

What is key insight #2 from this episode?

Not all contexts are equally useful - it's about the model's ability to effectively use the additional context

What is key insight #3 from this episode?

Manifest AI is working on 'retention' architectures as an alternative to attention for long-context processing

What is key insight #4 from this episode?

The goal is to develop models that can better synthesize and leverage large amounts of input information

Who should listen to this episode?

This episode is recommended for anyone interested in Long-context transformers, Attention mechanisms, Retention architectures, and those who want to stay updated on the latest developments in AI and technology.

Episode Description

Today, we're joined by Jacob Buckman, co-founder and CEO of Manifest AI to discuss achieving long context in transformers. We discuss the bottlenecks of scaling context length and recent techniques to overcome them, including windowed attention, grouped query attention, and latent space attention. We explore the idea of weight-state balance and the weight-state FLOP ratio as a way of reasoning about the optimality of compute architectures, and we dig into the Power Retention architecture, which blends the parallelization of attention with the linear scaling of recurrence and promises speedups of >10x during training and >100x during inference. We review Manifest AI’s recent open source projects as well: Vidrial—a custom CUDA framework for building highly optimized GPU kernels in Python, and PowerCoder—a 3B-parameter coding model fine-tuned from StarCoder to use power retention. Our chat also covers the use of metrics like in-context learning curves and negative log likelihood to measure context utility, the implications of scaling laws, and the future of long context lengths in AI applications. The complete show notes for this episode can be found at https://twimlai.com/go/750.

Full Transcript

I'd like to thank our friends at Capital One for sponsoring today's episode. Capital One's tech team isn't just talking about multi-agentic AI. They already deployed one. It's called Chat Concierge, and it's simplifying car shopping. Using self-reflection and layered reasoning with live API checks, it doesn't just help buyers find a car they love. It helps schedule a test drive, get pre-approved for financing, and estimate trade-in value. Advanced, intuitive, and deployed. That's how they stack. That's technology at Capital One. So a way to think about this is... All right, everyone. Welcome to another episode of the TwiML AI podcast. I am your host, Sam Charrington. Today, I'm joined by Jacob Buckman. Jacob is co-founder and CEO of Manifest AI. Before we get going, be sure to take a moment to hit that subscribe button wherever you're listening to today's show. Jacob, welcome to the podcast. Thanks so much for having me, Sam. I'm excited to meet you and dig into our conversation. We're going to be talking about achieving long contacts with Transformers and the new power retention architecture that you recently published. To get us started, I'd love to have you share a little bit about your background. Yeah, absolutely. So I've been working on deep learning for a better part of a decade now. I started maybe in 2016 when I was an undergrad at Carnegie Mellon. I did my master's thesis on deep learning for autoregressive language modeling, basically just because I thought it was cool, it was exciting. And then, you know, here we are 10 years later, and it's apparently the coolest and most exciting thing in the world, which was definitely not something I anticipated when I started getting into it. But pretty fun to see. Turns out, as it turns out, yeah. And so I worked for Google Brain for a while, got lots of exposure to different areas of deep learning. some GAN stuff, some general adversarial networks, a little bit of adversarial examples, worked on deep reinforcement learning. And then I eventually went to go get a PhD in mostly deep reinforcement learning at Mila in Montreal. So yeah, for long, I feel like the writing was on the wall that just actually pushing the current paradigm was going to take us to unbelievable places. And I started thinking together with my co-founder, Carlos, how far is it going to take us? People at this point were speculating, maybe the scale will take us all the way to AGI. And they had a point. They had a pretty good argument. And one of the things that we started thinking about is, is it going to take us all the way? Do we see any technical bottlenecks that will actually prevent us from getting there? And the conclusion that we reached was that although several of the major axes of scale are handled well. We know how to scale weights. We know how to scale the data set. There's one crucial axis, which scaling does not work so well on. And in this case, it's the context axis. Or more generally, I think context size is a very specific instance of a more general problem, which is how do you make a network that can synthesize a really big input? Now, when your input is text, obviously that becomes context length. But actually, a lot of these similar ideas happen in other domains. Images, of course, like video is just a sequence of images, so even more so, and also maybe even more esoteric modeling problems. Really, this issue comes to bear of it's so expensive to actually process larger and larger inputs. And we started Manifest AI essentially to crack this problem. So we've been working for the past couple of years, mostly doing focused research on the question of what is the right way to design neural architecture with extremely large input synthesis of those. Okay. It's super interesting that you chose to focus on context as what you see as like the linchpin to better results in AI. I feel like there's a clear need for longer context, but maybe our enthusiasm for context length has waned as we've got millions of tokens with Gemini and there was this initial enthusiasm that, oh, we don't need RAG anymore. We don't need to engineer context anymore. But I think that enthusiasm has waned a bit. And at least I get the sense that there's a realization that long context isn't everything. And there are other things that we need to be smarter about doing in order to make models more intelligent. Thinking is one thing that, for example, has come since the push for the really long context that I think has had much bigger results on usability of models. Talk a little bit more about why you think context is so important. Yeah, that's a great point. And I think that a lot of the time, people confuse the top line number of context. Oh, million tokens, right? For actual utility. And all contexts are not created equal here, right? Like you can have a context that's extraordinarily useful and one that is completely useless. And these can, in fact, be the same context for just two different models, right? It's about the interplay between model learning, the final model being able to effectively use what's in its context. That's what makes the context valuable. It's not that you can shove more stuff in there. It's that the model can use the additional shoved stuff to improve its prediction and its abilities. That's right. And I was having this conversation with someone recently, like use in and of itself is very use case dependent. So we were talking about long context. And I think the person quoted examples from it was probably the Gemini release. And like their the metric they chose to demonstrate utility was like needle in a haystack. but needle in a haystack doesn't necessarily translate to the ability to use all the context or the ability to you know prioritize something that's in the middle you know equally to what's at the beginning or the end there's like a lot of dimensions and I don't know that we've fully figured out the right set of metrics to fully define kind of context utility do you have a sense on that? Yeah, I think in some ways you can't ever fully define it. Certainly not with a single scalar number, right? A model can be great at finding needles in a haystack and horrible at computing summaries of books of equivalent length to the haystack, right? Or it can be great at needle in a haystack and great at summarizing books, but terrible at performing long mathematical derivations that are many, many tokens long. And all these different things are just consequences of, hey, here's a particular input. Do I get the output I want? And ultimately, it's such a vast space of possible inputs. I mean, that's why we love LLMs, right? You can show them whatever, and they'll give you something good back. But what thing you show them and how good it is, you can even have two different prompts that are the same length and encode the same question. But one of them is sort of more well-suited to the model for whatever reason, and so it produces better outputs as a result. And when we think about contexts, we mostly think about it as, we call it in-context learning curves. So I think people are familiar with in-context learning in general, right? Like if you give something a few examples, you can LLM a few examples, then it'll get better and better at actually solving the task when you present it, the final task. And this is one specific example, but the more general principle is, as you give it more information in a particular format. How does the resulting ability improve? And on the pre-training side, you can actually measure this directly in a super clean way. Pre-training in general, pre-training evaluation is much cleaner than downstream task evaluation because you're always just focused on minimizing the loss, right? The negative log likelihood. And you can actually measure, and we do measure this, is given each additional token of context, how much marginally easier does predicting the subsequent token become. So in other words, if I give you zero tokens of context and say, predict the first token, that's going to be pretty hard. It can just be any token, right? Or maybe weighted by their unigram probabilities or something. But now I give you a hundred tokens of context and say, predict the 101st. Okay. Now you at least know, are we in a news article? Are we in a tweet, right? Are we in a mathematical equation? Are we in a computer program? You have some sense of the topic. You're going to have a better, chance here. Now, what if I give you 100,000 tokens and ask you to predict the 100,000 and first? This is hopefully, if you are good at using a long context, going to be a much easier task than even predicting the 101st token from the 100th was. And obviously, you need to average this over many, many different prefixes of length 100,000 because there's going to be variants here and you want to really get some overall picture of what it's like. But what you end up seeing are these really beautiful scaling curves. You see scaling curves on the negative log likelihood that are completely analogous to the ones that you see in parameter scaling or dataset size scaling or steps of training scaling or any of these other scaling axes that we know and love. You get these beautiful log linear plots that show that each additional token on average improves your model's ability to predict the next token by a small, like a log linear amount, basically. And so I basically view appropriately trained, like appropriately situated models, inference time, predictive ability as downstream of the chosen context lane. So in this sense, context is just another axis of scaling that can deeply improve performance. Now, does this translate into downstream tasks? Not necessarily and not directly. And that's a whole other can of worms. But I do believe that just like we've seen in every other sort of piece of the field of AI, if you can find a way to add pre-training time, squeeze more intelligence into the model, it's going to come out on downstream tasks one way or another. You're not the only one that's excited about long context. There are a lot of different groups pursuing long context. I've interviewed the folks who work on Mamba, among others, like state-based models as a general class of approach, and there are others. Your approach is centered on the Transformer, though. Why have you chosen to focus on the Transformer? Actually, something that might be a bit surprising is that many of the other architectures are very similar to Transformers. Mamba, Mamba 2 in particular, extremely similar to a transformer. DeltaNet, RetNet, any other subquadratic transformer variant, most of the architecture looks just like a transformer. Now, there's a few little tweaks. You add an extra residual connection here. You add an extra gating signal there. But more or less, they follow the same basic structure, this alternating pattern of MLP and then time mixing, and then MLP and then time mixing. And actually, you can take any of those architectures. You can take Mamba and swap in. Instead of the Mamba time mixing, you can just swap in attention for time mixing. And you get something that anyone would call a transformer. It has some maybe slight architectural differences, but here's a transformer, right? And our approach is no different. Our approach is also to basically start from this really solid architectural backbone and intervene directly on the one piece that we feel is basically bottlenecking the whole thing. and that's the time mixing. Transformers use attention, and we swap in this thing called retention, which is also very related to Mamba 2, less so to Mamba 1, but very related to Mamba 2, very related to DeltaNet, all these sort of like subquagetic architectures. They're all variants of what we call retention. Retention is a form of linear attention that mixes the current most successful paradigm attention with the best parts of the previous most successful paradigm recurrence so like recurrence attention retention right and um yeah the the advantage of recurrence is that you get linear cost with respect to context length uh the another way to view this is that the state of a recurrent model is of fixed size it never grows and since you're just paying this constant cost at every step um that's what causes the overall cost to grow literally now Now, attention... That's the strength, but also the limitation. Right. The limitation, of course, is that all this computation must be done sequentially. GPUs love big matrix multiplications. There's nothing they like more, right? And when you're doing recurrent architecture, you have to unroll many, many sequential steps. So you can't express this as a big matrix multiplication. But attention changes this, right? Attention, the advantage of attention is that it's super parallelizable. Something that's basically like you multiply two matrices. It's basically like two matrices together. You multiply the keys and queries together to get the attention weights, and then you multiply that matrix by the values to get the outputs. Very fat mat moles, very parallelizable, very GPU-friendly. And because of this, you're able to squeeze so, so much more out of the hardware that you actually are getting faster wall clock performance from attention, even though it has a quadratic cost with respect to context instead of a linear cost, up to massive, massive context lengths to the point where it's basically just ridiculous that you would ever want to use recurrence just in terms of pure wall clock time But this is where retention comes in Retention is a family of architectures that has both forms The computation can be either expressed in a recurrent form as here a fixed state and here's its update function given one new token. Or it can be expressed in an equivalent attention form as here's a whole bunch of keys and a whole bunch of queries and a whole bunch of values and do one computation to compute them all together. And these two are mathematically equivalent. It actually doesn't matter which one you do in terms of what numbers you're going to get as output. But as we were just saying, they have very different computational properties. And what's really cool is that you can combine them to get the best of both worlds. People call it the chunked algorithm. Basically, what you do is if you have a very long input sequence, you break it up into reasonably small chunks. Now, within each chunk, you use the attention calculation. So you can do this very, very hardware-friendly. You're getting a huge amount of GPU utilization, just like you always do when you use attention. But the moment you hit that hardware limit, the moment your chunk is large enough that you're fully saturating the GPU, you end things there. You don't use any additional tokens beyond that. So you get that hardware saturation, but then you use recurrence between chunks. So you both get the parallelization, the hardware efficiency of attention, and you get an overall linear cost because you're ultimately given a fixed chunk size, your overall cost for a longer and longer sequence is just going to be linear in the number of chunks that you have. So you basically just get a best of both worlds situation. And these are the really, really powerful, like basically architecture families and all state space models that you're familiar with fall under the retention umbrella, as I was saying. So this is, I think, a huge unlock of the past few years of research is the realization that these sort of like retention operations exist with these two dual forms and the chunk form that combines them. And that chunking size is not a hyperparameter that the developer needs to twiddle with. It's strictly determined by the size of the GPU. Exactly. Yeah, it has no influence mathematically on the output. No matter what you set the chunk size to, you're producing the same numbers, right? No need to worry that you're looking for. You set it to optimize GPU utilization. Exactly. You mentioned that it doesn't have any influence mathematically. And so therefore it has no influence kind of qualitatively in terms of the result. These are essentially identical operations, but just performed computationally different ways. The only thing that maybe could be affected is numerical precision. This isn't an issue that we've ever seen. we're usually training in BF16. So maybe if you went down to like 8-bit, like FP8 or even FP4, as some people are doing now, you might start to have to think about this. But in none of the experiments we've run in BF16 has this ever been an issue. It just is purely dependent on how good you want the speed to be. And so is the chunking retention or is chunking a technique that preceded retention and retention builds on this chunking concept? The first retention algorithm was actually proposed under the name linear retention. Actually, there's some references to it even earlier than that. This operation was actually used even before attention existed. Somebody proposed an operation that was equivalent to linear attention, which is itself an example of retention. But in that paper, they did not use the chunking algorithm, and they certainly didn't use like a modern hardware efficient chunked algorithm. That came later in sort of concurrent work between a couple different groups, including us at Manifest in 2023. 2023, 2024. Yeah. It's like a six month range where a couple of different groups arrived at this chunked algorithm in parallel and put out different variants of the same core idea. You've talked a little bit about the equivalence of the different kind of long context architectures, state space, for example. Talk a little bit about the various things folks have done to try to overcome the quadratic scaling issues associated with attention, things like windowed attention and latent space attention and other techniques. The way I like to think about this is by focusing on the state size. Now, for recurrent neural networks, everyone expects there to be a state. It's obvious what that is. For transformers, it's less obvious, but I think a really good way to frame it is to consider the KV cache to be the state of the network. And indeed, it is Markovian with respect to the KV cache. You can always predict the next token as long as you've kept that cache around. Now, where this differs from the classic way that we think about RNNs is that the state is no longer of a fixed size. Each additional token you see causes the state to grow a little bit. And so ultimately, the state size is going to be dependent on the context length, on the number of tokens you've seen. But I mean, that's not a problem, right? It's maybe a bit of a tweak to the definition, but it's still a completely valid definition. And in general, you're only going to be training or inferencing on up to some finite maximum length anyway. So we can just call the state if we really want to put a single specific fixed number on it. We say the largest KV cache that this transformer might encounter. We call that the state size. And once you're looking at state size, you realize that there's a bunch of different techniques out there for reducing the size of the transformer state. The quadratic cost is downstream of the fact that the state size itself is growing linearly. So if you're doing a linear number of things, of operations on each of, and the size of each operation is growing linearly, right? Then that comes out to quadratic overall. So the reason why the quadratic scaling of a transformer sort of is a thing is because the state of a transformer is growing laterally. These are the same. So one way to think about what is really wrong with quadratic scaling transformers is the state is too big. The computation is too expensive because the state is just too large and growing too fast. This only kicks in at long contexts, of course, but that's where these techniques really come into play. And so windowed attention is sort of an obvious way that sort of obvious intervention that you could do here we say okay instead of attending over the entire history we're just going to be attending over a small local window of it or sometimes people use other forms of sparsity where it's not like the immediate 128 most recent tokens but it's sort of like a more scattered set of 128 tokens or whatever and this yeah this will obviously just decrease the size of the state in fact it'll make the state constant size with respect to time and so the overall cost will be linear and not quadratic again so that's great But the state size as a whole has a couple of different dimensions. So when you think about what it would take to sort of like load a state on disk, this is something that people actually do, for example, when they're running inference, right? You just have to take a user's whole state and load it in. So how much memory is that going to take up? And the flops that it's using is proportional to the memory that it's taken up because you basically need to use every element of the state in at least one flop or it's sort of useless. So these two, at least when you're not thinking about mixture of experts type stuff, these two are going to be the same, right? And the dimension, like the size of that object, well, there's going to be one state at each layer of the network. And each state is going to have one dimension that is proportional to the time. And it's also going to have one for each head. And finally, the last dimension is going to be sort of like the head dimension, usually set to like 64 or 128, just the feature dimension. And reductions on any of these axes do the same thing. They make the overall state smaller. And that's what we're trying to do when the state is too big and it's making the overall computation too expensive. So one thing you can do, windowed, is equivalent to a reduction on the time axis. Using something like GQA, grouped query attention, this is a reduction on the head's axis because you have the same number of queries, but now fewer keys and values along the head dimension because they're being shared. something like DeepSeq's latent attention, this is going to be a reduction on the feature dimension. Instead of keeping around a whole key or query of a particular size, you just keep around a smaller thing. And what about the layers dimension? So the layers dimension, this is actually one of the most common ones to modify. This is a hybrid model. So a very, very common design. Basically, every network whose details were released from a lab in recent years has used this design, where there's a few full attention layers and then many smaller, either linear attention or more commonly windowed attention layers in between. And this is basically saying, okay, assuming that the size of the windowed layers is negligible compared to the giant size of the full attention layers, this is going to say, okay, only one out of every eight things on the layers dimension is going to be meaningful. So it's basically a reduction by seven eighths overall in terms of state size. So I really like this lens because it just provides this unified- That's a great way of unifying. Yeah, yeah, it is. And so if you're building a new model and you've got these different axes, levers applying to these different axes, from a design perspective, how do you know which lever to tweak if they're all resulting in a reduction in state size? Is it all empirical? We tweak these levers. We see what gives us the best results, or is there a more grounded way of thinking about it from a design perspective? I think empirical is mostly my answer here. Ultimately, if you over-tweak any one lever, you're going to see bad performance. As an example, let's say we were using a hybrid architecture where the context was really long, and all of the other, like the windowed layers, Let's say they use a shorter window size of like 128 or something. And now we wanted to tweak the sort of sparsity on the layer dimension all the way up. What we end up with is an architecture that's literally one full attention layer and then a bunch of tiny window layers on top of it that are basically as good as MLPs because the windows are so small. That's not going to be great, right? You're going to be able to improve on that by instead like quadrupling the number of attention layers and decreasing the GQA groups by a little bit. So you just, in general, whenever you're working with a sort of like axis of scaling like this, you just want to keep everything balanced. If you lean too heavily on any one axis, it'll generally have worse performance. But ultimately, this is a heuristic. And I think the real answer is, yeah, you just fiddle with the knobs until you find the one that seems to work the best. So taking a step back and thinking about the interplay between the architecture decisions like state space versus power retention and these windowed levers that you have? Like, how do you think about the relationship between those? So RNNs, including all state space models, have the opposite problem of transformers. Transformers have a state, at long context at least, transformers have a state that's way too large. And so all these levers exist to make that state. smaller. Now RNNs do not have this problem. RNN states are way too small. And you can kind of just like napkin map this out and it becomes actually crazy that people ever trained an LSTM. Because for LSTM, the state size at each layer is, let's say, it's equivalent to the feature size. So it's like, let's say 64, right? Now a transformer, every single thing in the KV cache, there's a key and a value, each of size 64. So that means a transformer with even just a small context of length like 1,000, right, has a 2,000 times larger state than an LSTM that is equivalent in every way except for the time mixing. It's like a crazy ratio. And we're not even training transformers at size 1,000. We're training transformers at context size like 8,000, 16,000, 100,000. Like 100,000, this is five orders of magnitude gap in state size. And people love to focus on parameter scale as one of sort of the most defining axes of scale of a model. And indeed, right, like it's a model with way, way more parameters is always going to be better than a model with significantly less. But state size is just as important. It's just much less understood and much less studied because not many people have yet adopted this perspective of viewing the K-V cache of a transformer as its state. But once you do, you see, yeah, like five orders of magnitude of scale between a transformer and an LSTM make it completely impossible that the LSTM is going to compare in performance. It's just not going to, right? It's so far away. So actually all these interventions that are about reducing the state of a transformer, none of them apply for RNNs, including state space models, Mamba, all those good architectures, because their states are way too small already. You want to make those states bigger. And that's actually one of the main focuses of our research at ManFest. So you talked a little bit about the importance of state relative to the number of parameters, which is kind of a more popular way of thinking about model performance. Is it just state that we should be thinking about independently of other things that are happening in the model? Or is there more that is required to really characterize state as a relative metric for comparing models? All axes of scale are important. I don't want to say that state is the only important one at all. What you really want is an architecture that's balanced, right? Just like I was saying before, in this field, it's all about balance. That's how you know. When a model is beautiful and balanced and elegant, that's how you know it's going to be a good compute optimal model. And the problem with having a state that's really big or really small fundamentally is that it imbalanced So a way to think about this is as the ratio between the flop spent on parameter calculations and on state calculations So to paint a picture a little bit here right If you doing say the forward pass of a transformer, what you're doing is you're starting with some tokens, you embed them into activations, and then those activations get transformed and transformed and transformed again until they're finally giving you some output. And each step of transformation costs you something computationally, some number of flops. And so what we can do is we can partition the type of flops into sort of two broad categories that together account for like 99.99% of the total flops done in the forward pass. Some of them we can say are wait flops. A wait flop is just a multiplication between a parameter and an activation. So in the MLPs, all you're doing are wait flops. Also in the projection between, you know, like when you're projecting X to Q, that's going to be weight flops there. But then there's the state flops. State flops are multiplications between some state element, which for the case of a transformer would be some member of the KV cache and an activation. So when you're doing, for example, the QK inner product, like this is going to be something that is all state flops. And it's the ratio between these two things. How many state flops are we doing and how many weight flops are we doing that really guides you to compute optimal architectures. And the reason is this, right? both parameters and state scale log linearly. That means that doubling them and doubling them and doubling them always has the same linear effect on the ability of the resulting model. But of course, the cost in flops is not log linear. The cost in flops is just linear. So doubling either parameters or state will cost you twice as many of that sort of flop and produce a linear improvement. And so what this means is, if you think about an architecture, that's wildly imbalanced. Let's say you're using like an LSTM, which has just this tiny little state and a huge amount of weights, relatively speaking. Now, here's two things that you could do that would improve the performance the same amount. You could either double the state size, or you could double the weight size. Since they're both log linear, these will have, you know, I'm obviously being a little hand wavy here, but they'll have roughly the same effect on the outcome. But they will not have the same effect on the computational cost. Doubling the state size is negligible. Your weights are 100 times larger than your state. So doubling the state, you don't even notice the difference in runtime, but you do notice the difference in performance. So essentially, it's a free win, right? If one of these things is much larger than the other, double the smaller thing, free win. And then you apply that same argument again. You can still double it and double it and double it and double it. And eventually, you'll get into a regime where you notice the doubling in terms of all-clock time. But that won't happen until weights and states are doing roughly the same number of flops. Not exactly one-to-one, but at least within an order of magnitude. And the same argument applies, by the way, in the reverse direction. So if you have a transformer with a very long context, attention is so, so expensive. You're doing all those state flops. That is basically free to double the weights. You might as well do it, right? And so this means that the only architectures for which it makes sense to train a transformer on a long context have this unfathomably huge number of weights beyond even the very largest models that people are commonly training today to get to a context length of a million or something right you need several orders of magnitude of increase in parameter skip which i mean maybe we're getting there but we're certainly not there now does this imply that you could like graph all these models in terms of the state weight flop ratio and like find some pareto frontier and it's consistent with what we're seeing in terms of leaderboard performance or benchmark performance i guess So I actually haven't done this for sort of like the models that are out there, mostly because it's very hard to get information on exactly what the architectures look like and to compute these things. But I have run sort of like smaller, more controlled experiments just on our own experimental setup and seen exactly what you're describing. Essentially, I can show you an experiment in our paper. I can link it after if you like, where we train three models with the exact same everything. They have the same flops per update. They have the same tokens per update, the same batch size, just exactly the same in all regards. The only way they differ is in the ratio, the weight state flop ratio. We have one model that has a huge, huge context and a very small number of weights. Another model that has a huge number of weights and a very small context. And finally, the Goldilocks model, right? In between, perfectly balanced. And you see that, yeah, that's the one that has the best performance. So that's the one that you want to be training, given this sort of like compute budget and these set of constraints. So this is a pretty ubiquitous principle. You'll basically see it everywhere. And I think if we were able to get details on for each model, how much state it has versus how much weight it has, we'd see this as well. Now, unfortunately, those aren't the only factors. So you could have two models, one of which is much more well-balanced than the other. but maybe the worst balanced one has been trained 100 times as long or on much better data, right? And these will completely mess up the result here. And so you'll see something very different in downstream performance. That's why I prefer to do it on this nice, clean experimental set. But I think the principle is very robust and you'll see it everywhere when performing appropriately controlled experiments. So in publishing this power retention architecture, there's the paper and kind of a mathematical description of it. There's a model that you published. It is a 3 billion parameter coding model, if I'm remembering correctly. And the third piece is some CUDA kernels. Can you talk a little bit about how these elements fit together and maybe your general experience of building this out beyond the broader perspective that you gained into the problem? So we are a very practical research lab. We aren't satisfied with just getting a theoretical win. We want to actually have the best architecture, the one that is most worthwhile for almost everyone to train. Basically, anyone who's doing long context anything should be training their models using our architecture. That's our goal. And the first piece of that is, as you said, mathematical. We need an architecture that actually, when measured per flop, gets the performance gains that we expect. That it actually is able to do fewer operations to produce more intelligence. And that's really what the paper outlines. It shows very clearly, yes, when you focus on getting models that have a balanced weight state flop ratio, and you use sort of models that process the entire input instead of something like sparse attention that's only processing a piece of the input. Those are the two things that are required to get the best performance, the best scaling loss. Process the whole input and do so with a balanced weight state flop ratio. And that's the first piece that we explain in the paper with very careful comparisons to transformers, ablations, really just painting this whole picture that we've been talking about this whole time. But that's just theoretical. Next piece is it's got to be practical. We need something that you can actually run and have wall clock speedups. And this actually, there's a little bit of an aspect of co-design between this part of things and the mathematical thing I mentioned a moment ago, because we would only be happy with an architecture if we knew that it could be given a very efficient GPU-friendly implementation. But the second piece is to actually make that GPU-friendly implementation. We started off with Triton, which works fairly well. You're able to get very good utilization. The kernels are written basically in a similar style to Flash Attention in Triton, but that's not the most you can do. So we actually went to level deeper. We knew that we could really, really crush Flash Attention if we squeezed all the performance possible out of the hardware. And because this architecture, some of the pieces of it are a little unique, a little complex to do, we weren't able to do it at the Triton level. We had to drop down a level of abstraction and just program directly in CUDA. And it was so complicated that even typical CUDA programming patterns didn't really work for our application. There were too many moving pieces, too many workarounds that we needed to do. So we realized that actually the best way to write these kernels is to build a whole CUDA framework from the ground. So we built something that we call Vidrial, which is basically a way of writing non-spaghetti CUDA, not to put too fine a point on it. right? Everyone who writes CUDA kernels does so. There's very little code reuse, no testable components, a bunch of if statements everywhere with different compilation paths, hundreds and hundreds of lines of very complicated, nuanced pointer arithmetic in order to make all the memory movements happen correctly. And if you ever want to make a change to maybe like exploit new hardware or adapt to a different problem shape, it's a nightmare. Like it's so much rewriting and retesting in order to make that happen. So Vidriol is a framework that gives you just a set of very clean patterns for writing beautiful, testable, and very, very efficient CUDA. Basically, what Vigil does, it lets you write every possible CUDA kernel that implements this operation. So always, all these are going to have basically, they're going to be doing the same thing, but they're going to be using the hardware in very different ways, in all the different possible patterns that one might consider using. Things like tile shapes, right? How much memory does each thread own and copy? How much memory, where does it copy it to? Which thread does it hand it off to? Things like pipelining. How are operations being staggered within the GPU? All of these things can have lots of different values, and you have to basically just do a huge empirical sweep to find out which ones are the best for any given setup. Visual lets you just write the generic patterns that you want, and then gives a clean Python JIT system for doing any empirical sweep. And to be clear, is it specific to power architecture implementation, or is it a generalized framework for building CUDA kernels? It's completely generic. So we built it because we wanted to write kernels for power retention. But in order to write those kernels, we realized we just needed to make a good general CUDA framework, which we did. And we implemented lots of other kernels as well, things like basic addition, like element-wise operation, matrix multiplication, copy kernels. We also implemented flash attention. just to see what would happen. And we actually get 20% speedups over tree DAO's flash attention on a whole bunch of problem shapes, which we thought was pretty cool, pretty good like sort of proof point for the framework. And of course, on power retention itself, we're now able to get like two to four X speedups over the Triton implementation because we're really squeezing what can be squeezed out of the hardware. Yeah, it's interesting. You could easily see that going the other way, adding layers of abstraction and cleanups an organization having a computational cost, but it sounds like the typical CUDA kernel implementation is fairly messy and has some inefficiencies that you're able to clean up a bit systematically. Exactly. And I'm very, very proud of what the team is able to put together here because it is exactly what you said, right? You usually think about going up a level of abstraction as introducing inefficiencies. But in this case, it actually allows us to search the space of possible low-level implementations to get the best one. And that's something that can't be done without good abstractions basically partitioning what the different useful low-level implementations are. We rely heavily on something called QtLayout, which is a sort of like recent addition to the Cutlass library by the team at NVIDIA. It's absolutely brilliant innovation. And I think we're probably the most advanced users of Qt out there because our entire system for basically computing which memory needs to be moved where across basically any different possible configuration you can imagine relies on these sort of layout manipulation operations in Qt. Interesting. Have you been tracking what Chris Latner and Modular are doing with Mojo? they I think raised some money this week and I think of Mojo as a well I initially thought of it as another language but it's kind of like bringing more CUDA flexibility directly into Python like semantics I guess that's the way I think about it have you been following that? I think the work is amazing we haven't gone deep on Mojo yet I think we wanted to stick with CUDA because it's the most mature and we just wanted some really fast-working kernels instead of trying to explore other frameworks. But we really like the idea of Mojo and hope to, at some point in the, hopefully near future, generalize the Vigil framework to also work with Mojo as a backend. Ultimately, the core principles that we're leveraging are completely agnostic to what's going on underneath. It's done in CUDA and in C++ via C++ template metaprogramming because that's the current best practice way to do it. But the core ideas, really the two core ideas, one is this separation of static and dynamic computation where you at static time, like at compile time, learn everything you can about what memory is moving where for any given configuration. And when you say learn, do you mean learn as a human developer or learn as an optimizer, a software optimizer that is, you know, building out the low level components? Yeah, and learn in the software sense. And it's not even like a machine learning. Maybe that was a bad choice of words, honestly. Really, it's just you just infer. It's just at static time. Profiling kind of. Similar to this, yeah. You can say, okay, like if the problem shape is this and the layout is this, and we're going to want to use a tile size of that and break it up into this many threads, then okay, that means this thread needs to take this set of memory addresses and move them to here And this can be known statically right It doesn matter what data you giving the kernel You just need to know the shape that you want in the shape that you want out and what we call the configuration of the kernel which is basically the way that it's getting executed on the underlying hardware. And with that information, you can just learn a whole bunch about how this computation is going to go. And then when you give it some data, it just executes all those operations and you're done. So that's like one of the core ideas. And the other is this integration with Python. So we want it to be very easily usable, very easily suble. We have this JIT system. So you can just-in-time compile on the fly a brand-new, hyper-specialized visual kernel for any given problem shape, any given hardware platform. And by just running a sweep over different possible configs, you can just empirically measure each one and find the best for your very specific problem shape. So our gains over flash attention on normal problem shapes, the ones that the authors of flash attention expected people to be using, our gains there are minimal. They're in the like 1% to 2% range. But where we get the huge speedups is on the problem shapes that are less typical. Like if you want to have some weird number of tokens in your sequence lane. For example, if you're doing pre-fill on a real document, right? That document isn't going to have a nice even 1,024 tokens length. It's going to have some crazy number, right? And in that case, our ability to sweep over different ones will give the ideal way to basically execute that operation on the hardware. So I guess this doesn't really apply if you're just doing one document of that length. But if you're doing a whole bunch of documents, all of that weird length, then our ability to sweep over different configurations, find the best one, and run every single one with that best configuration really pays off. You also, as I mentioned, put out a model. Talk a little bit about the model itself. So this is sort of the third pillar. The first one was it needs to be mathematically and theoretically optimal. The next is it needs to actually run fast on hardware. And the final one is you need to want to train this thing. And in the context of the modern era of transformers, right, every good training run starts from some great initial weights. Unless you're one of the big labs, you're very rarely going to want to initialize your model randomly, right? You're always going to want to start off from some pre-trained weights that already do something useful. Now, by switching the architecture from using exponential attention to using power retention, this is going to completely mess up the outputs, right? The existing weights are going to produce like basically garbage. But by doing a little bit of additional retraining, you can get right back to the original performance. It's way, way cheaper to start from the bad predictions of a good set of weights on a power retention architecture and train it to match the performance that those weights were able to get on exponential architecture than it is to train from scratch. So in other words, StarCoder 2, this is a 3 billion parameter model trained by the BigCode project to be a fast local coding assistant, has good performance. I think, for example, it has like 30% accuracy on human eval, which is a Python coding eval. If you just take that, take those same weights. So this is a tune that you did for the power retention architecture that you are offering as a starting place as opposed to a ground up model that's based on power. You don't need to build a ground up model to use power retention. You can tune a model that you like. Exactly. So yeah, Power Coder is this tune of Star Coder. And it sort of has two things. One is it's just a useful model. If you want a model that is a coding assistant at the 3 billion parameter scale and you want it to have really fast inference, use PowerCoder. It's just better than anything else out there. Now, the other thing it is is a proof of concept, right? It just shows that, hey, if there is a model that you care about and it's pre-trained as an exponential transformer, all you have to do is do a little bit of retraining. For example, we were able to recover performance from StarCoder 3B to PowerCoder 3B. The sort of like realignment period, we call the metamorphosis, took around two hours on a cluster of 128 H100s. It's not nothing, but it's not massive. It's not something that you need big lab pre-training scale in order to pull off. And is there a complex recipe required or is it fairly straightforward tune? Surprisingly, no complex recipe is required. We did do various explorations with the recipe to see if we could squeeze a little bit more out, and you can. But the basic recipe is like 99% of the way there. Literally just set it up like a training run. A training run like any other with the power retention architecture, but instead of initializing randomly, initialize to the known good star coder weights. And after two hours of training, you've recovered 30% score on human eval, and with more training, you can push it even beyond. So it's actually just a really, really easy way to turn any existing model into the power retention variant. Okay, I'm not sure I understand it. The known good star coder weights, where do those come from? Like, how's that different from star coder? So if you just download the star coder weights and the star coder architecture, you can run star coder. If you download the star coder weights and download the star coder architecture, but then replace the call to flash attention in the star coder architecture with a call to power retention, You now have a new model with weights that are equivalent to the StarCoder weights. And if you take this model and train it for just two hours, it'll recover the performance that StarCoder initially had. So what I haven't heard you talk about in terms of performance, at least explicitly, is long context. Like, is the testing that you've done to demonstrate with, you know, existing benchmarks focused on, you know, small contexts or a normal context, like you're performing better with StarCoder? or can you take this star code of 3b parameter model and demonstrate that you can then, you know, shove a much longer context in it? I don't think that works because the model architecture kind of fixes the context, right? Yeah, but you can just change. So there are some parameters like the rope, the rotational positional embedding parameters that do sort of fix the context length, but you can work around it. You can just tweak those two and do additional retraining and that also works fairly well. And so, yeah, this is basically exactly what we do. We start, like our initial training run is usually to recover performance. So StarCoder 2, for example, trained at a context length 4K. So our initial performance recovery just happens at context length 4K on the same data that StarCoder was trained on. But then, in fact, this is currently an ongoing run. We'll be releasing updates to the way it's due when it's completed. You can actually just train it on 16K and do another big training run on 16K. And now, because you're just paying a linear cost per token of context instead of a quadratic cost, you can actually have that be like a tractable, almost like a cheap training run. So way cheaper than training StarCoder on 16K, for example, is training PowerCoder on 16K. You could push it to 32K, 64K. At this point, it just becomes a resource question, right? How much do you want to invest in these long context-free trainings? and we're a small lab. So, you know, we make do with what we have and are pushing it as long as is reasonable given our budgets. But we're really excited to see as this sort of starts to pick up momentum and other people get excited about it too, what models other people train and how far they're able to push the context. Presumably you feel comfortable that it works without having demonstrated the ability to push the context, you know, out very long. Like what gives you that confidence? I guess that's where the mathematics come in. But, you know, what do you talk a little bit, talk a little bit more about like the practical realities of, you know, pushing context out? Like how, you know, what do you think is possible? Maybe that's a way to ask the question. What do you think is possible with this model? And where do you think there will be limitations? I think the optimal context length is always going to be a scaling law, just like every other axis of scaling. And it's one that's going to be highly dependent on the data set. So I think that this is something that a lot of the big labs are doing these days, you know, just putting these big numbers on this is my max context length, while not actually verifying that the performance within that context length is any good, right? And of course, as we were saying at the beginning of the chat, that's actually the most important thing is that it's actually able to use those tokens of context. So by looking at these in-context learning curves, where you basically are able to see how much it's able to use each additional token of context, you can really deeply learn something about whether this model is taking advantage of the extra information it was given. And what you find is that on any given computational budget, there's a particular context length that is best suited for that computational budget, just like there's always a particular model size that's best suited for it. And it's these scaling laws that really give us confidence. Scaling laws themselves are so well established at this point on all these other domains, on model size, on batch size, on training steps, right? There's no question that scaling laws work. Now, state size scaling laws, which basically govern the performance as you increase the state, and context length scaling laws, which same as you increase the context length, these are less well understood. And I think a big part of this is because in transformer models, these two things are tied together. Increasing the context length means increasing the state size. It's the only way to do it and vice versa, right? So nobody has really, before us, carefully isolated these two factors and explored their independent effect. But they do have independent effects, right? And you can really reliably see these scaling behaviors. When you look at the plots, they look just like all the other scaling plots that we know and love and trust deeply. So I think really it's this more than anything else that gives me confidence that this is sort of a fundamental general principle and one that's going to just like very naturally be able to achieve and maintain good performance as the context length gets longer. I think there's a lot of other approaches, for example, windowed attention, where you're just like arbitrarily truncating the attention at a certain size that we've measured. And we've looked at what the in-context learning curves look like for windowed attention. And guess what? You see that it learns quickly at the beginning. It's able to use the first couple of tokens of context, and then it plateaus. And that long plateau where it stops improving as you show up more tokens is exactly what you would expect for a model that is only looking at a small recent window of tokens. And interestingly, the point at which this knee happens and it stops learning more is way shorter than what people call the effective context length of windowed transformers, which is like the depth times the window size. So really, window transformers are just not able to use very much of their context. But ARM architecture is. And we think that as you continue to scale it up, all these trends will just continue and we'll get the behavior we expect, even in very, very large state sizes and very, very long contexts. And of course, very large parameter counts as well. Of course, as a small lab, all we can do is extrapolate, right? And we're currently working on progressively scaling up. And every time we do, we see that things match what we expect it to see, and then we move on. But ultimately, a big part of the reason why we're open sourcing this work and excited to share with the community is so that lots of people can play with it themselves. You know, we really deeply believe in this research. Based on the results that we've seen, we're fully convinced that this is something that actually works. And so we're not afraid to put it into everyone's hands and let everyone see for themselves that it actually works. And we expect this isn't something that we're going to do alone. This is something that we're going to basically do in collaboration with the broader community. For folks that are intrigued and want to play around with it, where do you expect them to start or where do you suggest that they start? And what do you want them to know or think about as they're getting started? I think a great place to start would be with the PowerCoder model. It's available on Hugging Face. You can just download it and use it. You can play around with the inference. You can see the really fast inference speeds. You can extend the context and see the inference speed is consistent per token. I also think that it would be fun if people want to do some additional post training. For example, this is not an instruction tuned model. It's just trained on raw code. But maybe a little instruction fine tuning would be a cool project just so you can see this thing actually learns just like a transformer would in similar situations. Beyond there, I would love it if people who are already using open source models, especially people who are already fine tuning or post training open source models for their own needs, would consider doing a metamorphosis run where they take their favorite model, maybe it'll be a QuenCoder, 30 billion QuenCoder or something. And rather than post training it directly on their preferred post training data set, they first metamorphose it into power retention and then do the post training. And if their preferred post-training data set is something that has very long context, then they'll see big, big speedups. We'll get like, for example, 64K context, you'll get a 10X speedup at train time and a 100X speedup at inference time just by switching over to power retention. And yeah, I think just comparing those learning curves will let people convince themselves that this is something that is actually really going to be useful for them. Awesome. Awesome. Well, Jacob, thanks for jumping on and sharing a little bit about what you've been working on. Super interesting stuff. Thanks so much for having me. Very fun conversation. Awesome. Thank you.

Share on XShare on LinkedIn

Related Episodes

Comments
?

No comments yet

Be the first to comment

AI Curator

Your AI news assistant

Ask me anything about AI

I can help you understand AI news, trends, and technologies