Skip to content

vllm.model_executor.models.terratorch

Wrapper around Terratorch models

logger module-attribute

logger = init_logger(__name__)

Terratorch

Bases: Module, IsAttentionFree, SupportsMultiModal

Source code in vllm/model_executor/models/terratorch.py
@attn_type("attention_free")
@MULTIMODAL_REGISTRY.register_processor(
    TerratorchMultiModalProcessor,
    info=TerratorchProcessingInfo,
    dummy_inputs=TerratorchInputBuilder,
)
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
    supports_multimodal_raw_input_only = True
    is_pooling_model = True

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]

        self.inference_runner = InferenceRunner(config)
        self.model = self.inference_runner.model

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

        self.pooler = IdentityPooler()

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        # We do not really use any input tokens and therefore no embeddings
        # to be calculated. However, due to the mandatory token ids in
        # the input prompt we pass one token and the size of the dummy
        # embedding tensors must reflect that.
        return torch.empty((input_ids.shape[0], 0))

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ):
        input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)

        batched_kwargs = {k: v.unsqueeze(0) for k, v in kwargs.items()}
        model_output = self.inference_runner.forward(**batched_kwargs).output

        # The leading dimension of hidden states needs to equal input length
        return model_output.expand(
            input_len, *(-1 for _ in range(model_output.ndim - 1))
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        params_list = []
        model_buffers = dict(self.named_buffers())
        loaded_buffers = []
        for key, value in weights:
            if isinstance(value, (dict, OrderedDict)):
                if key == "state_dict":
                    weights_to_parse = value
                    for name, weight in weights_to_parse.items():
                        name = f"inference_runner.{name}"

                        if "pos_embed" in name:
                            continue

                        if "_timm_module." in name:
                            name = name.replace("_timm_module.", "")

                        # this model requires a couple of buffers to be loaded
                        # that are not loadable with the AutoWeightsLoader
                        if name in model_buffers:
                            if "_timm_module." in name:
                                name = name.replace("_timm_module.", "")
                            buffer = model_buffers[name]
                            weight_loader = getattr(
                                buffer, "weight_loader", default_weight_loader
                            )
                            weight_loader(buffer, weight)
                            loaded_buffers.append(name)
                        else:
                            params_list.append((name, weight))
                    break

            elif isinstance(value, torch.Tensor):
                params_list.append((f"inference_runner.model.{key}", value))

        # Load the remaining model parameters
        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(params_list)

        return autoloaded_weights.union(set(loaded_buffers))

inference_runner instance-attribute

inference_runner = InferenceRunner(config)

is_pooling_model class-attribute instance-attribute

is_pooling_model = True

model instance-attribute

model = model

pooler instance-attribute

pooler = IdentityPooler()

supports_multimodal_raw_input_only class-attribute instance-attribute

supports_multimodal_raw_input_only = True

__init__

__init__(vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/terratorch.py
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]

    self.inference_runner = InferenceRunner(config)
    self.model = self.inference_runner.model

    pooler_config = vllm_config.model_config.pooler_config
    assert pooler_config is not None

    self.pooler = IdentityPooler()

embed_input_ids

embed_input_ids(
    input_ids: Tensor,
    multimodal_embeddings: MultiModalEmbeddings
    | None = None,
    *,
    is_multimodal: Tensor | None = None,
    handle_oov_mm_token: bool = False,
) -> Tensor
Source code in vllm/model_executor/models/terratorch.py
def embed_input_ids(
    self,
    input_ids: torch.Tensor,
    multimodal_embeddings: MultiModalEmbeddings | None = None,
    *,
    is_multimodal: torch.Tensor | None = None,
    handle_oov_mm_token: bool = False,
) -> torch.Tensor:
    # We do not really use any input tokens and therefore no embeddings
    # to be calculated. However, due to the mandatory token ids in
    # the input prompt we pass one token and the size of the dummy
    # embedding tensors must reflect that.
    return torch.empty((input_ids.shape[0], 0))

forward

forward(
    input_ids: Tensor | None,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    **kwargs: object,
)
Source code in vllm/model_executor/models/terratorch.py
def forward(
    self,
    input_ids: torch.Tensor | None,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
    **kwargs: object,
):
    input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)

    batched_kwargs = {k: v.unsqueeze(0) for k, v in kwargs.items()}
    model_output = self.inference_runner.forward(**batched_kwargs).output

    # The leading dimension of hidden states needs to equal input length
    return model_output.expand(
        input_len, *(-1 for _ in range(model_output.ndim - 1))
    )

get_placeholder_str classmethod

