Shortcuts

lmdeploy.pytorch 新模型支持

lmdeploy.pytorch 被设计用来简化新模型的支持以及原型的开发,新模型的支持依赖于 patch 机制,对原模型做修改以及功能添加,以期可以最大程度上复用模型的原始实现,减少工作量。

模型支持

我们以 transformers 中的 llama 实现来介绍模型支持的流程

在开始之前,我们首先要了解一下模型的输入。lmdeploy.pytorch 的输入与标准 transformers 模型的输入略有不同,差异主要体现在如下方面:

  1. 由于支持了 continuous batching,一个 batch 的输入 input_ids 会被拼接成一维的长序列,然后 unsqueeze(0) 来保证输入维度与 transformers 中相同。这样的输入不会影响 MLP 以及 RMSNorm 等模块的计算。

  2. 由于添加了对 paged attention 的支持,past_key_value 不再是原来的大小,而是一组形状为 [num_blocks, block_size, num_heads, head_dim] 的 cache 块,num_blocks 为总 block 数量,由可用显存大小决定,block_size 为预设的块大小。这样的输入改变会影响到 LlamaModel 和 LlamaAttention 的计算,因此要对这两个模块的实现进行修改。

  3. 由于上述输入的改变,模型中需要一些额外的输入来支持推理,比如 batch 中的序列起始位置和长度,kv cache 的 block table 等。这些输入并不在模块的 forward 参数列表中,我们需要维护一个上下文以获得这些输入。

上面的输入改动会影响 LlamaModel 和 LlamaAttention,首先我们来实现新的 LlamaModel,这是对原始实现的简化,我们删除了很多检查代码,以避免由于输入改变造成的断言失败,仅保留了最小程度的代码:

# lmdeploy/pytorch/models/llama.py

class LlamaModel(nn.Module):
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        """Rewrite implementation of LlamaModel.forward."""
        inputs_embeds = self.embed_tokens(input_ids)
        hidden_states = inputs_embeds

        # decoder layers
        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = past_key_values[idx]
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )
            hidden_states = layer_outputs[0]
        hidden_states = self.norm(hidden_states)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=None,
            attentions=None,
        )

然后是对 LlamaAttention 模块的改写。按顺序实现如下操作:

  1. kqv proj

  2. rotary embedding

  3. 填充 kv cache

  4. MHA 计算

  5. o proj

continuous batching 和 kv cache 的改动对该模块的影响比较大

# lmdeploy/pytorch/models/llama.py
from lmdeploy.pytorch.kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd

class LlamaAttention(nn.Module):
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
               Optional[Tuple[torch.Tensor]]]:
        """Rewrite of LlamaAttention.forward."""
        context = self.context.context
        history_lengths = context.history_lengths
        position_ids_1d = context.position_ids_1d
        block_offsets = context.block_offsets

        # qkv proj
        query_states = q_proj(hidden_states)
        key_states = k_proj(hidden_states)
        value_states = v_proj(hidden_states)
        query_states = query_states.view(-1, num_heads, head_dim)
        key_states = key_states.view(-1, num_kv_heads, head_dim)
        value_states = value_states.view(-1, num_kv_heads, head_dim)

        # rotary embedding
        max_seq_len = position_ids.size(-1)
        kv_seq_len = max_seq_len + max(history_lengths)
        if kv_seq_len >= self.rotary_emb.max_seq_len_cached:
            cos, sin = self.rotary_emb(value_states,
                                        seq_len=kv_seq_len + 128)
        query_states, key_states = apply_rotary_pos_emb(
            query_states,
            key_states,
            self.rotary_emb.cos_cached,
            self.rotary_emb.sin_cached,
            position_ids,
            position_ids_1d,
            q_embed=query_states,
            k_embed=key_states)

        # fill kv cache
        kv_seq_length = context.kv_seq_length
        q_seq_length = context.q_seq_length
        q_start_loc = context.q_start_loc
        fill_kv_cache(key_states,
                      value_states,
                      past_key_value[0],
                      past_key_value[1],
                      q_start_loc,
                      q_seq_length,
                      block_offsets=block_offsets,
                      history_lengths=history_lengths,
                      context=context)

        # attention
        attn_output = query_states
        block_size = past_key_value[0].size(1)
        paged_attention_fwd(
            query_states,
            past_key_value[0],
            past_key_value[1],
            attn_output,
            block_offsets,
            q_start_loc=q_start_loc,
            q_seqlens=q_seq_length,
            kv_seqlens=kv_seq_length,
            max_seqlen=max_seq_len,
        )
        hidden_size = num_heads * head_dim
        attn_output = attn_output.reshape(*hidden_states.shape[:-1], hidden_size)

        # o proj
        attn_output = o_proj(attn_output)
        return attn_output, None, past_key_value

