Attention is a crucial component in generative neural architectures for continuous modalities like images and videos from natural language. More specifically, cross-attention helps to contextualize the relationship between the natural language prompt inputs and the media, being generated.
With modern diffusion models (or shall we say “flow”) for condition-guided image and video generation, we saw the community going beyond cross-attention. For example, Stable Diffusion 3 (SD 3) [1] introduced “joint-attention” in its MMDiT architecture. SANA [2], on the other hand, introduced a variant of “linear attention”, moving away from the standard attention mechanism.
While the changes between these variants may appear architecturally simple, it can be helpful to understand the factors that distinguish them. In this post, we will investigate the popular forms of attention blocks used in modern diffusion models. We will tear them apart with simple PyTorch code and comment on some additional findings.
Readers are expected to be familiar with diffusion-based image generation models and encoder-based transformer architectures.
Table of contents:

Image generated with Flux.1-Dev.
Self-attention
For the sake of completeness, let’s take a quick look at how self-attention is implemented. This will help us understand how it can be evolved to cross-attention and others.
def forward(self, x, attn_mask=None):
# x shape: (batch_size, seq_length, embed_dim)
= x.size()
bsz, seq_length, _
# Compute queries, keys, and values using separate linear layers
= self.to_q(x) # shape: (batch_size, seq_length, embed_dim)
q = self.to_k(x) # shape: (batch_size, seq_length, embed_dim)
k = self.to_v(x) # shape: (batch_size, seq_length, embed_dim)
v
# Reshape and transpose to get dimensions
# (batch_size, num_heads, seq_length, head_dim)
= q.view(bsz, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
q = k.view(bsz, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
k = v.view(bsz, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
v
# Compute scaled dot product attention using PyTorch's built-in function
= F.scaled_dot_product_attention(
attn_output =attn_mask, dropout_p=self.dropout.p, is_causal=False
q, k, v, attn_mask# shape: (batch_size, num_heads, seq_length, head_dim)
)
# Combine the attention output from multiple heads
= attn_output.transpose(1, 2).reshape(bsz, seq_length, self.embed_dim)
attn_output
# Final linear projection
= self.out_proj(attn_output)
output return output
With self-attention, we model the interactions between the different parts of the same input.
Regarding the implementation above, the initialization part of the underlying class was intentionally left out in the interest of brevity. We will follow this kind of snippets for the rest of this post.
Cross-attention
The premise of cross-attention is we want to model how two different inputs interact with each other. For example, image patches and text tokens.
+ def forward(self, x, context=None, attn_mask=None):
# x shape: (batch_size, seq_length, embed_dim)+ bsz, target_seq_len, _ = x.size()
+ if context is None:
+ context = x
+ source_seq_len = context.size(1)
+ # Compute queries from x; keys and values from context
+ q = self.to_q(x) # (batch_size, target_seq_length, embed_dim)
+ k = self.to_k(context) # (batch_size, source_seq_length, embed_dim)
+ v = self.to_v(context) # (batch_size, source_seq_length, embed_dim)
# Reshape and transpose to get dimensions+ q = q.view(bsz, target_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+ k = k.view(bsz, source_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+ v = v.view(bsz, source_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Compute scaled dot product attention using PyTorch's built-in function
attn_output = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=self.dropout.p, is_causal=False
) # shape: (batch_size, num_heads, seq_length, head_dim)
# Combine the attention output from multiple heads
attn_output = attn_output.transpose(1, 2).reshape(
bsz, target_seq_len, self.embed_dim
)
# Final linear projection
output = self.out_proj(attn_output) return output
For the context of this post, x
would be the noisy latents we will denoise during inference and context
would be the representations computed from input text prompts. In this case, the attention masks (attn_mask
) are usually computed from the context
. For example, for text prompts, the attention masks are constructed from the actual text tokens and the padding tokens.
Let’s consider the sentence - “a dog”. Without going into too many details, if we want to tokenize it with a maximum sequence length of 10, the attention masks would look like so:
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
The exact text tokens would change based on the tokenizer being used but we get the idea of how attention masks might look like.
With the presence of attention masks, attention computation accelerates while also reducing memory footprint.
This implementation often meets with other elements that help stabilize training, improve the end performance, extrapolate to larger resolutions, etc. Some of these popular elements include:
QK normalization
Introduced in ViT-22B [3], QK normalization is a commonly used technique to help stabilize the training of transformers at scale. In code, this is simple to implement:
...
q = self.to_q(x)
k = self.to_k(k)
...
+ q = self.q_norm_layer(q)
+ k = self.k_norm_layer(k)
...
Some choices of norms include LayerNorm, RMSNorm, and L2Norm, with the first two being the most common.
Grouped-query attention (GQA)
In the standard attention, every query in the sequence independently computes its attention weights with every key. But there may be redundancy [4] in this setup. We can maintain a reduced space for the keys, and the values repeat them across groups of queries. In practice, this looks like so:
...
q = self.to_q(x) # (batch_size, seq_length, embed_dim)+ k = self.to_k(context) # (batch_size, context_seq_length, embed_dim // reduced_kv_heads)
+ v = self.to_v(context) # (batch_size, context_seq_length, embed_dim // reduced_kv_heads)
# Note there's no transpose yet (.transpose(1, 2)).
q = q.view(batch_size, target_seq_length, self.num_heads, self.head_dim)+ k = k.view(batch_size, source_seq_length, self.kv_heads, self.head_dim)
+ v = v.view(batch_size, source_seq_length, self.kv_heads, self.head_dim)
+ n_rep = self.num_heads // self.kv_heads
+ if n_rep >= 1:
+ # Perform repeats.
+ k = k.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+ v = v.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+ # Complete the transpose to get (batch_size, num_heads, seq_length, head_dim).
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
# Apply `scaled_dot_product_attention()` ...
This helps reduce memory overhead without hurting performance too much. This is crucial when generating high-resolution images and videos.
Rotary position embeddings (RoPE)
Rotary position embeddings have become de facto as they help extrapolate to longer sequences. Image (and video) generation is no exception! The explanation of RoPE is out of this post’s scope. Interested readers should check out this post, instead.
Below, we provide where RoPE is usually incorporated when computing attention:
...
q = self.to_q(x) # (batch_size, seq_length, embed_dim)
k = self.to_k(context) # (batch_size, context_seq_length, embed_dim)
v = self.to_v(context) # (batch_size, context_seq_length, embed_dim)
# Note there's no transpose yet (.transpose(1, 2)).
q = q.view(batch_size, target_seq_length, self.num_heads, self.head_dim)
k = k.view(batch_size, source_seq_length, self.num_heads, self.head_dim)
v = v.view(batch_size, source_seq_length, self.num_heads, self.head_dim)
# The `*_rotary_emb()` below are computed based on the inputs.+ query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
+ key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
...
Popular models that use cross-attention include Stable Diffusion XL [5], PixArt-{Alpha. Sigma} [6, 7], Lumina-Next [8], LTX-Video [9], etc. Lumina-Next incorporates all the other elements as well.
Joint-attention
Through cross-attention, we also inherit any bias that might be present in the prompt embeddings computed with text encoders. For example, if a text encoder exhibits a unidirectional bias (through causal attention), that can creep unexpectedly into the diffusion model representations. Joint-attention alleviates this by allowing the representations coming from two different modalities to co-evolve with training.
The diagram below depicts the MMDiT architecture, which also introduces joint-attention.

MMDiT architecture. Figure taken from SD 3 paper [1].
If the diagram appears to be overwhelming, don’t worry, the implementation of it is simpler than one might think. In a nutshell, in joint-attention, the QKV projection is performed separately (with separate sets of params) on each of the two modalities shown above (c
being the representation computed from the text prompts and x
being noisy latents to be denoised). Before computing the attention, these projections are concatenated. Quoting from the SD3 paper [1]:
Since text and image embeddings are conceptually quite different, we use two separate sets of weights for the two modalities. […], this is equivalent to having two independent transformers for each modality, but joining the sequences of the two modalities for the attention operation, such that both representations can work in their own space yet take the other one into account.
Interested readers can check out this thread for more insights from the community.
Let’s now turn our attention to how it is implemented in practice.
def forward(self, x, context=None, attn_mask=None):
if context is not None:
= context.size(1)
source_seq_len
= x.size()
bsz, target_seq_len, _
# Compute projections on the different modalities separately
= self.to_q(x)
q = self.to_k(x)
k = self.to_v(x)
v
# Reshape and transpose for multi-head attention
= q.view(bsz, target_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
q = k.view(bsz, target_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = v.view(bsz, target_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v
# Compute projections on the condition separately
if context is not None:
= self.to_add_q(context)
context_q = self.to_add_k(context)
context_k = self.to_add_v(context)
context_v
= context_q.view(bsz, source_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
context_q = context_k.view(bsz, source_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
context_k = context_v.view(bsz, source_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
context_v
# Concatenate across the sequence length dimension
= torch.cat([context_q, q], dim=2)
q = torch.cat([context_k, k], dim=2)
k = torch.cat([context_v, v], dim=2)
v
# Compute scaled dot product attention
= F.scaled_dot_product_attention(
attn_output =attn_mask, dropout_p=self.dropout.p, is_causal=False
q, k, v, attn_mask
)
# Combine attention heads
= attn_output.transpose(1, 2).reshape(
attn_output self.embed_dim
bsz, target_seq_len,
)
# Seperate context from latents and final linear projection
if context is not None:
= (
context, x
attn_output[:, : source_seq_len],
attn_output[:, source_seq_len :],
)= self.add_out_proj(context)
context = self.out_proj(x)
x
return x, context
else:
return self.out_proj(attn_output)
With joint-attention, it becomes unclear how to incorporate attention masks while computing attention and how much of a performance penalty it incurs due to that.
Part joint-attention, part self-attention
Subsequent works like AuraFlow [10] and Flux [11] introduced a small change in the original MMDiT architecture. They use joint attention for the first few layers within the diffusion transformer. They then concatenate the two different outputs and operate on the concatenated output. As per the AuraFlow authors, it helps with better FLOPs optimization. In pseudo-code, it looks like so:
# Regular MMDiT blocks.
for block in self.double_blocks:
= block(
context, x =x,
x=context,
context
...
)
# Concatenate.
= torch.cat([context, x], dim=1)
context_x
# Continue with the rest.
for block in self.single_blocks:
= block(x=context_x, ...) x
The joint-attention implementation provided above is already equipped to handle situations when context
may not be provided while computing attention (the context is not None
code path).
Attention gymnastics
Flux, additionally, improves the hardware efficiency by using parallel layers [3]. To better understand how parallel layers can be helpful in improving efficiency, let’s look at the first set of equations that govern an encoder-style transformer block:
\[ \begin{aligned} & y^{\prime}=\operatorname{LayerNorm}(x) \\ & y=x+\operatorname{MLP}\left(y^{\prime}\right)+\operatorname{Attention}\left(y^{\prime}\right) \end{aligned} \]
We can combine the linear projection layers of attention (QKV) and the MLP. From the ViT-22B paper:
In particular, the matrix multiplication for query/key/value-projections and the first linear layer of the MLP are fused into a single operation, and the same is done for the attention out-projection and second linear layer of the MLP.
To understand how it is implemented in practice, we first need to understand QKV fusion. It lets us perform the three different projections involved in attention in one go. Instead of having three different projection layers, we only keep one:
...# This becomes
self.to_q = nn.Linear(embed_dim, embed_dim)
self.to_k = nn.Linear(embed_dim, embed_dim)
self.to_v = nn.Linear(embed_dim, embed_dim)
# this ⬇️
self.to_qkv = nn.Linear(embed_dim, embed_dim * 3)
Then, during the forward pass, we do:
...= self.to_qkv(x)
qkv = qkv.shape[-1] // 3
split_size = torch.split(qkv, split_size, dim=-1)
q, k, v
# Rest of the process is the same
...
Now, to incorporate the MLP into the mix, we need some changes in the initialization:
...# MLP ratio is typically 4.
# The first part of fusion. QKV + first layer of an MLP from a transformer block.
self.qkv_mlp_first = nn.Linear(embed_dim, embed_dim * (3 + mlp_ratio))
# Second part of fusion. Attention out projection + second MLP layer.
self.attn_proj_mlp_second = nn.Linear(
+ embed_dim * mlp_ratio, embed_dim
embed_dim )
The forward pass would be:
= torch.split(
qkv, mlp self.qkv_mlp_first(x_mod),
3 * self.embed_dim, int(self.embed_dim * self.mlp_ratio)],
[=-1
dim
)= torch.split(qkv, qkv.shape[-1] // 3, dim=-1)
q, k, v
# Compute attention
...
# MLP computation
= torch.cat((attn_output, self.mlp_act(mlp)), 2)
concat_attn_mlp_in = self.attn_proj_mlp_second(concat_attn_mlp_in) output
The ViT-22B paper [3] also provides a great visualization for this:

Transformer block with parallel layers. Figure taken from the ViT-22B paper [3].
It is also a good idea to mention that some of the other elements we discussed earlier — QK normalization, GQA, and RoPE — can also be combined with joint attention. One of the most popular models, Flux, uses QK normalization and RoPE. Lumina2 [12] combines all three and uses joint-attention with a twist:
- It first uses very few layers of self-attention transformer blocks separately on the noisy latents and the conditional representations.
- It then combines the two representations and runs it through a number of self-attention transformer blocks.
Interested readers can check out the implementation details here. The figure below provides a side-by-side comparison of the differences in attention used in SD3 and Lumina2:

Comparison between the attention schemes used in SD3 and Lumina2, respectively. Figures were intentionally simplified to convey the main idea.
Linear attention
As the world already knows, attention has a quadratic time complexity. This can pose prohibitive challenges when operating with very long sequences despite improved techniques like Flash Attention.
SANA [2] replaced all vanilla attention with linear attention. More specifically, in each of its transformer blocks, SANA has two kinds of attention:
- linear self-attention for the noisy latents,
- regular cross attention for the noisy latents (
x
) and the condition representations (context
).
To facilitate local interactions between the tokens of the noisy latents, it used “Mix-FFN” blocks.

Linear attention block and the Mix-FFN block. Figure taken from the SANA paper [2].
Implementation of this linear-attention variant is by far the simplest. We show the main changes introduced when compared to the classic self-attention:
+ def forward(self, x):
# x shape: (batch_size, seq_length, embed_dim)
bsz, seq_length, _ = x.size()
# Compute queries, keys, and values using separate linear layers
q = self.to_q(x) # shape: (batch_size, seq_length, embed_dim)
k = self.to_k(x) # shape: (batch_size, seq_length, embed_dim)
v = self.to_v(x) # shape: (batch_size, seq_length, embed_dim)
# Reshape and transpose to get dimensions
# (batch_size, num_heads, seq_length, head_dim)
q = q.view(bsz, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(bsz, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(bsz, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ # Reshape to (batch_size, seq_length, num_heads, head_dim)
+ q = q.transpose(2, 3)
+ v = v.transpose(2, 3)
+ # Introduce non-linearity
+ q = F.relu(q)
+ k = F.relu(k)
+ # Combine scores
+ scores = torch.matmul(v, k)
+ x = torch.matmul(scores, q)
+ # Scale
+ x = x[:, :, :-1] / (x[:, :, -1:] + 1e-15)
# Combine the output from multiple heads
x = x.transpose(1, 2).reshape(bsz, seq_length, self.embed_dim)
# Final linear projection
output = self.out_proj(x) return output
Thoughts
Throughout the course of this post, we saw many architectural configurations of the attention mechanism. Some questions that may still arise:
- Is there any benefit to using cross-attention for image-video generation at all?
- How can we compensate for the compute intensity of joint-attention?
- Is the Lumina2 way of doing joint-attention the way to go?
- Is it necessary to do masking in joint-attention? If so, what are the benefits?
In defense of the widely adopted and optimized vanilla attention, could we interleave quadratic attention and window attention (as done in Gemma2 [13])?
All of these questions (and possibly more) warrant a careful ablation study.
Acknowledgements: Thanks to Aritra Roy Gosthipaty for useful feedback.