论文地址: https://arxiv.org/abs/2411.17116

代码地址: https://github.com/NVIDIA/Star-Attention

偶然看到了这个项目,正好借着这个机会学习一下不同的 Attention 的机制;之后还有比如 MLA, MQA, GQA 等等的 attention 计算方法,也都计划学习一下

High Level

Star attention 的一个主要思路是用并行处理的方式来解决 long context 语境下的 attention 的 $N^2$ 复杂度的问题,整体的算法分为两步

整体上,因为原来的 context 被分成了很多块儿让多个 GPU 进行并发处理,所以能够实现速度上的提升,虽然在最后的 online softmax 上面有额外的 communication,但是整体上在 perfill 阶段节省下来的时间还是比在 decoding 时候的额外开销要多

Screenshot 2025-03-01 at 1.00.29 PM.png

之所以在每个 context block 里面要使用 anchor block, 一个主要原因是 attention_sink: https://arxiv.org/pdf/2309.17453 , 也就是说,即使最开始的几个 token 没什么 semantic 的意义,但是 attention score 还是会在这几个 token 上面很高

代码理解

在 Star-Attention 的 repo 里面,reuse 了 huggingface 的 transformer 库里面的 Llama 模型的代码,比如 LlamaMLP, LlamaRotaryEmbedding 等; 唯一重写的就是 StarLlamaFlashAttention2, 不过这个也是继承了 LlamaFlashAttention2 里面的方法; forward 方法里面调用了自定义的 _flash_attention_forward 方法: https://github.com/NVIDIA/Star-Attention/blob/24a0092bfcb2af8ddcfb869b078ad57fe446d832/star_attention/modeling_flash_attention_utils.py#L40; 这个方法其实就是个 routing function, 用来进行一些针对输入的处理,比如如果有 padding 的话需要进行 upad 然后使用 varlen 版本的 start attention 方法: https://github.com/NVIDIA/Star-Attention/blob/24a0092bfcb2af8ddcfb869b078ad57fe446d832/star_attention/modeling_flash_attention_utils.py#L102 (这里面可能涉及到了一些 flash attention 的方法,有些地方我还不是很理解,需要找时间再去看一下 flash attention 的 document 和实现了)

model.py 这个文件中定义了 star-attention 的 phase 1 的处理方法, 然后使用了在 modeling_llama.py 里面定义的使用了 star attention 的 Llama 模型,也就是负责的 phase 2 的部分; 新生成的 token 会被放到最后一个 rank 的 KV cache 里面去,这里比较好理解,因为新生成的 token 会成为新的 context, 然后因为 context 已经被 chunk 了,只有最后一个的 chunk 可能是不满的,所以可以 append 到那里去 (具体不满的原因在这里,虽然最开始还是进行了 padding 但是在这里把 padding 的东西拿掉了: https://github.com/NVIDIA/Star-Attention/blob/24a0092bfcb2af8ddcfb869b078ad57fe446d832/model.py#L252-L258); 在 star attention model 里面,主要的逻辑在这里: https://github.com/NVIDIA/Star-Attention/blob/24a0092bfcb2af8ddcfb869b078ad57fe446d832/model.py#L240

def __call__(self, prompt_context: str, prompt_query: str) -> Dict[str, List[str]]:
    # Prepare the context
    ctx_ids, position_ids, ctx_len = self._tokenize_and_partition_context(prompt_context)

    # Split the context into blocks and divide the blocks among the ranks
    ctx_ids_blocks = torch.tensor_split(torch.stack(ctx_ids.split(self.block_size, dim=-1)), self.world_size)
    position_ids_blocks = torch.tensor_split(
        torch.stack(position_ids.split(self.block_size, dim=-1)), self.world_size
    )

    # Phase 1: Generate the KV cache for the local context
    kv_rank = self._process_blockwise_context(ctx_ids_blocks, position_ids_blocks)
    if self.rank == self.world_size - 1:  # discard padding from the last rank
        padding = ctx_ids.shape[-1] - ctx_len
        if padding > 0:
            kv_rank = [
                [kv_rank[i][0][:, :, :-padding], kv_rank[i][1][:, :, :-padding]] for i in range(len(kv_rank))
            ]

    # Phase 2: Process query with global attention
    qry_ids = self._tokenize(prompt_query)
    qry_position_ids = torch.arange(ctx_len, ctx_len + qry_ids.shape[-1]).unsqueeze(0).to(self.model.device)
    output = self._generate_output(qry_ids, qry_position_ids, kv_rank)

    # Get the generated text
    generated_text = self._get_output_text(output)
    return {'text': [generated_text]}
def _process_blockwise_context(self, ctx_ids_blocks, position_ids_blocks):
    """Phase 1 of Star Attention: Blockwise Context Encoding with Anchor Blocks"""

    # If the anchor block size is not provided, use the entire first block
    if self.anchor_block_size is None:
        self.anchor_block_size = ctx_ids_blocks[0][0].shape[-1]

    kv_rank = []
    # only check the block assigned to its own rank here
    for idx in range(len(ctx_ids_blocks[self.rank])):
        # Select the current block
        ctx_block = ctx_ids_blocks[self.rank][idx]
        position_block = position_ids_blocks[self.rank][idx]

        # From 2nd block onwards, prepend the anchor block to the current block
        if self.rank != 0 or idx > 0:
            ctx_block = torch.cat((ctx_ids_blocks[0][0][:, : self.anchor_block_size], ctx_block), dim=-1)
            position_block = torch.cat(
                (position_ids_blocks[0][0][:, : self.anchor_block_size], position_block), dim=-1
            )

        with torch.no_grad():
            kv_block = self.model(
                ctx_block,
                position_ids=position_block,
                use_cache=True,
                num_ring_steps=0,  # disable ring attention (local blockwise attention)
                enable_star_attn=False,
            ).past_key_values  # type: ignore

        # Discard the anchor block KV cache
        if self.rank != 0 or idx > 0:
            kv_block = [
                [x[0][:, :, self.anchor_block_size :], x[1][:, :, self.anchor_block_size :]] for x in kv_block
            ]

				# combine k, v; j loop over k & v (a.k.a 2) 
				# i loop over the blocks (-1 is dim, -2 is by ctx id)
        kv_rank = (
            kv_block
            if not kv_rank
            else [
                [torch.cat((kv_rank[i][j], kv_block[i][j]), dim=-2) for j in range(2)] for i in range(len(kv_rank))
            ]
        )

    return kv_rank

Star attention 的第二个部分的逻辑在这里: https://github.com/NVIDIA/Star-Attention/blob/24a0092bfcb2af8ddcfb869b078ad57fe446d832/star_attention/star_flash_attn/star_flash_attn.py#L23 , 这里面比较主要的几个地方是

$$ out = \frac{s_{old}}{s_{old} + s_{block}} * out_{old} + \frac{s_{block}}{s_{old} + s_{block}} * out_{block} $$