Skip to content

vllm.model_executor.layers.pooler.special

__all__ module-attribute

__all__ = [
    "BOSEOSFilter",
    "DispatchPooler",
    "IdentityPooler",
]

BOSEOSFilter

Bases: Pooler

Filters the BOS and EOS token results from outputs.

Source code in vllm/model_executor/layers/pooler/special.py
class BOSEOSFilter(Pooler):
    """Filters the BOS and EOS token results from outputs."""

    def __init__(
        self,
        pooler: Pooler,
        bos_token_id: int = -1,  # -1 disables the filtering
        eos_token_id: int = -1,
    ) -> None:
        super().__init__()

        self.pooler = pooler
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooler.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)

    def forward(
        self,
        hidden_states: torch.Tensor | list[torch.Tensor],
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        pooled_outputs = self.pooler(hidden_states, pooling_metadata)
        assert isinstance(pooled_outputs, list)

        for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
            pooled_data = pooled_outputs[i]
            assert (
                isinstance(pooled_data, torch.Tensor)
                and pooled_data.shape[0] == prompt_len
            )
            token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
            if token_ids[0] == self.bos_token_id:
                pooled_data = pooled_data[1:]
            if token_ids[-1] == self.eos_token_id:
                pooled_data = pooled_data[:-1]
            pooled_outputs[i] = pooled_data.squeeze()

        return pooled_outputs

bos_token_id instance-attribute

bos_token_id = bos_token_id

eos_token_id instance-attribute

eos_token_id = eos_token_id

pooler instance-attribute

pooler = pooler

__init__

__init__(
    pooler: Pooler,
    bos_token_id: int = -1,
    eos_token_id: int = -1,
) -> None
Source code in vllm/model_executor/layers/pooler/special.py
def __init__(
    self,
    pooler: Pooler,
    bos_token_id: int = -1,  # -1 disables the filtering
    eos_token_id: int = -1,
) -> None:
    super().__init__()

    self.pooler = pooler
    self.bos_token_id = bos_token_id
    self.eos_token_id = eos_token_id

forward

forward(
    hidden_states: Tensor | list[Tensor],
    pooling_metadata: PoolingMetadata,
) -> PoolerOutput
Source code in vllm/model_executor/layers/pooler/special.py
def forward(
    self,
    hidden_states: torch.Tensor | list[torch.Tensor],
    pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
    pooled_outputs = self.pooler(hidden_states, pooling_metadata)
    assert isinstance(pooled_outputs, list)

    for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
        pooled_data = pooled_outputs[i]
        assert (
            isinstance(pooled_data, torch.Tensor)
            and pooled_data.shape[0] == prompt_len
        )
        token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
        if token_ids[0] == self.bos_token_id:
            pooled_data = pooled_data[1:]
        if token_ids[-1] == self.eos_token_id:
            pooled_data = pooled_data[:-1]
        pooled_outputs[i] = pooled_data.squeeze()

    return pooled_outputs

get_pooling_updates

get_pooling_updates(
    task: PoolingTask,
) -> PoolingParamsUpdate
Source code in vllm/model_executor/layers/pooler/special.py
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
    return PoolingParamsUpdate(requires_token_ids=True)

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/special.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    return self.pooler.get_supported_tasks()

DispatchPooler

Bases: Pooler

Dispatches calls to a sub-pooler based on the pooling task.

Source code in vllm/model_executor/layers/pooler/special.py
class DispatchPooler(Pooler):
    """Dispatches calls to a sub-pooler based on the pooling task."""

    @classmethod
    def for_embedding(cls, pooler_config: PoolerConfig):
        return cls(
            {
                "token_embed": pooler_for_token_embed(pooler_config),
                "embed": pooler_for_embed(pooler_config),
            },
        )

    @classmethod
    def for_seq_cls(
        cls,
        pooler_config: PoolerConfig,
        *,
        pooling: SequencePoolingMethod | SequencePoolingFn | None = None,
        classifier: ClassifierFn | None = None,
    ):
        return cls(
            {
                "token_classify": pooler_for_token_classify(
                    pooler_config,
                    pooling=AllPool(),
                    classifier=classifier,
                ),
                "classify": pooler_for_classify(
                    pooler_config,
                    pooling=pooling,
                    classifier=classifier,
                    act_fn="classify",
                ),
                "score": pooler_for_classify(
                    pooler_config,
                    pooling=pooling,
                    classifier=classifier,
                    act_fn="score",
                ),
            }
        )

    def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
        super().__init__()

        for task, pooler in poolers_by_task.items():
            if task not in pooler.get_supported_tasks():
                raise ValueError(
                    f"{pooler=} does not support {task=}. "
                    f"Supported tasks: {pooler.get_supported_tasks()}"
                )

        self.poolers_by_task = poolers_by_task

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return set(self.poolers_by_task)

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return self.poolers_by_task[task].get_pooling_updates(task)

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        poolers_by_task = self.poolers_by_task

        outputs = list[torch.Tensor | None]()
        offset = 0
        for task, group in groupby(pooling_metadata.tasks):
            if not (pooler := poolers_by_task.get(task)):
                raise ValueError(
                    f"Unsupported task: {task!r} "
                    f"Supported tasks: {self.get_supported_tasks()}"
                )

            num_items = len(list(group))
            group_output: PoolerOutput = pooler(
                hidden_states,
                pooling_metadata[offset : offset + num_items],
            )

            outputs.extend(group_output)
            offset += num_items

        return outputs

    def extra_repr(self) -> str:
        s = f"supported_task={self.get_supported_tasks()}"
        return s

poolers_by_task instance-attribute

poolers_by_task = poolers_by_task

__init__

__init__(
    poolers_by_task: Mapping[PoolingTask, Pooler],
) -> None
Source code in vllm/model_executor/layers/pooler/special.py
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
    super().__init__()

    for task, pooler in poolers_by_task.items():
        if task not in pooler.get_supported_tasks():
            raise ValueError(
                f"{pooler=} does not support {task=}. "
                f"Supported tasks: {pooler.get_supported_tasks()}"
            )

    self.poolers_by_task = poolers_by_task

extra_repr

extra_repr() -> str
Source code in vllm/model_executor/layers/pooler/special.py
def extra_repr(self) -> str:
    s = f"supported_task={self.get_supported_tasks()}"
    return s

for_embedding classmethod

for_embedding(pooler_config: PoolerConfig)
Source code in vllm/model_executor/layers/pooler/special.py
@classmethod
def for_embedding(cls, pooler_config: PoolerConfig):
    return cls(
        {
            "token_embed": pooler_for_token_embed(pooler_config),
            "embed": pooler_for_embed(pooler_config),
        },
    )

for_seq_cls classmethod

for_seq_cls(
    pooler_config: PoolerConfig,
    *,
    pooling: SequencePoolingMethod
    | SequencePoolingFn
    | None = None,
    classifier: ClassifierFn | None = None,
)
Source code in vllm/model_executor/layers/pooler/special.py
@classmethod
def for_seq_cls(
    cls,
    pooler_config: PoolerConfig,
    *,
    pooling: SequencePoolingMethod | SequencePoolingFn | None = None,
    classifier: ClassifierFn | None = None,
):
    return cls(
        {
            "token_classify": pooler_for_token_classify(
                pooler_config,
                pooling=AllPool(),
                classifier=classifier,
            ),
            "classify": pooler_for_classify(
                pooler_config,
                pooling=pooling,
                classifier=classifier,
                act_fn="classify",
            ),
            "score": pooler_for_classify(
                pooler_config,
                pooling=pooling,
                classifier=classifier,
                act_fn="score",
            ),
        }
    )

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> PoolerOutput
Source code in vllm/model_executor/layers/pooler/special.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
    poolers_by_task = self.poolers_by_task

    outputs = list[torch.Tensor | None]()
    offset = 0
    for task, group in groupby(pooling_metadata.tasks):
        if not (pooler := poolers_by_task.get(task)):
            raise ValueError(
                f"Unsupported task: {task!r} "
                f"Supported tasks: {self.get_supported_tasks()}"
            )

        num_items = len(list(group))
        group_output: PoolerOutput = pooler(
            hidden_states,
            pooling_metadata[offset : offset + num_items],
        )

        outputs.extend(group_output)
        offset += num_items

    return outputs

get_pooling_updates

get_pooling_updates(
    task: PoolingTask,
) -> PoolingParamsUpdate
Source code in vllm/model_executor/layers/pooler/special.py
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
    return self.poolers_by_task[task].get_pooling_updates(task)

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/special.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    return set(self.poolers_by_task)

IdentityPooler

Bases: Pooler

Source code in vllm/model_executor/layers/pooler/special.py
class IdentityPooler(Pooler):
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"plugin", "score"}

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        return hidden_states

forward

forward(
    hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> PoolerOutput
Source code in vllm/model_executor/layers/pooler/special.py
def forward(
    self,
    hidden_states: torch.Tensor,
    pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
    return hidden_states

get_supported_tasks

get_supported_tasks() -> Set[PoolingTask]
Source code in vllm/model_executor/layers/pooler/special.py
def get_supported_tasks(self) -> Set[PoolingTask]:
    return {"plugin", "score"}