Skip to content

vllm.v1.worker.gpu.mm.encoder_runner

EncoderRunner

Source code in vllm/v1/worker/gpu/mm/encoder_runner.py
class EncoderRunner:
    def __init__(
        self,
        max_num_tokens: int,
        hidden_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.max_num_tokens = max_num_tokens
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device

        self.inputs_embeds = torch.zeros(
            max_num_tokens,
            hidden_size,
            dtype=dtype,
            device=device,
        )
        self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
        self.encoder_cache: dict[str, torch.Tensor] = {}

        self.tmp_is_mm_embed = UvaBufferPool(max_num_tokens, torch.bool)

    def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
        self.req_id_to_mm_features[req_id] = mm_features

    def free_encoder_cache(self, mm_hash: str) -> None:
        self.encoder_cache.pop(mm_hash, None)

    def remove_request(self, req_id: str) -> None:
        self.req_id_to_mm_features.pop(req_id, None)

    def prepare_mm_inputs(
        self,
        scheduled_encoder_inputs: dict[str, list[int]],
    ) -> tuple[list[str], list[MultiModalKwargsItem]]:
        mm_hashes: list[str] = []
        mm_kwargs: list[MultiModalKwargsItem] = []
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            mm_features = self.req_id_to_mm_features[req_id]
            for mm_input_id in encoder_input_ids:
                mm_feature = mm_features[mm_input_id]
                if mm_feature.data is None:
                    continue
                mm_hashes.append(mm_feature.identifier)
                mm_kwargs.append(mm_feature.data)
        return mm_hashes, mm_kwargs

    @torch.inference_mode()
    def execute_mm_encoder(
        self,
        model: SupportsMultiModal,
        mm_hashes: list[str],
        mm_kwargs: list[MultiModalKwargsItem],
    ) -> list[torch.Tensor]:
        if not mm_hashes:
            return []

        encoder_outputs: list[torch.Tensor] = []
        for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
            mm_kwargs,
            device=self.device,
            pin_memory=False,
        ):
            curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
                expected_num_items=num_items,
            )
            encoder_outputs.extend(curr_group_outputs)

        # Cache the encoder outputs by mm_hash
        for mm_hash, output in zip(mm_hashes, encoder_outputs):
            self.encoder_cache[mm_hash] = output
        return encoder_outputs

    def gather_mm_embeddings(
        self,
        req_ids: list[str],
        total_num_scheduled_tokens: int,
        num_scheduled_tokens: np.ndarray,
        query_start_loc: np.ndarray,
        prefill_lens: np.ndarray,
        computed_prefill_lens: np.ndarray,
    ) -> tuple[list[torch.Tensor], torch.Tensor]:
        is_prefilling = (computed_prefill_lens < prefill_lens).tolist()
        all_decode = not any(is_prefilling)
        if all_decode:
            # All decode requests, so no need to gather any embeddings.
            return [], torch.zeros(
                total_num_scheduled_tokens,
                dtype=torch.bool,
                device=self.device,
            )

        query_start = computed_prefill_lens.tolist()
        query_end = (computed_prefill_lens + num_scheduled_tokens).tolist()

        mm_embeds: list[torch.Tensor] = []
        is_mm_embed = torch.zeros(
            total_num_scheduled_tokens,
            dtype=torch.bool,
            device="cpu",
            pin_memory=False,
        )
        for i, req_id in enumerate(req_ids):
            if not is_prefilling[i]:
                # OPTIMIZATION: Skip decode requests.
                continue

            mm_features = self.req_id_to_mm_features[req_id]
            for mm_feature in mm_features:
                pos_info = mm_feature.mm_position
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length

                if start_pos >= query_end[i]:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= query_start[i]:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(query_start[i] - start_pos, 0)
                end_idx = min(query_end[i] - start_pos, num_encoder_tokens)
                assert start_idx < end_idx
                curr_embeds_start, curr_embeds_end = (
                    pos_info.get_embeds_indices_in_range(start_idx, end_idx)
                )
                # If there are no embeddings in the current range, we skip
                # gathering the embeddings.
                if curr_embeds_start == curr_embeds_end:
                    continue

                mm_hash = mm_feature.identifier
                encoder_output = self.encoder_cache.get(mm_hash, None)
                assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."

                if (is_embed := pos_info.is_embed) is not None:
                    is_embed = is_embed[start_idx:end_idx]
                    mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
                else:
                    mm_embeds_item = encoder_output[start_idx:end_idx]

                req_start_pos = query_start_loc[i] + start_pos - query_start[i]
                is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
                    True if is_embed is None else is_embed
                )
                mm_embeds.append(mm_embeds_item)

        # Copy the is_mm_embed tensor to the GPU.
        is_mm_embed = self.tmp_is_mm_embed.copy_to_gpu(is_mm_embed)
        return mm_embeds, is_mm_embed

    @torch.inference_mode()
    def get_inputs_embeds(
        self,
        model: SupportsMultiModal,
        input_ids: torch.Tensor,
        mm_embeds: list[torch.Tensor],
        is_mm_embed: torch.Tensor,
    ) -> torch.Tensor:
        x = model.embed_input_ids(
            input_ids,
            multimodal_embeddings=mm_embeds,
            is_multimodal=is_mm_embed,
        )
        # Copy to the pre-allocated buffer for CUDA graphs.
        self.inputs_embeds[: x.shape[0]] = x
        return self.inputs_embeds