上面的代码有几处值得注意的地方,首先是 context 对象。我们需要 history_lengths、block_offsets 等参数辅助运算,这些参数无法通过模型的 forward 函数传递进来。因此我们维护了一个 context 对象,把几乎所有可能用到的输入参数都保存在其中,方便在各个模块间共享。context 对象可以通过 self.context.context 来访问,结构可以参考 context-结构

另一个值得注意的地方就是自定义 kernel,由于输入形式的改变,原来的 LlamaAttention 实现变得不再适用,为了保证推理的速度和正确性,我们在 lmdeploy.pytorch.kernels 中实现了许多自定义的 triton kernel,上面的模块中就用到了 apply_rotary_pos_embfill_kv_cachepaged_attention_fwd ,分别负责实现 rotary embedding,填充 kv cache 还有 attention 的计算。

有了上述的两个模块后,还需要将他们注册到 lmdeploy/pytorch/models/module_map.py 中,进行原模块与 patch 模块的映射

# lmdeploy/pytorch/models/module_map.py
MODEL_MAP.update({
    'transformers.models.llama.modeling_llama.LlamaAttention':
    'lmdeploy.pytorch.models.llama.LlamaAttention',
    'transformers.models.llama.modeling_llama.LlamaModel':
    'lmdeploy.pytorch.models.llama.LlamaModel'
})

完成注册后,Engine 在启动时就会将这两个模块 patch 成新的实现,完成后续的部署任务。

Tensor 并发支持

为了支持 Tensor 并发,需要对模型的权重做切分。让我们试着为上面接入的 Llama 模型添加 TP 的支持。

Llama 中涉及到 Tensor 并发的模块是 LlamaAttention 中的 qkvo proj 和 LlamaMLP 中的 gate,up 和 down proj。其中 o_proj 和 down_proj 需要按行切分,剩下的按列切分。我们可以在对应的模块中实现 _distribution_partition_fn 函数:

# lmdeploy/pytorch/models/llama.py
from ..dist_utils import (colwise_parallelize_linear_fn,
                          rowwise_parallelize_linear_fn)

class LlamaAttention(nn.Module):
    @classmethod
    def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
                                 device_mesh: DeviceMesh):
        """Distribution partition callback."""
        if mod_name in ['q_proj', 'k_proj', 'v_proj']:
            colwise_parallelize_linear_fn(mod,
                                          device_mesh=device_mesh,
                                          to_local=True)
        elif mod_name in ['o_proj']:
            rowwise_parallelize_linear_fn(mod,
                                          device_mesh=device_mesh,
                                          to_local=True)

class LlamaMLP(nn.Module):
    @classmethod
    def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
                                 device_mesh: DeviceMesh):
        """Distribution partition callback."""
        if mod_name in ['gate_proj', 'up_proj']:
            colwise_parallelize_linear_fn(mod,
                                          device_mesh=device_mesh,
                                          to_local=True)
        elif mod_name in ['down_proj']:
            rowwise_parallelize_linear_fn(mod,
                                          device_mesh=device_mesh,
                                          to_local=True)

_distribute_partition_fn 会在加载模型权重时被调用,对应的权重会被按照特定的形式分配到对应的设备中。

按照目前的方案切分后的权重,需要对 o_proj 和 down_proj 的结果进行 all_reduce 操作才能得到正确的结果。可以选择将 all_reduce 放在模型的 forward 函数中,也可以选择另一种方案,添加 _distribute_output_fn 函数:

# lmdeploy/pytorch/models/llama.py
import torch.distributed as dist

class LlamaAttention(nn.Module):
    @classmethod
    def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
        """Distribution output hook."""
        dist.all_reduce(outputs[0])
        return outputs

class LlamaMLP(nn.Module):
    @classmethod
    def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
        """Distribution output hook."""
        dist.all_reduce(outputs)
        return outputs

最后别忘了将 LlamaMLP 也注册进 module_map 中

# lmdeploy/pytorch/models/module_map.py
MODEL_MAP.update({
    'transformers.models.llama.modeling_llama.LlamaMLP':
    'lmdeploy.pytorch.models.llama.LlamaMLP'
})

这样就可以利用多卡的优势,让更大的模型部署成为可能

模块调试

当模型的输出不符合预期时,我们会希望调试某个特定模块以确定添加的重写是否正确。lmdeploy.pytorch 提供了一些工具以帮助进行精度对齐。还是以上面提到的 LlamaAttention 模块为例。

