nemotron-speech (sha256:32e686ba622247e38ac0f3240c659f12d4c09e920e68bf802fcd2b34b7ae1575)

Published 2026-04-05 09:42:00 +00:00 by j

Installation

docker pull forge.jde.nz/public/nemotron-speech@sha256:32e686ba622247e38ac0f3240c659f12d4c09e920e68bf802fcd2b34b7ae1575
sha256:32e686ba622247e38ac0f3240c659f12d4c09e920e68bf802fcd2b34b7ae1575

Image layers

ARG RELEASE
ARG LAUNCHPAD_BUILD_ARCH
LABEL org.opencontainers.image.ref.name=ubuntu
LABEL org.opencontainers.image.version=24.04
ADD file:b4619a63cd7829e1338ddaa4995ca17003002dd54b0dfd675a6f54a2b69151a6 in /
CMD ["/bin/bash"]
ENV NVARCH=x86_64
ENV NVIDIA_REQUIRE_CUDA=cuda>=13.0 brand=unknown,driver>=535,driver<536 brand=grid,driver>=535,driver<536 brand=tesla,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=vapps,driver>=535,driver<536 brand=vpc,driver>=535,driver<536 brand=vcs,driver>=535,driver<536 brand=vws,driver>=535,driver<536 brand=cloudgaming,driver>=535,driver<536 brand=unknown,driver>=550,driver<551 brand=grid,driver>=550,driver<551 brand=tesla,driver>=550,driver<551 brand=nvidia,driver>=550,driver<551 brand=quadro,driver>=550,driver<551 brand=quadrortx,driver>=550,driver<551 brand=nvidiartx,driver>=550,driver<551 brand=vapps,driver>=550,driver<551 brand=vpc,driver>=550,driver<551 brand=vcs,driver>=550,driver<551 brand=vws,driver>=550,driver<551 brand=cloudgaming,driver>=550,driver<551 brand=unknown,driver>=565,driver<566 brand=grid,driver>=565,driver<566 brand=tesla,driver>=565,driver<566 brand=nvidia,driver>=565,driver<566 brand=quadro,driver>=565,driver<566 brand=quadrortx,driver>=565,driver<566 brand=nvidiartx,driver>=565,driver<566 brand=vapps,driver>=565,driver<566 brand=vpc,driver>=565,driver<566 brand=vcs,driver>=565,driver<566 brand=vws,driver>=565,driver<566 brand=cloudgaming,driver>=565,driver<566 brand=unknown,driver>=570,driver<571 brand=grid,driver>=570,driver<571 brand=tesla,driver>=570,driver<571 brand=nvidia,driver>=570,driver<571 brand=quadro,driver>=570,driver<571 brand=quadrortx,driver>=570,driver<571 brand=nvidiartx,driver>=570,driver<571 brand=vapps,driver>=570,driver<571 brand=vpc,driver>=570,driver<571 brand=vcs,driver>=570,driver<571 brand=vws,driver>=570,driver<571 brand=cloudgaming,driver>=570,driver<571 brand=unknown,driver>=575,driver<576 brand=grid,driver>=575,driver<576 brand=tesla,driver>=575,driver<576 brand=nvidia,driver>=575,driver<576 brand=quadro,driver>=575,driver<576 brand=quadrortx,driver>=575,driver<576 brand=nvidiartx,driver>=575,driver<576 brand=vapps,driver>=575,driver<576 brand=vpc,driver>=575,driver<576 brand=vcs,driver>=575,driver<576 brand=vws,driver>=575,driver<576 brand=cloudgaming,driver>=575,driver<576
ENV NV_CUDA_CUDART_VERSION=13.0.48-1
ARG TARGETARCH
LABEL maintainer=NVIDIA CORPORATION <cudatools@nvidia.com>
RUN |1 TARGETARCH=amd64 /bin/sh -c apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates && curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${NVARCH}/3bf863cc.pub | apt-key add - && echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${NVARCH} /" > /etc/apt/sources.list.d/cuda.list && apt-get purge --autoremove -y curl && rm -rf /var/lib/apt/lists/* # buildkit
ENV CUDA_VERSION=13.0.0
RUN |1 TARGETARCH=amd64 /bin/sh -c apt-get update && apt-get install -y --no-install-recommends cuda-cudart-13-0=${NV_CUDA_CUDART_VERSION} cuda-compat-13-0 && rm -rf /var/lib/apt/lists/* # buildkit
RUN |1 TARGETARCH=amd64 /bin/sh -c echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/nvidia.conf # buildkit
ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64
COPY NGC-DL-CONTAINER-LICENSE / # buildkit
ENV NVIDIA_VISIBLE_DEVICES=all
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENV NV_CUDA_LIB_VERSION=13.0.0-1
ENV NV_NVTX_VERSION=13.0.39-1
ENV NV_LIBNPP_VERSION=13.0.0.50-1
ENV NV_LIBNPP_PACKAGE=libnpp-13-0=13.0.0.50-1
ENV NV_LIBCUSPARSE_VERSION=12.6.2.49-1
ENV NV_LIBCUBLAS_PACKAGE_NAME=libcublas-13-0
ENV NV_LIBCUBLAS_VERSION=13.0.0.19-1
ENV NV_LIBCUBLAS_PACKAGE=libcublas-13-0=13.0.0.19-1
ENV NV_LIBNCCL_PACKAGE_NAME=libnccl2
ENV NV_LIBNCCL_PACKAGE_VERSION=2.27.7-1
ENV NCCL_VERSION=2.27.7-1
ENV NV_LIBNCCL_PACKAGE=libnccl2=2.27.7-1+cuda13.0
ARG TARGETARCH
LABEL maintainer=NVIDIA CORPORATION <cudatools@nvidia.com>
RUN |1 TARGETARCH=amd64 /bin/sh -c apt-get update && apt-get install -y --no-install-recommends cuda-libraries-13-0=${NV_CUDA_LIB_VERSION} ${NV_LIBNPP_PACKAGE} cuda-nvtx-13-0=${NV_NVTX_VERSION} libcusparse-13-0=${NV_LIBCUSPARSE_VERSION} ${NV_LIBCUBLAS_PACKAGE} ${NV_LIBNCCL_PACKAGE} && rm -rf /var/lib/apt/lists/* # buildkit
RUN |1 TARGETARCH=amd64 /bin/sh -c apt-mark hold ${NV_LIBCUBLAS_PACKAGE_NAME} ${NV_LIBNCCL_PACKAGE_NAME} # buildkit
COPY entrypoint.d/ /opt/nvidia/entrypoint.d/ # buildkit
COPY nvidia_entrypoint.sh /opt/nvidia/ # buildkit
ENV NVIDIA_PRODUCT_NAME=CUDA
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
ENV NV_CUDA_LIB_VERSION=13.0.0-1
ENV NV_CUDA_CUDART_DEV_VERSION=13.0.48-1
ENV NV_NVML_DEV_VERSION=13.0.39-1
ENV NV_LIBCUSPARSE_DEV_VERSION=12.6.2.49-1
ENV NV_LIBNPP_DEV_VERSION=13.0.0.50-1
ENV NV_LIBNPP_DEV_PACKAGE=libnpp-dev-13-0=13.0.0.50-1
ENV NV_LIBCUBLAS_DEV_VERSION=13.0.0.19-1
ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME=libcublas-dev-13-0
ENV NV_LIBCUBLAS_DEV_PACKAGE=libcublas-dev-13-0=13.0.0.19-1
ENV NV_CUDA_NSIGHT_COMPUTE_VERSION=13.0.0-1
ENV NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE=cuda-nsight-compute-13-0=13.0.0-1
ENV NV_LIBNCCL_DEV_PACKAGE_NAME=libnccl-dev
ENV NV_LIBNCCL_DEV_PACKAGE_VERSION=2.27.7-1
ENV NCCL_VERSION=2.27.7-1
ENV NV_LIBNCCL_DEV_PACKAGE=libnccl-dev=2.27.7-1+cuda13.0
ARG TARGETARCH
LABEL maintainer=NVIDIA CORPORATION <cudatools@nvidia.com>
RUN |1 TARGETARCH=amd64 /bin/sh -c apt-get update && apt-get install -y --no-install-recommends cuda-cudart-dev-13-0=${NV_CUDA_CUDART_DEV_VERSION} cuda-command-line-tools-13-0=${NV_CUDA_LIB_VERSION} cuda-minimal-build-13-0=${NV_CUDA_LIB_VERSION} cuda-libraries-dev-13-0=${NV_CUDA_LIB_VERSION} cuda-nvml-dev-13-0=${NV_NVML_DEV_VERSION} ${NV_LIBNPP_DEV_PACKAGE} libcusparse-dev-13-0=${NV_LIBCUSPARSE_DEV_VERSION} ${NV_LIBCUBLAS_DEV_PACKAGE} ${NV_LIBNCCL_DEV_PACKAGE} ${NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE} && rm -rf /var/lib/apt/lists/* # buildkit
RUN |1 TARGETARCH=amd64 /bin/sh -c apt-mark hold ${NV_LIBCUBLAS_DEV_PACKAGE_NAME} ${NV_LIBNCCL_DEV_PACKAGE_NAME} # buildkit
ENV LIBRARY_PATH=/usr/local/cuda/lib64/stubs
LABEL maintainer=nemotron-speech
LABEL description=Unified ASR + TTS + LLM container (ARM64 sm_121 CUDA 13.1 / x86_64 sm_120 CUDA 13.0)
LABEL version=1.2
ENV DEBIAN_FRONTEND=noninteractive
RUN /bin/sh -c apt-get update && apt-get install -y --no-install-recommends python3.12 python3.12-dev python3.12-venv python3-pip git curl wget cmake ninja-build ccache libopenblas-dev libomp-dev libffi-dev libssl-dev libnuma-dev libcurl4-openssl-dev ffmpeg sox libsndfile1 && rm -rf /var/lib/apt/lists/* && ln -sf /usr/bin/python3.12 /usr/bin/python3 && ln -sf /usr/bin/python3 /usr/bin/python # buildkit
COPY /uv /uvx /bin/ # buildkit
ENV UV_SYSTEM_PYTHON=1
ENV UV_BREAK_SYSTEM_PACKAGES=1
COPY /usr/lib/*/libcudnn* /tmp/cudnn_libs/ # buildkit
COPY /usr/include/cudnn* /usr/include/ # buildkit
COPY /usr/lib/*/libnccl* /tmp/nccl_libs/ # buildkit
COPY /usr/include/nccl.h /usr/include/ # buildkit
COPY /usr/include/nccl_device.h /usr/include/ # buildkit
COPY /usr/include/nccl_device/ /usr/include/nccl_device/ # buildkit
RUN /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then LIB_DIR="/usr/lib/aarch64-linux-gnu"; else LIB_DIR="/usr/lib/x86_64-linux-gnu"; fi && echo "=== Detected architecture: $ARCH, using $LIB_DIR ===" && mkdir -p "$LIB_DIR" && mv /tmp/cudnn_libs/* "$LIB_DIR/" && mv /tmp/nccl_libs/* "$LIB_DIR/" && rmdir /tmp/cudnn_libs /tmp/nccl_libs && echo "=== cuDNN libraries ===" && ls -la "$LIB_DIR"/libcudnn* | head -3 && echo "=== NCCL libraries ===" && ls -la "$LIB_DIR"/libnccl* | head -3 && echo "=== NCCL device headers ===" && ls -la /usr/include/nccl_device/ | head -3 && echo "=== Checking ncclDevCommDestroy symbol ===" && nm -D "$LIB_DIR"/libnccl.so* 2>/dev/null | grep ncclDevCommDestroy | head -1 || echo "Note: nm check may fail on stripped library" && ldconfig # buildkit
RUN /bin/sh -c uv pip install --no-cache numpy pyyaml typing_extensions sympy filelock networkx jinja2 fsspec packaging setuptools wheel cffi future requests dataclasses pillow expecttest hypothesis pytest # buildkit
ARG PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f
WORKDIR /build
RUN |1 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f /bin/sh -c git clone --recursive https://github.com/pytorch/pytorch.git && cd pytorch && git checkout ${PYTORCH_COMMIT} # buildkit
WORKDIR /build/pytorch
RUN |1 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f /bin/sh -c git submodule sync && git submodule update --init --recursive # buildkit
ENV USE_CUDA=1
ENV USE_CUDNN=1
ENV USE_MKLDNN=1
ENV USE_DISTRIBUTED=1
ENV USE_NCCL=1
ENV USE_TENSORPIPE=0
ENV USE_SYSTEM_NCCL=1
ENV NCCL_ROOT=/usr
ENV NCCL_INCLUDE_DIR=/usr/include
ENV BUILD_TEST=0
ENV MAX_JOBS=8
ENV CMAKE_BUILD_TYPE=Release
ENV CUDNN_INCLUDE_DIR=/usr/include
ENV USE_PRIORITIZED_TEXT_FOR_LD=1
RUN |1 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f /bin/sh -c ln -sf /usr/local/cuda/include/cccl/cub /usr/local/cuda/include/cub && ln -sf /usr/local/cuda/include/cccl/thrust /usr/local/cuda/include/thrust && echo "=== Verifying CUB access ===" && ls /usr/local/cuda/include/cub/cub.cuh && ls /usr/local/cuda/include/cccl/cub/cub.cuh # buildkit
ENV CUB_INCLUDE_DIR=/usr/local/cuda/include/cccl
RUN |1 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then export TORCH_CUDA_ARCH_LIST="12.1"; export NCCL_LIB_DIR="/usr/lib/aarch64-linux-gnu"; export CUDNN_LIB_DIR="/usr/lib/aarch64-linux-gnu"; else export TORCH_CUDA_ARCH_LIST="12.0"; export NCCL_LIB_DIR="/usr/lib/x86_64-linux-gnu"; export CUDNN_LIB_DIR="/usr/lib/x86_64-linux-gnu"; fi && echo "=== Building PyTorch for $ARCH with CUDA arch $TORCH_CUDA_ARCH_LIST ===" && python3 setup.py bdist_wheel > /tmp/pytorch_build.log 2>&1 || { tail -50 /tmp/pytorch_build.log; exit 1; } && uv pip install --no-cache dist/*.whl && mkdir -p /tmp/pytorch_wheel && cp dist/*.whl /tmp/pytorch_wheel/ # buildkit
ARG TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73
WORKDIR /build
RUN |2 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 /bin/sh -c git clone --recursive https://github.com/pytorch/audio.git && cd audio && git checkout ${TORCHAUDIO_COMMIT} # buildkit
WORKDIR /build/audio
ENV BUILD_SOX=0
ENV USE_CUDA=1
RUN |2 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then export TORCH_CUDA_ARCH_LIST="12.1"; else export TORCH_CUDA_ARCH_LIST="12.0"; fi && echo "=== Building torchaudio for $ARCH with CUDA arch $TORCH_CUDA_ARCH_LIST ===" && python3 setup.py bdist_wheel > /tmp/torchaudio_build.log 2>&1 || { tail -50 /tmp/torchaudio_build.log; exit 1; } && uv pip install --no-cache dist/*.whl && mkdir -p /tmp/torchaudio_wheel && cp dist/*.whl /tmp/torchaudio_wheel/ # buildkit
WORKDIR /workspace
RUN |2 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 /bin/sh -c rm -rf /build/pytorch /build/audio # buildkit
ARG NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc
RUN |3 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc /bin/sh -c git clone https://github.com/NVIDIA/NeMo.git /opt/nemo && cd /opt/nemo && git checkout ${NEMO_COMMIT} # buildkit
RUN |3 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc /bin/sh -c uv pip install --no-cache Cython hydra-core>=1.3.0 omegaconf>=2.3 pytorch-lightning>=2.0 torchmetrics>=0.11.0 transformers>=4.36.0 sentencepiece webdataset lhotse>=1.20.0 braceexpand editdistance g2p_en inflect kaldi-python-io kaldiio librosa>=0.10.0 marshmallow ruamel.yaml soundfile text-unidecode numba kaldialign # buildkit
RUN |3 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc /bin/sh -c cd /opt/nemo && uv pip install --no-cache -e ".[asr,tts]" # buildkit
COPY <<EOF /tmp/patch_nvrtc.py # buildkit
RUN |3 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc /bin/sh -c python3 /tmp/patch_nvrtc.py || echo "Patch may have already been applied" # buildkit
RUN |3 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc /bin/sh -c uv pip install --no-cache tokenizers>=0.19 fastapi uvicorn[standard] pydantic>=2.0 prometheus_client py-cpuinfo tiktoken lm-format-enforcer outlines xgrammar pyzmq msgspec gguf compressed-tensors importlib_metadata mistral_common>=1.5.0 partial-json-parser # buildkit
ARG VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a
ARG VLLM_CACHE_BUSTER=v1
WORKDIR /build
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c git clone https://github.com/vllm-project/vllm.git && cd vllm && git checkout ${VLLM_COMMIT} # buildkit
WORKDIR /build/vllm
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c sed -i 's/cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f"/cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f"/g' CMakeLists.txt && echo "=== Verifying SM120/SM121 patch applied ===" && grep -n "SCALED_MM_ARCHS" CMakeLists.txt | head -5 # buildkit
ENV VLLM_TARGET_DEVICE=cuda
ENV MAX_JOBS=8
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c curl -fsSL https://raw.githubusercontent.com/pipecat-ai/nemotron-january-2026/main/patches/vllm-pr31607-sm121-support.patch -o /tmp/vllm-pr31607-sm121-support.patch # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c curl -fsSL https://raw.githubusercontent.com/pipecat-ai/nemotron-january-2026/main/patches/apply-vllm-pr31607.py -o /tmp/apply-vllm-pr31607.py # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c uv pip install --no-cache cmake>=3.26 ninja packaging setuptools>=61 setuptools-scm>=8 wheel jinja2 # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then export TORCH_CUDA_ARCH_LIST="12.1"; else export TORCH_CUDA_ARCH_LIST="12.0"; fi && echo "=== Building vLLM for $ARCH with CUDA arch $TORCH_CUDA_ARCH_LIST ===" && uv pip install --no-cache -e . --no-build-isolation > /tmp/vllm_build.log 2>&1 || { tail -50 /tmp/vllm_build.log; exit 1; } # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 /tmp/apply-vllm-pr31607.py && patch -p1 < /tmp/vllm-pr31607-sm121-support.patch || echo "Patch may already be applied" # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/quantization/modelopt.py") text = path.read_text() if "VLLM_FP8_FORCE_BACKEND" not in text: if "import os" not in text: needle = "from vllm.logger import init_logger\n" if needle not in text: raise SystemExit("Failed to locate logger import for os insertion") text = text.replace(needle, "import os\n" + needle, 1) insert = ( " if current_platform.is_device_capability(121):\n" " self.fp8_linear.preferred_backend = \"cutlass\"\n" " logger.info_once(\n" " \"Forcing FP8 linear backend to cutlass on SM121 (flashinfer is incorrect)\"\n" " )\n" "\n" " override = os.getenv(\"VLLM_FP8_FORCE_BACKEND\")\n" " if override:\n" " self.fp8_linear.preferred_backend = override.lower()\n" " logger.info_once(\n" " \"Forcing FP8 linear backend to %s via VLLM_FP8_FORCE_BACKEND\",\n" " self.fp8_linear.preferred_backend,\n" " )\n" "\n" ) if "Forcing FP8 linear backend" not in text: lines = text.splitlines() start = None end = None for i, line in enumerate(lines): if "self.fp8_linear = Fp8LinearOp(" in line: start = i continue if start is not None and line.strip() == ")": end = i break if start is None or end is None: raise SystemExit("Failed to locate fp8_linear init block in ModelOptFp8LinearMethod") insert_lines = insert.rstrip("\n").split("\n") lines[end + 1:end + 1] = insert_lines text = "\n".join(lines) + "\n" path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/quantization/modelopt.py") text = path.read_text() if "from vllm.platforms import current_platform" not in text: needle = "import vllm.envs as envs" if needle not in text: raise SystemExit("Failed to locate envs import for insertion") text = text.replace( needle, "import vllm.envs as envs\nfrom vllm.platforms import current_platform", 1, ) old = "if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():" new = ( "if current_platform.is_device_capability(121):\n" " logger.info_once(\n" " \"Disabling FlashInfer FP8 MoE kernels on SM121; using fallback.\"\n" " )\n" " elif envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():" ) if old in text and "Disabling FlashInfer FP8 MoE kernels on SM121" not in text: text = text.replace(old, new) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py") text = path.read_text() if "import os" not in text: needle = "import torch\n" if needle not in text: raise SystemExit("Failed to locate torch import for insertion") text = text.replace(needle, "import os\nimport torch\n", 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py") text = path.read_text() if "per_tensor_dequantize" not in text: needle = "from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6\n" if needle not in text: raise SystemExit("Failed to locate quantization util imports") text = text.replace( needle, needle + "from vllm.model_executor.layers.quantization.utils.w8a8_utils import per_tensor_dequantize\n", 1, ) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py") text = path.read_text() if "_FP8_MOE_ACT_SCALE_LOGGED" not in text: needle = "logger = init_logger(__name__)\n" if needle not in text: raise SystemExit("Failed to locate logger for insertion") text = text.replace(needle, needle + "\n_FP8_MOE_ACT_SCALE_LOGGED = 0\n", 1) if "global _FP8_MOE_ACT_SCALE_LOGGED" not in text: lines = text.splitlines(keepends=True) def_line = None end_sig = None for i, line in enumerate(lines): if def_line is None and line.startswith("def fused_experts("): def_line = i continue if def_line is not None and end_sig is None and line.strip() == ") -> torch.Tensor:": end_sig = i break if end_sig is None: raise SystemExit("Failed to locate fused_experts signature end") lines.insert(end_sig + 1, " global _FP8_MOE_ACT_SCALE_LOGGED\n") impl_def_line = None impl_end_sig = None for i, line in enumerate(lines): if impl_def_line is None and line.startswith("def fused_experts_impl("): impl_def_line = i continue if impl_def_line is not None and impl_end_sig is None and line.strip() == ") -> torch.Tensor:": impl_end_sig = i break if impl_end_sig is None: raise SystemExit("Failed to locate fused_experts_impl signature end") lines.insert(impl_end_sig + 1, " global _FP8_MOE_ACT_SCALE_LOGGED\n") text = "".join(lines) anchor = ( "qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(\n" " A=curr_hidden_states,\n" " A_scale=a1_scale,\n" " quant_dtype=quant_dtype,\n" " per_act_token_quant=per_channel_quant,\n" " block_shape=block_shape,\n" " )\n" ) if "FP8 MoE a1q_scale" not in text and anchor in text: insert = ( "qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(\n" " A=curr_hidden_states,\n" " A_scale=a1_scale,\n" " quant_dtype=quant_dtype,\n" " per_act_token_quant=per_channel_quant,\n" " block_shape=block_shape,\n" " )\n" " if os.getenv(\"VLLM_FP8_MOE_ACT_SCALE_DEBUG\") == \"1\":\n" " if _FP8_MOE_ACT_SCALE_LOGGED < 1:\n" " _FP8_MOE_ACT_SCALE_LOGGED += 1\n" " def _sinfo(name, t):\n" " if t is None:\n" " return f\"{name}=None\"\n" " t_float = t.float()\n" " return (\n" " f\"{name}: shape={tuple(t.shape)} dtype={t.dtype} \"\n" " f\"min={t_float.min().item():.6g} max={t_float.max().item():.6g}\"\n" " )\n" " logger.info_once(\n" " \"FP8 MoE a1q_scale: %s\",\n" " _sinfo(\"a1q_scale\", a1q_scale),\n" " )\n" ) text = text.replace(anchor, insert, 1) anchor2 = ( "qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(\n" " A=intermediate_cache2,\n" " A_scale=a2_scale,\n" " quant_dtype=quant_dtype,\n" " per_act_token_quant=per_channel_quant,\n" " block_shape=block_shape,\n" " )\n" ) if "FP8 MoE a2q_scale" not in text and anchor2 in text: insert2 = ( "qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(\n" " A=intermediate_cache2,\n" " A_scale=a2_scale,\n" " quant_dtype=quant_dtype,\n" " per_act_token_quant=per_channel_quant,\n" " block_shape=block_shape,\n" " )\n" " if os.getenv(\"VLLM_FP8_MOE_ACT_SCALE_DEBUG\") == \"1\":\n" " if _FP8_MOE_ACT_SCALE_LOGGED == 1:\n" " _FP8_MOE_ACT_SCALE_LOGGED += 1\n" " def _sinfo(name, t):\n" " if t is None:\n" " return f\"{name}=None\"\n" " t_float = t.float()\n" " return (\n" " f\"{name}: shape={tuple(t.shape)} dtype={t.dtype} \"\n" " f\"min={t_float.min().item():.6g} max={t_float.max().item():.6g}\"\n" " )\n" " logger.info_once(\n" " \"FP8 MoE a2q_scale: %s\",\n" " _sinfo(\"a2q_scale\", a2q_scale),\n" " )\n" ) text = text.replace(anchor2, insert2, 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/models/nemotron_h.py") text = path.read_text() if "import os" not in text: needle = "import typing\n" if needle not in text: raise SystemExit("Failed to locate import typing for insertion") text = text.replace(needle, needle + "import os\n", 1) old = ( " fused_moe_out = self.experts(\n" " hidden_states=hidden_states, router_logits=router_logits\n" " )\n" ) if old in text and "VLLM_NEMO_BYPASS_MOE" not in text: new = ( " if os.getenv(\"VLLM_NEMO_BYPASS_MOE\") == \"1\":\n" " if self.shared_experts is not None and not self.use_latent_moe:\n" " shared_output = self.shared_experts(hidden_states)\n" " fused_moe_out = (shared_output, hidden_states)\n" " else:\n" " fused_moe_out = self.experts(\n" " hidden_states=hidden_states, router_logits=router_logits\n" " )\n" ) text = text.replace(old, new, 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/models/nemotron_h.py") text = path.read_text() if "init_logger" not in text: needle = "from torch import nn\n\n" if needle not in text: raise SystemExit("Failed to locate torch import block for logger insertion") insert = needle + "from vllm.logger import init_logger\n\nlogger = init_logger(__name__)\n\n_NEMO_SHARED_REF_USED = 0\n\n" text = text.replace(needle, insert, 1) block = ( " if self.use_latent_moe:\n" " _, final_hidden_states = fused_moe_out\n" " else:\n" " shared_output, final_hidden_states = fused_moe_out\n" ) if block in text and "VLLM_NEMO_SHARED_REF_COMPARE" not in text: extra = ( " if (\n" " os.getenv(\"VLLM_NEMO_SHARED_REF_COMPARE\") == \"1\"\n" " and (not self.use_latent_moe)\n" " and shared_output is not None\n" " and self.shared_experts is not None\n" " ):\n" " global _NEMO_SHARED_REF_USED\n" " if _NEMO_SHARED_REF_USED == 0:\n" " ref = self.shared_experts(hidden_states.to(dtype=torch.bfloat16))\n" " diff = (ref.to(torch.float32) - shared_output.to(torch.float32)).abs()\n" " logger.info(\n" " \"NemotronH shared_experts ref compare: mean_abs=%s max_abs=%s ref_norm=%s out_norm=%s\",\n" " diff.mean().item(),\n" " diff.max().item(),\n" " ref.to(torch.float32).norm().item(),\n" " shared_output.to(torch.float32).norm().item(),\n" " )\n" " _NEMO_SHARED_REF_USED = 1\n" ) text = text.replace(block, block + extra, 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py") text = path.read_text() if "_FP8_MOE_REF_USED" not in text: needle = "logger = init_logger(__name__)\n" if needle not in text: raise SystemExit("Failed to locate logger for insertion") text = text.replace(needle, needle + "\n_FP8_MOE_REF_USED = 0\n", 1) if "_pytorch_fp8_moe_fallback" not in text: anchor = ( "def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:\n" " return torch.ops.vllm.outplace_fused_experts(**kwargs)\n\n\n" ) if anchor not in text: raise SystemExit("Failed to locate torch_vllm_outplace_fused_experts anchor") fallback = '''def _pytorch_fp8_moe_fallback( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, apply_router_weight_on_input: bool, expert_map: torch.Tensor | None, quant_config: FusedMoEQuantConfig, ) -> torch.Tensor: # Slow but safe fallback for SM121 FP8 MoE. if expert_map is not None: topk_ids = expert_map[topk_ids] num_tokens = hidden_states.size(0) top_k = topk_ids.size(1) out = torch.zeros_like(hidden_states, dtype=torch.float32) w1_scale = quant_config.w1_scale w2_scale = quant_config.w2_scale a1_scale = quant_config.a1_scale a2_scale = quant_config.a2_scale if os.getenv("VLLM_MOE_FALLBACK_DEBUG") == "1": def _tinfo(name, t): if t is None: return f\"{name}=None\" with torch.no_grad(): t_float = t.float() return ( f\"{name}: shape={tuple(t.shape)} dtype={t.dtype} \" f\"min={t_float.min().item():.6g} max={t_float.max().item():.6g}\" ) print(\"[moe_fallback_debug]\", _tinfo(\"w1_scale\", w1_scale)) print(\"[moe_fallback_debug]\", _tinfo(\"w2_scale\", w2_scale)) print(\"[moe_fallback_debug]\", _tinfo(\"a1_scale\", a1_scale)) print(\"[moe_fallback_debug]\", _tinfo(\"a2_scale\", a2_scale)) force_bf16 = os.getenv("VLLM_FP8_MOE_FALLBACK_BF16") == "1" for t in range(num_tokens): x = hidden_states[t].to(torch.float32) for k in range(top_k): expert_id = int(topk_ids[t, k].item()) if expert_id < 0 or expert_id >= w1.size(0): continue weight = topk_weights[t, k].to(torch.float32) x_in = x * weight if apply_router_weight_on_input else x # a1_scale intentionally ignored in this variant w1_e = w1[expert_id] w2_e = w2[expert_id] if w1_scale is not None: s1 = w1_scale[expert_id] if s1.numel() > 1: s1 = s1.max() w1_e = per_tensor_dequantize(w1_e, s1) else: w1_e = w1_e.to(torch.float16) if w2_scale is not None: s2 = w2_scale[expert_id] if s2.numel() > 1: s2 = s2.max() w2_e = per_tensor_dequantize(w2_e, s2) else: w2_e = w2_e.to(torch.float16) if force_bf16: w1_e = w1_e.to(torch.bfloat16) w2_e = w2_e.to(torch.bfloat16) x_in = x_in.to(torch.bfloat16) else: w1_e = w1_e.to(torch.float32) w2_e = w2_e.to(torch.float32) z = torch.matmul(w1_e, x_in) if activation == "silu": if z.numel() % 2 == 0: a, b = z.chunk(2, dim=-1) h = F.silu(a) * b else: h = F.silu(z) elif activation == "relu2_no_mul": h = F.relu(z) ** 2 elif activation == "gelu": h = F.gelu(z) else: h = F.silu(z) # a2_scale intentionally ignored in this variant y = torch.matmul(w2_e, h) if not apply_router_weight_on_input: y = y * weight out[t] += y.float() return out.to(hidden_states.dtype) ''' text = text.replace(anchor, anchor + fallback, 1) # Inject fallback / reference use in fused_experts old = " if quant_config is None:\\n quant_config = FUSED_MOE_UNQUANTIZED_CONFIG\\n\\n" new = ( " if quant_config is None:\\n" " quant_config = FUSED_MOE_UNQUANTIZED_CONFIG\\n\\n" " if quant_config.use_fp8_w8a8 and current_platform.is_device_capability(121):\\n" " if os.getenv(\"VLLM_FP8_MOE_REF_ONCE\") == \"1\":\\n" " global _FP8_MOE_REF_USED\\n" " if _FP8_MOE_REF_USED < 1:\\n" " _FP8_MOE_REF_USED += 1\\n" " ref_out = _pytorch_fp8_moe_fallback(\\n" " hidden_states=hidden_states,\\n" " w1=w1,\\n" " w2=w2,\\n" " topk_weights=topk_weights,\\n" " topk_ids=topk_ids,\\n" " activation=activation,\\n" " apply_router_weight_on_input=apply_router_weight_on_input,\\n" " expert_map=expert_map,\\n" " quant_config=quant_config,\\n" " )\\n" " if os.getenv(\"VLLM_FP8_MOE_REF_COMPARE\") == \"1\":\\n" " cmp_out = dispatch_fused_experts_func(inplace)(\\n" " hidden_states=hidden_states,\\n" " w1=w1,\\n" " w2=w2,\\n" " topk_weights=topk_weights,\\n" " topk_ids=topk_ids,\\n" " activation=activation,\\n" " apply_router_weight_on_input=apply_router_weight_on_input,\\n" " use_fp8_w8a8=quant_config.use_fp8_w8a8,\\n" " use_int8_w8a8=quant_config.use_int8_w8a8,\\n" " use_int8_w8a16=quant_config.use_int8_w8a16,\\n" " use_int4_w4a16=quant_config.use_int4_w4a16,\\n" " ocp_mx_scheme=quant_config.ocp_mx_scheme,\\n" " per_channel_quant=quant_config.per_act_token_quant,\\n" " global_num_experts=global_num_experts,\\n" " expert_map=expert_map,\\n" " w1_scale=quant_config.w1_scale,\\n" " w2_scale=quant_config.w2_scale,\\n" " w1_zp=quant_config.w1_zp,\\n" " w2_zp=quant_config.w2_zp,\\n" " a1_scale=quant_config.a1_scale,\\n" " a2_scale=quant_config.a2_scale,\\n" " block_shape=quant_config.block_shape,\\n" " w1_bias=quant_config.w1_bias,\\n" " w2_bias=quant_config.w2_bias,\\n" " )\\n" " with torch.no_grad():\\n" " diff = (ref_out.float() - cmp_out.float()).abs()\\n" " logger.info_once(\\n" " \"FP8 MoE ref compare: mean_abs=%.6g max_abs=%.6g ref_norm=%.6g out_norm=%.6g\",\\n" " diff.mean().item(),\\n" " diff.max().item(),\\n" " ref_out.float().norm().item(),\\n" " cmp_out.float().norm().item(),\\n" " )\\n" " if os.getenv(\"VLLM_FP8_MOE_REF_RETURN\", \"1\") != \"0\":\\n" " return ref_out\\n" " if os.getenv(\"VLLM_FP8_MOE_FORCE_FALLBACK\", \"1\") != \"0\":\\n" " return _pytorch_fp8_moe_fallback(\\n" " hidden_states=hidden_states,\\n" " w1=w1,\\n" " w2=w2,\\n" " topk_weights=topk_weights,\\n" " topk_ids=topk_ids,\\n" " activation=activation,\\n" " apply_router_weight_on_input=apply_router_weight_on_input,\\n" " expert_map=expert_map,\\n" " quant_config=quant_config,\\n" " )\\n\\n" ) if old in text and "_pytorch_fp8_moe_fallback" in text: text = text.replace(old, new, 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py") text = path.read_text() if "VLLM_FP8_MOE_REF_ONCE" not in text: if "_FP8_MOE_REF_USED" not in text: needle = "logger = init_logger(__name__)\n" if needle not in text: raise SystemExit("Failed to locate logger for insertion") text = text.replace(needle, needle + "\n_FP8_MOE_REF_USED = 0\n", 1) needle = " if quant_config is None:\n quant_config = FUSED_MOE_UNQUANTIZED_CONFIG\n\n" if needle not in text: raise SystemExit("Failed to locate quant_config init for insertion") insert = "".join([ needle, " if quant_config.use_fp8_w8a8 and current_platform.is_device_capability(121):\n", " if os.getenv(\"VLLM_FP8_MOE_REF_ONCE\") == \"1\":\n", " global _FP8_MOE_REF_USED\n", " if _FP8_MOE_REF_USED < 1:\n", " _FP8_MOE_REF_USED += 1\n", " ref_out = _pytorch_fp8_moe_fallback(\n", " hidden_states=hidden_states,\n", " w1=w1,\n", " w2=w2,\n", " topk_weights=topk_weights,\n", " topk_ids=topk_ids,\n", " activation=activation,\n", " apply_router_weight_on_input=apply_router_weight_on_input,\n", " expert_map=expert_map,\n", " quant_config=quant_config,\n", " )\n", " if os.getenv(\"VLLM_FP8_MOE_REF_COMPARE\") == \"1\":\n", " cmp_out = dispatch_fused_experts_func(inplace)(\n", " hidden_states=hidden_states,\n", " w1=w1,\n", " w2=w2,\n", " topk_weights=topk_weights,\n", " topk_ids=topk_ids,\n", " activation=activation,\n", " apply_router_weight_on_input=apply_router_weight_on_input,\n", " use_fp8_w8a8=quant_config.use_fp8_w8a8,\n", " use_int8_w8a8=quant_config.use_int8_w8a8,\n", " use_int8_w8a16=quant_config.use_int8_w8a16,\n", " use_int4_w4a16=quant_config.use_int4_w4a16,\n", " ocp_mx_scheme=quant_config.ocp_mx_scheme,\n", " per_channel_quant=quant_config.per_act_token_quant,\n", " global_num_experts=global_num_experts,\n", " expert_map=expert_map,\n", " w1_scale=quant_config.w1_scale,\n", " w2_scale=quant_config.w2_scale,\n", " w1_zp=quant_config.w1_zp,\n", " w2_zp=quant_config.w2_zp,\n", " a1_scale=quant_config.a1_scale,\n", " a2_scale=quant_config.a2_scale,\n", " block_shape=quant_config.block_shape,\n", " w1_bias=quant_config.w1_bias,\n", " w2_bias=quant_config.w2_bias,\n", " )\n", " with torch.no_grad():\n", " diff = (ref_out.float() - cmp_out.float()).abs()\n", " logger.info_once(\n", " \"FP8 MoE ref compare: mean_abs=%.6g max_abs=%.6g ref_norm=%.6g out_norm=%.6g\",\n", " diff.mean().item(),\n", " diff.max().item(),\n", " ref_out.float().norm().item(),\n", " cmp_out.float().norm().item(),\n", " )\n", " def _scale_info(name, t):\n", " if t is None:\n", " return f\"{name}=None\"\n", " t_float = t.float()\n", " return (\n", " f\"{name}: shape={tuple(t.shape)} dtype={t.dtype} \"\n", " f\"min={t_float.min().item():.6g} max={t_float.max().item():.6g}\"\n", " )\n", " logger.info_once(\n", " \"FP8 MoE ref scales: %s | %s | %s | %s\",\n", " _scale_info(\"w1_scale\", quant_config.w1_scale),\n", " _scale_info(\"w2_scale\", quant_config.w2_scale),\n", " _scale_info(\"a1_scale\", quant_config.a1_scale),\n", " _scale_info(\"a2_scale\", quant_config.a2_scale),\n", " )\n", " if os.getenv(\"VLLM_FP8_MOE_BF16_REF_COMPARE\") == \"1\":\n", " prev = os.getenv(\"VLLM_FP8_MOE_FALLBACK_BF16\")\n", " os.environ[\"VLLM_FP8_MOE_FALLBACK_BF16\"] = \"1\"\n", " bf16_ref = _pytorch_fp8_moe_fallback(\n", " hidden_states=hidden_states,\n", " w1=w1,\n", " w2=w2,\n", " topk_weights=topk_weights,\n", " topk_ids=topk_ids,\n", " activation=activation,\n", " apply_router_weight_on_input=apply_router_weight_on_input,\n", " expert_map=expert_map,\n", " quant_config=quant_config,\n", " )\n", " if prev is None:\n", " del os.environ[\"VLLM_FP8_MOE_FALLBACK_BF16\"]\n", " else:\n", " os.environ[\"VLLM_FP8_MOE_FALLBACK_BF16\"] = prev\n", " diff_bf16 = (bf16_ref.float() - cmp_out.float()).abs()\n", " logger.info_once(\n", " \"FP8 MoE BF16 ref compare: mean_abs=%.6g max_abs=%.6g ref_norm=%.6g out_norm=%.6g\",\n", " diff_bf16.mean().item(),\n", " diff_bf16.max().item(),\n", " bf16_ref.float().norm().item(),\n", " cmp_out.float().norm().item(),\n", " )\n", " if os.getenv(\"VLLM_FP8_MOE_REF_RETURN\", \"1\") != \"0\":\n", " return ref_out\n", ]) text = text.replace(needle, insert, 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/mamba/ops/ssd_combined.py") text = path.read_text() if "mamba_ssm.ops.triton.ssd_combined" not in text: needle = "TRITON_22 = version.parse(triton.__version__) >= version.parse(\"2.2.0\")\n" if needle not in text: raise SystemExit("Failed to locate TRITON_22 for insertion") insert = ( needle + "\n" + "try:\n" + " from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as _mamba_ssm_chunk_scan_combined\n" + "except Exception:\n" + " _mamba_ssm_chunk_scan_combined = None\n" + "from vllm.platforms import current_platform\n\n" + "def _use_original_mamba_ssm_varlen(cu_seqlens, return_intermediate_states):\n" + " return (\n" + " _mamba_ssm_chunk_scan_combined is not None\n" + " and current_platform.is_device_capability(121)\n" + " and cu_seqlens is not None\n" + " and not return_intermediate_states\n" + " )\n\n" + "def _mamba_ssm_chunk_scan_varlen(\n" + " x,\n" + " dt,\n" + " A,\n" + " B,\n" + " C,\n" + " chunk_size,\n" + " out,\n" + " D=None,\n" + " z=None,\n" + " dt_bias=None,\n" + " initial_states=None,\n" + " cu_seqlens=None,\n" + " dt_softplus=False,\n" + " dt_limit=(0.0, float(\"inf\")),\n" + " state_dtype=None,\n" + "):\n" + " batch = int(cu_seqlens.numel() - 1)\n" + " lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).to(torch.int64)\n" + " max_len = int(lengths.max().item()) if batch > 0 else 0\n" + " seqlen_total, nheads, headdim = x.shape\n" + " _, ngroups, dstate = B.shape\n" + " if out is None:\n" + " out = torch.empty_like(x)\n" + " if max_len == 0:\n" + " empty_dtype = state_dtype if state_dtype is not None else B.dtype\n" + " return torch.zeros((batch, nheads, headdim, dstate), device=x.device, dtype=empty_dtype)\n" + " x_b = x.new_zeros((batch, max_len, nheads, headdim))\n" + " dt_b = dt.new_zeros((batch, max_len, nheads))\n" + " B_b = B.new_zeros((batch, max_len, ngroups, dstate))\n" + " C_b = C.new_zeros((batch, max_len, ngroups, dstate))\n" + " z_b = None\n" + " if z is not None:\n" + " z_b = z.new_zeros((batch, max_len, nheads, headdim))\n" + " for i in range(batch):\n" + " start = int(cu_seqlens[i].item())\n" + " end = int(cu_seqlens[i + 1].item())\n" + " length = end - start\n" + " if length <= 0:\n" + " continue\n" + " x_b[i, :length] = x[start:end]\n" + " dt_b[i, :length] = dt[start:end]\n" + " B_b[i, :length] = B[start:end]\n" + " C_b[i, :length] = C[start:end]\n" + " if z_b is not None:\n" + " z_b[i, :length] = z[start:end]\n" + " cu_seqlens_b = torch.zeros(batch + 1, device=cu_seqlens.device, dtype=cu_seqlens.dtype)\n" + " if batch > 0:\n" + " cu_seqlens_b[1:] = torch.cumsum(lengths, dim=0)\n" + " out_b, varlen_states = _mamba_ssm_chunk_scan_combined(\n" + " x_b,\n" + " dt_b,\n" + " A,\n" + " B_b,\n" + " C_b,\n" + " chunk_size,\n" + " D=D,\n" + " z=z_b,\n" + " dt_bias=dt_bias,\n" + " initial_states=initial_states,\n" + " seq_idx=None,\n" + " cu_seqlens=cu_seqlens_b,\n" + " dt_softplus=dt_softplus,\n" + " dt_limit=dt_limit,\n" + " return_final_states=False,\n" + " return_varlen_states=True,\n" + " )\n" + " for i in range(batch):\n" + " start = int(cu_seqlens[i].item())\n" + " end = int(cu_seqlens[i + 1].item())\n" + " length = end - start\n" + " if length <= 0:\n" + " continue\n" + " out[start:end].copy_(out_b[i, :length])\n" + " if state_dtype is not None and varlen_states.dtype != state_dtype:\n" + " varlen_states = varlen_states.to(state_dtype)\n" + " return varlen_states\n" ) text = text.replace(needle, insert, 1) needle = " assert cu_seqlens is not None, \"cu_seqlens must be provided assuming varlen input\"" if needle in text and "_mamba_ssm_chunk_scan_varlen" in text: insert = ( " if _use_original_mamba_ssm_varlen(cu_seqlens, return_intermediate_states):\n" " return _mamba_ssm_chunk_scan_varlen(\n" " x=x,\n" " dt=dt,\n" " A=A,\n" " B=B,\n" " C=C,\n" " chunk_size=chunk_size,\n" " out=out,\n" " D=D,\n" " z=z,\n" " dt_bias=dt_bias,\n" " initial_states=initial_states,\n" " cu_seqlens=cu_seqlens,\n" " dt_softplus=dt_softplus,\n" " dt_limit=dt_limit,\n" " state_dtype=state_dtype,\n" " )\n\n" ) text = text.replace(needle, insert + needle, 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/quantization/modelopt.py") text = path.read_text() needle = " def process_weights_after_loading(self, layer: Module) -> None:\n" if needle not in text: raise SystemExit("Failed to locate process_weights_after_loading") insert = "".join( [ needle, " weight = layer.weight\n", " if getattr(layer, \"fp8_keep_scales\", False):\n", " layer.weight = Parameter(weight.t(), requires_grad=False)\n", " layer.weight_scale = Parameter(layer.weight_scale.detach(), requires_grad=False)\n", " layer.input_scale = Parameter(layer.input_scale.detach(), requires_grad=False)\n", " return\n", ] ) if insert not in text: text = text.replace(needle, insert, 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/models/nemotron_h.py") text = path.read_text() import_line = "from vllm.model_executor.layers.quantization import QuantizationConfig\n" if "per_tensor_dequantize" not in text and import_line in text: text = text.replace( import_line, import_line + "from vllm.model_executor.layers.quantization.utils.w8a8_utils import per_tensor_dequantize\n", 1, ) if "is_forward_context_available" not in text: text = text.replace( "from vllm.distributed.parallel_state import get_pp_group\n", "from vllm.distributed.parallel_state import get_pp_group\n" "from vllm.forward_context import get_forward_context, is_forward_context_available\n", 1, ) if "_NEMO_ATT_REF_USED" not in text: needle = "_NEMO_SHARED_REF_USED = 0\n" if needle not in text: raise SystemExit("Failed to locate _NEMO_SHARED_REF_USED for insertion") helper = "".join( [ "_NEMO_ATT_REF_USED = 0\n\n", "def _dequantize_linear_weight_for_ref(layer, logical_widths=None):\n", " weight = layer.weight\n", " scale = getattr(layer, \"weight_scale\", None)\n", " if scale is None:\n", " return weight.to(torch.bfloat16)\n", " scale = scale.detach()\n", " if scale.numel() == 1 or not logical_widths:\n", " s = scale.max()\n", " return per_tensor_dequantize(weight, s).to(torch.bfloat16)\n", " if scale.numel() == len(logical_widths):\n", " splits = torch.split(weight, logical_widths, dim=1)\n", " dq_splits = [\n", " per_tensor_dequantize(w, scale[idx]) for idx, w in enumerate(splits)\n", " ]\n", " return torch.cat(dq_splits, dim=1).to(torch.bfloat16)\n", " s = scale.max()\n", " return per_tensor_dequantize(weight, s).to(torch.bfloat16)\n\n", ] ) text = text.replace(needle, needle + helper, 1) if "NemotronH attention BF16 ref compare" not in text: old = " output, _ = self.o_proj(attn_output)\n return output\n" if old not in text: raise SystemExit("Failed to locate NemotronHAttention output return") new = "".join( [ " output, _ = self.o_proj(attn_output)\n", " if os.getenv(\"VLLM_FP8_ATTENTION_BF16_REF\") == \"1\":\n", " global _NEMO_ATT_REF_USED\n", " if (\n", " is_forward_context_available()\n", " and get_forward_context().attn_metadata is not None\n", " ):\n", " if _NEMO_ATT_REF_USED < 1:\n", " _NEMO_ATT_REF_USED += 1\n", " with torch.no_grad():\n", " x_bf16 = hidden_states.to(torch.bfloat16)\n", " qkv_w = _dequantize_linear_weight_for_ref(\n", " self.qkv_proj,\n", " getattr(self.qkv_proj, \"output_sizes\", None),\n", " )\n", " if qkv_w.shape[0] == x_bf16.shape[-1]:\n", " qkv_ref = torch.matmul(x_bf16, qkv_w)\n", " else:\n", " qkv_ref = torch.matmul(x_bf16, qkv_w.t())\n", " if self.qkv_proj.bias is not None:\n", " qkv_ref = qkv_ref + self.qkv_proj.bias\n", " q_ref, k_ref, v_ref = qkv_ref.split(\n", " [self.q_size, self.kv_size, self.kv_size], dim=-1\n", " )\n", " attn_ref = self.attn(q_ref, k_ref, v_ref)\n", " o_w = _dequantize_linear_weight_for_ref(self.o_proj)\n", " attn_ref_bf16 = attn_ref.to(torch.bfloat16)\n", " if o_w.shape[0] == attn_ref_bf16.shape[-1]:\n", " ref_out = torch.matmul(attn_ref_bf16, o_w)\n", " else:\n", " ref_out = torch.matmul(attn_ref_bf16, o_w.t())\n", " if self.o_proj.bias is not None:\n", " ref_out = ref_out + self.o_proj.bias\n", " diff = (ref_out.float() - output.float()).abs()\n", " logger.info_once(\n", " \"NemotronH attention BF16 ref compare: mean_abs=%s max_abs=%s ref_norm=%s out_norm=%s\",\n", " diff.mean().item(),\n", " diff.max().item(),\n", " ref_out.float().norm().item(),\n", " output.float().norm().item(),\n", " )\n", " return output\n", ] ) text = text.replace(old, new, 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/mamba/mamba_mixer2.py") text = path.read_text() if "import os\n" not in text: text = text.replace( "import torch\n", "import os\nimport torch\n", 1, ) if "init_logger" not in text: text = text.replace( "from vllm.forward_context import ForwardContext, get_forward_context\n", "from vllm.forward_context import ForwardContext, get_forward_context, is_forward_context_available\n" "from vllm.logger import init_logger\n", 1, ) if "per_tensor_dequantize" not in text: text = text.replace( "from vllm.model_executor.layers.quantization import QuantizationConfig\n", "from vllm.model_executor.layers.quantization import QuantizationConfig\n" "from vllm.model_executor.layers.quantization.utils.w8a8_utils import per_tensor_dequantize\n", 1, ) if "_MAMBA_REF_USED" not in text: insert_after = "from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata\n" if insert_after not in text: raise SystemExit("Failed to locate Mamba2AttentionMetadata import") helper = "".join( [ "\nlogger = init_logger(__name__)\n", "_MAMBA_REF_USED = 0\n", "_MAMBA_INPROJ_BYPASS_USED = 0\n\n", "def _dequantize_linear_weight_for_ref(layer, logical_widths=None):\n", " weight = layer.weight\n", " scale = getattr(layer, \"weight_scale\", None)\n", " if scale is None:\n", " return weight.to(torch.bfloat16)\n", " scale = scale.detach()\n", " if scale.numel() == 1 or not logical_widths:\n", " s = scale.max()\n", " return per_tensor_dequantize(weight, s).to(torch.bfloat16)\n", " if scale.numel() == len(logical_widths):\n", " splits = torch.split(weight, logical_widths, dim=1)\n", " dq_splits = [\n", " per_tensor_dequantize(w, scale[idx]) for idx, w in enumerate(splits)\n", " ]\n", " return torch.cat(dq_splits, dim=1).to(torch.bfloat16)\n", " s = scale.max()\n", " return per_tensor_dequantize(weight, s).to(torch.bfloat16)\n", "def _matmul_weight(x, w):\n", " if w.shape[0] == x.shape[-1]:\n", " return torch.matmul(x, w)\n", " return torch.matmul(x, w.t())\n", ] ) text = text.replace(insert_after, insert_after + helper, 1) if "Mamba BF16 ref compare" not in text: old = " # 1. Gated MLP's linear projection\n projected_states, _ = self.in_proj(hidden_states)\n" if old not in text: raise SystemExit("Failed to locate MambaMixer2 projection block") new = "".join( [ " input_states = hidden_states\n", " # 1. Gated MLP's linear projection\n", " weight_scale = getattr(self.in_proj, \"weight_scale\", None)\n", " input_scale = getattr(self.in_proj, \"input_scale\", None)\n", " use_fp8_inproj = (\n", " hasattr(self.in_proj, \"fp8_keep_scales\")\n", " and self.in_proj.fp8_keep_scales\n", " and weight_scale is not None\n", " and input_scale is not None\n", " and weight_scale.numel() > 1\n", " and input_scale.numel() == weight_scale.numel()\n", " and hasattr(self.in_proj, \"quant_method\")\n", " and hasattr(self.in_proj.quant_method, \"fp8_linear\")\n", " )\n", " output_sizes = getattr(self.in_proj, \"output_sizes\", None)\n", " if isinstance(output_sizes, list):\n", " output_sizes = tuple(output_sizes)\n", " if os.getenv(\"VLLM_FP8_MAMBA_INPROJ_DEBUG\"):\n", " logger.info_once(\n", " \"Mamba in_proj per-chunk FP8 eligible=%s weight_scale=%s input_scale=%s output_sizes=%s\",\n", " use_fp8_inproj,\n", " None if weight_scale is None else tuple(weight_scale.shape),\n", " None if input_scale is None else tuple(input_scale.shape),\n", " output_sizes,\n", " )\n", " if use_fp8_inproj:\n", " # Mamba in_proj unfused FP8: apply per-chunk scales.\n", " weight = self.in_proj.weight\n", " weight_scales = weight_scale\n", " input_scales = input_scale\n", " outputs = []\n", " bias = self.in_proj.bias\n", " bias_chunks = None\n", " if bias is not None:\n", " bias_chunks = torch.split(bias, self.in_proj.output_sizes, dim=0)\n", " for idx, out_size in enumerate(self.in_proj.output_sizes):\n", " w_chunk = weight[:, :out_size]\n", " weight = weight[:, out_size:]\n", " b_chunk = None\n", " if bias_chunks is not None:\n", " b_chunk = bias_chunks[idx]\n", " chunk_out = self.in_proj.quant_method.fp8_linear.apply(\n", " input=hidden_states,\n", " weight=w_chunk,\n", " weight_scale=weight_scales[idx],\n", " input_scale=input_scales[idx],\n", " bias=b_chunk,\n", " )\n", " outputs.append(chunk_out)\n", " projected_states = torch.cat(outputs, dim=-1)\n", " else:\n", " projected_states, _ = self.in_proj(hidden_states)\n", ] ) text = text.replace(old, new, 1) # Mark Mamba in_proj to keep per-chunk scales for FP8. if "fp8_keep_scales" not in text: tag = " self.in_proj.fp8_keep_scales = True\n" merged_block = "".join( [ " self.in_proj = MergedColumnParallelLinear(\n", " input_size=hidden_size,\n", " output_sizes=[\n", " intermediate_size,\n", " intermediate_size,\n", " self.groups_ssm_state_size,\n", " self.groups_ssm_state_size,\n", " self.num_heads,\n", " ],\n", " bias=use_bias,\n", " quant_config=quant_config,\n", " prefix=f\"{prefix}.in_proj\",\n", " )\n", ] ) column_block = "".join( [ " self.in_proj = ColumnParallelLinear(\n", " input_size=hidden_size,\n", " output_size=intermediate_size + self.conv_dim + self.num_heads,\n", " bias=use_bias,\n", " quant_config=quant_config,\n", " prefix=f\"{prefix}.in_proj\",\n", " )\n", ] ) if merged_block in text: text = text.replace(merged_block, merged_block + tag, 1) if column_block in text: text = text.replace(column_block, column_block + tag, 1) # Per-chunk FP8 in_proj path is installed by the projection block replacement. bypass_old = ( " if mup_vector is not None:\n" " projected_states = projected_states * mup_vector\n" ) if bypass_old not in text: raise SystemExit("Failed to locate MambaMixer2 mup_vector block") bypass_insert = "".join( [ bypass_old, " bypass_mode = os.getenv(\"VLLM_FP8_MAMBA_INPROJ_BF16_BYPASS\")\n", " if bypass_mode in {\"1\", \"all\"}:\n", " global _MAMBA_INPROJ_BYPASS_USED\n", " if (\n", " is_forward_context_available()\n", " and get_forward_context().attn_metadata is not None\n", " ):\n", " if bypass_mode == \"all\" or _MAMBA_INPROJ_BYPASS_USED < 1:\n", " _MAMBA_INPROJ_BYPASS_USED += 1\n", " with torch.no_grad():\n", " x_bf16 = input_states.to(torch.bfloat16)\n", " in_w = _dequantize_linear_weight_for_ref(\n", " self.in_proj,\n", " getattr(self.in_proj, \"output_sizes\", None),\n", " )\n", " proj_ref = _matmul_weight(x_bf16, in_w)\n", " if self.in_proj.bias is not None:\n", " proj_ref = proj_ref + self.in_proj.bias\n", " if mup_vector is not None:\n", " proj_ref = proj_ref * mup_vector.to(torch.bfloat16)\n", " projected_states = proj_ref.to(projected_states.dtype)\n", " if bypass_mode == \"all\":\n", " logger.info_once(\n", " \"Mamba in_proj BF16 bypass enabled (all layers)\")\n", " else:\n", " logger.info_once(\n", " \"Mamba in_proj BF16 bypass enabled (single call)\")\n", ] ) text = text.replace(bypass_old, bypass_insert, 1) old_out = " output, _ = self.out_proj(hidden_states)\n\n return output\n" if old_out not in text: raise SystemExit("Failed to locate MambaMixer2 output return") insert = "".join( [ " output, _ = self.out_proj(hidden_states)\n", " if os.getenv(\"VLLM_FP8_MAMBA_BF16_REF\") == \"1\":\n", " global _MAMBA_REF_USED\n", " if (\n", " is_forward_context_available()\n", " and get_forward_context().attn_metadata is not None\n", " ):\n", " if _MAMBA_REF_USED < 1:\n", " _MAMBA_REF_USED += 1\n", " with torch.no_grad():\n", " x_bf16 = input_states.to(torch.bfloat16)\n", " in_w = _dequantize_linear_weight_for_ref(\n", " self.in_proj,\n", " getattr(self.in_proj, \"output_sizes\", None),\n", " )\n", " proj_ref = _matmul_weight(x_bf16, in_w)\n", " if self.in_proj.bias is not None:\n", " proj_ref = proj_ref + self.in_proj.bias\n", " if mup_vector is not None:\n", " proj_ref = proj_ref * mup_vector.to(torch.bfloat16)\n", " diff_in = (proj_ref.float() - projected_states.float()).abs()\n", " logger.info_once(\n", " \"Mamba BF16 ref compare (in_proj): mean_abs=%s max_abs=%s ref_norm=%s out_norm=%s\",\n", " diff_in.mean().item(),\n", " diff_in.max().item(),\n", " proj_ref.float().norm().item(),\n", " projected_states.float().norm().item(),\n", " )\n", " ssm_ref = torch.empty(\n", " [\n", " input_states.shape[0],\n", " (self.num_heads // self.tp_size) * self.head_dim,\n", " ],\n", " dtype=proj_ref.dtype,\n", " device=proj_ref.device,\n", " )\n", " torch.ops.vllm.mamba_mixer2(\n", " proj_ref,\n", " ssm_ref,\n", " self.prefix,\n", " )\n", " diff_ssm = (ssm_ref.float() - ssm_output.float()).abs()\n", " logger.info_once(\n", " \"Mamba BF16 ref compare (ssm): mean_abs=%s max_abs=%s ref_norm=%s out_norm=%s\",\n", " diff_ssm.mean().item(),\n", " diff_ssm.max().item(),\n", " ssm_ref.float().norm().item(),\n", " ssm_output.float().norm().item(),\n", " )\n", " gate_ref = proj_ref[..., : self.tped_intermediate_size]\n", " hidden_ref = self.norm(ssm_ref, gate_ref)\n", " out_w = _dequantize_linear_weight_for_ref(self.out_proj)\n", " hidden_ref_bf16 = hidden_ref.to(torch.bfloat16)\n", " ref_out = _matmul_weight(hidden_ref_bf16, out_w)\n", " if self.out_proj.bias is not None:\n", " ref_out = ref_out + self.out_proj.bias\n", " diff_out = (ref_out.float() - output.float()).abs()\n", " logger.info_once(\n", " \"Mamba BF16 ref compare (out_proj): mean_abs=%s max_abs=%s ref_norm=%s out_norm=%s\",\n", " diff_out.mean().item(),\n", " diff_out.max().item(),\n", " ref_out.float().norm().item(),\n", " output.float().norm().item(),\n", " )\n", "\n", " return output\n", ] ) text = text.replace(old_out, insert, 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/quantization/modelopt.py") text = path.read_text() if "VLLM_MOE_FALLBACK_DEBUG" not in text: needle = " # Expert selection\n topk_weights, topk_ids = layer.select_experts(\n" if needle not in text: raise SystemExit("Failed to locate expert selection block") insert = ( " if os.getenv(\"VLLM_MOE_FALLBACK_DEBUG\") == \"1\":\n" " qc = self.moe_quant_config\n" " def _tinfo(name, t):\n" " if t is None:\n" " return f\"{name}=None\"\n" " with torch.no_grad():\n" " tf = t.float()\n" " return (\n" " f\"{name}: shape={tuple(t.shape)} dtype={t.dtype} \"\n" " f\"min={tf.min().item():.6g} max={tf.max().item():.6g}\"\n" " )\n" " if qc is not None:\n" " logger.info_once(\"[moe_debug] %s\", _tinfo(\"w1_scale\", qc.w1_scale))\n" " logger.info_once(\"[moe_debug] %s\", _tinfo(\"w2_scale\", qc.w2_scale))\n" " logger.info_once(\"[moe_debug] %s\", _tinfo(\"a1_scale\", qc.a1_scale))\n" " logger.info_once(\"[moe_debug] %s\", _tinfo(\"a2_scale\", qc.a2_scale))\n" "\n" ) text = text.replace(needle, insert + needle, 1) if "import os" not in text: needle = "import vllm.envs as envs\n" if needle in text: text = text.replace(needle, "import os\nimport vllm.envs as envs\n", 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/mamba/ops/mamba_ssm.py") text = path.read_text() if "mamba_ssm.ops.triton.selective_state_update" not in text: needle = "from vllm.triton_utils import HAS_TRITON, tl, triton\n" if needle not in text: raise SystemExit("Failed to locate triton import for insertion") insert = ( needle + "from vllm.platforms import current_platform\n" + "try:\n" + " from mamba_ssm.ops.triton.selective_state_update import selective_state_update as _mamba_ssm_selective_state_update\n" + "except Exception:\n" + " _mamba_ssm_selective_state_update = None\n\n" + "def _use_original_selective_state_update(num_accepted_tokens, cu_seqlens, dst_state_batch_indices):\n" + " return (\n" + " _mamba_ssm_selective_state_update is not None\n" + " and current_platform.is_device_capability(121)\n" + " and num_accepted_tokens is None\n" + " and cu_seqlens is None\n" + " and dst_state_batch_indices is None\n" + " )\n\n" ) text = text.replace(needle, insert, 1) needle = " assert out.shape == x.shape\n" if needle in text and "_use_original_selective_state_update" in text: insert = ( " if _use_original_selective_state_update(num_accepted_tokens, cu_seqlens, dst_state_batch_indices):\n" " result = _mamba_ssm_selective_state_update(\n" " state,\n" " x,\n" " dt,\n" " A,\n" " B,\n" " C,\n" " D=D,\n" " z=z,\n" " dt_bias=dt_bias,\n" " dt_softplus=dt_softplus,\n" " state_batch_indices=state_batch_indices,\n" " )\n" " if out is not None:\n" " out.copy_(result)\n" " return out\n" " return result\n\n" ) text = text.replace(needle, needle + insert, 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY' from pathlib import Path path = Path("/build/vllm/vllm/model_executor/layers/mamba/ops/causal_conv1d.py") text = path.read_text() if "causal_conv1d_interface" not in text: needle = "from vllm.triton_utils import tl, triton\n" if needle not in text: raise SystemExit("Failed to locate triton import for insertion") insert = ( needle + "from vllm.platforms import current_platform\n" + "try:\n" + " from causal_conv1d.causal_conv1d_interface import causal_conv1d_update as _causal_conv1d_update\n" + "except Exception:\n" + " _causal_conv1d_update = None\n\n" + "def _use_original_causal_conv1d_update(\n" + " query_start_loc,\n" + " num_accepted_tokens,\n" + " block_idx_last_scheduled_token,\n" + " initial_state_idx,\n" + " max_query_len,\n" + "):\n" + " return (\n" + " _causal_conv1d_update is not None\n" + " and current_platform.is_device_capability(121)\n" + " and query_start_loc is None\n" + " and num_accepted_tokens is None\n" + " and block_idx_last_scheduled_token is None\n" + " and initial_state_idx is None\n" + " and max_query_len == -1\n" + " )\n\n" ) text = text.replace(needle, insert, 1) needle = " if query_start_loc is None:\n batch, dim, seqlen = x.shape\n else:\n assert conv_state_indices is not None\n batch = conv_state_indices.size(0)\n dim = x.size(1)\n seqlen = max_query_len\n" if needle in text and "_use_original_causal_conv1d_update" in text: insert = ( needle + "\n" + " if _use_original_causal_conv1d_update(\n" + " query_start_loc,\n" + " num_accepted_tokens,\n" + " block_idx_last_scheduled_token,\n" + " initial_state_idx,\n" + " max_query_len,\n" + " ):\n" + " out = _causal_conv1d_update(\n" + " x,\n" + " conv_state,\n" + " weight,\n" + " bias=bias,\n" + " activation=activation,\n" + " cache_seqlens=None,\n" + " conv_state_indices=conv_state_indices,\n" + " )\n" + " return out.to(original_x_dtype)\n\n" ) text = text.replace(needle, insert, 1) path.write_text(text) PY # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c uv pip uninstall torchvision torchaudio nvidia-nccl-cu12 nvidia-cudnn-cu12 || true && uv pip install --no-cache --reinstall /tmp/pytorch_wheel/torch*.whl && ldconfig # buildkit
WORKDIR /build
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then export TORCH_CUDA_ARCH_LIST="12.1"; else export TORCH_CUDA_ARCH_LIST="12.0"; fi && echo "=== Rebuilding torchaudio for $ARCH with CUDA arch $TORCH_CUDA_ARCH_LIST ===" && git clone --recursive https://github.com/pytorch/audio.git torchaudio-rebuild && cd torchaudio-rebuild && git checkout ${TORCHAUDIO_COMMIT} && BUILD_SOX=0 USE_CUDA=1 python3 setup.py bdist_wheel && uv pip install --no-cache dist/*.whl && cd .. && rm -rf torchaudio-rebuild # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c uv pip install --no-cache triton && rm -f /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas && ln -s /usr/local/cuda/bin/ptxas /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas # buildkit
RUN |5 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c uv pip install --no-cache --no-build-isolation mamba-ssm==2.3.0 causal-conv1d==1.6.0 # buildkit
ARG LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805
RUN |6 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805 /bin/sh -c curl -fsSL https://raw.githubusercontent.com/pipecat-ai/nemotron-january-2026/main/patches/llama-cpp-hybrid-cache-fix.patch -o /tmp/llama-cpp-hybrid-cache-fix.patch # buildkit
RUN |6 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805 /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then CUDA_ARCH="121a"; else CUDA_ARCH="120a"; fi && echo "=== Building llama.cpp for $ARCH with CUDA arch $CUDA_ARCH ===" && git clone https://github.com/ggerganov/llama.cpp.git /opt/llama.cpp && cd /opt/llama.cpp && git checkout ${LLAMACPP_COMMIT} && echo "=== Applying hybrid cache fix patch ===" && patch -p1 < /tmp/llama-cpp-hybrid-cache-fix.patch && cmake -B build -DGGML_CUDA=ON -DGGML_CUDA_F16=ON -DCMAKE_CUDA_ARCHITECTURES="$CUDA_ARCH" -DCMAKE_BUILD_TYPE=Release > /tmp/llamacpp_cmake.log 2>&1 || { tail -50 /tmp/llamacpp_cmake.log; exit 1; } && cmake --build build --config Release -j$(nproc) > /tmp/llamacpp_build.log 2>&1 || { tail -50 /tmp/llamacpp_build.log; exit 1; } && cp build/bin/llama-server /usr/local/bin/ && cp build/bin/llama-cli /usr/local/bin/ && cp build/bin/llama-quantize /usr/local/bin/ && cp build/bin/llama-bench /usr/local/bin/ && cp build/bin/*.so* /usr/local/lib/ 2>/dev/null || true && ldconfig && rm -rf /opt/llama.cpp/build && rm -rf /opt/llama.cpp/.git # buildkit
RUN |6 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805 /bin/sh -c llama-server --version || echo "llama-server installed (version check requires GPU)" # buildkit
WORKDIR /workspace
RUN |6 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805 /bin/sh -c rm -rf /tmp/pytorch_wheel /tmp/torchaudio_wheel && rm -f /tmp/patch_nvrtc.py # buildkit
RUN |6 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805 /bin/sh -c uv pip install --no-cache websockets>=12.0 loguru>=0.7.0 httpx>=0.25.0 # buildkit

Labels

Key Value
description Unified ASR + TTS + LLM container (ARM64 sm_121 CUDA 13.1 / x86_64 sm_120 CUDA 13.0)
maintainer nemotron-speech
org.opencontainers.image.ref.name ubuntu
org.opencontainers.image.version 24.04
version 1.2
Details
Container
2026-04-05 09:42:00 +00:00
7
OCI / Docker
linux/amd64
23 GiB
Versions (1) View all
latest 2026-04-05