Skip to content

vllm.model_executor.model_loader.reload.layerwise

LAYERWISE_INFO module-attribute

__all__ module-attribute

__all__ = [
    "get_layerwise_info",
    "record_metadata_for_reloading",
    "initialize_layerwise_reload",
    "finalize_layerwise_reload",
]

logger module-attribute

logger = init_logger(__name__)

_get_original_loader

_get_original_loader(tensor: Tensor) -> Callable

Return the weight loader with any layerwise wrappers removed

Source code in vllm/model_executor/model_loader/reload/layerwise.py
def _get_original_loader(tensor: torch.Tensor) -> Callable:
    """Return the weight loader with any layerwise wrappers removed"""
    loader = _get_weight_loader(tensor)
    while loader.__name__ == "online_process_loader":
        loader = loader.__wrapped__  # type: ignore[union-attr]

    return loader

_get_weight_loader

_get_weight_loader(tensor: Tensor)
Source code in vllm/model_executor/model_loader/reload/layerwise.py
def _get_weight_loader(tensor: torch.Tensor):
    return getattr(tensor, "weight_loader", default_weight_loader)

_layerwise_process

_layerwise_process(layer: Module, info: LayerReloadingInfo)

Finalize layer loading after all weights have been cached.

This function: 1. Materializes the layer onto the target device 2. Loads all cached weights 3. Runs quantization processing if applicable 4. Copies processed values back to original tensor storage

Source code in vllm/model_executor/model_loader/reload/layerwise.py
def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
    """
    Finalize layer loading after all weights have been cached.

    This function:
    1. Materializes the layer onto the target device
    2. Loads all cached weights
    3. Runs quantization processing if applicable
    4. Copies processed values back to original tensor storage
    """
    # Materialize layer tensors onto device
    materialize_layer(layer)

    # Unwrap layerwise loading wrappers
    for param in get_layer_tensors(layer).values():
        param.weight_loader = _get_original_loader(param)

    # Load all cached weights into materialized layer (using original loaders)
    for name, args in info.loaded_weights:
        param = getattr(layer, name)
        args.arguments["param"] = param
        param.weight_loader(*args.args, **args.kwargs)

    # Process weights (quantization, repacking, etc.)
    # Attention/MLA are processed in `finalize_layerwise_reload`
    quant_method = getattr(layer, "quant_method", None)
    if isinstance(quant_method, QuantizeMethodBase):
        quant_method.process_weights_after_loading(layer)

    # Copy processed values into original tensor storage (preserves cudagraph refs)
    # this code is a no-op if not reloading (because kernel tensors is empty)
    parameters, buffers = info.kernel_tensors
    for name, param in parameters.items():
        param.data.copy_(getattr(layer, name))
    for name, buffer in buffers.items():
        buffer.data.copy_(getattr(layer, name))

    _place_kernel_tensors(layer, info)

    info.reset()
    logger.debug("%s: Processed", layer.__class__.__name__)

_place_kernel_tensors

_place_kernel_tensors(
    layer: Module, info: LayerReloadingInfo
)
Source code in vllm/model_executor/model_loader/reload/layerwise.py
def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
    for name in get_layer_tensors(layer):
        delattr(layer, name)

    parameters, buffers = info.kernel_tensors
    for name, param in parameters.items():
        layer.register_parameter(name, param)
    for name, buffer in buffers.items():
        layer.register_buffer(name, buffer)

finalize_layerwise_reload

finalize_layerwise_reload(
    model: Module, model_config: ModelConfig
)

Remove the outermost layer of weight loading wrappers.

This function should be applied after initialize_layerwise_reload is applied unwrap the layerwise weight loaders.

Also processes Attention/MLA layers, which must be processed after all other layers

Source code in vllm/model_executor/model_loader/reload/layerwise.py
def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig):
    """
    Remove the outermost layer of weight loading wrappers.

    This function should be applied after `initialize_layerwise_reload` is applied
    unwrap the layerwise weight loaders.

    Also processes Attention/MLA layers, which must be processed after all other layers
    """
    model._do_torchao_reload = model._original_do_torchao_reload

    for layer in model.modules():
        info = get_layerwise_info(layer)

        # Attention/MLA layers are processed after all other layers
        if isinstance(layer, (Attention, MLAAttention)):
            if info.load_numel > 0:
                raise NotImplementedError(
                    "Layerwise reloading of Q/K/V scale weights is not implemented yet"
                )

            else:
                _place_kernel_tensors(layer, info)
                layer.process_weights_after_loading(model_config.dtype)

        # No weights were loaded, place kernel tensors back
        elif info.can_process() and info.load_numel <= 0:
            _place_kernel_tensors(layer, info)

        # Process non-attention layers which did not load all elements. This can happen
        # if the created weight has extra padding elements which are not loaded
        # Having too many of these delayed layers can lead to execess memory usage
        # see Limitations(4)
        elif info.load_numel > 0 and info.load_numel < info.load_numel_total:  # type: ignore[operator]
            logger.debug("%s: Delayed processing", layer.__class__.__name__)
            _layerwise_process(layer, info)

        info.reset()

get_layerwise_info

get_layerwise_info(layer: Module) -> LayerReloadingInfo

Get information related to restoring and layerwise processing. If no previous information existed, a new entry is constructed