get_placeholder_str(modality: str, i: int) -> str | None
Source code in vllm/model_executor/models/terratorch.py
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
    if modality.startswith("image"):
        return None

    raise ValueError("Only image modality is supported")

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/terratorch.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    params_list = []
    model_buffers = dict(self.named_buffers())
    loaded_buffers = []
    for key, value in weights:
        if isinstance(value, (dict, OrderedDict)):
            if key == "state_dict":
                weights_to_parse = value
                for name, weight in weights_to_parse.items():
                    name = f"inference_runner.{name}"

                    if "pos_embed" in name:
                        continue

                    if "_timm_module." in name:
                        name = name.replace("_timm_module.", "")

                    # this model requires a couple of buffers to be loaded
                    # that are not loadable with the AutoWeightsLoader
                    if name in model_buffers:
                        if "_timm_module." in name:
                            name = name.replace("_timm_module.", "")
                        buffer = model_buffers[name]
                        weight_loader = getattr(
                            buffer, "weight_loader", default_weight_loader
                        )
                        weight_loader(buffer, weight)
                        loaded_buffers.append(name)
                    else:
                        params_list.append((name, weight))
                break

        elif isinstance(value, torch.Tensor):
            params_list.append((f"inference_runner.model.{key}", value))

    # Load the remaining model parameters
    loader = AutoWeightsLoader(self)
    autoloaded_weights = loader.load_weights(params_list)

    return autoloaded_weights.union(set(loaded_buffers))

TerratorchInputBuilder

Bases: BaseDummyInputsBuilder[TerratorchProcessingInfo]

Source code in vllm/model_executor/models/terratorch.py
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
    def __init__(self, info: TerratorchProcessingInfo):
        super().__init__(info)
        self.dummy_data_generator = DummyDataGenerator(
            self.info.get_hf_config().to_dict()["pretrained_cfg"]
        )

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
    ) -> MultiModalDataDict:
        # Dummy data is generated based on the 'input' section
        # defined in the HF configuration file

        if mm_options:
            logger.warning(
                "Configurable multimodal profiling "
                "options are not supported for Terratorch. "
                "They are ignored for now."
            )

        return self.dummy_data_generator.get_dummy_mm_data()

dummy_data_generator instance-attribute

dummy_data_generator = DummyDataGenerator(
    to_dict()["pretrained_cfg"]
)

__init__

__init__(info: TerratorchProcessingInfo)
Source code in vllm/model_executor/models/terratorch.py
def __init__(self, info: TerratorchProcessingInfo):
    super().__init__(info)
    self.dummy_data_generator = DummyDataGenerator(
        self.info.get_hf_config().to_dict()["pretrained_cfg"]
    )

get_dummy_mm_data

