# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import annotations
import asyncio
import atexit
import concurrent.futures
import os
from collections.abc import Iterator
from contextlib import closing
from functools import partial
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING
import torch
import tqdm
from typing_extensions import deprecated
from .archs import autoget_backend_config, get_task
from .messages import GenerationConfig, PytorchEngineConfig, Response, SpeculativeConfig, TurbomindEngineConfig
from .model import ChatTemplateConfig
from .serve.processors import MultimodalProcessor
from .utils import get_logger, get_model
if TYPE_CHECKING:
from PIL.Image import Image
from .serve.managers import Session
logger = get_logger('lmdeploy')
[docs]
class Pipeline:
"""Pipeline - User-facing API layer for inference."""
[docs]
def __init__(self,
model_path: str,
backend_config: TurbomindEngineConfig | PytorchEngineConfig | None = None,
chat_template_config: ChatTemplateConfig | None = None,
log_level: str = 'WARNING',
max_log_len: int | None = None,
trust_remote_code: bool = False,
speculative_config: SpeculativeConfig | None = None,
**kwargs):
"""Initialize Pipeline.
Args:
model_path: Path to the model.
backend_config: Backend configuration.
chat_template_config: Chat template configuration.
log_level: Log level.
max_log_len: Max number of prompt characters or prompt tokens being printed in log.
trust_remote_code: whether to trust remote code from model repositories.
speculative_config: Speculative decoding configuration.
**kwargs: Additional keyword arguments.
"""
os.environ.setdefault('TM_LOG_LEVEL', log_level)
logger.setLevel(log_level)
# Download model if the path does not exist locally
if not os.path.exists(model_path):
download_dir = backend_config.download_dir if backend_config else None
revision = backend_config.revision if backend_config else None
model_path = get_model(model_path, download_dir, revision)
# Download speculative model if the path does not exist locally
if speculative_config and speculative_config.model and not os.path.exists(speculative_config.model):
download_dir = backend_config.download_dir if backend_config else None
speculative_config.model = get_model(speculative_config.model, download_dir)
# Create inference engine
backend, backend_config = autoget_backend_config(model_path, backend_config,
trust_remote_code=trust_remote_code)
_, pipeline_class = get_task(backend, model_path, trust_remote_code=trust_remote_code)
self.async_engine = pipeline_class(model_path,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
max_log_len=max_log_len,
trust_remote_code=trust_remote_code,
speculative_config=speculative_config,
**kwargs)
self.internal_thread = _EventLoopThread(daemon=True)
self.limiter: asyncio.Semaphore = None
self.session_mgr = self.async_engine.session_mgr
self.backend_config = self.async_engine.backend_config
self.async_engine.start_loop(self.internal_thread.loop, use_async_api=False)
[docs]
def infer(self,
prompts: list[str] | str | list[dict] | list[list[dict]] | tuple | list[tuple],
gen_config: GenerationConfig | list[GenerationConfig] | None = None,
do_preprocess: bool = True,
adapter_name: str | None = None,
use_tqdm: bool = False,
**kwargs):
"""Inference prompts.
Args:
prompts: Prompts for inference. It can be a single prompt, a list of prompts, a list of tuples, or a tuple.
tuple can be (prompt, image or [images]) or (image or [images], prompt).
gen_config: Generation configuration(s).
do_preprocess: Whether to pre-process messages.
adapter_name: Adapter name.
use_tqdm: Whether to use progress bar.
**kwargs: Additional keyword arguments.
Returns:
Response | list[Response]: A single response or a list of responses.
"""
is_single = self._is_single(prompts)
# format prompts to openai message format, which is a list of dicts
prompts = MultimodalProcessor.format_prompts(prompts)
pbar = tqdm.tqdm(total=len(prompts)) if use_tqdm else None
outputs = []
try:
requests = self._request_generator(prompts,
gen_config=gen_config,
do_preprocess=do_preprocess,
adapter_name=adapter_name,
stream_response=False,
**kwargs)
for g in self._infer(requests, multiplex=False, pbar=pbar):
res = None
for out in g:
res = res.extend(out) if res else out
outputs.append(res)
finally:
if pbar: pbar.close() # noqa
if is_single:
return outputs[0]
return outputs
@deprecated('This method is deprecated. Please use "Pipeline.infer" instead.')
def batch_infer(self, *args, **kwargs):
return self.infer(*args, **kwargs)
[docs]
def stream_infer(self,
prompts: list[str] | str | list[dict] | list[list[dict]] | tuple | list[tuple],
sessions: Session | list[Session] | None = None,
gen_config: GenerationConfig | list[GenerationConfig] | None = None,
do_preprocess: bool = True,
adapter_name: str | None = None,
stream_response: bool = True,
**kwargs):
"""Stream inference.
Args:
prompts: Prompts to inference. It can be a single prompt, a list of prompts, a list of tuples, or a tuple.
tuple can be (prompt, image or [images]) or (image or [images], prompt).
sessions: Sessions. Each of which corresponds to a prompt.
gen_config: Generation configuration(s).
do_preprocess: Whether to pre-process messages.
adapter_name: Adapter name.
stream_response: Whether to stream the response. If True, the generator will stream the response.
Otherwise, the generator will run until finish and return the final response. This argument
is introduced to support the streaming and non-streaming modes of Pipeline.chat.
**kwargs: Additional keyword arguments.
Returns:
Iterator: A generator that yields the output (i.e. instance of class ``Response``) of the inference.
"""
prompts = MultimodalProcessor.format_prompts(prompts)
requests = self._request_generator(prompts,
sessions=sessions,
gen_config=gen_config,
do_preprocess=do_preprocess,
adapter_name=adapter_name,
stream_response=stream_response,
**kwargs)
return self._infer(requests, multiplex=True)
def close(self):
"""Close the pipeline."""
self.internal_thread.close()
self.async_engine.close()
[docs]
def chat(self,
prompt: str | tuple[str, Image | list[Image]],
session=None,
gen_config: GenerationConfig | None = None,
stream_response=False,
adapter_name=None,
**kwargs) -> Session | Iterator:
"""Chat.
Args:
prompt: prompt string or a tuple of (prompt, image or [images]).
session: the chat session.
gen_config: an instance of GenerationConfig. Default to None.
stream_response: whether to stream the response.
adapter_name: adapter name.
**kwargs: additional keyword arguments.
Returns:
Session | Iterator: the updated session, or a streaming iterator if stream_response is True.
"""
if session is None:
session = self.session_mgr.get()
session.update(prompt=prompt, response=None)
prompt = MultimodalProcessor.format_prompts(prompt)
sequence_start = session.step == 0
generator = self.stream_infer(prompts=prompt,
sessions=session,
gen_config=gen_config,
stream_response=stream_response,
adapter_name=adapter_name,
multiplex=True,
sequence_start=sequence_start,
sequence_end=False,
step=session.step,
**kwargs)
def _gen():
resp = None
try:
for out in generator:
resp = resp.extend(out) if resp else out
yield out
except: # noqa
self._run(coro=session.async_abort())
raise
else:
session.response = resp
session.step += resp.generate_token_len + resp.input_token_len
session.history.append((session.prompt, resp.text))
if stream_response:
return _gen()
else:
# run the generator until finish
with closing(_gen()) as gen:
for _ in gen:
pass
session.generator = None
return session
def session(self) -> Session:
"""Create a new session."""
return self.session_mgr.get()
def get_reward_score(self, input_ids: list) -> list[float]:
"""Get reward score.
Args:
input_ids: a list of token_id or a list of token_id list or token_id tensor.
Returns:
list[float]: reward score in a list. If the input_ids is a list of token_id,
the return value is still a list with length 1.
"""
supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel']
arch = self.async_engine.arch
if arch not in supported_reward_models:
raise ValueError(f'{arch} is not in reward model list: {supported_reward_models}')
assert isinstance(input_ids, list)
assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, list) for x in input_ids)
# Make input_ids a list of token_id list
input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids
logits = self._run(coro=self.async_engine.async_get_logits(input_ids=input_ids)).result()
logits = [x.squeeze() for x in logits]
scores = [x[-1].cpu().item() for x in logits]
return scores
[docs]
def get_ppl(self, input_ids: list[int] | list[list[int]]) -> list[float]:
"""Get perplexity scores given a list of input tokens that have to be
of the same length.
Args:
input_ids: the batch of input token ids.
Returns:
list[float]: A list of perplexity scores.
"""
assert isinstance(input_ids, list)
if isinstance(input_ids[0], int):
input_ids = [input_ids]
assert all(len(_) > 1 for _ in input_ids)
# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.async_engine.hf_cfg.vocab_size
max_input_len = 2 * 1024**3 // (vocab_size * 4)
sizes = [len(_) for _ in input_ids]
result = []
sorted_index_values = sorted(list(enumerate(sizes)), key=lambda x: x[1], reverse=True)
sizes = [value for index, value in sorted_index_values]
indices = [index for index, value in sorted_index_values]
logger.info(f'sorted sizes: {sizes}')
logger.info(f'sorted indices: {indices}')
for (start, end) in self._batch_iterator(sizes, max_input_len):
logger.info(f'start: {start}, end: {end}')
if start == end:
_input_ids = input_ids[indices[start]]
session = self.session_mgr.get()
res = self._get_long_text_ppl(session, input_ids=_input_ids, max_input_len=max_input_len)
result.append(res)
self.session_mgr.remove(session)
else:
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
sessions = [self.session_mgr.get() for _ in range(start, end)]
res = self._get_ppl(
sessions=sessions,
input_ids=_input_ids,
max_input_len=max_input_len,
)
result.extend(res)
for session in sessions:
self.session_mgr.remove(session)
output = list(range(len(result)))
for index, sorted_index in enumerate(indices):
output[sorted_index] = result[index]
return output
def __call__(self,
prompts: list[str] | str | list[dict] | list[list[dict]],
gen_config: GenerationConfig | list[GenerationConfig] | None = None,
**kwargs):
return self.infer(prompts, gen_config=gen_config, **kwargs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
@deprecated('This method is deprecated. Please use "AsyncEngine.generate" instead.')
async def generate(self, *args, **kwargs):
"""Generate responses as an async generator.
This method delegates to async_engine.generate and forwards all yielded values.
"""
async for item in self.async_engine.generate(*args, **kwargs):
yield item
@staticmethod
def _is_single(prompts):
"""Check if prompts is a single prompt."""
return (isinstance(prompts, str) or (isinstance(prompts, tuple) and len(prompts) == 2)
or (isinstance(prompts, list) and len(prompts) > 0 and isinstance(prompts[0], dict)))
def _request_generator(self,
prompts: list[str] | str | list[dict] | list[list[dict]],
sessions: list[Session] | Session | None = None,
gen_config: GenerationConfig | list[GenerationConfig] | None = None,
**kwargs):
"""Generate requests."""
is_single = self._is_single(prompts)
prompts = [prompts] if is_single else prompts
if sessions is None:
sessions = [self.session_mgr.get() for _ in prompts]
elif isinstance(sessions, list):
sessions = sessions
else:
sessions = [sessions]
if len(prompts) != len(sessions):
raise ValueError(f'prompts and sessions should have the same length. '
f'Got {len(prompts)} prompts and {len(sessions)} sessions')
if gen_config is None:
gen_configs = [GenerationConfig()] * len(prompts)
elif isinstance(gen_config, list):
gen_configs = gen_config
else:
gen_configs = [gen_config] * len(prompts)
if len(prompts) != len(gen_configs):
raise ValueError(f'input gen_config length differs from the length of prompts. '
f'Got {len(prompts)} prompts and {len(gen_configs)} gen_configs')
for prompt, gen_cfg, session in zip(prompts, gen_configs, sessions):
# Use session_id is for backward compatibility. We will remove it in the future.
# Since AsyncEngine.generate defines session_id in the argument lists, here we
# use session_id to pass the session to the AsyncEngine.generate. It's
yield dict(session_id=session, messages=prompt, gen_config=gen_cfg, **kwargs)
def _get_limiter(self):
if not self.limiter:
self.limiter = asyncio.Semaphore(self.backend_config.max_batch_size)
return self.limiter
def _infer(self, requests: Iterator[dict], multiplex: bool, pbar=None, loop=None) -> Iterator[Iterator[Response]]:
async def _sync_resp(g, que: Queue, idx: int, sem: asyncio.Semaphore):
async for out in g:
que.put(out.to_response(idx))
sem.release()
if not multiplex:
que.put(None) # sentinel of inner generator
if pbar:
pbar.update(1)
que = Queue()
async def _infer():
sem = self._get_limiter()
tasks = []
for idx, req in enumerate(requests):
await sem.acquire()
gen = self.async_engine.generate(**req)
dst = que if multiplex else Queue()
if not multiplex:
que.put(iter(dst.get, None))
# create a task to send the responses
task = asyncio.create_task(_sync_resp(gen, dst, idx, sem))
tasks.append(task)
if not multiplex: # sentinel of outer generator
que.put(None)
await asyncio.gather(*tasks)
if multiplex:
que.put(None) # sentinel of inner generator
loop = loop or self.internal_thread.loop
# submit the coroutine to async world
asyncio.run_coroutine_threadsafe(_infer(),
loop).add_done_callback(lambda f: None if f.cancelled() else f.result())
return iter(que.get, None)
def _run(self, fn=None, coro=None):
assert (fn or coro) and not (fn and coro)
loop = self.internal_thread.loop
if fn:
async def _coro():
return fn()
coro = _coro()
return asyncio.run_coroutine_threadsafe(coro, loop)
def _batch_iterator(self, sizes, max_value):
"""Return an iterator that calculates intervals (start, end) of a
descend-order list, in which the sum of values in the range is the
maximum number not less than max_value. By "the sum of values",
here it means $$len(sizes[start:end]) * sizes[start]$$
"""
i = 0
while i < len(sizes):
current_sum = 0
start_index = i
while i < len(sizes) and current_sum + sizes[start_index] <= max_value:
current_sum += sizes[start_index]
i += 1
yield (start_index, i)
if i > start_index:
continue
else:
i += 1
def _get_long_text_ppl(self, session, input_ids, max_input_len):
assert all(isinstance(_, int) for _ in input_ids)
seq_len = len(input_ids)
assert seq_len > max_input_len
logger.info(f'get long text ppl: seq_len {seq_len}')
losses = []
target_counts = []
for i in range(0, seq_len, max_input_len):
token_ids = input_ids[i:i + max_input_len]
session.update(step=i)
# shift token_ids by 1 to the left
target_ids = input_ids[i + 1:i + 1 + max_input_len]
loss = self._get_ppl(sessions=[session],
input_ids=[token_ids],
max_input_len=len(token_ids),
target_ids=[target_ids],
sequence_start=(i == 0),
sequence_end=False)
losses.extend(loss)
target_counts.append(len(target_ids))
losses = [loss * target_count for loss, target_count in zip(losses, target_counts)]
loss_sum = sum(losses)
target_count = sum(target_counts)
return loss_sum / target_count
def _get_ppl(self,
sessions: list[Session],
input_ids: list[list[int]],
max_input_len: int,
target_ids=None,
sequence_start: bool = True,
sequence_end: bool = True):
assert (isinstance(input_ids, list) and all(isinstance(_, list) for _ in input_ids))
assert target_ids is None or len(target_ids) == len(input_ids)
assert len(sessions) == len(input_ids)
lens = [len(_) for _ in input_ids]
total_len = sum(lens)
assert sum(lens) <= max_input_len
logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, '
f'total_len: {total_len}')
torch.cuda.empty_cache()
logits = self._run(coro=self.async_engine.async_get_logits(
input_ids=input_ids, sessions=sessions, sequence_start=sequence_start, sequence_end=sequence_end)).result()
padding_token_id = -100
if target_ids is None:
target_ids = [x[1:] + [padding_token_id] for x in input_ids]
else:
target_ids = [
target_ids[i] + [padding_token_id] if len(target_ids[i]) < len(input_ids[i]) else target_ids[i]
for i in range(len(input_ids))
]
target_ids = [torch.Tensor(torch.LongTensor(_target_ids)) for _target_ids in target_ids]
result = []
for _logits, _target_ids in zip(logits, target_ids):
_logits = _logits.float()
vocab_size = _logits.shape[-1]
_target_ids = _target_ids.to(_logits.device)
target_mask = _target_ids != padding_token_id
# compute cross entropy loss
flat_logits = _logits.contiguous().view(-1, vocab_size)
flat_target_ids = _target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)
loss = flat_loss_matrix.sum()
target_count = target_mask.sum()
result.append(loss.item() / target_count.item())
logger.info(f'ppl result: {result}')
return result
class _EventLoopThread:
def __init__(self, daemon=False):
fut = concurrent.futures.Future()
self.thread = Thread(target=partial(self._thread_entry, fut), daemon=daemon)
self.thread.start()
self.loop: asyncio.AbstractEventLoop = fut.result()
self.closed = False
if daemon:
atexit.register(self.close)
def _thread_entry(self, fut):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
fut.set_result(loop)
try:
loop.run_forever()
except BaseException as e:
logger.error(f'[internal_thread] {type(e).__name__} {e}')
finally:
try:
self._cancel_all_tasks()
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
loop.close()
def _cancel_all_tasks(self):
"""Modified from asyncio/runners.py."""
to_cancel = asyncio.all_tasks(self.loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
async def _gather():
await asyncio.gather(*to_cancel, return_exceptions=True)
self.loop.run_until_complete(_gather())
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
self.loop.call_exception_handler({
'message': 'unhandled exception during worker thread shutdown',
'exception': task.exception(),
'task': task,
})
def close(self):
if self.closed:
return
self.closed = True
self.loop.call_soon_threadsafe(self.loop.stop)
self.thread.join()