论文地址: https://arxiv.org/abs/2411.17116
代码地址: https://github.com/NVIDIA/Star-Attention
偶然看到了这个项目,正好借着这个机会学习一下不同的 Attention 的机制;之后还有比如 MLA, MQA, GQA 等等的 attention 计算方法,也都计划学习一下
Star attention 的一个主要思路是用并行处理的方式来解决 long context 语境下的 attention 的 $N^2$ 复杂度的问题,整体的算法分为两步
整体上,因为原来的 context 被分成了很多块儿让多个 GPU 进行并发处理,所以能够实现速度上的提升,虽然在最后的 online softmax 上面有额外的 communication,但是整体上在 perfill 阶段节省下来的时间还是比在 decoding 时候的额外开销要多

之所以在每个 context block 里面要使用 anchor block, 一个主要原因是 attention_sink: https://arxiv.org/pdf/2309.17453 , 也就是说,即使最开始的几个 token 没什么 semantic 的意义,但是 attention score 还是会在这几个 token 上面很高
在 Star-Attention 的 repo 里面,reuse 了 huggingface 的 transformer 库里面的 Llama 模型的代码,比如 LlamaMLP, LlamaRotaryEmbedding 等; 唯一重写的就是 StarLlamaFlashAttention2, 不过这个也是继承了 LlamaFlashAttention2 里面的方法; forward 方法里面调用了自定义的 _flash_attention_forward 方法: https://github.com/NVIDIA/Star-Attention/blob/24a0092bfcb2af8ddcfb869b078ad57fe446d832/star_attention/modeling_flash_attention_utils.py#L40; 这个方法其实就是个 routing function, 用来进行一些针对输入的处理,比如如果有 padding 的话需要进行 upad 然后使用 varlen 版本的 start attention 方法: https://github.com/NVIDIA/Star-Attention/blob/24a0092bfcb2af8ddcfb869b078ad57fe446d832/star_attention/modeling_flash_attention_utils.py#L102 (这里面可能涉及到了一些 flash attention 的方法,有些地方我还不是很理解,需要找时间再去看一下 flash attention 的 document 和实现了)
在 model.py 这个文件中定义了 star-attention 的 phase 1 的处理方法, 然后使用了在 modeling_llama.py 里面定义的使用了 star attention 的 Llama 模型,也就是负责的 phase 2 的部分; 新生成的 token 会被放到最后一个 rank 的 KV cache 里面去,这里比较好理解,因为新生成的 token 会成为新的 context, 然后因为 context 已经被 chunk 了,只有最后一个的 chunk 可能是不满的,所以可以 append 到那里去 (具体不满的原因在这里,虽然最开始还是进行了 padding 但是在这里把 padding 的东西拿掉了: https://github.com/NVIDIA/Star-Attention/blob/24a0092bfcb2af8ddcfb869b078ad57fe446d832/model.py#L252-L258); 在 star attention model 里面,主要的逻辑在这里: https://github.com/NVIDIA/Star-Attention/blob/24a0092bfcb2af8ddcfb869b078ad57fe446d832/model.py#L240
def __call__(self, prompt_context: str, prompt_query: str) -> Dict[str, List[str]]:
# Prepare the context
ctx_ids, position_ids, ctx_len = self._tokenize_and_partition_context(prompt_context)
# Split the context into blocks and divide the blocks among the ranks
ctx_ids_blocks = torch.tensor_split(torch.stack(ctx_ids.split(self.block_size, dim=-1)), self.world_size)
position_ids_blocks = torch.tensor_split(
torch.stack(position_ids.split(self.block_size, dim=-1)), self.world_size
)
# Phase 1: Generate the KV cache for the local context
kv_rank = self._process_blockwise_context(ctx_ids_blocks, position_ids_blocks)
if self.rank == self.world_size - 1: # discard padding from the last rank
padding = ctx_ids.shape[-1] - ctx_len
if padding > 0:
kv_rank = [
[kv_rank[i][0][:, :, :-padding], kv_rank[i][1][:, :, :-padding]] for i in range(len(kv_rank))
]
# Phase 2: Process query with global attention
qry_ids = self._tokenize(prompt_query)
qry_position_ids = torch.arange(ctx_len, ctx_len + qry_ids.shape[-1]).unsqueeze(0).to(self.model.device)
output = self._generate_output(qry_ids, qry_position_ids, kv_rank)
# Get the generated text
generated_text = self._get_output_text(output)
return {'text': [generated_text]}
ctx_ids_blocks[0][0][:, : self.anchor_block_size], 但是实际上每个 rank 只需要存自己的那部分 context block 外加 anchor block 应该就足够了)def _process_blockwise_context(self, ctx_ids_blocks, position_ids_blocks):
"""Phase 1 of Star Attention: Blockwise Context Encoding with Anchor Blocks"""
# If the anchor block size is not provided, use the entire first block
if self.anchor_block_size is None:
self.anchor_block_size = ctx_ids_blocks[0][0].shape[-1]
kv_rank = []
# only check the block assigned to its own rank here
for idx in range(len(ctx_ids_blocks[self.rank])):
# Select the current block
ctx_block = ctx_ids_blocks[self.rank][idx]
position_block = position_ids_blocks[self.rank][idx]
# From 2nd block onwards, prepend the anchor block to the current block
if self.rank != 0 or idx > 0:
ctx_block = torch.cat((ctx_ids_blocks[0][0][:, : self.anchor_block_size], ctx_block), dim=-1)
position_block = torch.cat(
(position_ids_blocks[0][0][:, : self.anchor_block_size], position_block), dim=-1
)
with torch.no_grad():
kv_block = self.model(
ctx_block,
position_ids=position_block,
use_cache=True,
num_ring_steps=0, # disable ring attention (local blockwise attention)
enable_star_attn=False,
).past_key_values # type: ignore
# Discard the anchor block KV cache
if self.rank != 0 or idx > 0:
kv_block = [
[x[0][:, :, self.anchor_block_size :], x[1][:, :, self.anchor_block_size :]] for x in kv_block
]
# combine k, v; j loop over k & v (a.k.a 2)
# i loop over the blocks (-1 is dim, -2 is by ctx id)
kv_rank = (
kv_block
if not kv_rank
else [
[torch.cat((kv_rank[i][j], kv_block[i][j]), dim=-2) for j in range(2)] for i in range(len(kv_rank))
]
)
return kv_rank
Star attention 的第二个部分的逻辑在这里: https://github.com/NVIDIA/Star-Attention/blob/24a0092bfcb2af8ddcfb869b078ad57fe446d832/star_attention/star_flash_attn/star_flash_attn.py#L23 , 这里面比较主要的几个地方是
all_gather 的方法来收集到所有的 host 上面跑出来的 block output 和 block lseN, seq_len, nhead, head_dim, 但是 lse 的 shape 确是 N, nhead, seq_len, sequence length 和 num head 的顺序是反的$$ out = \frac{s_{old}}{s_{old} + s_{block}} * out_{old} + \frac{s_{block}}{s_{old} + s_{block}} * out_{block} $$