get_dummy_mm_data(
    seq_len: int,
    mm_counts: Mapping[str, int],
    mm_options: Mapping[str, BaseDummyOptions]
    | None = None,
) -> MultiModalDataDict
Source code in vllm/model_executor/models/terratorch.py
def get_dummy_mm_data(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
    mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
    # Dummy data is generated based on the 'input' section
    # defined in the HF configuration file

    if mm_options:
        logger.warning(
            "Configurable multimodal profiling "
            "options are not supported for Terratorch. "
            "They are ignored for now."
        )

    return self.dummy_data_generator.get_dummy_mm_data()

get_dummy_text

get_dummy_text(mm_counts: Mapping[str, int]) -> str
Source code in vllm/model_executor/models/terratorch.py
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
    return ""

TerratorchMultiModalDataParser

Bases: MultiModalDataParser

Source code in vllm/model_executor/models/terratorch.py
class TerratorchMultiModalDataParser(MultiModalDataParser):
    def __init__(self, input_definition: InputDefinition, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.input_definition = input_definition

    def _parse_image_data(
        self,
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields=_terratorch_field_names(self.input_definition),
                fields_factory=_terratorch_field_factory(self.input_definition),
            )

        return super()._parse_image_data(data)

    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
        if "image" not in mm_data:
            mm_data = {"image": mm_data}

        return super().parse_mm_data(mm_data)

input_definition instance-attribute

input_definition = input_definition

__init__

__init__(
    input_definition: InputDefinition, *args, **kwargs
)
Source code in vllm/model_executor/models/terratorch.py
def __init__(self, input_definition: InputDefinition, *args, **kwargs):
    super().__init__(*args, **kwargs)

    self.input_definition = input_definition

_parse_image_data

_parse_image_data(
    data: dict[str, Tensor] | ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None
Source code in vllm/model_executor/models/terratorch.py
def _parse_image_data(
    self,
    data: dict[str, torch.Tensor] | ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None:
    if isinstance(data, dict):
        return DictEmbeddingItems(
            data,
            modality="image",
            required_fields=_terratorch_field_names(self.input_definition),
            fields_factory=_terratorch_field_factory(self.input_definition),
        )

    return super()._parse_image_data(data)

parse_mm_data

parse_mm_data(
    mm_data: MultiModalDataDict,
) -> MultiModalDataItems
Source code in vllm/model_executor/models/terratorch.py
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
    if "image" not in mm_data:
        mm_data = {"image": mm_data}

    return super().parse_mm_data(mm_data)

TerratorchMultiModalProcessor

Bases: BaseMultiModalProcessor[TerratorchProcessingInfo]

Source code in vllm/model_executor/models/terratorch.py
class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessingInfo]):
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return _terratorch_field_factory(self.info.input_definition)(hf_inputs)

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        return []

    def apply(
        self,
        prompt: str | list[int],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> MultiModalInputs:
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

        mm_hashes = self._hash_mm_items(
            mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
        )

        _, passthrough_data = self._get_hf_mm_data(mm_items)
        mm_processed_data = BatchFeature(dict(passthrough_data), tensor_type="pt")
        mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}

        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
            mm_processed_data,
            self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
        )

        return MultiModalInputs(
            type="multimodal",
            prompt_token_ids=[1],
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
            mm_placeholders=mm_placeholders,
        )

_get_mm_fields_config

_get_mm_fields_config(
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]
Source code in vllm/model_executor/models/terratorch.py
def _get_mm_fields_config(
    self,
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
    return _terratorch_field_factory(self.info.input_definition)(hf_inputs)

_get_prompt_updates

_get_prompt_updates(
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]
Source code in vllm/model_executor/models/terratorch.py
def _get_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
    return []

apply

apply(
    prompt: str | list[int],
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Mapping[str, object] | None = None,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs
Source code in vllm/model_executor/models/terratorch.py
def apply(
    self,
    prompt: str | list[int],
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Mapping[str, object] | None = None,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
    if tokenization_kwargs is None:
        tokenization_kwargs = {}

    mm_hashes = self._hash_mm_items(
        mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
    )

    _, passthrough_data = self._get_hf_mm_data(mm_items)
    mm_processed_data = BatchFeature(dict(passthrough_data), tensor_type="pt")
    mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}

    mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
        mm_processed_data,
        self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
    )

    return MultiModalInputs(
        type="multimodal",
        prompt_token_ids=[1],
        mm_kwargs=mm_kwargs,
        mm_hashes=mm_hashes,
        mm_placeholders=mm_placeholders,
    )

TerratorchProcessingInfo

Bases: BaseProcessingInfo

Source code in vllm/model_executor/models/terratorch.py
class TerratorchProcessingInfo(BaseProcessingInfo):
    @cached_property
    def input_definition(self) -> InputDefinition:
        pretrained_cfg = self.get_hf_config().to_dict()["pretrained_cfg"]
        return InputDefinition(**pretrained_cfg["input"])

    def get_data_parser(self):
        return TerratorchMultiModalDataParser(
            self.input_definition,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"image": None}

input_definition cached property

input_definition: InputDefinition

get_data_parser

get_data_parser()
Source code in vllm/model_executor/models/terratorch.py
def get_data_parser(self):
    return TerratorchMultiModalDataParser(
        self.input_definition,
        expected_hidden_size=self._get_expected_hidden_size(),
    )

get_supported_mm_limits

get_supported_mm_limits() -> Mapping[str, int | None]
Source code in vllm/model_executor/models/terratorch.py
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
    return {"image": None}

_terratorch_field_factory

_terratorch_field_factory(
    input_definition: InputDefinition,
)
Source code in vllm/model_executor/models/terratorch.py
def _terratorch_field_factory(input_definition: InputDefinition):
    def _terratorch_field_config(
        hf_inputs: Mapping[str, torch.Tensor],
    ) -> Mapping[str, MultiModalFieldConfig]:
        fields = dict[str, MultiModalFieldConfig]()
        for name, input in input_definition.data.items():
            modality = "image"
            if input.type == InputTypeEnum.tensor:
                fields[name] = MultiModalFieldConfig.shared(modality, batch_size=1)

        return fields

    return _terratorch_field_config

_terratorch_field_names

_terratorch_field_names(input_definition: InputDefinition)
Source code in vllm/model_executor/models/terratorch.py
def _terratorch_field_names(input_definition: InputDefinition):
    return set(input_definition.data.keys())