首先,我们通过 transformers 的 API 得到想要调试的子模块的一个实例:

import torch
from transformers import AutoModelForCausalLM

# get module
model_path = 'meta-llama/Llama-2-7b-chat-hf'
dtype = torch.float16
model = AutoModelForCausalLM.from_pretrained(model_path).to(torch.float16).cuda()
self_attn = model.model.layers[0].self_attn

然后,使用 ModuleIOExtractor 工具可以生成该模块的一组输入输出

from lmdeploy.pytorch.tools.make_inputs import ModuleIOExtractor

# extract module input/output
input_ids = torch.tensor([[1, 2, 3, 4, 5]]).cuda()
extractor = ModuleIOExtractor(model, self_attn)
attn_args, attn_kwargs, attn_output = extractor.extract(input_ids)

重写模块的输入与原模块略有不同,主要体现在三方面:

  1. 模型需要一些特殊输入输出,他们以 StepContext 的形式传入,可以使用 make_step_context 生成。

  2. input_idshidden_states 等数据都被 continuous 化,可以使用 continuous_tensor 进行处理。

  3. 由于 paged caching 的需要, past_key_value 需要被 page 化处理。

基于以上原因,我们要对提取的输入进行加工:

from lmdeploy.pytorch.tools.make_inputs import make_step_context
from lmdeploy.pytorch.tools.layout_convert import continuous_tensor

# create patched input/output
context = make_step_context(input_ids,
                            kv_cache_dtype=dtype,
                            num_key_value_heads=32)
seq_length = context.q_seq_length
attn_kwargs['hidden_states'] = continuous_tensor(
    attn_kwargs['hidden_states'],
    seq_length)
attn_kwargs['past_key_value'] = context.kv_caches[0]

然后就可以启动重写,并比较结果正确性了。(注意输出也要 continuous 化后进行比较)

from lmdeploy.pytorch.models import patch

# patch and test
patched_self_attn = patch(self_attn, extra_args=['context'])
with torch.inference_mode():
    patched_output = patched_self_attn.patched_forward(*attn_args,
                                                       **attn_kwargs,
                                                       context=context)
torch.testing.assert_close(patched_output[0],
                            continuous_tensor(attn_output[0], seq_length))

可以通过上述方法调试重写模块,直到精度满足预期。

附录

context 结构

@dataclass
class StepContext:
    """context of Model.
    """
    inputs: ModelInputs
    block_offsets: torch.LongTensor
    position_ids: torch.LongTensor
    position_ids_1d: torch.LongTensor
    q_start_loc: torch.LongTensor
    history_lengths: torch.LongTensor
    seq_length: torch.LongTensor
    max_seq_length: int
    kv_seq_length: torch.LongTensor
    kv_caches: List
    is_decoding: bool
    world_size: int = 1
    json_config: Dict = None
    local_adapter_ids: torch.LongTensor = None
    global_adapter_ids: torch.LongTensor = None
    adapter_offsets: torch.LongTensor = None
    max_rank: int = 0

FAQ

  • 如何访问 patch 前的模块?

有时我们只希望在函数前后加一个 hook 代码,不希望大段的拷贝函数,可以通过 self.origin_mod 访问 patch 前的模块。

  • 非 transformers 官方的模型该如何注册?

一些模型的实现代码可能是以 remote code 的形式添加的,这样的模块无法通过完整的 qualname 来定位。lmdeploy.pytorch 支持使用缩写的模块名进行注册:

MODULE_MAP.update({
    'modeling_internlm.InternLMAttention':
    'lmdeploy.pytorch.models.internlm.PatchedInternLMAttention',
})

[!NOTE]

缩写的优先级会更低,有条件的话还是鼓励使用完整的 qualname 进行注册。

  • 模块出现同名但不同实现怎么处理?

目前推荐的做法是同名就映射到同一个实现中,然后在实现内部根据模块的固有参数来判断模型该使用的类型,以 baichuan2 7b/13b 为例:

class BaichuanModel(nn.Module):
    def forward(self, ...):
        if self.config.num_hidden_layers == 32:
            return forward_7b(...)
        else:
            return forward_default(...)
  • 如果希望在推理前对模块进行初始化?

可以实现模块的 _update_model_fn 函数,它会在模块的权重都加载完,完成 TP 权重切分后被调用

class LlamaAttention:
    def _update_model_fn(self):
        # ADD YOUR CODE HERE
Read the Docs v: latest
Versions
latest
stable
v0.4.1
v0.4.0
v0.3.0
v0.2.6
v0.2.5
v0.2.4
v0.2.3
v0.2.2
v0.2.0
v0.1.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.