device instance-attribute

device = device

dtype instance-attribute

dtype = dtype

encoder_cache instance-attribute

encoder_cache: dict[str, Tensor] = {}

hidden_size instance-attribute

hidden_size = hidden_size

inputs_embeds instance-attribute

inputs_embeds = zeros(
    max_num_tokens, hidden_size, dtype=dtype, device=device
)

max_num_tokens instance-attribute

max_num_tokens = max_num_tokens

req_id_to_mm_features instance-attribute

req_id_to_mm_features: dict[
    str, list[MultiModalFeatureSpec]
] = {}

tmp_is_mm_embed instance-attribute

tmp_is_mm_embed = UvaBufferPool(max_num_tokens, bool)

__init__

__init__(
    max_num_tokens: int,
    hidden_size: int,
    dtype: dtype,
    device: device,
)
Source code in vllm/v1/worker/gpu/mm/encoder_runner.py
def __init__(
    self,
    max_num_tokens: int,
    hidden_size: int,
    dtype: torch.dtype,
    device: torch.device,
):
    self.max_num_tokens = max_num_tokens
    self.hidden_size = hidden_size
    self.dtype = dtype
    self.device = device

    self.inputs_embeds = torch.zeros(
        max_num_tokens,
        hidden_size,
        dtype=dtype,
        device=device,
    )
    self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
    self.encoder_cache: dict[str, torch.Tensor] = {}

    self.tmp_is_mm_embed = UvaBufferPool(max_num_tokens, torch.bool)

add_request

add_request(
    req_id: str, mm_features: list[MultiModalFeatureSpec]
)
Source code in vllm/v1/worker/gpu/mm/encoder_runner.py
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
    self.req_id_to_mm_features[req_id] = mm_features

execute_mm_encoder

execute_mm_encoder(
    model: SupportsMultiModal,
    mm_hashes: list[str],
    mm_kwargs: list[MultiModalKwargsItem],
) -> list[Tensor]
Source code in vllm/v1/worker/gpu/mm/encoder_runner.py
@torch.inference_mode()
def execute_mm_encoder(
    self,
    model: SupportsMultiModal,
    mm_hashes: list[str],
    mm_kwargs: list[MultiModalKwargsItem],
) -> list[torch.Tensor]:
    if not mm_hashes:
        return []

    encoder_outputs: list[torch.Tensor] = []
    for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
        mm_kwargs,
        device=self.device,
        pin_memory=False,
    ):
        curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
        sanity_check_mm_encoder_outputs(
            curr_group_outputs,
            expected_num_items=num_items,
        )
        encoder_outputs.extend(curr_group_outputs)

    # Cache the encoder outputs by mm_hash
    for mm_hash, output in zip(mm_hashes, encoder_outputs):
        self.encoder_cache[mm_hash] = output
    return encoder_outputs

free_encoder_cache

free_encoder_cache(mm_hash: str) -> None
Source code in vllm/v1/worker/gpu/mm/encoder_runner.py
def free_encoder_cache(self, mm_hash: str) -> None:
    self.encoder_cache.pop(mm_hash, None)

gather_mm_embeddings

