Skip to content

vllm.model_executor.model_loader.reload

Layerwise weight reloading utilities for vLLM.

This module provides functionality to reload model weights layer-by-layer, which is useful for weight updates without full model reconstruction.

Limitations: 1. Composition with CPU offloading has not been implemented 2. Reloading Attention/MLA weights (q_scale, k_scale, v_scale) has not been implemented 3. Tied parameters will only reflect processing from one of the parent layers (for example, only processing from embed_tokens will have an effect) 4. This design assumes that the number of weights loaded from disk is the same as the number of weights created at model init time. This is not true for quant methods which (1) pad weights or (2) load qkv weights into the same parameter. Both of these cases are non-issues for today's quant methods, but future quantizations may cause reloading to fail

Modules:

Name Description
layerwise
meta
sanitize
torchao_decorator
types
utils

__all__ module-attribute

__all__ = [
    "record_metadata_for_reloading",
    "initialize_layerwise_reload",
    "finalize_layerwise_reload",
    "set_torchao_reload_attrs",
    "support_quantized_model_reload_from_hp_weights",
]

finalize_layerwise_reload

finalize_layerwise_reload(
    model: Module, model_config: ModelConfig
)

Remove the outermost layer of weight loading wrappers.

This function should be applied after initialize_layerwise_reload is applied unwrap the layerwise weight loaders.

Also processes Attention/MLA layers, which must be processed after all other layers

Source code in vllm/model_executor/model_loader/reload/layerwise.py
def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig):
    """
    Remove the outermost layer of weight loading wrappers.

    This function should be applied after `initialize_layerwise_reload` is applied
    unwrap the layerwise weight loaders.

    Also processes Attention/MLA layers, which must be processed after all other layers
    """
    model._do_torchao_reload = model._original_do_torchao_reload

    for layer in model.modules():
        info = get_layerwise_info(layer)

        # Attention/MLA layers are processed after all other layers
        if isinstance(layer, (Attention, MLAAttention)):
            if info.load_numel > 0:
                raise NotImplementedError(
                    "Layerwise reloading of Q/K/V scale weights is not implemented yet"
                )

            else:
                _place_kernel_tensors(layer, info)
                layer.process_weights_after_loading(model_config.dtype)

        # No weights were loaded, place kernel tensors back
        elif info.can_process() and info.load_numel <= 0:
            _place_kernel_tensors(layer, info)

        # Process non-attention layers which did not load all elements. This can happen
        # if the created weight has extra padding elements which are not loaded
        # Having too many of these delayed layers can lead to execess memory usage
        # see Limitations(4)
        elif info.load_numel > 0 and info.load_numel < info.load_numel_total:  # type: ignore[operator]
            logger.debug("%s: Delayed processing", layer.__class__.__name__)
            _layerwise_process(layer, info)

        info.reset()

initialize_layerwise_reload

initialize_layerwise_reload(model: Module)

Set up layerwise weight loading with deferred processing.

Must be called after record_metadata_for_reloading. This function: 1. Saves current kernel tensors for later copying 2. Restores layer parameters/buffers from metadata (on meta device) 3. Wraps weight loaders to defer processing until all weights are loaded

When all weights for a layer are loaded, the wrapped loaders will: 1. Materialize the layer onto the target device 2. Load all cached weights 3. Run quantization processing if applicable 4. Copy processed values back to original tensor storage

Source code in vllm/model_executor/model_loader/reload/layerwise.py
@torch.no_grad()
def initialize_layerwise_reload(model: torch.nn.Module):
    """
    Set up layerwise weight loading with deferred processing.

    Must be called after `record_metadata_for_reloading`. This function:
    1. Saves current kernel tensors for later copying
    2. Restores layer parameters/buffers from metadata (on meta device)
    3. Wraps weight loaders to defer processing until all weights are loaded

    When all weights for a layer are loaded, the wrapped loaders will:
    1. Materialize the layer onto the target device
    2. Load all cached weights
    3. Run quantization processing if applicable
    4. Copy processed values back to original tensor storage
    """
    # disable torchao reloading to avoid infinite recursion
    model._original_do_torchao_reload = getattr(model, "_do_torchao_reload", False)
    model._do_torchao_reload = False

    for layer in model.modules():
        info = get_layerwise_info(layer)

        # Skip if the layer has already been initialized
        if info.can_process():
            continue

        # Save current tensors for later copying
        info.kernel_tensors = get_layer_params_buffers(layer)

        # Restore layer parameters/buffers onto meta device
        restore_layer_on_meta(layer, info)

        # Track loading progress to determine when to process/copy
        info.load_numel = 0
        info.load_numel_total = get_layer_size(layer)

        # Wrap each parameter's weight loader
        # Note that nested wrapping will occur for shared tensors
        for name, tensor in get_layer_tensors(layer).items():
            if _get_weight_loader(tensor).__name__ != "online_process_loader":
                tensor.weight_loader = make_online_process_loader(layer, name)

record_metadata_for_reloading

record_metadata_for_reloading(model: Module)

Record layer metadata needed for later reloading.

Stores parameter and buffer metadata as meta tensors for restoration. Must be called before initialize_layerwise_reload.

Source code in vllm/model_executor/model_loader/reload/layerwise.py
def record_metadata_for_reloading(model: torch.nn.Module):
    """
    Record layer metadata needed for later reloading.

    Stores parameter and buffer metadata as meta tensors for restoration.
    Must be called before `initialize_layerwise_reload`.
    """
    for layer in model.modules():
        info = get_layerwise_info(layer)
        info.restore_metadata = capture_layer_to_meta(layer)

set_torchao_reload_attrs

set_torchao_reload_attrs(
    model: Module, model_config: ModelConfig
)
Source code in vllm/model_executor/model_loader/reload/torchao_decorator.py
def set_torchao_reload_attrs(model: torch.nn.Module, model_config: ModelConfig):
    model._do_torchao_reload = True
    model._model_config = model_config

support_quantized_model_reload_from_hp_weights

support_quantized_model_reload_from_hp_weights(
    original_load_weights: FunctionType,
)

Decorator for load_weights method for AutoWeightsLoader.load_weights to support reloading high precision (bfloat16/float16/float32) weight for an already quantized model, this involves restoring the weights to a high precision weights and then online quantize the weights.

Only applies to torchao quantized models. Assumes that all model weights are loaded within a single weights iterator (cannot perform batched updates)

Source code in vllm/model_executor/model_loader/reload/torchao_decorator.py
def support_quantized_model_reload_from_hp_weights(original_load_weights: FunctionType):
    """
    Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
    reloading high precision (bfloat16/float16/float32) weight for an already quantized
    model, this involves restoring the weights to a high precision weights and
    then online quantize the weights.

    Only applies to torchao quantized models. Assumes that all model weights are
    loaded within a single weights iterator (cannot perform batched updates)
    """

    @wraps(original_load_weights)
    def patched_model_load_weights(
        self: "AutoWeightsLoader",
        weights: Iterable[tuple[str, torch.Tensor]],
        *args,
        **kwargs,
    ):
        model = self.module

        if not getattr(model, "_do_torchao_reload", False):
            return original_load_weights(self, weights, *args, **kwargs)

        initialize_layerwise_reload(model)
        loaded_weights = original_load_weights(self, weights, *args, **kwargs)
        finalize_layerwise_reload(model, model._model_config)

        return loaded_weights

    return patched_model_load_weights