Source code in vllm/model_executor/model_loader/reload/layerwise.py
def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
    """
    Get information related to restoring and layerwise processing. If no previous
    information existed, a new entry is constructed
    """
    if layer not in LAYERWISE_INFO:
        LAYERWISE_INFO[layer] = LayerReloadingInfo()

    return LAYERWISE_INFO[layer]

initialize_layerwise_reload

initialize_layerwise_reload(model: Module)

Set up layerwise weight loading with deferred processing.

Must be called after record_metadata_for_reloading. This function: 1. Saves current kernel tensors for later copying 2. Restores layer parameters/buffers from metadata (on meta device) 3. Wraps weight loaders to defer processing until all weights are loaded

When all weights for a layer are loaded, the wrapped loaders will: 1. Materialize the layer onto the target device 2. Load all cached weights 3. Run quantization processing if applicable 4. Copy processed values back to original tensor storage

Source code in vllm/model_executor/model_loader/reload/layerwise.py
@torch.no_grad()
def initialize_layerwise_reload(model: torch.nn.Module):
    """
    Set up layerwise weight loading with deferred processing.

    Must be called after `record_metadata_for_reloading`. This function:
    1. Saves current kernel tensors for later copying
    2. Restores layer parameters/buffers from metadata (on meta device)
    3. Wraps weight loaders to defer processing until all weights are loaded

    When all weights for a layer are loaded, the wrapped loaders will:
    1. Materialize the layer onto the target device
    2. Load all cached weights
    3. Run quantization processing if applicable
    4. Copy processed values back to original tensor storage
    """
    # disable torchao reloading to avoid infinite recursion
    model._original_do_torchao_reload = getattr(model, "_do_torchao_reload", False)
    model._do_torchao_reload = False

    for layer in model.modules():
        info = get_layerwise_info(layer)

        # Skip if the layer has already been initialized
        if info.can_process():
            continue

        # Save current tensors for later copying
        info.kernel_tensors = get_layer_params_buffers(layer)

        # Restore layer parameters/buffers onto meta device
        restore_layer_on_meta(layer, info)

        # Track loading progress to determine when to process/copy
        info.load_numel = 0
        info.load_numel_total = get_layer_size(layer)

        # Wrap each parameter's weight loader
        # Note that nested wrapping will occur for shared tensors
        for name, tensor in get_layer_tensors(layer).items():
            if _get_weight_loader(tensor).__name__ != "online_process_loader":
                tensor.weight_loader = make_online_process_loader(layer, name)

make_online_process_loader

make_online_process_loader(
    layer: Module, param_name: str
) -> Callable

Create a wrapped weight loader that defers processing.

Source code in vllm/model_executor/model_loader/reload/layerwise.py
def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable:
    """Create a wrapped weight loader that defers processing."""
    info = get_layerwise_info(layer)
    param = getattr(layer, param_name)
    original_loader = _get_original_loader(param)
    loader_signature = inspect.signature(original_loader)

    @wraps(original_loader, assigned=("__doc__", "__annotations__"))
    def online_process_loader(*args, **kwargs):
        if not info.can_process():
            # Unfortunately, some qconfigs are set up to load the same weight
            # multiple times. For example, CT_WNA16 loads `weight_shape` for
            # each of the qkv partitions. This results in layers loading extra
            # weights (beyond load_numel_total) after it's already processed.
            #
            # Best solution is to ensure that `load_numel_total` reflects the
            # actual number of weights loaded, either by modifying qconfigs to
            # create as many weights as loaded (see padding issue as well)
            # or maybe capturing how many weights are loaded on first pass
            #
            # For now, `load_numel_total` is still safe to use as long as
            # there's no way to reach `load_numel_total` without loading all
            # necessary weights. `weight_shape` is very small, so this is safe.
            # see Limitations(4)
            logger.debug("%s: Excessive loading", layer.__class__.__name__)
            return

        # Bind and normalize arguments
        bound_args = loader_signature.bind(*args, **kwargs)
        bound_args.apply_defaults()

        # Cache loaded weights, track loading progress
        info.loaded_weights.append((param_name, bound_args))
        num_loaded, ret = get_numel_loaded(original_loader, bound_args)
        info.load_numel += num_loaded

        logger.debug(
            "%s: %d / %d",
            layer.__class__.__name__,
            info.load_numel,
            info.load_numel_total,
        )

        # Process and copy when all weights are loaded
        if info.load_numel >= info.load_numel_total and not isinstance(  # type: ignore[operator]
            layer, (Attention, MLAAttention)
        ):
            _layerwise_process(layer, info)

        return ret

    return online_process_loader

record_metadata_for_reloading

record_metadata_for_reloading(model: Module)

Record layer metadata needed for later reloading.

Stores parameter and buffer metadata as meta tensors for restoration. Must be called before initialize_layerwise_reload.

Source code in vllm/model_executor/model_loader/reload/layerwise.py
def record_metadata_for_reloading(model: torch.nn.Module):
    """
    Record layer metadata needed for later reloading.

    Stores parameter and buffer metadata as meta tensors for restoration.
    Must be called before `initialize_layerwise_reload`.
    """
    for layer in model.modules():
        info = get_layerwise_info(layer)
        info.restore_metadata = capture_layer_to_meta(layer)