gather_mm_embeddings(
    req_ids: list[str],
    total_num_scheduled_tokens: int,
    num_scheduled_tokens: ndarray,
    query_start_loc: ndarray,
    prefill_lens: ndarray,
    computed_prefill_lens: ndarray,
) -> tuple[list[Tensor], Tensor]
Source code in vllm/v1/worker/gpu/mm/encoder_runner.py
def gather_mm_embeddings(
    self,
    req_ids: list[str],
    total_num_scheduled_tokens: int,
    num_scheduled_tokens: np.ndarray,
    query_start_loc: np.ndarray,
    prefill_lens: np.ndarray,
    computed_prefill_lens: np.ndarray,
) -> tuple[list[torch.Tensor], torch.Tensor]:
    is_prefilling = (computed_prefill_lens < prefill_lens).tolist()
    all_decode = not any(is_prefilling)
    if all_decode:
        # All decode requests, so no need to gather any embeddings.
        return [], torch.zeros(
            total_num_scheduled_tokens,
            dtype=torch.bool,
            device=self.device,
        )

    query_start = computed_prefill_lens.tolist()
    query_end = (computed_prefill_lens + num_scheduled_tokens).tolist()

    mm_embeds: list[torch.Tensor] = []
    is_mm_embed = torch.zeros(
        total_num_scheduled_tokens,
        dtype=torch.bool,
        device="cpu",
        pin_memory=False,
    )
    for i, req_id in enumerate(req_ids):
        if not is_prefilling[i]:
            # OPTIMIZATION: Skip decode requests.
            continue

        mm_features = self.req_id_to_mm_features[req_id]
        for mm_feature in mm_features:
            pos_info = mm_feature.mm_position
            start_pos = pos_info.offset
            num_encoder_tokens = pos_info.length

            if start_pos >= query_end[i]:
                # The encoder output is not needed in this step.
                break
            if start_pos + num_encoder_tokens <= query_start[i]:
                # The encoder output is already processed and stored
                # in the decoder's KV cache.
                continue

            start_idx = max(query_start[i] - start_pos, 0)
            end_idx = min(query_end[i] - start_pos, num_encoder_tokens)
            assert start_idx < end_idx
            curr_embeds_start, curr_embeds_end = (
                pos_info.get_embeds_indices_in_range(start_idx, end_idx)
            )
            # If there are no embeddings in the current range, we skip
            # gathering the embeddings.
            if curr_embeds_start == curr_embeds_end:
                continue

            mm_hash = mm_feature.identifier
            encoder_output = self.encoder_cache.get(mm_hash, None)
            assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."

            if (is_embed := pos_info.is_embed) is not None:
                is_embed = is_embed[start_idx:end_idx]
                mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
            else:
                mm_embeds_item = encoder_output[start_idx:end_idx]

            req_start_pos = query_start_loc[i] + start_pos - query_start[i]
            is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
                True if is_embed is None else is_embed
            )
            mm_embeds.append(mm_embeds_item)

    # Copy the is_mm_embed tensor to the GPU.
    is_mm_embed = self.tmp_is_mm_embed.copy_to_gpu(is_mm_embed)
    return mm_embeds, is_mm_embed

get_inputs_embeds

get_inputs_embeds(
    model: SupportsMultiModal,
    input_ids: Tensor,
    mm_embeds: list[Tensor],
    is_mm_embed: Tensor,
) -> Tensor
Source code in vllm/v1/worker/gpu/mm/encoder_runner.py
@torch.inference_mode()
def get_inputs_embeds(
    self,
    model: SupportsMultiModal,
    input_ids: torch.Tensor,
    mm_embeds: list[torch.Tensor],
    is_mm_embed: torch.Tensor,
) -> torch.Tensor:
    x = model.embed_input_ids(
        input_ids,
        multimodal_embeddings=mm_embeds,
        is_multimodal=is_mm_embed,
    )
    # Copy to the pre-allocated buffer for CUDA graphs.
    self.inputs_embeds[: x.shape[0]] = x
    return self.inputs_embeds

prepare_mm_inputs

prepare_mm_inputs(
    scheduled_encoder_inputs: dict[str, list[int]],
) -> tuple[list[str], list[MultiModalKwargsItem]]
Source code in vllm/v1/worker/gpu/mm/encoder_runner.py
def prepare_mm_inputs(
    self,
    scheduled_encoder_inputs: dict[str, list[int]],
) -> tuple[list[str], list[MultiModalKwargsItem]]:
    mm_hashes: list[str] = []
    mm_kwargs: list[MultiModalKwargsItem] = []
    for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
        mm_features = self.req_id_to_mm_features[req_id]
        for mm_input_id in encoder_input_ids:
            mm_feature = mm_features[mm_input_id]
            if mm_feature.data is None:
                continue
            mm_hashes.append(mm_feature.identifier)
            mm_kwargs.append(mm_feature.data)
    return mm_hashes, mm_kwargs

remove_request

remove_request(req_id: str) -> None
Source code in vllm/v1/worker/gpu/mm/encoder_runner.py
def remove_request(self, req_id: str) -> None:
    self.req_id_to_mm_features.pop(req_id, None)