| ARG RELEASE |
| ARG LAUNCHPAD_BUILD_ARCH |
| LABEL org.opencontainers.image.ref.name=ubuntu |
| LABEL org.opencontainers.image.version=24.04 |
| ADD file:b4619a63cd7829e1338ddaa4995ca17003002dd54b0dfd675a6f54a2b69151a6 in / |
| CMD ["/bin/bash"] |
| ENV NVARCH=x86_64 |
| ENV NVIDIA_REQUIRE_CUDA=cuda>=13.0 brand=unknown,driver>=535,driver<536 brand=grid,driver>=535,driver<536 brand=tesla,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=vapps,driver>=535,driver<536 brand=vpc,driver>=535,driver<536 brand=vcs,driver>=535,driver<536 brand=vws,driver>=535,driver<536 brand=cloudgaming,driver>=535,driver<536 brand=unknown,driver>=550,driver<551 brand=grid,driver>=550,driver<551 brand=tesla,driver>=550,driver<551 brand=nvidia,driver>=550,driver<551 brand=quadro,driver>=550,driver<551 brand=quadrortx,driver>=550,driver<551 brand=nvidiartx,driver>=550,driver<551 brand=vapps,driver>=550,driver<551 brand=vpc,driver>=550,driver<551 brand=vcs,driver>=550,driver<551 brand=vws,driver>=550,driver<551 brand=cloudgaming,driver>=550,driver<551 brand=unknown,driver>=565,driver<566 brand=grid,driver>=565,driver<566 brand=tesla,driver>=565,driver<566 brand=nvidia,driver>=565,driver<566 brand=quadro,driver>=565,driver<566 brand=quadrortx,driver>=565,driver<566 brand=nvidiartx,driver>=565,driver<566 brand=vapps,driver>=565,driver<566 brand=vpc,driver>=565,driver<566 brand=vcs,driver>=565,driver<566 brand=vws,driver>=565,driver<566 brand=cloudgaming,driver>=565,driver<566 brand=unknown,driver>=570,driver<571 brand=grid,driver>=570,driver<571 brand=tesla,driver>=570,driver<571 brand=nvidia,driver>=570,driver<571 brand=quadro,driver>=570,driver<571 brand=quadrortx,driver>=570,driver<571 brand=nvidiartx,driver>=570,driver<571 brand=vapps,driver>=570,driver<571 brand=vpc,driver>=570,driver<571 brand=vcs,driver>=570,driver<571 brand=vws,driver>=570,driver<571 brand=cloudgaming,driver>=570,driver<571 brand=unknown,driver>=575,driver<576 brand=grid,driver>=575,driver<576 brand=tesla,driver>=575,driver<576 brand=nvidia,driver>=575,driver<576 brand=quadro,driver>=575,driver<576 brand=quadrortx,driver>=575,driver<576 brand=nvidiartx,driver>=575,driver<576 brand=vapps,driver>=575,driver<576 brand=vpc,driver>=575,driver<576 brand=vcs,driver>=575,driver<576 brand=vws,driver>=575,driver<576 brand=cloudgaming,driver>=575,driver<576 |
| ENV NV_CUDA_CUDART_VERSION=13.0.48-1 |
| ARG TARGETARCH |
| LABEL maintainer=NVIDIA CORPORATION <cudatools@nvidia.com> |
| RUN |1 TARGETARCH=amd64 /bin/sh -c apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates && curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${NVARCH}/3bf863cc.pub | apt-key add - && echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${NVARCH} /" > /etc/apt/sources.list.d/cuda.list && apt-get purge --autoremove -y curl && rm -rf /var/lib/apt/lists/* # buildkit |
| ENV CUDA_VERSION=13.0.0 |
| RUN |1 TARGETARCH=amd64 /bin/sh -c apt-get update && apt-get install -y --no-install-recommends cuda-cudart-13-0=${NV_CUDA_CUDART_VERSION} cuda-compat-13-0 && rm -rf /var/lib/apt/lists/* # buildkit |
| RUN |1 TARGETARCH=amd64 /bin/sh -c echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/nvidia.conf # buildkit |
| ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin |
| ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64 |
| COPY NGC-DL-CONTAINER-LICENSE / # buildkit |
| ENV NVIDIA_VISIBLE_DEVICES=all |
| ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility |
| ENV NV_CUDA_LIB_VERSION=13.0.0-1 |
| ENV NV_NVTX_VERSION=13.0.39-1 |
| ENV NV_LIBNPP_VERSION=13.0.0.50-1 |
| ENV NV_LIBNPP_PACKAGE=libnpp-13-0=13.0.0.50-1 |
| ENV NV_LIBCUSPARSE_VERSION=12.6.2.49-1 |
| ENV NV_LIBCUBLAS_PACKAGE_NAME=libcublas-13-0 |
| ENV NV_LIBCUBLAS_VERSION=13.0.0.19-1 |
| ENV NV_LIBCUBLAS_PACKAGE=libcublas-13-0=13.0.0.19-1 |
| ENV NV_LIBNCCL_PACKAGE_NAME=libnccl2 |
| ENV NV_LIBNCCL_PACKAGE_VERSION=2.27.7-1 |
| ENV NCCL_VERSION=2.27.7-1 |
| ENV NV_LIBNCCL_PACKAGE=libnccl2=2.27.7-1+cuda13.0 |
| ARG TARGETARCH |
| LABEL maintainer=NVIDIA CORPORATION <cudatools@nvidia.com> |
| RUN |1 TARGETARCH=amd64 /bin/sh -c apt-get update && apt-get install -y --no-install-recommends cuda-libraries-13-0=${NV_CUDA_LIB_VERSION} ${NV_LIBNPP_PACKAGE} cuda-nvtx-13-0=${NV_NVTX_VERSION} libcusparse-13-0=${NV_LIBCUSPARSE_VERSION} ${NV_LIBCUBLAS_PACKAGE} ${NV_LIBNCCL_PACKAGE} && rm -rf /var/lib/apt/lists/* # buildkit |
| RUN |1 TARGETARCH=amd64 /bin/sh -c apt-mark hold ${NV_LIBCUBLAS_PACKAGE_NAME} ${NV_LIBNCCL_PACKAGE_NAME} # buildkit |
| COPY entrypoint.d/ /opt/nvidia/entrypoint.d/ # buildkit |
| COPY nvidia_entrypoint.sh /opt/nvidia/ # buildkit |
| ENV NVIDIA_PRODUCT_NAME=CUDA |
| ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"] |
| ENV NV_CUDA_LIB_VERSION=13.0.0-1 |
| ENV NV_CUDA_CUDART_DEV_VERSION=13.0.48-1 |
| ENV NV_NVML_DEV_VERSION=13.0.39-1 |
| ENV NV_LIBCUSPARSE_DEV_VERSION=12.6.2.49-1 |
| ENV NV_LIBNPP_DEV_VERSION=13.0.0.50-1 |
| ENV NV_LIBNPP_DEV_PACKAGE=libnpp-dev-13-0=13.0.0.50-1 |
| ENV NV_LIBCUBLAS_DEV_VERSION=13.0.0.19-1 |
| ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME=libcublas-dev-13-0 |
| ENV NV_LIBCUBLAS_DEV_PACKAGE=libcublas-dev-13-0=13.0.0.19-1 |
| ENV NV_CUDA_NSIGHT_COMPUTE_VERSION=13.0.0-1 |
| ENV NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE=cuda-nsight-compute-13-0=13.0.0-1 |
| ENV NV_LIBNCCL_DEV_PACKAGE_NAME=libnccl-dev |
| ENV NV_LIBNCCL_DEV_PACKAGE_VERSION=2.27.7-1 |
| ENV NCCL_VERSION=2.27.7-1 |
| ENV NV_LIBNCCL_DEV_PACKAGE=libnccl-dev=2.27.7-1+cuda13.0 |
| ARG TARGETARCH |
| LABEL maintainer=NVIDIA CORPORATION <cudatools@nvidia.com> |
| RUN |1 TARGETARCH=amd64 /bin/sh -c apt-get update && apt-get install -y --no-install-recommends cuda-cudart-dev-13-0=${NV_CUDA_CUDART_DEV_VERSION} cuda-command-line-tools-13-0=${NV_CUDA_LIB_VERSION} cuda-minimal-build-13-0=${NV_CUDA_LIB_VERSION} cuda-libraries-dev-13-0=${NV_CUDA_LIB_VERSION} cuda-nvml-dev-13-0=${NV_NVML_DEV_VERSION} ${NV_LIBNPP_DEV_PACKAGE} libcusparse-dev-13-0=${NV_LIBCUSPARSE_DEV_VERSION} ${NV_LIBCUBLAS_DEV_PACKAGE} ${NV_LIBNCCL_DEV_PACKAGE} ${NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE} && rm -rf /var/lib/apt/lists/* # buildkit |
| RUN |1 TARGETARCH=amd64 /bin/sh -c apt-mark hold ${NV_LIBCUBLAS_DEV_PACKAGE_NAME} ${NV_LIBNCCL_DEV_PACKAGE_NAME} # buildkit |
| ENV LIBRARY_PATH=/usr/local/cuda/lib64/stubs |
| ARG APT_PROXY=http://pb-lxcaptcache:3142 |
| RUN |1 APT_PROXY=http://pb-lxcaptcache:3142 /bin/sh -c if [ -n "$APT_PROXY" ]; then echo "Acquire::http::Proxy \"$APT_PROXY\";" > /etc/apt/apt.conf.d/01proxy; fi # buildkit |
| LABEL maintainer=nemotron-speech |
| LABEL description=PyTorch + torchaudio from source (Blackwell CUDA 13.0/13.1) - base layer |
| ENV DEBIAN_FRONTEND=noninteractive |
| RUN |1 APT_PROXY=http://pb-lxcaptcache:3142 /bin/sh -c apt-get update && apt-get install -y --no-install-recommends python3.12 python3.12-dev python3.12-venv python3-pip git curl wget cmake ninja-build ccache libopenblas-dev libomp-dev libffi-dev libssl-dev libnuma-dev libcurl4-openssl-dev ffmpeg sox libsndfile1 && rm -rf /var/lib/apt/lists/* && ln -sf /usr/bin/python3.12 /usr/bin/python3 && ln -sf /usr/bin/python3 /usr/bin/python # buildkit |
| COPY /uv /uvx /bin/ # buildkit |
| ENV UV_SYSTEM_PYTHON=1 |
| ENV UV_BREAK_SYSTEM_PACKAGES=1 |
| COPY /usr/lib/*/libcudnn* /tmp/cudnn_libs/ # buildkit |
| COPY /usr/include/cudnn* /usr/include/ # buildkit |
| COPY /usr/lib/*/libnccl* /tmp/nccl_libs/ # buildkit |
| COPY /usr/include/nccl.h /usr/include/ # buildkit |
| COPY /usr/include/nccl_device.h /usr/include/ # buildkit |
| COPY /usr/include/nccl_device/ /usr/include/nccl_device/ # buildkit |
| RUN |1 APT_PROXY=http://pb-lxcaptcache:3142 /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then LIB_DIR="/usr/lib/aarch64-linux-gnu"; else LIB_DIR="/usr/lib/x86_64-linux-gnu"; fi && echo "=== Detected architecture: $ARCH, using $LIB_DIR ===" && mkdir -p "$LIB_DIR" && mv /tmp/cudnn_libs/* "$LIB_DIR/" && mv /tmp/nccl_libs/* "$LIB_DIR/" && rmdir /tmp/cudnn_libs /tmp/nccl_libs && echo "=== cuDNN libraries ===" && ls -la "$LIB_DIR"/libcudnn* | head -3 && echo "=== NCCL libraries ===" && ls -la "$LIB_DIR"/libnccl* | head -3 && echo "=== NCCL device headers ===" && ls -la /usr/include/nccl_device/ | head -3 && echo "=== Checking ncclDevCommDestroy symbol ===" && nm -D "$LIB_DIR"/libnccl.so* 2>/dev/null | grep ncclDevCommDestroy | head -1 || echo "Note: nm check may fail on stripped library" && ldconfig # buildkit |
| RUN |1 APT_PROXY=http://pb-lxcaptcache:3142 /bin/sh -c uv pip install --no-cache numpy pyyaml typing_extensions sympy filelock networkx jinja2 fsspec packaging setuptools wheel cffi future requests dataclasses pillow expecttest hypothesis pytest # buildkit |
| ARG PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f |
| WORKDIR /build |
| RUN |2 APT_PROXY=http://pb-lxcaptcache:3142 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f /bin/sh -c git clone --recursive https://github.com/pytorch/pytorch.git && cd pytorch && git checkout ${PYTORCH_COMMIT} # buildkit |
| WORKDIR /build/pytorch |
| RUN |2 APT_PROXY=http://pb-lxcaptcache:3142 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f /bin/sh -c git submodule sync && git submodule update --init --recursive # buildkit |
| ENV USE_CUDA=1 |
| ENV USE_CUDNN=1 |
| ENV USE_MKLDNN=1 |
| ENV USE_DISTRIBUTED=1 |
| ENV USE_NCCL=1 |
| ENV USE_TENSORPIPE=0 |
| ENV USE_SYSTEM_NCCL=1 |
| ENV NCCL_ROOT=/usr |
| ENV NCCL_INCLUDE_DIR=/usr/include |
| ENV BUILD_TEST=0 |
| ENV MAX_JOBS=8 |
| ENV CMAKE_BUILD_TYPE=Release |
| ENV CUDNN_INCLUDE_DIR=/usr/include |
| ENV USE_PRIORITIZED_TEXT_FOR_LD=1 |
| RUN |2 APT_PROXY=http://pb-lxcaptcache:3142 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f /bin/sh -c ln -sf /usr/local/cuda/include/cccl/cub /usr/local/cuda/include/cub && ln -sf /usr/local/cuda/include/cccl/thrust /usr/local/cuda/include/thrust && echo "=== Verifying CUB access ===" && ls /usr/local/cuda/include/cub/cub.cuh && ls /usr/local/cuda/include/cccl/cub/cub.cuh # buildkit |
| ENV CUB_INCLUDE_DIR=/usr/local/cuda/include/cccl |
| RUN |2 APT_PROXY=http://pb-lxcaptcache:3142 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then export TORCH_CUDA_ARCH_LIST="12.1"; export NCCL_LIB_DIR="/usr/lib/aarch64-linux-gnu"; export CUDNN_LIB_DIR="/usr/lib/aarch64-linux-gnu"; else export TORCH_CUDA_ARCH_LIST="12.0"; export NCCL_LIB_DIR="/usr/lib/x86_64-linux-gnu"; export CUDNN_LIB_DIR="/usr/lib/x86_64-linux-gnu"; fi && echo "=== Building PyTorch for $ARCH with CUDA arch $TORCH_CUDA_ARCH_LIST ===" && python3 setup.py bdist_wheel > /tmp/pytorch_build.log 2>&1 || { tail -50 /tmp/pytorch_build.log; exit 1; } && uv pip install --no-cache dist/*.whl && mkdir -p /tmp/pytorch_wheel && cp dist/*.whl /tmp/pytorch_wheel/ # buildkit |
| ARG TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 |
| WORKDIR /build |
| RUN |3 APT_PROXY=http://pb-lxcaptcache:3142 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 /bin/sh -c git clone --recursive https://github.com/pytorch/audio.git && cd audio && git checkout ${TORCHAUDIO_COMMIT} # buildkit |
| WORKDIR /build/audio |
| ENV BUILD_SOX=0 |
| ENV USE_CUDA=1 |
| RUN |3 APT_PROXY=http://pb-lxcaptcache:3142 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then export TORCH_CUDA_ARCH_LIST="12.1"; else export TORCH_CUDA_ARCH_LIST="12.0"; fi && echo "=== Building torchaudio for $ARCH with CUDA arch $TORCH_CUDA_ARCH_LIST ===" && python3 setup.py bdist_wheel > /tmp/torchaudio_build.log 2>&1 || { tail -50 /tmp/torchaudio_build.log; exit 1; } && uv pip install --no-cache dist/*.whl # buildkit |
| WORKDIR /workspace |
| RUN |3 APT_PROXY=http://pb-lxcaptcache:3142 PYTORCH_COMMIT=32cb1dac896fe212d77073a4a53fee840c13442f TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 /bin/sh -c rm -rf /build/pytorch /build/audio # buildkit |
| ARG APT_PROXY=http://pb-lxcaptcache:3142 |
| RUN |1 APT_PROXY=http://pb-lxcaptcache:3142 /bin/sh -c if [ -n "$APT_PROXY" ]; then echo "Acquire::http::Proxy \"$APT_PROXY\";" > /etc/apt/apt.conf.d/01proxy; fi # buildkit |
| ARG TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 |
| RUN |2 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 /bin/sh -c uv pip install --no-cache tokenizers>=0.19 fastapi uvicorn[standard] pydantic>=2.0 prometheus_client py-cpuinfo tiktoken lm-format-enforcer outlines xgrammar pyzmq msgspec gguf compressed-tensors importlib_metadata mistral_common>=1.5.0 partial-json-parser # buildkit |
| ARG VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a |
| ARG VLLM_CACHE_BUSTER=v1 |
| WORKDIR /build |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c git clone https://github.com/vllm-project/vllm.git && cd vllm && git checkout ${VLLM_COMMIT} # buildkit |
| WORKDIR /build/vllm |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c sed -i 's/cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f"/cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f"/g' CMakeLists.txt && echo "=== Verifying SM120/SM121 patch applied ===" && grep -n "SCALED_MM_ARCHS" CMakeLists.txt | head -5 # buildkit |
| ENV VLLM_TARGET_DEVICE=cuda |
| ENV MAX_JOBS=8 |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c curl -fsSL https://raw.githubusercontent.com/pipecat-ai/nemotron-january-2026/main/patches/vllm-pr31607-sm121-support.patch -o /tmp/vllm-pr31607-sm121-support.patch # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c curl -fsSL https://raw.githubusercontent.com/pipecat-ai/nemotron-january-2026/main/patches/apply-vllm-pr31607.py -o /tmp/apply-vllm-pr31607.py # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c uv pip install --no-cache cmake>=3.26 ninja packaging setuptools>=61 setuptools-scm>=8 wheel jinja2 # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then export TORCH_CUDA_ARCH_LIST="12.1"; else export TORCH_CUDA_ARCH_LIST="12.0"; fi && echo "=== Building vLLM for $ARCH with CUDA arch $TORCH_CUDA_ARCH_LIST ===" && uv pip install --no-cache -e . --no-build-isolation > /tmp/vllm_build.log 2>&1 || { tail -50 /tmp/vllm_build.log; exit 1; } # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 /tmp/apply-vllm-pr31607.py && patch -p1 < /tmp/vllm-pr31607-sm121-support.patch || echo "Patch may already be applied" # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/quantization/modelopt.py")
text = path.read_text()
if "VLLM_FP8_FORCE_BACKEND" not in text:
if "import os" not in text:
needle = "from vllm.logger import init_logger\n"
if needle not in text:
raise SystemExit("Failed to locate logger import for os insertion")
text = text.replace(needle, "import os\n" + needle, 1)
insert = (
" if current_platform.is_device_capability(121):\n"
" self.fp8_linear.preferred_backend = \"cutlass\"\n"
" logger.info_once(\n"
" \"Forcing FP8 linear backend to cutlass on SM121 (flashinfer is incorrect)\"\n"
" )\n"
"\n"
" override = os.getenv(\"VLLM_FP8_FORCE_BACKEND\")\n"
" if override:\n"
" self.fp8_linear.preferred_backend = override.lower()\n"
" logger.info_once(\n"
" \"Forcing FP8 linear backend to %s via VLLM_FP8_FORCE_BACKEND\",\n"
" self.fp8_linear.preferred_backend,\n"
" )\n"
"\n"
)
if "Forcing FP8 linear backend" not in text:
lines = text.splitlines()
start = None
end = None
for i, line in enumerate(lines):
if "self.fp8_linear = Fp8LinearOp(" in line:
start = i
continue
if start is not None and line.strip() == ")":
end = i
break
if start is None or end is None:
raise SystemExit("Failed to locate fp8_linear init block in ModelOptFp8LinearMethod")
insert_lines = insert.rstrip("\n").split("\n")
lines[end + 1:end + 1] = insert_lines
text = "\n".join(lines) + "\n"
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/quantization/modelopt.py")
text = path.read_text()
if "from vllm.platforms import current_platform" not in text:
needle = "import vllm.envs as envs"
if needle not in text:
raise SystemExit("Failed to locate envs import for insertion")
text = text.replace(
needle,
"import vllm.envs as envs\nfrom vllm.platforms import current_platform",
1,
)
old = "if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():"
new = (
"if current_platform.is_device_capability(121):\n"
" logger.info_once(\n"
" \"Disabling FlashInfer FP8 MoE kernels on SM121; using fallback.\"\n"
" )\n"
" elif envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():"
)
if old in text and "Disabling FlashInfer FP8 MoE kernels on SM121" not in text:
text = text.replace(old, new)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py")
text = path.read_text()
if "import os" not in text:
needle = "import torch\n"
if needle not in text:
raise SystemExit("Failed to locate torch import for insertion")
text = text.replace(needle, "import os\nimport torch\n", 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py")
text = path.read_text()
if "per_tensor_dequantize" not in text:
needle = "from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6\n"
if needle not in text:
raise SystemExit("Failed to locate quantization util imports")
text = text.replace(
needle,
needle
+ "from vllm.model_executor.layers.quantization.utils.w8a8_utils import per_tensor_dequantize\n",
1,
)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py")
text = path.read_text()
if "_FP8_MOE_ACT_SCALE_LOGGED" not in text:
needle = "logger = init_logger(__name__)\n"
if needle not in text:
raise SystemExit("Failed to locate logger for insertion")
text = text.replace(needle, needle + "\n_FP8_MOE_ACT_SCALE_LOGGED = 0\n", 1)
if "global _FP8_MOE_ACT_SCALE_LOGGED" not in text:
lines = text.splitlines(keepends=True)
def_line = None
end_sig = None
for i, line in enumerate(lines):
if def_line is None and line.startswith("def fused_experts("):
def_line = i
continue
if def_line is not None and end_sig is None and line.strip() == ") -> torch.Tensor:":
end_sig = i
break
if end_sig is None:
raise SystemExit("Failed to locate fused_experts signature end")
lines.insert(end_sig + 1, " global _FP8_MOE_ACT_SCALE_LOGGED\n")
impl_def_line = None
impl_end_sig = None
for i, line in enumerate(lines):
if impl_def_line is None and line.startswith("def fused_experts_impl("):
impl_def_line = i
continue
if impl_def_line is not None and impl_end_sig is None and line.strip() == ") -> torch.Tensor:":
impl_end_sig = i
break
if impl_end_sig is None:
raise SystemExit("Failed to locate fused_experts_impl signature end")
lines.insert(impl_end_sig + 1, " global _FP8_MOE_ACT_SCALE_LOGGED\n")
text = "".join(lines)
anchor = (
"qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(\n"
" A=curr_hidden_states,\n"
" A_scale=a1_scale,\n"
" quant_dtype=quant_dtype,\n"
" per_act_token_quant=per_channel_quant,\n"
" block_shape=block_shape,\n"
" )\n"
)
if "FP8 MoE a1q_scale" not in text and anchor in text:
insert = (
"qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(\n"
" A=curr_hidden_states,\n"
" A_scale=a1_scale,\n"
" quant_dtype=quant_dtype,\n"
" per_act_token_quant=per_channel_quant,\n"
" block_shape=block_shape,\n"
" )\n"
" if os.getenv(\"VLLM_FP8_MOE_ACT_SCALE_DEBUG\") == \"1\":\n"
" if _FP8_MOE_ACT_SCALE_LOGGED < 1:\n"
" _FP8_MOE_ACT_SCALE_LOGGED += 1\n"
" def _sinfo(name, t):\n"
" if t is None:\n"
" return f\"{name}=None\"\n"
" t_float = t.float()\n"
" return (\n"
" f\"{name}: shape={tuple(t.shape)} dtype={t.dtype} \"\n"
" f\"min={t_float.min().item():.6g} max={t_float.max().item():.6g}\"\n"
" )\n"
" logger.info_once(\n"
" \"FP8 MoE a1q_scale: %s\",\n"
" _sinfo(\"a1q_scale\", a1q_scale),\n"
" )\n"
)
text = text.replace(anchor, insert, 1)
anchor2 = (
"qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(\n"
" A=intermediate_cache2,\n"
" A_scale=a2_scale,\n"
" quant_dtype=quant_dtype,\n"
" per_act_token_quant=per_channel_quant,\n"
" block_shape=block_shape,\n"
" )\n"
)
if "FP8 MoE a2q_scale" not in text and anchor2 in text:
insert2 = (
"qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(\n"
" A=intermediate_cache2,\n"
" A_scale=a2_scale,\n"
" quant_dtype=quant_dtype,\n"
" per_act_token_quant=per_channel_quant,\n"
" block_shape=block_shape,\n"
" )\n"
" if os.getenv(\"VLLM_FP8_MOE_ACT_SCALE_DEBUG\") == \"1\":\n"
" if _FP8_MOE_ACT_SCALE_LOGGED == 1:\n"
" _FP8_MOE_ACT_SCALE_LOGGED += 1\n"
" def _sinfo(name, t):\n"
" if t is None:\n"
" return f\"{name}=None\"\n"
" t_float = t.float()\n"
" return (\n"
" f\"{name}: shape={tuple(t.shape)} dtype={t.dtype} \"\n"
" f\"min={t_float.min().item():.6g} max={t_float.max().item():.6g}\"\n"
" )\n"
" logger.info_once(\n"
" \"FP8 MoE a2q_scale: %s\",\n"
" _sinfo(\"a2q_scale\", a2q_scale),\n"
" )\n"
)
text = text.replace(anchor2, insert2, 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/models/nemotron_h.py")
text = path.read_text()
if "import os" not in text:
needle = "import typing\n"
if needle not in text:
raise SystemExit("Failed to locate import typing for insertion")
text = text.replace(needle, needle + "import os\n", 1)
old = (
" fused_moe_out = self.experts(\n"
" hidden_states=hidden_states, router_logits=router_logits\n"
" )\n"
)
if old in text and "VLLM_NEMO_BYPASS_MOE" not in text:
new = (
" if os.getenv(\"VLLM_NEMO_BYPASS_MOE\") == \"1\":\n"
" if self.shared_experts is not None and not self.use_latent_moe:\n"
" shared_output = self.shared_experts(hidden_states)\n"
" fused_moe_out = (shared_output, hidden_states)\n"
" else:\n"
" fused_moe_out = self.experts(\n"
" hidden_states=hidden_states, router_logits=router_logits\n"
" )\n"
)
text = text.replace(old, new, 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/models/nemotron_h.py")
text = path.read_text()
if "init_logger" not in text:
needle = "from torch import nn\n\n"
if needle not in text:
raise SystemExit("Failed to locate torch import block for logger insertion")
insert = needle + "from vllm.logger import init_logger\n\nlogger = init_logger(__name__)\n\n_NEMO_SHARED_REF_USED = 0\n\n"
text = text.replace(needle, insert, 1)
block = (
" if self.use_latent_moe:\n"
" _, final_hidden_states = fused_moe_out\n"
" else:\n"
" shared_output, final_hidden_states = fused_moe_out\n"
)
if block in text and "VLLM_NEMO_SHARED_REF_COMPARE" not in text:
extra = (
" if (\n"
" os.getenv(\"VLLM_NEMO_SHARED_REF_COMPARE\") == \"1\"\n"
" and (not self.use_latent_moe)\n"
" and shared_output is not None\n"
" and self.shared_experts is not None\n"
" ):\n"
" global _NEMO_SHARED_REF_USED\n"
" if _NEMO_SHARED_REF_USED == 0:\n"
" ref = self.shared_experts(hidden_states.to(dtype=torch.bfloat16))\n"
" diff = (ref.to(torch.float32) - shared_output.to(torch.float32)).abs()\n"
" logger.info(\n"
" \"NemotronH shared_experts ref compare: mean_abs=%s max_abs=%s ref_norm=%s out_norm=%s\",\n"
" diff.mean().item(),\n"
" diff.max().item(),\n"
" ref.to(torch.float32).norm().item(),\n"
" shared_output.to(torch.float32).norm().item(),\n"
" )\n"
" _NEMO_SHARED_REF_USED = 1\n"
)
text = text.replace(block, block + extra, 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py")
text = path.read_text()
if "_FP8_MOE_REF_USED" not in text:
needle = "logger = init_logger(__name__)\n"
if needle not in text:
raise SystemExit("Failed to locate logger for insertion")
text = text.replace(needle, needle + "\n_FP8_MOE_REF_USED = 0\n", 1)
if "_pytorch_fp8_moe_fallback" not in text:
anchor = (
"def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:\n"
" return torch.ops.vllm.outplace_fused_experts(**kwargs)\n\n\n"
)
if anchor not in text:
raise SystemExit("Failed to locate torch_vllm_outplace_fused_experts anchor")
fallback = '''def _pytorch_fp8_moe_fallback(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
apply_router_weight_on_input: bool,
expert_map: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
) -> torch.Tensor:
# Slow but safe fallback for SM121 FP8 MoE.
if expert_map is not None:
topk_ids = expert_map[topk_ids]
num_tokens = hidden_states.size(0)
top_k = topk_ids.size(1)
out = torch.zeros_like(hidden_states, dtype=torch.float32)
w1_scale = quant_config.w1_scale
w2_scale = quant_config.w2_scale
a1_scale = quant_config.a1_scale
a2_scale = quant_config.a2_scale
if os.getenv("VLLM_MOE_FALLBACK_DEBUG") == "1":
def _tinfo(name, t):
if t is None:
return f\"{name}=None\"
with torch.no_grad():
t_float = t.float()
return (
f\"{name}: shape={tuple(t.shape)} dtype={t.dtype} \"
f\"min={t_float.min().item():.6g} max={t_float.max().item():.6g}\"
)
print(\"[moe_fallback_debug]\", _tinfo(\"w1_scale\", w1_scale))
print(\"[moe_fallback_debug]\", _tinfo(\"w2_scale\", w2_scale))
print(\"[moe_fallback_debug]\", _tinfo(\"a1_scale\", a1_scale))
print(\"[moe_fallback_debug]\", _tinfo(\"a2_scale\", a2_scale))
force_bf16 = os.getenv("VLLM_FP8_MOE_FALLBACK_BF16") == "1"
for t in range(num_tokens):
x = hidden_states[t].to(torch.float32)
for k in range(top_k):
expert_id = int(topk_ids[t, k].item())
if expert_id < 0 or expert_id >= w1.size(0):
continue
weight = topk_weights[t, k].to(torch.float32)
x_in = x * weight if apply_router_weight_on_input else x
# a1_scale intentionally ignored in this variant
w1_e = w1[expert_id]
w2_e = w2[expert_id]
if w1_scale is not None:
s1 = w1_scale[expert_id]
if s1.numel() > 1:
s1 = s1.max()
w1_e = per_tensor_dequantize(w1_e, s1)
else:
w1_e = w1_e.to(torch.float16)
if w2_scale is not None:
s2 = w2_scale[expert_id]
if s2.numel() > 1:
s2 = s2.max()
w2_e = per_tensor_dequantize(w2_e, s2)
else:
w2_e = w2_e.to(torch.float16)
if force_bf16:
w1_e = w1_e.to(torch.bfloat16)
w2_e = w2_e.to(torch.bfloat16)
x_in = x_in.to(torch.bfloat16)
else:
w1_e = w1_e.to(torch.float32)
w2_e = w2_e.to(torch.float32)
z = torch.matmul(w1_e, x_in)
if activation == "silu":
if z.numel() % 2 == 0:
a, b = z.chunk(2, dim=-1)
h = F.silu(a) * b
else:
h = F.silu(z)
elif activation == "relu2_no_mul":
h = F.relu(z) ** 2
elif activation == "gelu":
h = F.gelu(z)
else:
h = F.silu(z)
# a2_scale intentionally ignored in this variant
y = torch.matmul(w2_e, h)
if not apply_router_weight_on_input:
y = y * weight
out[t] += y.float()
return out.to(hidden_states.dtype)
'''
text = text.replace(anchor, anchor + fallback, 1)
# Inject fallback / reference use in fused_experts
old = " if quant_config is None:\\n quant_config = FUSED_MOE_UNQUANTIZED_CONFIG\\n\\n"
new = (
" if quant_config is None:\\n"
" quant_config = FUSED_MOE_UNQUANTIZED_CONFIG\\n\\n"
" if quant_config.use_fp8_w8a8 and current_platform.is_device_capability(121):\\n"
" if os.getenv(\"VLLM_FP8_MOE_REF_ONCE\") == \"1\":\\n"
" global _FP8_MOE_REF_USED\\n"
" if _FP8_MOE_REF_USED < 1:\\n"
" _FP8_MOE_REF_USED += 1\\n"
" ref_out = _pytorch_fp8_moe_fallback(\\n"
" hidden_states=hidden_states,\\n"
" w1=w1,\\n"
" w2=w2,\\n"
" topk_weights=topk_weights,\\n"
" topk_ids=topk_ids,\\n"
" activation=activation,\\n"
" apply_router_weight_on_input=apply_router_weight_on_input,\\n"
" expert_map=expert_map,\\n"
" quant_config=quant_config,\\n"
" )\\n"
" if os.getenv(\"VLLM_FP8_MOE_REF_COMPARE\") == \"1\":\\n"
" cmp_out = dispatch_fused_experts_func(inplace)(\\n"
" hidden_states=hidden_states,\\n"
" w1=w1,\\n"
" w2=w2,\\n"
" topk_weights=topk_weights,\\n"
" topk_ids=topk_ids,\\n"
" activation=activation,\\n"
" apply_router_weight_on_input=apply_router_weight_on_input,\\n"
" use_fp8_w8a8=quant_config.use_fp8_w8a8,\\n"
" use_int8_w8a8=quant_config.use_int8_w8a8,\\n"
" use_int8_w8a16=quant_config.use_int8_w8a16,\\n"
" use_int4_w4a16=quant_config.use_int4_w4a16,\\n"
" ocp_mx_scheme=quant_config.ocp_mx_scheme,\\n"
" per_channel_quant=quant_config.per_act_token_quant,\\n"
" global_num_experts=global_num_experts,\\n"
" expert_map=expert_map,\\n"
" w1_scale=quant_config.w1_scale,\\n"
" w2_scale=quant_config.w2_scale,\\n"
" w1_zp=quant_config.w1_zp,\\n"
" w2_zp=quant_config.w2_zp,\\n"
" a1_scale=quant_config.a1_scale,\\n"
" a2_scale=quant_config.a2_scale,\\n"
" block_shape=quant_config.block_shape,\\n"
" w1_bias=quant_config.w1_bias,\\n"
" w2_bias=quant_config.w2_bias,\\n"
" )\\n"
" with torch.no_grad():\\n"
" diff = (ref_out.float() - cmp_out.float()).abs()\\n"
" logger.info_once(\\n"
" \"FP8 MoE ref compare: mean_abs=%.6g max_abs=%.6g ref_norm=%.6g out_norm=%.6g\",\\n"
" diff.mean().item(),\\n"
" diff.max().item(),\\n"
" ref_out.float().norm().item(),\\n"
" cmp_out.float().norm().item(),\\n"
" )\\n"
" if os.getenv(\"VLLM_FP8_MOE_REF_RETURN\", \"1\") != \"0\":\\n"
" return ref_out\\n"
" if os.getenv(\"VLLM_FP8_MOE_FORCE_FALLBACK\", \"1\") != \"0\":\\n"
" return _pytorch_fp8_moe_fallback(\\n"
" hidden_states=hidden_states,\\n"
" w1=w1,\\n"
" w2=w2,\\n"
" topk_weights=topk_weights,\\n"
" topk_ids=topk_ids,\\n"
" activation=activation,\\n"
" apply_router_weight_on_input=apply_router_weight_on_input,\\n"
" expert_map=expert_map,\\n"
" quant_config=quant_config,\\n"
" )\\n\\n"
)
if old in text and "_pytorch_fp8_moe_fallback" in text:
text = text.replace(old, new, 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py")
text = path.read_text()
if "VLLM_FP8_MOE_REF_ONCE" not in text:
if "_FP8_MOE_REF_USED" not in text:
needle = "logger = init_logger(__name__)\n"
if needle not in text:
raise SystemExit("Failed to locate logger for insertion")
text = text.replace(needle, needle + "\n_FP8_MOE_REF_USED = 0\n", 1)
needle = " if quant_config is None:\n quant_config = FUSED_MOE_UNQUANTIZED_CONFIG\n\n"
if needle not in text:
raise SystemExit("Failed to locate quant_config init for insertion")
insert = "".join([
needle,
" if quant_config.use_fp8_w8a8 and current_platform.is_device_capability(121):\n",
" if os.getenv(\"VLLM_FP8_MOE_REF_ONCE\") == \"1\":\n",
" global _FP8_MOE_REF_USED\n",
" if _FP8_MOE_REF_USED < 1:\n",
" _FP8_MOE_REF_USED += 1\n",
" ref_out = _pytorch_fp8_moe_fallback(\n",
" hidden_states=hidden_states,\n",
" w1=w1,\n",
" w2=w2,\n",
" topk_weights=topk_weights,\n",
" topk_ids=topk_ids,\n",
" activation=activation,\n",
" apply_router_weight_on_input=apply_router_weight_on_input,\n",
" expert_map=expert_map,\n",
" quant_config=quant_config,\n",
" )\n",
" if os.getenv(\"VLLM_FP8_MOE_REF_COMPARE\") == \"1\":\n",
" cmp_out = dispatch_fused_experts_func(inplace)(\n",
" hidden_states=hidden_states,\n",
" w1=w1,\n",
" w2=w2,\n",
" topk_weights=topk_weights,\n",
" topk_ids=topk_ids,\n",
" activation=activation,\n",
" apply_router_weight_on_input=apply_router_weight_on_input,\n",
" use_fp8_w8a8=quant_config.use_fp8_w8a8,\n",
" use_int8_w8a8=quant_config.use_int8_w8a8,\n",
" use_int8_w8a16=quant_config.use_int8_w8a16,\n",
" use_int4_w4a16=quant_config.use_int4_w4a16,\n",
" ocp_mx_scheme=quant_config.ocp_mx_scheme,\n",
" per_channel_quant=quant_config.per_act_token_quant,\n",
" global_num_experts=global_num_experts,\n",
" expert_map=expert_map,\n",
" w1_scale=quant_config.w1_scale,\n",
" w2_scale=quant_config.w2_scale,\n",
" w1_zp=quant_config.w1_zp,\n",
" w2_zp=quant_config.w2_zp,\n",
" a1_scale=quant_config.a1_scale,\n",
" a2_scale=quant_config.a2_scale,\n",
" block_shape=quant_config.block_shape,\n",
" w1_bias=quant_config.w1_bias,\n",
" w2_bias=quant_config.w2_bias,\n",
" )\n",
" with torch.no_grad():\n",
" diff = (ref_out.float() - cmp_out.float()).abs()\n",
" logger.info_once(\n",
" \"FP8 MoE ref compare: mean_abs=%.6g max_abs=%.6g ref_norm=%.6g out_norm=%.6g\",\n",
" diff.mean().item(),\n",
" diff.max().item(),\n",
" ref_out.float().norm().item(),\n",
" cmp_out.float().norm().item(),\n",
" )\n",
" def _scale_info(name, t):\n",
" if t is None:\n",
" return f\"{name}=None\"\n",
" t_float = t.float()\n",
" return (\n",
" f\"{name}: shape={tuple(t.shape)} dtype={t.dtype} \"\n",
" f\"min={t_float.min().item():.6g} max={t_float.max().item():.6g}\"\n",
" )\n",
" logger.info_once(\n",
" \"FP8 MoE ref scales: %s | %s | %s | %s\",\n",
" _scale_info(\"w1_scale\", quant_config.w1_scale),\n",
" _scale_info(\"w2_scale\", quant_config.w2_scale),\n",
" _scale_info(\"a1_scale\", quant_config.a1_scale),\n",
" _scale_info(\"a2_scale\", quant_config.a2_scale),\n",
" )\n",
" if os.getenv(\"VLLM_FP8_MOE_BF16_REF_COMPARE\") == \"1\":\n",
" prev = os.getenv(\"VLLM_FP8_MOE_FALLBACK_BF16\")\n",
" os.environ[\"VLLM_FP8_MOE_FALLBACK_BF16\"] = \"1\"\n",
" bf16_ref = _pytorch_fp8_moe_fallback(\n",
" hidden_states=hidden_states,\n",
" w1=w1,\n",
" w2=w2,\n",
" topk_weights=topk_weights,\n",
" topk_ids=topk_ids,\n",
" activation=activation,\n",
" apply_router_weight_on_input=apply_router_weight_on_input,\n",
" expert_map=expert_map,\n",
" quant_config=quant_config,\n",
" )\n",
" if prev is None:\n",
" del os.environ[\"VLLM_FP8_MOE_FALLBACK_BF16\"]\n",
" else:\n",
" os.environ[\"VLLM_FP8_MOE_FALLBACK_BF16\"] = prev\n",
" diff_bf16 = (bf16_ref.float() - cmp_out.float()).abs()\n",
" logger.info_once(\n",
" \"FP8 MoE BF16 ref compare: mean_abs=%.6g max_abs=%.6g ref_norm=%.6g out_norm=%.6g\",\n",
" diff_bf16.mean().item(),\n",
" diff_bf16.max().item(),\n",
" bf16_ref.float().norm().item(),\n",
" cmp_out.float().norm().item(),\n",
" )\n",
" if os.getenv(\"VLLM_FP8_MOE_REF_RETURN\", \"1\") != \"0\":\n",
" return ref_out\n",
])
text = text.replace(needle, insert, 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/mamba/ops/ssd_combined.py")
text = path.read_text()
if "mamba_ssm.ops.triton.ssd_combined" not in text:
needle = "TRITON_22 = version.parse(triton.__version__) >= version.parse(\"2.2.0\")\n"
if needle not in text:
raise SystemExit("Failed to locate TRITON_22 for insertion")
insert = (
needle
+ "\n"
+ "try:\n"
+ " from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as _mamba_ssm_chunk_scan_combined\n"
+ "except Exception:\n"
+ " _mamba_ssm_chunk_scan_combined = None\n"
+ "from vllm.platforms import current_platform\n\n"
+ "def _use_original_mamba_ssm_varlen(cu_seqlens, return_intermediate_states):\n"
+ " return (\n"
+ " _mamba_ssm_chunk_scan_combined is not None\n"
+ " and current_platform.is_device_capability(121)\n"
+ " and cu_seqlens is not None\n"
+ " and not return_intermediate_states\n"
+ " )\n\n"
+ "def _mamba_ssm_chunk_scan_varlen(\n"
+ " x,\n"
+ " dt,\n"
+ " A,\n"
+ " B,\n"
+ " C,\n"
+ " chunk_size,\n"
+ " out,\n"
+ " D=None,\n"
+ " z=None,\n"
+ " dt_bias=None,\n"
+ " initial_states=None,\n"
+ " cu_seqlens=None,\n"
+ " dt_softplus=False,\n"
+ " dt_limit=(0.0, float(\"inf\")),\n"
+ " state_dtype=None,\n"
+ "):\n"
+ " batch = int(cu_seqlens.numel() - 1)\n"
+ " lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).to(torch.int64)\n"
+ " max_len = int(lengths.max().item()) if batch > 0 else 0\n"
+ " seqlen_total, nheads, headdim = x.shape\n"
+ " _, ngroups, dstate = B.shape\n"
+ " if out is None:\n"
+ " out = torch.empty_like(x)\n"
+ " if max_len == 0:\n"
+ " empty_dtype = state_dtype if state_dtype is not None else B.dtype\n"
+ " return torch.zeros((batch, nheads, headdim, dstate), device=x.device, dtype=empty_dtype)\n"
+ " x_b = x.new_zeros((batch, max_len, nheads, headdim))\n"
+ " dt_b = dt.new_zeros((batch, max_len, nheads))\n"
+ " B_b = B.new_zeros((batch, max_len, ngroups, dstate))\n"
+ " C_b = C.new_zeros((batch, max_len, ngroups, dstate))\n"
+ " z_b = None\n"
+ " if z is not None:\n"
+ " z_b = z.new_zeros((batch, max_len, nheads, headdim))\n"
+ " for i in range(batch):\n"
+ " start = int(cu_seqlens[i].item())\n"
+ " end = int(cu_seqlens[i + 1].item())\n"
+ " length = end - start\n"
+ " if length <= 0:\n"
+ " continue\n"
+ " x_b[i, :length] = x[start:end]\n"
+ " dt_b[i, :length] = dt[start:end]\n"
+ " B_b[i, :length] = B[start:end]\n"
+ " C_b[i, :length] = C[start:end]\n"
+ " if z_b is not None:\n"
+ " z_b[i, :length] = z[start:end]\n"
+ " cu_seqlens_b = torch.zeros(batch + 1, device=cu_seqlens.device, dtype=cu_seqlens.dtype)\n"
+ " if batch > 0:\n"
+ " cu_seqlens_b[1:] = torch.cumsum(lengths, dim=0)\n"
+ " out_b, varlen_states = _mamba_ssm_chunk_scan_combined(\n"
+ " x_b,\n"
+ " dt_b,\n"
+ " A,\n"
+ " B_b,\n"
+ " C_b,\n"
+ " chunk_size,\n"
+ " D=D,\n"
+ " z=z_b,\n"
+ " dt_bias=dt_bias,\n"
+ " initial_states=initial_states,\n"
+ " seq_idx=None,\n"
+ " cu_seqlens=cu_seqlens_b,\n"
+ " dt_softplus=dt_softplus,\n"
+ " dt_limit=dt_limit,\n"
+ " return_final_states=False,\n"
+ " return_varlen_states=True,\n"
+ " )\n"
+ " for i in range(batch):\n"
+ " start = int(cu_seqlens[i].item())\n"
+ " end = int(cu_seqlens[i + 1].item())\n"
+ " length = end - start\n"
+ " if length <= 0:\n"
+ " continue\n"
+ " out[start:end].copy_(out_b[i, :length])\n"
+ " if state_dtype is not None and varlen_states.dtype != state_dtype:\n"
+ " varlen_states = varlen_states.to(state_dtype)\n"
+ " return varlen_states\n"
)
text = text.replace(needle, insert, 1)
needle = " assert cu_seqlens is not None, \"cu_seqlens must be provided assuming varlen input\""
if needle in text and "_mamba_ssm_chunk_scan_varlen" in text:
insert = (
" if _use_original_mamba_ssm_varlen(cu_seqlens, return_intermediate_states):\n"
" return _mamba_ssm_chunk_scan_varlen(\n"
" x=x,\n"
" dt=dt,\n"
" A=A,\n"
" B=B,\n"
" C=C,\n"
" chunk_size=chunk_size,\n"
" out=out,\n"
" D=D,\n"
" z=z,\n"
" dt_bias=dt_bias,\n"
" initial_states=initial_states,\n"
" cu_seqlens=cu_seqlens,\n"
" dt_softplus=dt_softplus,\n"
" dt_limit=dt_limit,\n"
" state_dtype=state_dtype,\n"
" )\n\n"
)
text = text.replace(needle, insert + needle, 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/quantization/modelopt.py")
text = path.read_text()
needle = " def process_weights_after_loading(self, layer: Module) -> None:\n"
if needle not in text:
raise SystemExit("Failed to locate process_weights_after_loading")
insert = "".join(
[
needle,
" weight = layer.weight\n",
" if getattr(layer, \"fp8_keep_scales\", False):\n",
" layer.weight = Parameter(weight.t(), requires_grad=False)\n",
" layer.weight_scale = Parameter(layer.weight_scale.detach(), requires_grad=False)\n",
" layer.input_scale = Parameter(layer.input_scale.detach(), requires_grad=False)\n",
" return\n",
]
)
if insert not in text:
text = text.replace(needle, insert, 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/models/nemotron_h.py")
text = path.read_text()
import_line = "from vllm.model_executor.layers.quantization import QuantizationConfig\n"
if "per_tensor_dequantize" not in text and import_line in text:
text = text.replace(
import_line,
import_line
+ "from vllm.model_executor.layers.quantization.utils.w8a8_utils import per_tensor_dequantize\n",
1,
)
if "is_forward_context_available" not in text:
text = text.replace(
"from vllm.distributed.parallel_state import get_pp_group\n",
"from vllm.distributed.parallel_state import get_pp_group\n"
"from vllm.forward_context import get_forward_context, is_forward_context_available\n",
1,
)
if "_NEMO_ATT_REF_USED" not in text:
needle = "_NEMO_SHARED_REF_USED = 0\n"
if needle not in text:
raise SystemExit("Failed to locate _NEMO_SHARED_REF_USED for insertion")
helper = "".join(
[
"_NEMO_ATT_REF_USED = 0\n\n",
"def _dequantize_linear_weight_for_ref(layer, logical_widths=None):\n",
" weight = layer.weight\n",
" scale = getattr(layer, \"weight_scale\", None)\n",
" if scale is None:\n",
" return weight.to(torch.bfloat16)\n",
" scale = scale.detach()\n",
" if scale.numel() == 1 or not logical_widths:\n",
" s = scale.max()\n",
" return per_tensor_dequantize(weight, s).to(torch.bfloat16)\n",
" if scale.numel() == len(logical_widths):\n",
" splits = torch.split(weight, logical_widths, dim=1)\n",
" dq_splits = [\n",
" per_tensor_dequantize(w, scale[idx]) for idx, w in enumerate(splits)\n",
" ]\n",
" return torch.cat(dq_splits, dim=1).to(torch.bfloat16)\n",
" s = scale.max()\n",
" return per_tensor_dequantize(weight, s).to(torch.bfloat16)\n\n",
]
)
text = text.replace(needle, needle + helper, 1)
if "NemotronH attention BF16 ref compare" not in text:
old = " output, _ = self.o_proj(attn_output)\n return output\n"
if old not in text:
raise SystemExit("Failed to locate NemotronHAttention output return")
new = "".join(
[
" output, _ = self.o_proj(attn_output)\n",
" if os.getenv(\"VLLM_FP8_ATTENTION_BF16_REF\") == \"1\":\n",
" global _NEMO_ATT_REF_USED\n",
" if (\n",
" is_forward_context_available()\n",
" and get_forward_context().attn_metadata is not None\n",
" ):\n",
" if _NEMO_ATT_REF_USED < 1:\n",
" _NEMO_ATT_REF_USED += 1\n",
" with torch.no_grad():\n",
" x_bf16 = hidden_states.to(torch.bfloat16)\n",
" qkv_w = _dequantize_linear_weight_for_ref(\n",
" self.qkv_proj,\n",
" getattr(self.qkv_proj, \"output_sizes\", None),\n",
" )\n",
" if qkv_w.shape[0] == x_bf16.shape[-1]:\n",
" qkv_ref = torch.matmul(x_bf16, qkv_w)\n",
" else:\n",
" qkv_ref = torch.matmul(x_bf16, qkv_w.t())\n",
" if self.qkv_proj.bias is not None:\n",
" qkv_ref = qkv_ref + self.qkv_proj.bias\n",
" q_ref, k_ref, v_ref = qkv_ref.split(\n",
" [self.q_size, self.kv_size, self.kv_size], dim=-1\n",
" )\n",
" attn_ref = self.attn(q_ref, k_ref, v_ref)\n",
" o_w = _dequantize_linear_weight_for_ref(self.o_proj)\n",
" attn_ref_bf16 = attn_ref.to(torch.bfloat16)\n",
" if o_w.shape[0] == attn_ref_bf16.shape[-1]:\n",
" ref_out = torch.matmul(attn_ref_bf16, o_w)\n",
" else:\n",
" ref_out = torch.matmul(attn_ref_bf16, o_w.t())\n",
" if self.o_proj.bias is not None:\n",
" ref_out = ref_out + self.o_proj.bias\n",
" diff = (ref_out.float() - output.float()).abs()\n",
" logger.info_once(\n",
" \"NemotronH attention BF16 ref compare: mean_abs=%s max_abs=%s ref_norm=%s out_norm=%s\",\n",
" diff.mean().item(),\n",
" diff.max().item(),\n",
" ref_out.float().norm().item(),\n",
" output.float().norm().item(),\n",
" )\n",
" return output\n",
]
)
text = text.replace(old, new, 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/mamba/mamba_mixer2.py")
text = path.read_text()
if "import os\n" not in text:
text = text.replace(
"import torch\n",
"import os\nimport torch\n",
1,
)
if "init_logger" not in text:
text = text.replace(
"from vllm.forward_context import ForwardContext, get_forward_context\n",
"from vllm.forward_context import ForwardContext, get_forward_context, is_forward_context_available\n"
"from vllm.logger import init_logger\n",
1,
)
if "per_tensor_dequantize" not in text:
text = text.replace(
"from vllm.model_executor.layers.quantization import QuantizationConfig\n",
"from vllm.model_executor.layers.quantization import QuantizationConfig\n"
"from vllm.model_executor.layers.quantization.utils.w8a8_utils import per_tensor_dequantize\n",
1,
)
if "_MAMBA_REF_USED" not in text:
insert_after = "from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata\n"
if insert_after not in text:
raise SystemExit("Failed to locate Mamba2AttentionMetadata import")
helper = "".join(
[
"\nlogger = init_logger(__name__)\n",
"_MAMBA_REF_USED = 0\n",
"_MAMBA_INPROJ_BYPASS_USED = 0\n\n",
"def _dequantize_linear_weight_for_ref(layer, logical_widths=None):\n",
" weight = layer.weight\n",
" scale = getattr(layer, \"weight_scale\", None)\n",
" if scale is None:\n",
" return weight.to(torch.bfloat16)\n",
" scale = scale.detach()\n",
" if scale.numel() == 1 or not logical_widths:\n",
" s = scale.max()\n",
" return per_tensor_dequantize(weight, s).to(torch.bfloat16)\n",
" if scale.numel() == len(logical_widths):\n",
" splits = torch.split(weight, logical_widths, dim=1)\n",
" dq_splits = [\n",
" per_tensor_dequantize(w, scale[idx]) for idx, w in enumerate(splits)\n",
" ]\n",
" return torch.cat(dq_splits, dim=1).to(torch.bfloat16)\n",
" s = scale.max()\n",
" return per_tensor_dequantize(weight, s).to(torch.bfloat16)\n",
"def _matmul_weight(x, w):\n",
" if w.shape[0] == x.shape[-1]:\n",
" return torch.matmul(x, w)\n",
" return torch.matmul(x, w.t())\n",
]
)
text = text.replace(insert_after, insert_after + helper, 1)
if "Mamba BF16 ref compare" not in text:
old = " # 1. Gated MLP's linear projection\n projected_states, _ = self.in_proj(hidden_states)\n"
if old not in text:
raise SystemExit("Failed to locate MambaMixer2 projection block")
new = "".join(
[
" input_states = hidden_states\n",
" # 1. Gated MLP's linear projection\n",
" weight_scale = getattr(self.in_proj, \"weight_scale\", None)\n",
" input_scale = getattr(self.in_proj, \"input_scale\", None)\n",
" use_fp8_inproj = (\n",
" hasattr(self.in_proj, \"fp8_keep_scales\")\n",
" and self.in_proj.fp8_keep_scales\n",
" and weight_scale is not None\n",
" and input_scale is not None\n",
" and weight_scale.numel() > 1\n",
" and input_scale.numel() == weight_scale.numel()\n",
" and hasattr(self.in_proj, \"quant_method\")\n",
" and hasattr(self.in_proj.quant_method, \"fp8_linear\")\n",
" )\n",
" output_sizes = getattr(self.in_proj, \"output_sizes\", None)\n",
" if isinstance(output_sizes, list):\n",
" output_sizes = tuple(output_sizes)\n",
" if os.getenv(\"VLLM_FP8_MAMBA_INPROJ_DEBUG\"):\n",
" logger.info_once(\n",
" \"Mamba in_proj per-chunk FP8 eligible=%s weight_scale=%s input_scale=%s output_sizes=%s\",\n",
" use_fp8_inproj,\n",
" None if weight_scale is None else tuple(weight_scale.shape),\n",
" None if input_scale is None else tuple(input_scale.shape),\n",
" output_sizes,\n",
" )\n",
" if use_fp8_inproj:\n",
" # Mamba in_proj unfused FP8: apply per-chunk scales.\n",
" weight = self.in_proj.weight\n",
" weight_scales = weight_scale\n",
" input_scales = input_scale\n",
" outputs = []\n",
" bias = self.in_proj.bias\n",
" bias_chunks = None\n",
" if bias is not None:\n",
" bias_chunks = torch.split(bias, self.in_proj.output_sizes, dim=0)\n",
" for idx, out_size in enumerate(self.in_proj.output_sizes):\n",
" w_chunk = weight[:, :out_size]\n",
" weight = weight[:, out_size:]\n",
" b_chunk = None\n",
" if bias_chunks is not None:\n",
" b_chunk = bias_chunks[idx]\n",
" chunk_out = self.in_proj.quant_method.fp8_linear.apply(\n",
" input=hidden_states,\n",
" weight=w_chunk,\n",
" weight_scale=weight_scales[idx],\n",
" input_scale=input_scales[idx],\n",
" bias=b_chunk,\n",
" )\n",
" outputs.append(chunk_out)\n",
" projected_states = torch.cat(outputs, dim=-1)\n",
" else:\n",
" projected_states, _ = self.in_proj(hidden_states)\n",
]
)
text = text.replace(old, new, 1)
# Mark Mamba in_proj to keep per-chunk scales for FP8.
if "fp8_keep_scales" not in text:
tag = " self.in_proj.fp8_keep_scales = True\n"
merged_block = "".join(
[
" self.in_proj = MergedColumnParallelLinear(\n",
" input_size=hidden_size,\n",
" output_sizes=[\n",
" intermediate_size,\n",
" intermediate_size,\n",
" self.groups_ssm_state_size,\n",
" self.groups_ssm_state_size,\n",
" self.num_heads,\n",
" ],\n",
" bias=use_bias,\n",
" quant_config=quant_config,\n",
" prefix=f\"{prefix}.in_proj\",\n",
" )\n",
]
)
column_block = "".join(
[
" self.in_proj = ColumnParallelLinear(\n",
" input_size=hidden_size,\n",
" output_size=intermediate_size + self.conv_dim + self.num_heads,\n",
" bias=use_bias,\n",
" quant_config=quant_config,\n",
" prefix=f\"{prefix}.in_proj\",\n",
" )\n",
]
)
if merged_block in text:
text = text.replace(merged_block, merged_block + tag, 1)
if column_block in text:
text = text.replace(column_block, column_block + tag, 1)
# Per-chunk FP8 in_proj path is installed by the projection block replacement.
bypass_old = (
" if mup_vector is not None:\n"
" projected_states = projected_states * mup_vector\n"
)
if bypass_old not in text:
raise SystemExit("Failed to locate MambaMixer2 mup_vector block")
bypass_insert = "".join(
[
bypass_old,
" bypass_mode = os.getenv(\"VLLM_FP8_MAMBA_INPROJ_BF16_BYPASS\")\n",
" if bypass_mode in {\"1\", \"all\"}:\n",
" global _MAMBA_INPROJ_BYPASS_USED\n",
" if (\n",
" is_forward_context_available()\n",
" and get_forward_context().attn_metadata is not None\n",
" ):\n",
" if bypass_mode == \"all\" or _MAMBA_INPROJ_BYPASS_USED < 1:\n",
" _MAMBA_INPROJ_BYPASS_USED += 1\n",
" with torch.no_grad():\n",
" x_bf16 = input_states.to(torch.bfloat16)\n",
" in_w = _dequantize_linear_weight_for_ref(\n",
" self.in_proj,\n",
" getattr(self.in_proj, \"output_sizes\", None),\n",
" )\n",
" proj_ref = _matmul_weight(x_bf16, in_w)\n",
" if self.in_proj.bias is not None:\n",
" proj_ref = proj_ref + self.in_proj.bias\n",
" if mup_vector is not None:\n",
" proj_ref = proj_ref * mup_vector.to(torch.bfloat16)\n",
" projected_states = proj_ref.to(projected_states.dtype)\n",
" if bypass_mode == \"all\":\n",
" logger.info_once(\n",
" \"Mamba in_proj BF16 bypass enabled (all layers)\")\n",
" else:\n",
" logger.info_once(\n",
" \"Mamba in_proj BF16 bypass enabled (single call)\")\n",
]
)
text = text.replace(bypass_old, bypass_insert, 1)
old_out = " output, _ = self.out_proj(hidden_states)\n\n return output\n"
if old_out not in text:
raise SystemExit("Failed to locate MambaMixer2 output return")
insert = "".join(
[
" output, _ = self.out_proj(hidden_states)\n",
" if os.getenv(\"VLLM_FP8_MAMBA_BF16_REF\") == \"1\":\n",
" global _MAMBA_REF_USED\n",
" if (\n",
" is_forward_context_available()\n",
" and get_forward_context().attn_metadata is not None\n",
" ):\n",
" if _MAMBA_REF_USED < 1:\n",
" _MAMBA_REF_USED += 1\n",
" with torch.no_grad():\n",
" x_bf16 = input_states.to(torch.bfloat16)\n",
" in_w = _dequantize_linear_weight_for_ref(\n",
" self.in_proj,\n",
" getattr(self.in_proj, \"output_sizes\", None),\n",
" )\n",
" proj_ref = _matmul_weight(x_bf16, in_w)\n",
" if self.in_proj.bias is not None:\n",
" proj_ref = proj_ref + self.in_proj.bias\n",
" if mup_vector is not None:\n",
" proj_ref = proj_ref * mup_vector.to(torch.bfloat16)\n",
" diff_in = (proj_ref.float() - projected_states.float()).abs()\n",
" logger.info_once(\n",
" \"Mamba BF16 ref compare (in_proj): mean_abs=%s max_abs=%s ref_norm=%s out_norm=%s\",\n",
" diff_in.mean().item(),\n",
" diff_in.max().item(),\n",
" proj_ref.float().norm().item(),\n",
" projected_states.float().norm().item(),\n",
" )\n",
" ssm_ref = torch.empty(\n",
" [\n",
" input_states.shape[0],\n",
" (self.num_heads // self.tp_size) * self.head_dim,\n",
" ],\n",
" dtype=proj_ref.dtype,\n",
" device=proj_ref.device,\n",
" )\n",
" torch.ops.vllm.mamba_mixer2(\n",
" proj_ref,\n",
" ssm_ref,\n",
" self.prefix,\n",
" )\n",
" diff_ssm = (ssm_ref.float() - ssm_output.float()).abs()\n",
" logger.info_once(\n",
" \"Mamba BF16 ref compare (ssm): mean_abs=%s max_abs=%s ref_norm=%s out_norm=%s\",\n",
" diff_ssm.mean().item(),\n",
" diff_ssm.max().item(),\n",
" ssm_ref.float().norm().item(),\n",
" ssm_output.float().norm().item(),\n",
" )\n",
" gate_ref = proj_ref[..., : self.tped_intermediate_size]\n",
" hidden_ref = self.norm(ssm_ref, gate_ref)\n",
" out_w = _dequantize_linear_weight_for_ref(self.out_proj)\n",
" hidden_ref_bf16 = hidden_ref.to(torch.bfloat16)\n",
" ref_out = _matmul_weight(hidden_ref_bf16, out_w)\n",
" if self.out_proj.bias is not None:\n",
" ref_out = ref_out + self.out_proj.bias\n",
" diff_out = (ref_out.float() - output.float()).abs()\n",
" logger.info_once(\n",
" \"Mamba BF16 ref compare (out_proj): mean_abs=%s max_abs=%s ref_norm=%s out_norm=%s\",\n",
" diff_out.mean().item(),\n",
" diff_out.max().item(),\n",
" ref_out.float().norm().item(),\n",
" output.float().norm().item(),\n",
" )\n",
"\n",
" return output\n",
]
)
text = text.replace(old_out, insert, 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/quantization/modelopt.py")
text = path.read_text()
if "VLLM_MOE_FALLBACK_DEBUG" not in text:
needle = " # Expert selection\n topk_weights, topk_ids = layer.select_experts(\n"
if needle not in text:
raise SystemExit("Failed to locate expert selection block")
insert = (
" if os.getenv(\"VLLM_MOE_FALLBACK_DEBUG\") == \"1\":\n"
" qc = self.moe_quant_config\n"
" def _tinfo(name, t):\n"
" if t is None:\n"
" return f\"{name}=None\"\n"
" with torch.no_grad():\n"
" tf = t.float()\n"
" return (\n"
" f\"{name}: shape={tuple(t.shape)} dtype={t.dtype} \"\n"
" f\"min={tf.min().item():.6g} max={tf.max().item():.6g}\"\n"
" )\n"
" if qc is not None:\n"
" logger.info_once(\"[moe_debug] %s\", _tinfo(\"w1_scale\", qc.w1_scale))\n"
" logger.info_once(\"[moe_debug] %s\", _tinfo(\"w2_scale\", qc.w2_scale))\n"
" logger.info_once(\"[moe_debug] %s\", _tinfo(\"a1_scale\", qc.a1_scale))\n"
" logger.info_once(\"[moe_debug] %s\", _tinfo(\"a2_scale\", qc.a2_scale))\n"
"\n"
)
text = text.replace(needle, insert + needle, 1)
if "import os" not in text:
needle = "import vllm.envs as envs\n"
if needle in text:
text = text.replace(needle, "import os\nimport vllm.envs as envs\n", 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/mamba/ops/mamba_ssm.py")
text = path.read_text()
if "mamba_ssm.ops.triton.selective_state_update" not in text:
needle = "from vllm.triton_utils import HAS_TRITON, tl, triton\n"
if needle not in text:
raise SystemExit("Failed to locate triton import for insertion")
insert = (
needle
+ "from vllm.platforms import current_platform\n"
+ "try:\n"
+ " from mamba_ssm.ops.triton.selective_state_update import selective_state_update as _mamba_ssm_selective_state_update\n"
+ "except Exception:\n"
+ " _mamba_ssm_selective_state_update = None\n\n"
+ "def _use_original_selective_state_update(num_accepted_tokens, cu_seqlens, dst_state_batch_indices):\n"
+ " return (\n"
+ " _mamba_ssm_selective_state_update is not None\n"
+ " and current_platform.is_device_capability(121)\n"
+ " and num_accepted_tokens is None\n"
+ " and cu_seqlens is None\n"
+ " and dst_state_batch_indices is None\n"
+ " )\n\n"
)
text = text.replace(needle, insert, 1)
needle = " assert out.shape == x.shape\n"
if needle in text and "_use_original_selective_state_update" in text:
insert = (
" if _use_original_selective_state_update(num_accepted_tokens, cu_seqlens, dst_state_batch_indices):\n"
" result = _mamba_ssm_selective_state_update(\n"
" state,\n"
" x,\n"
" dt,\n"
" A,\n"
" B,\n"
" C,\n"
" D=D,\n"
" z=z,\n"
" dt_bias=dt_bias,\n"
" dt_softplus=dt_softplus,\n"
" state_batch_indices=state_batch_indices,\n"
" )\n"
" if out is not None:\n"
" out.copy_(result)\n"
" return out\n"
" return result\n\n"
)
text = text.replace(needle, needle + insert, 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c python3 - <<'PY'
from pathlib import Path
path = Path("/build/vllm/vllm/model_executor/layers/mamba/ops/causal_conv1d.py")
text = path.read_text()
if "causal_conv1d_interface" not in text:
needle = "from vllm.triton_utils import tl, triton\n"
if needle not in text:
raise SystemExit("Failed to locate triton import for insertion")
insert = (
needle
+ "from vllm.platforms import current_platform\n"
+ "try:\n"
+ " from causal_conv1d.causal_conv1d_interface import causal_conv1d_update as _causal_conv1d_update\n"
+ "except Exception:\n"
+ " _causal_conv1d_update = None\n\n"
+ "def _use_original_causal_conv1d_update(\n"
+ " query_start_loc,\n"
+ " num_accepted_tokens,\n"
+ " block_idx_last_scheduled_token,\n"
+ " initial_state_idx,\n"
+ " max_query_len,\n"
+ "):\n"
+ " return (\n"
+ " _causal_conv1d_update is not None\n"
+ " and current_platform.is_device_capability(121)\n"
+ " and query_start_loc is None\n"
+ " and num_accepted_tokens is None\n"
+ " and block_idx_last_scheduled_token is None\n"
+ " and initial_state_idx is None\n"
+ " and max_query_len == -1\n"
+ " )\n\n"
)
text = text.replace(needle, insert, 1)
needle = " if query_start_loc is None:\n batch, dim, seqlen = x.shape\n else:\n assert conv_state_indices is not None\n batch = conv_state_indices.size(0)\n dim = x.size(1)\n seqlen = max_query_len\n"
if needle in text and "_use_original_causal_conv1d_update" in text:
insert = (
needle
+ "\n"
+ " if _use_original_causal_conv1d_update(\n"
+ " query_start_loc,\n"
+ " num_accepted_tokens,\n"
+ " block_idx_last_scheduled_token,\n"
+ " initial_state_idx,\n"
+ " max_query_len,\n"
+ " ):\n"
+ " out = _causal_conv1d_update(\n"
+ " x,\n"
+ " conv_state,\n"
+ " weight,\n"
+ " bias=bias,\n"
+ " activation=activation,\n"
+ " cache_seqlens=None,\n"
+ " conv_state_indices=conv_state_indices,\n"
+ " )\n"
+ " return out.to(original_x_dtype)\n\n"
)
text = text.replace(needle, insert, 1)
path.write_text(text)
PY # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c uv pip uninstall torchvision torchaudio nvidia-nccl-cu12 nvidia-cudnn-cu12 || true && uv pip install --no-cache --reinstall /tmp/pytorch_wheel/torch*.whl && ldconfig # buildkit |
| WORKDIR /build |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then export TORCH_CUDA_ARCH_LIST="12.1"; else export TORCH_CUDA_ARCH_LIST="12.0"; fi && echo "=== Rebuilding torchaudio for $ARCH with CUDA arch $TORCH_CUDA_ARCH_LIST ===" && git clone --recursive https://github.com/pytorch/audio.git torchaudio-rebuild && cd torchaudio-rebuild && git checkout ${TORCHAUDIO_COMMIT} && BUILD_SOX=0 USE_CUDA=1 python3 setup.py bdist_wheel && uv pip install --no-cache dist/*.whl && cd .. && rm -rf torchaudio-rebuild # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c uv pip install --no-cache triton && rm -f /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas && ln -s /usr/local/cuda/bin/ptxas /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas # buildkit |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c uv pip install --no-cache --no-build-isolation mamba-ssm==2.3.0 causal-conv1d==1.6.0 # buildkit |
| WORKDIR /workspace |
| RUN |4 APT_PROXY=http://pb-lxcaptcache:3142 TORCHAUDIO_COMMIT=0764cfdedb769e63f3ab8b90bc06541a6a2c0b73 VLLM_COMMIT=bb80f69bc98cbf062bf030cb11185f7ba526e28a VLLM_CACHE_BUSTER=v1 /bin/sh -c rm -rf /tmp/pytorch_wheel # buildkit |
| ARG APT_PROXY=http://pb-lxcaptcache:3142 |
| RUN |1 APT_PROXY=http://pb-lxcaptcache:3142 /bin/sh -c if [ -n "$APT_PROXY" ]; then echo "Acquire::http::Proxy \"$APT_PROXY\";" > /etc/apt/apt.conf.d/01proxy; fi # buildkit |
| ARG NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc |
| RUN |2 APT_PROXY=http://pb-lxcaptcache:3142 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc /bin/sh -c git clone https://github.com/NVIDIA/NeMo.git /opt/nemo && cd /opt/nemo && git checkout ${NEMO_COMMIT} # buildkit |
| RUN |2 APT_PROXY=http://pb-lxcaptcache:3142 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc /bin/sh -c uv pip install --no-cache Cython hydra-core>=1.3.0 omegaconf>=2.3 pytorch-lightning>=2.0 torchmetrics>=0.11.0 transformers>=4.36.0 sentencepiece webdataset lhotse>=1.20.0 braceexpand editdistance g2p_en inflect kaldi-python-io kaldiio librosa>=0.10.0 marshmallow ruamel.yaml soundfile text-unidecode numba kaldialign # buildkit |
| RUN |2 APT_PROXY=http://pb-lxcaptcache:3142 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc /bin/sh -c cd /opt/nemo && uv pip install --no-cache -e ".[asr,tts]" # buildkit |
| COPY <<EOF /tmp/patch_nvrtc.py # buildkit |
| RUN |2 APT_PROXY=http://pb-lxcaptcache:3142 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc /bin/sh -c python3 /tmp/patch_nvrtc.py || echo "Patch may have already been applied" # buildkit |
| ARG LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805 |
| RUN |3 APT_PROXY=http://pb-lxcaptcache:3142 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805 /bin/sh -c curl -fsSL https://raw.githubusercontent.com/pipecat-ai/nemotron-january-2026/main/patches/llama-cpp-hybrid-cache-fix.patch -o /tmp/llama-cpp-hybrid-cache-fix.patch # buildkit |
| RUN |3 APT_PROXY=http://pb-lxcaptcache:3142 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805 /bin/sh -c ARCH=$(uname -m) && if [ "$ARCH" = "aarch64" ]; then CUDA_ARCH="121a"; else CUDA_ARCH="120a"; fi && echo "=== Building llama.cpp for $ARCH with CUDA arch $CUDA_ARCH ===" && git clone https://github.com/ggerganov/llama.cpp.git /opt/llama.cpp && cd /opt/llama.cpp && git checkout ${LLAMACPP_COMMIT} && echo "=== Applying hybrid cache fix patch ===" && patch -p1 < /tmp/llama-cpp-hybrid-cache-fix.patch && cmake -B build -DGGML_CUDA=ON -DGGML_CUDA_F16=ON -DCMAKE_CUDA_ARCHITECTURES="$CUDA_ARCH" -DCMAKE_BUILD_TYPE=Release > /tmp/llamacpp_cmake.log 2>&1 || { tail -50 /tmp/llamacpp_cmake.log; exit 1; } && cmake --build build --config Release -j$(nproc) > /tmp/llamacpp_build.log 2>&1 || { tail -50 /tmp/llamacpp_build.log; exit 1; } && cp build/bin/llama-server /usr/local/bin/ && cp build/bin/llama-cli /usr/local/bin/ && cp build/bin/llama-quantize /usr/local/bin/ && cp build/bin/llama-bench /usr/local/bin/ && cp build/bin/*.so* /usr/local/lib/ 2>/dev/null || true && ldconfig && rm -rf /opt/llama.cpp/build && rm -rf /opt/llama.cpp/.git # buildkit |
| RUN |3 APT_PROXY=http://pb-lxcaptcache:3142 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805 /bin/sh -c llama-server --version || echo "llama-server installed (version check requires GPU)" # buildkit |
| WORKDIR /workspace |
| RUN |3 APT_PROXY=http://pb-lxcaptcache:3142 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805 /bin/sh -c rm -f /tmp/patch_nvrtc.py # buildkit |
| RUN |3 APT_PROXY=http://pb-lxcaptcache:3142 NEMO_COMMIT=644201898480ec8c8d0a637f0c773825509ac4dc LLAMACPP_COMMIT=c18428423018ed214c004e6ecaedb0cbdda06805 /bin/sh -c uv pip install --no-cache websockets>=12.0 loguru>=0.7.0 httpx>=0.25.0 # buildkit |