# Copyright (c) OpenMMLab. All rights reserved.
import enum
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Literal
import torch
from pydantic.dataclasses import dataclass as pydantic_dataclass
from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest
from .tokenizer import Tokenizer
from .utils import get_logger
logger = get_logger('lmdeploy')
class QuantPolicy(enum.IntEnum):
"""Quantization policy constants for KV cache."""
NONE = 0
INT4 = 4 # 4-bit KV cache
INT8 = 8 # 8-bit KV cache
TURBO_QUANT = 42 # TurboQuant: K=4bit QJL4 + V=2bit MSE
LogitsProcessor = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
"""LogitsProcessor is a function that takes a tensor of input_ids, the logits
tensor for the next token, and returns a modified tensor of logits to sample
from."""
[docs]
@dataclass
class GenerationConfig:
"""Generation parameters used by inference engines.
Args:
n: Define how many chat completion choices to generate for each
input message. **Only 1** is supported now.
max_new_tokens: The maximum number of tokens that can be
generated in the chat completion
do_sample: Whether or not to use sampling, use greedy
decoding otherwise. Default to be False.
top_p: An alternative to sampling with temperature, called
nucleus sampling, where the model considers the results of the
tokens with top_p probability mass
top_k: An alternative to sampling with temperature, where
the model considers the top_k tokens with the highest probability
min_p: Minimum token probability, which will be scaled by the
probability of the most likely token. It must be a value between
0 and 1. Typical values are in the 0.01-0.2 range, comparably
selective as setting `top_p` in the 0.99-0.8 range (use the
opposite of normal `top_p` values)
temperature: Sampling temperature
repetition_penalty: Penalty to prevent the model from
generating repeated words or phrases. A value larger than
1 discourages repetition
ignore_eos: Indicator to ignore the eos_token_id or not
random_seed: Seed used when sampling a token
stop_words: Words that stop generating further tokens
bad_words: Words that the engine will never generate
stop_token_ids: list of tokens that stop the generation
when they are generated. The returned output will not contain
the stop tokens.
bad_token_ids: list of tokens that the engine will never
generate.
min_new_tokens: The minimum numbers of tokens to generate,
ignoring the number of tokens in the prompt.
skip_special_tokens: Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens: Whether or not to add spaces
around special tokens. The behavior of Fast tokenizers is to have
this to False. This is setup to True in slow tokenizers.
logprobs: Number of log probabilities to return per output token.
response_format: Generate responses according to given formatting.
Examples:
.. code-block:: json
{
"type": "json_schema",
"json_schema": {
"name": "test",
"schema": {
"properties": {
"name": {
"type": "string"
}
},
"required": ["name"],
"type": "object"
}
}
}
or,
.. code-block:: json
{
"type": "regex_schema",
"regex_schema": "call me [A-Za-z]{1,10}"
}
logits_processors: Custom logit processors.
repetition_ngram_size: The size of n-grams to consider for repetition early stop.
Must be non-negative; values below 0 are treated as 0.
repetition_ngram_threshold: The number of times an n-gram must be repeated to trigger early stop.
Must be non-negative; values below 0 are treated as 0.
"""
n: int = 1
max_new_tokens: int = 512
do_sample: bool = False
top_p: float = 1.0
top_k: int = 50
min_p: float = 0.0
temperature: float = 0.8
repetition_penalty: float = 1.0
ignore_eos: bool = False
random_seed: int = None
stop_words: list[str] = None
bad_words: list[str] = None
stop_token_ids: list[int] = None
bad_token_ids: list[int] = None
min_new_tokens: int = None
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
logprobs: int = None
response_format: dict | None = None
logits_processors: list[LogitsProcessor] | None = None
output_logits: Literal['all', 'generation'] = None
output_last_hidden_state: Literal['all', 'generation'] = None
include_stop_str_in_output: bool = False
# for disaggregation
with_cache: bool = False
preserve_cache: bool = False
migration_request: MigrationRequest | None = None
# router replay
return_routed_experts: bool = False
# ngram, generation would stop if latest [size] tokens are repeated for [threshold] times
repetition_ngram_size: int = 0
repetition_ngram_threshold: int = 0
def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):
"""Convert stop_words/bad_sords to ids and append the ids to
stop_token_ids/bad_token_ids."""
def special_word_token_ids(words):
if words is not None:
assert isinstance(words, list) and \
all(isinstance(elem, str) for elem in words), \
f'stop_words must be a list of str but got {type(words)}'
indexes = []
for word in words:
indexes += tokenizer.indexes_containing_token(word)
return indexes
return None
stop_token_ids = special_word_token_ids(self.stop_words) or []
bad_token_ids = special_word_token_ids(self.bad_words) or []
stop_token_ids.extend(self.stop_token_ids or [])
bad_token_ids.extend(self.bad_token_ids or [])
self.stop_token_ids = list(set(stop_token_ids)) or None
self.bad_token_ids = list(set(bad_token_ids)) or None
def update_from_hf_gen_cfg(self, generation_config, tokenizer_eos_token_id):
"""Update the stop_token_ids."""
stop_token_ids = set(self.stop_token_ids or [])
# add tokenizer's eos_token_id
if tokenizer_eos_token_id is not None:
stop_token_ids.add(tokenizer_eos_token_id)
# add eos_token_id from model's generation_config.json file if there
# is any.
eos_token_id = generation_config.get('eos_token_id')
if eos_token_id is not None:
if isinstance(eos_token_id, int):
stop_token_ids.add(eos_token_id)
else:
stop_token_ids.update(eos_token_id)
self.stop_token_ids = list(stop_token_ids)
def __post_init__(self):
"""Check input validation."""
assert type(self.n) is int and self.n > 0, 'n is not a positive integer'
assert self.top_p >= 0 and self.top_p <= 1 # [0, 1]
assert self.top_k >= 0, 'top_k can not be a negative integer'
assert self.temperature >= 0 and self.temperature <= 2 # [0,2]
assert 0 <= self.min_p <= 1, \
f'min_p should be in range [0, 1], but found {self.min_p}'
if self.repetition_ngram_size <= 0 or self.repetition_ngram_threshold <= 0:
self.repetition_ngram_size = 0
self.repetition_ngram_threshold = 0
[docs]
@pydantic_dataclass
class TurbomindEngineConfig:
"""TurboMind Engine config.
Args:
dtype: data type for model weights and activations. It can be
one of the following values, ['auto', 'float16', 'bfloat16']
The `auto` option will use FP16 precision for FP32 and FP16
models, and BF16 precision for BF16 models.
model_format: the layout of the deployed model. It can be one
of the following values [hf, awq, gptq, compressed-tensors,
fp8, mxfp4]. `hf` means a Hugging Face model (.bin,
.safetensors), `awq` and `gptq` mean grouped 4-bit
weight-only checkpoints, `compressed-tensors` means
pack-quantized grouped int4 checkpoints and is usually
auto-detected from the input model config, `fp8` means
blocked fp8 checkpoints, and `mxfp4` means MXFP4 expert
weights. If it is not specified, i.e. None, it will be
extracted from the input model
tp: the number of GPU cards used in tensor parallelism,
default to 1
session_len: the max session length of a sequence, default to
None
max_batch_size: the max batch size during inference. If it is
not specified, the engine will automatically set it according to
the device
cache_max_entry_count: the percentage of gpu memory occupied
by the k/v cache.
For versions of lmdeploy between `v0.2.0` and `v0.2.1`, it
defaults to 0.5, depicting the percentage of TOTAL GPU memory to
be allocated to the k/v cache.
For lmdeploy versions greater than `v0.2.1`, it defaults to 0.8,
signifying the percentage of FREE GPU memory to be reserved for
the k/v cache.
When it's an integer > 0, it represents the total number of k/v
blocks.
cache_chunk_size: The policy to apply for KV block from
the block manager, default to -1.
cache_block_seq_len: the length of the token sequence in
a k/v block, default to 64
enable_prefix_caching: enable cache prompts for block reuse,
default to False
quant_policy: default to 0. When k/v is quantized into 4 or 8
bit, set it to 4 or 8, respectively
rope_scaling_factor: scaling factor used for dynamic ntk,
default to 0. TurboMind follows the implementation of transformer
LlamaAttention
use_logn_attn: whether or not to use log attn: default to False
download_dir: Directory to download and load the weights,
default to the default cache directory of huggingface.
revision: The specific model version to use. It can be a branch
name, a tag name, or a commit id. If unspecified, will use the
default version.
max_prefill_token_num: the number of tokens each iteration during
prefill, default to 8192
num_tokens_per_iter: the number of tokens processed in each
forward pass. Working with `max_prefill_iters` enables the
"Dynamic SplitFuse"-like scheduling
max_prefill_iters: the max number of forward pass during prefill
stage
async_: enable async execution, default to 1 (enabled)
devices: the used devices
empty_init: Whether to load the model weights, you should set
it to True if you want to update weights after create the pipeline
hf_overrides: Huggingface overrides for the model.
It can be used to override the default config of the model
enable_metrics: enable metrics system
"""
dtype: str = 'auto'
model_format: str | None = None
tp: int = 1
dp: int = 1
cp: int = 1
device_num: int = None
attn_tp_size: int = None
attn_cp_size: int = None
attn_dp_size: int = None
mlp_tp_size: int = None
mlp_dp_size: int = None
outer_dp_size: int = None
nnodes: int = 1
node_rank: int = 0
dist_init_addr: str | None = None
devices: list[int] = None
session_len: int | None = None
max_batch_size: int = None
cache_max_entry_count: float = 0.8
cache_chunk_size: int = -1
cache_block_seq_len: int = 64
enable_prefix_caching: bool = False
quant_policy: int = 0
rope_scaling_factor: float = 0.0
use_logn_attn: bool = False
download_dir: str | None = None
revision: str | None = None
max_prefill_token_num: int = 8192
num_tokens_per_iter: int = 0
max_prefill_iters: int = 1
async_: int = 1
devices: list[int] | None = None
empty_init: bool = False
communicator: str = 'nccl'
hf_overrides: dict[str, Any] | None = None
enable_metrics: bool = True
def __post_init__(self):
"""Check input validation."""
assert self.dtype in ['auto', 'float16', 'bfloat16']
assert self.tp >= 1, 'tp must be a positive integer'
assert self.cache_max_entry_count > 0, 'invalid cache_max_entry_count'
assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT), \
'invalid quant_policy'
assert self.rope_scaling_factor >= 0, 'invalid rope_scaling_factor'
assert self.max_prefill_token_num >= 0, \
'invalid max_prefill_token_num'
assert self.num_tokens_per_iter >= 0, 'invalid num_tokens_per_iter'
assert self.async_ in (0, 1), 'async_ must be 0 (disabled) or 1 (enabled)'
[docs]
@dataclass
class PytorchEngineConfig:
"""PyTorch Engine Config.
Args:
dtype: data type for model weights and activations. It can be
one of the following values, ['auto', 'float16', 'bfloat16']
The `auto` option will use FP16 precision for FP32 and FP16
models, and BF16 precision for BF16 models.
tp: Tensor Parallelism. default 1.
dp: Data Parallelism. default 1.
dp_rank: rank of dp.
ep: Expert Parallelism. default 1.
session_len: Max session length. Default None.
max_batch_size: Max batch size. If it is not specified,
the engine will automatically set it according to the device
attn_tp_size: tp size for attention, only works for dp>1
mlp_tp_size: tp size for mlp, only works for dp>1
moe_tp_size: tp size for moe, only works for dp>1
cache_max_entry_count: the percentage of gpu memory occupied
by the k/v cache. For lmdeploy versions greater than `v0.2.1`,
it defaults to 0.8, signifying the percentage of FREE GPU memory
to be reserved for the k/v cache
prefill_interval: Interval to perform prefill,
Default 16.
block_size: paging cache block size, default 64.
num_cpu_blocks: Num cpu blocks. If num is 0, cache
would be allocate according to current environment.
num_gpu_blocks: Num gpu blocks. If num is 0, cache
would be allocate according to current environment.
adapters: The path configs to lora adapters.
max_prefill_token_num: tokens per iteration.
thread_safe: thread safe engine instance.
enable_prefix_caching: Enable token match and sharing caches.
device_type: The inference device type, options ['cuda']
eager_mode: Enable "eager" mode or not
custom_module_map: nn module map customized by users. Once
provided, the original nn modules of the model will be
substituted by the mapping ones
download_dir: Directory to download and load the weights,
default to the default cache directory of huggingface.
revision: The specific model version to use.
It can be a branch name, a tag name, or a commit id.
If unspecified, will use the default version.
quant_policy: default to 0. When k/v is quantized into 4 or 8
bit, set it to 4 or 8, respectively
distributed_executor_backend: backend of distributed backend,
options: ['uni', 'mp', 'ray']
empty_init: Whether to load the model weights, you should set
it to True if you want to update weights after create the pipeline
enable_microbatch: enable microbatch for specified model
enable_eplb: enable eplb for specified model
enable_metrics: enable metrics system
role: role of engin, options: ['Hybrid', 'Prefill',
'Decode']. Default to `EngineRole.Hybrid`.
migration_backend: migration backend. options: ['DLSlime'].
Default to `MigrationBackend.DLSlime`.
enable_mp_engine: run engine in multi-process mode.
mp_engine_backend: backend of mp engine, options:
['mp', 'ray']. Default to `mp`.
model_format: weight quantization policy, options: ['fp8'].
hf_overrides: Huggingface overrides for the model.
It can be used to override the default config of the model,
disable_vision_encoder: Whether to disable loading vision
encoder. Default to False.
logprobs_mode: The mode of logprob, options: ['raw_logits', 'raw_logprobs']
dllm_block_length: Block size of block diffusion model.
dllm_unmasking_strategy: Dllm unmasking strategy, options:
['low_confidence_dynamic', 'low_confidence_static', 'sequential'].
dllm_denoising_steps: Dllm denoising steps.
dllm_confidence_threshold: dllm unmasking threshold for
dynamic unmasking.
"""
dtype: str = 'auto'
tp: int = 1
dp: int = 1
dp_rank: int = 0
ep: int = 1
session_len: int = None
max_batch_size: int = None
attn_tp_size: int = None
mlp_tp_size: int = None
moe_tp_size: int = None
cache_max_entry_count: float = 0.8
prefill_interval: int = 16
block_size: int = 64
kernel_block_size: int = -1
num_cpu_blocks: int = 0
num_gpu_blocks: int = 0
adapters: dict[str, str] = None
max_prefill_token_num: int = 8192
thread_safe: bool = False
enable_prefix_caching: bool = False
device_type: str = 'cuda'
eager_mode: bool = False
custom_module_map: dict[str, str] = None
download_dir: str = None
revision: str = None
quant_policy: QuantPolicy = QuantPolicy.NONE
distributed_executor_backend: str = None
empty_init: bool = False
enable_microbatch: bool = False
enable_eplb: bool = False
enable_mp_engine: bool = False
mp_engine_backend: str = 'mp'
model_format: str = None
enable_metrics: bool = True
hf_overrides: dict[str, Any] | None = None
disable_vision_encoder: bool = False
logprobs_mode: str = None
# router replay
enable_return_routed_experts: bool = False
enable_transfer_obj_ref: bool = False
# dllm
dllm_block_length: int = None
dllm_unmasking_strategy: str = 'low_confidence_dynamic'
dllm_denoising_steps: int = None
dllm_confidence_threshold: float = 0.85
role: EngineRole = EngineRole.Hybrid
migration_backend: MigrationBackend = MigrationBackend.DLSlime
def __post_init__(self):
"""Check input validation."""
if self.kernel_block_size == -1:
self.kernel_block_size = self.block_size
assert self.dtype in ['auto', 'float16', 'bfloat16']
assert self.tp >= 1, 'invalid tp'
assert self.dp >= 1, 'invalid dp'
assert self.ep >= 1, 'invalid ep'
assert 0 < self.cache_max_entry_count < 1, \
'invalid cache_max_entry_count'
assert self.num_cpu_blocks >= 0, 'invalid num_cpu_blocks'
assert self.max_prefill_token_num >= 0, \
'invalid max_prefill_token_num'
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT), \
'invalid quant_policy'
assert self.device_type in ['cuda', 'ascend', 'maca', 'camb'], (f'invalid device_type: {self.device_type}')
assert self.kernel_block_size >= 16 and \
(self.kernel_block_size & (self.kernel_block_size - 1)) == 0, \
f'kernel_block_size must be >= 16 and a power of 2, but got {self.kernel_block_size}'
assert self.block_size >= self.kernel_block_size and \
self.block_size % self.kernel_block_size == 0, \
(f'block_size must be >= kernel_block_size and an integer multiple '
f'of kernel_block_size, but got block_size {self.block_size} '
f'and kernel_block_size {self.kernel_block_size}')
if self.quant_policy > 0 and self.device_type not in ['cuda', 'ascend']:
assert False, \
'kv cache quantization only works for CUDA and ASCEND.'
if self.device_type == 'camb' and self.block_size != 16:
self.block_size = 16
self.kernel_block_size = 16
logger.warning('Currently, camb device requires block_size and kernel_block_size to be 16, '
'setting both to 16.')
class ResponseType(enum.Enum):
"""Response type."""
SUCCESS = enum.auto()
FINISH = enum.auto()
ENGINE_STOP_ERROR = enum.auto()
SESSION_REPEAT = enum.auto()
SESSION_NOT_EXIST = enum.auto()
HANDLER_NOT_EXIST = enum.auto()
INPUT_LENGTH_ERROR = enum.auto()
INTERNAL_ENGINE_ERROR = enum.auto()
CANCEL = enum.auto()
PREFIX_CACHE_CONFLICT_INTERACTIVE_MODE = enum.auto()
NO_QUEUE = enum.auto()
@dataclass
class Response:
"""Pack all response information together.
Args:
text: the response text from the server. If the output text is
an empty str and the finish_reason is length, it means the session
length is reached.
generate_token_len: the response token length.
input_token_len: the input prompt token length. Note that it may
contains chat template part.
finish_reason: the reason the model stopped
generating tokens. This will be 'stop' if the model hit a natural
stop point or a provided stop sequence, 'length' if the maximum
number of tokens specified in the request was reached.
token_ids: the output token ids.
logprobs: the top logprobs for each output position.
index: it refers to the position index of the input request batch.
"""
text: str
generate_token_len: int
input_token_len: int
finish_reason: Literal['stop', 'length'] | None = None
token_ids: list[int] = field(default_factory=list)
logprobs: list[dict[int, float]] = None
logits: torch.Tensor = None
last_hidden_state: torch.Tensor = None
index: int = 0
routed_experts: Any = None
def __str__(self):
return f'text={self.text}\n{self._format_none_text_fields()}'
def __repr__(self):
return f'text={self.text!r}\n{self._format_none_text_fields()}'
def _format_none_text_fields(self):
fields = []
fields.append(f'input_token_len={self.input_token_len}')
fields.append(f'generate_token_len={self.generate_token_len}')
fields.append(f'finish_reason="{self.finish_reason}"')
fields.append(f'token_ids={self.token_ids}')
fields.append(f'logprobs={self.logprobs}')
# Helper function to format tensor information
def _format_tensor(name: str, tensor: torch.Tensor | None) -> list[str]:
if tensor is None:
return [f'{name}=None']
try:
return [f'{name}.shape={tensor.shape}', f'{name}={tensor}']
except: # noqa
# in case tensor is not torch.Tensor or has no shape
return [f'{name}={tensor}']
# Format tensor fields
fields.extend(_format_tensor('logits', self.logits))
fields.extend(_format_tensor('last_hidden_state', self.last_hidden_state))
fields.extend(_format_tensor('routed_experts', self.routed_experts))
return '\n'.join(fields)
def extend(self, other: 'Response') -> 'Response':
"""Extend this response with another response.
This method merges the content of another Response into this one,
similar to list.extend(). The text, token_ids, and logprobs are
concatenated, while other fields are updated from the other response.
Args:
other: Another Response to append to this one.
Returns:
Self (for method chaining).
"""
self.text += other.text
self.generate_token_len = other.generate_token_len
self.input_token_len = other.input_token_len
self.finish_reason = other.finish_reason
self.index = other.index
if other.token_ids:
self.token_ids += other.token_ids
if other.logprobs:
self.logprobs = self.logprobs or []
self.logprobs += other.logprobs
self.routed_experts = other.routed_experts
return self
# modified from https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/__init__.py
class EventType(enum.IntEnum):
"""The type of request event.
QUEUED - when the request was enqued by the engine
SCHEDULED - when the request was first scheduled for execution
PREEMPTED - the request has been put back in the waiting queue in order to make room for other requests to complete.
It will be re-scheduled in future and re-start its prefill phase
"""
QUEUED = 1
SCHEDULED = 2
PREEMPTED = 3 # FIXME, currently ignored for simplicity
# modified from https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/__init__.py
@dataclass
class EngineEvent:
"""A timestamped engine event associated with a request.
Attributes:
type: the type of an event associated with a request during its life cycle
timestamp: the WALL-CLOCK time when the event happens.
"""
type: EventType
timestamp: float
@classmethod
def new_event(cls, event_type: EventType, timestamp: float | None = None) -> 'EngineEvent':
# Timestamps MUST use wall-clock time (time.time()) to maintain consistency
# between csrc(std::chrono::system_clock) and python
timestamp = time.time() if timestamp is None else timestamp
return cls(event_type, timestamp)
@dataclass
class ScheduleMetrics:
active_seqs: int = 0
waiting_seqs: int = 0
total_blocks: int = 0
active_blocks: int = 0
cached_blocks: int = 0
free_blocks: int = 0
prefix_cache_hit_rate: float = 0
@dataclass
class RequestMetrics:
"""Basic metrics for a request.
Attributes:
token_timestamp: A wall-clock time when a token is generated.
engine_events: list of engine events during inference.
"""
token_timestamp: float = 0.0
engine_events: list[EngineEvent] = field(default_factory=list)
spec_info: dict[str, Any] | None = None
@dataclass
class EngineOutput:
"""Engine output from turbomind/pytorch engine.
Args:
status: the response type.
token_ids: the newly generated token ids in each iteration.
logprobs: the top logprobs for each output
position.
cache_block_ids: send cache blocks back for migration in
Disaggregated LLM Serving when Prefill Engine is Done.
req_metrics: request metrics information
"""
status: ResponseType
token_ids: list[int]
logprobs: list[dict[int, float]] = None
logits: torch.Tensor = None
last_hidden_state: torch.Tensor = None
cache_block_ids: list[int] | None = None
req_metrics: RequestMetrics | None = None
routed_experts: torch.Tensor = None
@dataclass
class VisionConfig:
"""Vision model configs.
Args:
max_batch_size: the max image size passed to the model, since
some models will use image patch, the actual running batch could
be larger than this value.
thread_safe: Specifies whether the engine instance is
thread-safe. Please set it to True when using the pipeline
in a multi-threaded environment.
"""
max_batch_size: int = 1
thread_safe: bool = False
@dataclass
class SpeculativeConfig:
"""Speculative decoding config.
Args:
method: the speculative decoding method.
model: the path of speculative model.
num_speculative_tokens: number of generated token of draft model per step
"""
method: str
model: str = ''
num_speculative_tokens: int = 1