← 返回首页
Training a model with Primus and JAX MaxText — ROCm Documentation
Back to top
Ctrl+K
The ROCm 7.13.0 technology preview release documentation is available at ROCm Preview documentation. For production use, continue to use ROCm 7.2.3 documentation.

Training a model with Primus and JAX MaxText#

2026-05-19

55 min read time

Applies to Linux

The JAX MaxText for ROCm training Docker image provides a prebuilt environment for training on AMD Instinct MI355X, MI350X, MI325X, and MI300X GPUs, with essential components such as JAX, XLA, ROCm libraries, and MaxText utilities.

The image also integrates with Primus, a high-level training framework that supports multiple backends. You can use the unified primus-cli to run training jobs using the JAX MaxText backend.

It includes the following software components:

rocm/jax-training:maxtext-v26.3

Software component

Version

ROCm

7.2.1

JAX

0.8.2

Python

3.12

Transformer Engine

2.8.0.dev0+9b312832

hipBLASLt

1.3.0+bfcf25fa18

MaxText with on ROCm provides the following key features to train large language models efficiently:

  • Transformer Engine (TE)

  • Flash Attention (FA) 3 – with or without sequence input packing

  • GEMM tuning

  • Multi-node support

  • NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support

Supported models#

The following models are pre-optimized for performance on AMD Instinct GPUs. Some instructions, commands, and available training configurations in this documentation might vary by model – select one to get started.

Model
Meta Llama
DeepSeek
Mistral AI
Qwen
Variant
Llama 2 7B
Llama 2 70B
Llama 3 8B
Llama 3 70B
Llama 3.1 8B
Llama 3.1 70B
Llama 3.1 405B (multi-node)
Llama 3.3 70B
DeepSeek-V2-Lite (16B)
Mixtral 8x7B
Qwen 14B
Qwen 30B A3B

Note

Some models, such as Llama 3, require an external license agreement through a third party (for example, Meta).

System validation#

Before running AI workloads, it’s important to validate that your AMD hardware is configured correctly and performing optimally.

If you have already validated your system settings, including aspects like NUMA auto-balancing, you can skip this step. Otherwise, complete the procedures in the System validation and optimization guide to properly configure your system settings before starting training.

To test for optimal performance, consult the recommended System health benchmarks. This suite of tests will help you verify and fine-tune your system’s configuration.

Environment setup#

This Docker image is optimized for specific model configurations outlined as follows. Performance can vary for other training workloads, as AMD doesn’t validate configurations and run conditions outside those described.

Pull the Docker image#

Use the following command to pull the Docker image from Docker Hub.

docker pull rocm/jax-training:maxtext-v26.3

Multi-node configuration#

See Multi-node setup for AI workloads to configure your environment for multi-node training.

Benchmarking#

Once the setup is complete, choose between two options to reproduce the benchmark results:

Primus benchmarking

The following run commands are tailored to Llama 2 7B. See Supported models to switch to another available model.

Download the Docker image and required packages

  1. Pull the rocm/jax-training:maxtext-v26.3 Docker image from Docker Hub.

    docker pull rocm/jax-training:maxtext-v26.3
  2. Run the Docker container.

    docker run -it \ --device /dev/dri \ --device /dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3

    Use these commands if you exit the training_env container and need to return to it.

    docker start training_env docker exec -it training_env bash
  3. Clone the Primus repository.

    git clone https://github.com/AMD-AIG-AIMA/Primus.git cd Primus git checkout main git submodule update --init third_party/maxtext/

Run the training job with primus-cli

For detailed usage instructions for primus-cli, see the Primus CLI User Guide.

Use the following examples to run training with primus-cli:

  • Direct mode: run directly on the current host or within an existing Docker container

    MI355X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml
    MI300X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama2_7B-pretrain.yaml
  • Container mode: run in Docker containers

    MI355X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml
    MI300X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama2_7B-pretrain.yaml
  • Slurm mode: run distributed training on a Slurm cluster

    MI355X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml
    MI300X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama2_7B-pretrain.yaml
MAD-integrated benchmarking

The following run command is tailored to Llama 2 7B. See Supported models to switch to another available model.

  1. Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.

    git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
  2. Use this command to run the performance benchmark test on the Llama 2 7B model using one GPU with the bf16 data type on the host machine.

    export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_llama-2-7b \ --keep-model-dir \ --live-output \ --timeout 28800

MAD launches a Docker container with the name container_ci-jax_maxtext_train_llama-2-7b. The latency and throughput reports of the model are collected in the following path: ~/MAD/perf.csv/.

Standalone benchmarking

The following commands are optimized for Llama 2 7B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Single node training

  1. Set up environment variables.

    export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>

    MAD_SECRETS_HFTOKEN is your Hugging Face access token to access models, tokenizers, and data. See User access tokens.

    HF_HOME is where huggingface_hub will store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to ~/.cache/huggingface.

  2. Launch the Docker container.

    docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3
  3. In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at MAD/scripts/jax-maxtext.

    git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
  4. Run the setup scripts to install libraries and datasets needed for benchmarking.

    ./jax-maxtext_benchmark_setup.sh -m Llama-2-7B
  5. To run the training benchmark without quantization, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-2-7B

    For quantized training, run the script with the appropriate option for your Instinct GPU.

    MI355X and MI350X

    For fp8 quantized training on MI355X and MI350X GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q fp8
    MI325X and MI300X

    For nanoo_fp8 quantized training on MI300X series GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q nanoo_fp8

Multi-node training

The following SLURM scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container. The unified multi-node benchmark script accepts a configuration file that specifies the model and training parameters.

sbatch -N <NUM_NODES> jax_maxtext_multinode_benchmark.sh <config_file.yml> [docker_image]
<NUM_NODES>

The number of nodes to use for training (for example, 2, 4, 8).

<config_file.yml>

Path to the YAML configuration file containing model and training parameters. Configuration files are available in the scripts/jax-maxtext/env_scripts/ directory for different models and GPU architectures.

[docker_image] (optional)

The Docker image to use. If not specified, it defaults to rocm/jax-training:maxtext-v26.3.

For example, to run a multi-node training benchmark on Llama 2 7B:

MI355X and MI350X (gfx950)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/gfx950_llama2_7b.yml
MI325X and MI300X (gfx942)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/llama2_7b.yml
Primus benchmarking

The following run commands are tailored to Llama 2 70B. See Supported models to switch to another available model.

Download the Docker image and required packages

  1. Pull the rocm/jax-training:maxtext-v26.3 Docker image from Docker Hub.

    docker pull rocm/jax-training:maxtext-v26.3
  2. Run the Docker container.

    docker run -it \ --device /dev/dri \ --device /dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3

    Use these commands if you exit the training_env container and need to return to it.

    docker start training_env docker exec -it training_env bash
  3. Clone the Primus repository.

    git clone https://github.com/AMD-AIG-AIMA/Primus.git cd Primus git checkout main git submodule update --init third_party/maxtext/

Run the training job with primus-cli

For detailed usage instructions for primus-cli, see the Primus CLI User Guide.

Use the following examples to run training with primus-cli:

  • Direct mode: run directly on the current host or within an existing Docker container

    MI355X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama2_70B-pretrain.yaml
    MI300X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama2_70B-pretrain.yaml
  • Container mode: run in Docker containers

    MI355X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama2_70B-pretrain.yaml
    MI300X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama2_70B-pretrain.yaml
  • Slurm mode: run distributed training on a Slurm cluster

    MI355X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama2_70B-pretrain.yaml
    MI300X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama2_70B-pretrain.yaml
MAD-integrated benchmarking

The following run command is tailored to Llama 2 70B. See Supported models to switch to another available model.

  1. Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.

    git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
  2. Use this command to run the performance benchmark test on the Llama 2 70B model using one GPU with the bf16 data type on the host machine.

    export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_llama-2-70b \ --keep-model-dir \ --live-output \ --timeout 28800

MAD launches a Docker container with the name container_ci-jax_maxtext_train_llama-2-70b. The latency and throughput reports of the model are collected in the following path: ~/MAD/perf.csv/.

Standalone benchmarking

The following commands are optimized for Llama 2 70B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Single node training

  1. Set up environment variables.

    export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>

    MAD_SECRETS_HFTOKEN is your Hugging Face access token to access models, tokenizers, and data. See User access tokens.

    HF_HOME is where huggingface_hub will store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to ~/.cache/huggingface.

  2. Launch the Docker container.

    docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3
  3. In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at MAD/scripts/jax-maxtext.

    git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
  4. Run the setup scripts to install libraries and datasets needed for benchmarking.

    ./jax-maxtext_benchmark_setup.sh -m Llama-2-70B
  5. To run the training benchmark without quantization, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-2-70B

    For quantized training, run the script with the appropriate option for your Instinct GPU.

    MI355X and MI350X

    For fp8 quantized training on MI355X and MI350X GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-2-70B -q fp8
    MI325X and MI300X

    For nanoo_fp8 quantized training on MI300X series GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-2-70B -q nanoo_fp8

Multi-node training

The following SLURM scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container. The unified multi-node benchmark script accepts a configuration file that specifies the model and training parameters.

sbatch -N <NUM_NODES> jax_maxtext_multinode_benchmark.sh <config_file.yml> [docker_image]
<NUM_NODES>

The number of nodes to use for training (for example, 2, 4, 8).

<config_file.yml>

Path to the YAML configuration file containing model and training parameters. Configuration files are available in the scripts/jax-maxtext/env_scripts/ directory for different models and GPU architectures.

[docker_image] (optional)

The Docker image to use. If not specified, it defaults to rocm/jax-training:maxtext-v26.3.

For example, to run a multi-node training benchmark on Llama 2 70B:

MI355X and MI350X (gfx950)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/gfx950_llama2_70b.yml
MI325X and MI300X (gfx942)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/llama2_70b.yml
Primus benchmarking

The following run commands are tailored to Llama 3 8B. See Supported models to switch to another available model.

Download the Docker image and required packages

  1. Pull the rocm/jax-training:maxtext-v26.3 Docker image from Docker Hub.

    docker pull rocm/jax-training:maxtext-v26.3
  2. Run the Docker container.

    docker run -it \ --device /dev/dri \ --device /dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3

    Use these commands if you exit the training_env container and need to return to it.

    docker start training_env docker exec -it training_env bash
  3. Clone the Primus repository.

    git clone https://github.com/AMD-AIG-AIMA/Primus.git cd Primus git checkout main git submodule update --init third_party/maxtext/

Run the training job with primus-cli

For detailed usage instructions for primus-cli, see the Primus CLI User Guide.

Use the following examples to run training with primus-cli:

  • Direct mode: run directly on the current host or within an existing Docker container

    MI355X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama3_8B-pretrain.yaml
    MI300X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama3_8B-pretrain.yaml
  • Container mode: run in Docker containers

    MI355X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama3_8B-pretrain.yaml
    MI300X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama3_8B-pretrain.yaml
  • Slurm mode: run distributed training on a Slurm cluster

    MI355X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama3_8B-pretrain.yaml
    MI300X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama3_8B-pretrain.yaml
Standalone benchmarking

The following commands are optimized for Llama 3 8B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Multi-node training

The following SLURM scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container. The unified multi-node benchmark script accepts a configuration file that specifies the model and training parameters.

sbatch -N <NUM_NODES> jax_maxtext_multinode_benchmark.sh <config_file.yml> [docker_image]
<NUM_NODES>

The number of nodes to use for training (for example, 2, 4, 8).

<config_file.yml>

Path to the YAML configuration file containing model and training parameters. Configuration files are available in the scripts/jax-maxtext/env_scripts/ directory for different models and GPU architectures.

[docker_image] (optional)

The Docker image to use. If not specified, it defaults to rocm/jax-training:maxtext-v26.3.

For example, to run a multi-node training benchmark on Llama 3 8B:

MI355X and MI350X (gfx950)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/gfx950_llama3_8b.yml
MI325X and MI300X (gfx942)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/llama3_8b.yml
Primus benchmarking

The following run commands are tailored to Llama 3 70B. See Supported models to switch to another available model.

Download the Docker image and required packages

  1. Pull the rocm/jax-training:maxtext-v26.3 Docker image from Docker Hub.

    docker pull rocm/jax-training:maxtext-v26.3
  2. Run the Docker container.

    docker run -it \ --device /dev/dri \ --device /dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3

    Use these commands if you exit the training_env container and need to return to it.

    docker start training_env docker exec -it training_env bash
  3. Clone the Primus repository.

    git clone https://github.com/AMD-AIG-AIMA/Primus.git cd Primus git checkout main git submodule update --init third_party/maxtext/

Run the training job with primus-cli

For detailed usage instructions for primus-cli, see the Primus CLI User Guide.

Use the following examples to run training with primus-cli:

  • Direct mode: run directly on the current host or within an existing Docker container

    MI355X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama3_70B-pretrain.yaml
    MI300X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama3_70B-pretrain.yaml
  • Container mode: run in Docker containers

    MI355X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama3_70B-pretrain.yaml
    MI300X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama3_70B-pretrain.yaml
  • Slurm mode: run distributed training on a Slurm cluster

    MI355X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama3_70B-pretrain.yaml
    MI300X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama3_70B-pretrain.yaml
Standalone benchmarking

The following commands are optimized for Llama 3 70B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Multi-node training

The following SLURM scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container. The unified multi-node benchmark script accepts a configuration file that specifies the model and training parameters.

sbatch -N <NUM_NODES> jax_maxtext_multinode_benchmark.sh <config_file.yml> [docker_image]
<NUM_NODES>

The number of nodes to use for training (for example, 2, 4, 8).

<config_file.yml>

Path to the YAML configuration file containing model and training parameters. Configuration files are available in the scripts/jax-maxtext/env_scripts/ directory for different models and GPU architectures.

[docker_image] (optional)

The Docker image to use. If not specified, it defaults to rocm/jax-training:maxtext-v26.3.

For example, to run a multi-node training benchmark on Llama 3 70B:

MI355X and MI350X (gfx950)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/gfx950_llama3_70b.yml
MI325X and MI300X (gfx942)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/llama3_70b.yml
MAD-integrated benchmarking

The following run command is tailored to Llama 3.1 8B. See Supported models to switch to another available model.

  1. Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.

    git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
  2. Use this command to run the performance benchmark test on the Llama 3.1 8B model using one GPU with the bf16 data type on the host machine.

    export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_llama-3.1-8b \ --keep-model-dir \ --live-output \ --timeout 28800

MAD launches a Docker container with the name container_ci-jax_maxtext_train_llama-3.1-8b. The latency and throughput reports of the model are collected in the following path: ~/MAD/perf.csv/.

Standalone benchmarking

The following commands are optimized for Llama 3.1 8B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Single node training

  1. Set up environment variables.

    export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>

    MAD_SECRETS_HFTOKEN is your Hugging Face access token to access models, tokenizers, and data. See User access tokens.

    HF_HOME is where huggingface_hub will store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to ~/.cache/huggingface.

  2. Launch the Docker container.

    docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3
  3. In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at MAD/scripts/jax-maxtext.

    git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
  4. Run the setup scripts to install libraries and datasets needed for benchmarking.

    ./jax-maxtext_benchmark_setup.sh -m Llama-3.1-8B
  5. To run the training benchmark without quantization, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-3.1-8B

    For quantized training, run the script with the appropriate option for your Instinct GPU.

    MI355X and MI350X

    For fp8 quantized training on MI355X and MI350X GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-3.1-8B -q fp8
    MI325X and MI300X

    For nanoo_fp8 quantized training on MI300X series GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-3.1-8B -q nanoo_fp8

Multi-node training

For multi-node training examples, choose a model from Supported models with an available multi-node training script.

MAD-integrated benchmarking

The following run command is tailored to Llama 3.1 70B. See Supported models to switch to another available model.

  1. Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.

    git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
  2. Use this command to run the performance benchmark test on the Llama 3.1 70B model using one GPU with the bf16 data type on the host machine.

    export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_llama-3.1-70b \ --keep-model-dir \ --live-output \ --timeout 28800

MAD launches a Docker container with the name container_ci-jax_maxtext_train_llama-3.1-70b. The latency and throughput reports of the model are collected in the following path: ~/MAD/perf.csv/.

Standalone benchmarking

The following commands are optimized for Llama 3.1 70B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Single node training

  1. Set up environment variables.

    export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>

    MAD_SECRETS_HFTOKEN is your Hugging Face access token to access models, tokenizers, and data. See User access tokens.

    HF_HOME is where huggingface_hub will store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to ~/.cache/huggingface.

  2. Launch the Docker container.

    docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3
  3. In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at MAD/scripts/jax-maxtext.

    git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
  4. Run the setup scripts to install libraries and datasets needed for benchmarking.

    ./jax-maxtext_benchmark_setup.sh -m Llama-3.1-70B
  5. To run the training benchmark without quantization, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-3.1-70B

    For quantized training, run the script with the appropriate option for your Instinct GPU.

    MI355X and MI350X

    For fp8 quantized training on MI355X and MI350X GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-3.1-70B -q fp8

Multi-node training

For multi-node training examples, choose a model from Supported models with an available multi-node training script.

Standalone benchmarking

The following commands are optimized for Llama 3.1 405B (multi-node). See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Multi-node training

The following SLURM scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container. The unified multi-node benchmark script accepts a configuration file that specifies the model and training parameters.

sbatch -N <NUM_NODES> jax_maxtext_multinode_benchmark.sh <config_file.yml> [docker_image]
<NUM_NODES>

The number of nodes to use for training (for example, 2, 4, 8).

<config_file.yml>

Path to the YAML configuration file containing model and training parameters. Configuration files are available in the scripts/jax-maxtext/env_scripts/ directory for different models and GPU architectures.

[docker_image] (optional)

The Docker image to use. If not specified, it defaults to rocm/jax-training:maxtext-v26.3.

For example, to run a multi-node training benchmark on Llama 3.1 405B (multi-node):

MI355X and MI350X (gfx950)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/gfx950_llama3_405b.yml
Primus benchmarking

The following run commands are tailored to Llama 3.3 70B. See Supported models to switch to another available model.

Download the Docker image and required packages

  1. Pull the rocm/jax-training:maxtext-v26.3 Docker image from Docker Hub.

    docker pull rocm/jax-training:maxtext-v26.3
  2. Run the Docker container.

    docker run -it \ --device /dev/dri \ --device /dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3

    Use these commands if you exit the training_env container and need to return to it.

    docker start training_env docker exec -it training_env bash
  3. Clone the Primus repository.

    git clone https://github.com/AMD-AIG-AIMA/Primus.git cd Primus git checkout main git submodule update --init third_party/maxtext/

Run the training job with primus-cli

For detailed usage instructions for primus-cli, see the Primus CLI User Guide.

Use the following examples to run training with primus-cli:

  • Direct mode: run directly on the current host or within an existing Docker container

    MI355X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama3.3_70B-pretrain.yaml
    MI300X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama3.3_70B-pretrain.yaml
  • Container mode: run in Docker containers

    MI355X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama3.3_70B-pretrain.yaml
    MI300X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama3.3_70B-pretrain.yaml
  • Slurm mode: run distributed training on a Slurm cluster

    MI355X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/llama3.3_70B-pretrain.yaml
    MI300X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/llama3.3_70B-pretrain.yaml
MAD-integrated benchmarking

The following run command is tailored to Llama 3.3 70B. See Supported models to switch to another available model.

  1. Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.

    git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
  2. Use this command to run the performance benchmark test on the Llama 3.3 70B model using one GPU with the bf16 data type on the host machine.

    export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_llama-3.3-70b \ --keep-model-dir \ --live-output \ --timeout 28800

MAD launches a Docker container with the name container_ci-jax_maxtext_train_llama-3.3-70b. The latency and throughput reports of the model are collected in the following path: ~/MAD/perf.csv/.

Standalone benchmarking

The following commands are optimized for Llama 3.3 70B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Single node training

  1. Set up environment variables.

    export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>

    MAD_SECRETS_HFTOKEN is your Hugging Face access token to access models, tokenizers, and data. See User access tokens.

    HF_HOME is where huggingface_hub will store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to ~/.cache/huggingface.

  2. Launch the Docker container.

    docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3
  3. In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at MAD/scripts/jax-maxtext.

    git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
  4. Run the setup scripts to install libraries and datasets needed for benchmarking.

    ./jax-maxtext_benchmark_setup.sh -m Llama-3.3-70B
  5. To run the training benchmark without quantization, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-3.3-70B

    For quantized training, run the script with the appropriate option for your Instinct GPU.

    MI355X and MI350X

    For fp8 quantized training on MI355X and MI350X GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Llama-3.3-70B -q fp8

Multi-node training

The following SLURM scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container. The unified multi-node benchmark script accepts a configuration file that specifies the model and training parameters.

sbatch -N <NUM_NODES> jax_maxtext_multinode_benchmark.sh <config_file.yml> [docker_image]
<NUM_NODES>

The number of nodes to use for training (for example, 2, 4, 8).

<config_file.yml>

Path to the YAML configuration file containing model and training parameters. Configuration files are available in the scripts/jax-maxtext/env_scripts/ directory for different models and GPU architectures.

[docker_image] (optional)

The Docker image to use. If not specified, it defaults to rocm/jax-training:maxtext-v26.3.

For example, to run a multi-node training benchmark on Llama 3.3 70B:

MI355X and MI350X (gfx950)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/gfx950_llama3.3_70b.yml
MI325X and MI300X (gfx942)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/llama3.3_70b.yml
Primus benchmarking

The following run commands are tailored to DeepSeek-V2-Lite (16B). See Supported models to switch to another available model.

Download the Docker image and required packages

  1. Pull the rocm/jax-training:maxtext-v26.3 Docker image from Docker Hub.

    docker pull rocm/jax-training:maxtext-v26.3
  2. Run the Docker container.

    docker run -it \ --device /dev/dri \ --device /dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3

    Use these commands if you exit the training_env container and need to return to it.

    docker start training_env docker exec -it training_env bash
  3. Clone the Primus repository.

    git clone https://github.com/AMD-AIG-AIMA/Primus.git cd Primus git checkout main git submodule update --init third_party/maxtext/

Run the training job with primus-cli

For detailed usage instructions for primus-cli, see the Primus CLI User Guide.

Use the following examples to run training with primus-cli:

  • Direct mode: run directly on the current host or within an existing Docker container

    MI355X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI355X/deepseek_v2_16B-pretrain.yaml
    MI300X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI300X/deepseek_v2_16B-pretrain.yaml
  • Container mode: run in Docker containers

    MI355X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/deepseek_v2_16B-pretrain.yaml
    MI300X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/deepseek_v2_16B-pretrain.yaml
  • Slurm mode: run distributed training on a Slurm cluster

    MI355X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/deepseek_v2_16B-pretrain.yaml
    MI300X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/deepseek_v2_16B-pretrain.yaml
MAD-integrated benchmarking

The following run command is tailored to DeepSeek-V2-Lite (16B). See Supported models to switch to another available model.

  1. Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.

    git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
  2. Use this command to run the performance benchmark test on the DeepSeek-V2-Lite (16B) model using one GPU with the bf16 data type on the host machine.

    export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_deepseek-v2-lite-16b \ --keep-model-dir \ --live-output \ --timeout 28800

MAD launches a Docker container with the name container_ci-jax_maxtext_train_deepseek-v2-lite-16b. The latency and throughput reports of the model are collected in the following path: ~/MAD/perf.csv/.

Standalone benchmarking

The following commands are optimized for DeepSeek-V2-Lite (16B). See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Single node training

  1. Set up environment variables.

    export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>

    MAD_SECRETS_HFTOKEN is your Hugging Face access token to access models, tokenizers, and data. See User access tokens.

    HF_HOME is where huggingface_hub will store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to ~/.cache/huggingface.

  2. Launch the Docker container.

    docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3
  3. In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at MAD/scripts/jax-maxtext.

    git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
  4. Run the setup scripts to install libraries and datasets needed for benchmarking.

    ./jax-maxtext_benchmark_setup.sh -m DeepSeek-V2-lite
  5. To run the training benchmark without quantization, use the following command:

    ./jax-maxtext_benchmark_report.sh -m DeepSeek-V2-lite

    For quantized training, run the script with the appropriate option for your Instinct GPU.

    MI355X and MI350X

    For fp8 quantized training on MI355X and MI350X GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m DeepSeek-V2-lite -q fp8
    MI325X and MI300X

    For nanoo_fp8 quantized training on MI300X series GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m DeepSeek-V2-lite -q nanoo_fp8

Multi-node training

The following SLURM scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container. The unified multi-node benchmark script accepts a configuration file that specifies the model and training parameters.

sbatch -N <NUM_NODES> jax_maxtext_multinode_benchmark.sh <config_file.yml> [docker_image]
<NUM_NODES>

The number of nodes to use for training (for example, 2, 4, 8).

<config_file.yml>

Path to the YAML configuration file containing model and training parameters. Configuration files are available in the scripts/jax-maxtext/env_scripts/ directory for different models and GPU architectures.

[docker_image] (optional)

The Docker image to use. If not specified, it defaults to rocm/jax-training:maxtext-v26.3.

For example, to run a multi-node training benchmark on DeepSeek-V2-Lite (16B):

MI355X and MI350X (gfx950)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/gfx950_deepseek2_16b.yml
MI325X and MI300X (gfx942)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/deepseek2_16b.yml
Primus benchmarking

The following run commands are tailored to Mixtral 8x7B. See Supported models to switch to another available model.

Download the Docker image and required packages

  1. Pull the rocm/jax-training:maxtext-v26.3 Docker image from Docker Hub.

    docker pull rocm/jax-training:maxtext-v26.3
  2. Run the Docker container.

    docker run -it \ --device /dev/dri \ --device /dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3

    Use these commands if you exit the training_env container and need to return to it.

    docker start training_env docker exec -it training_env bash
  3. Clone the Primus repository.

    git clone https://github.com/AMD-AIG-AIMA/Primus.git cd Primus git checkout main git submodule update --init third_party/maxtext/

Run the training job with primus-cli

For detailed usage instructions for primus-cli, see the Primus CLI User Guide.

Use the following examples to run training with primus-cli:

  • Direct mode: run directly on the current host or within an existing Docker container

    MI355X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml
    MI300X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml
  • Container mode: run in Docker containers

    MI355X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml
    MI300X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml
  • Slurm mode: run distributed training on a Slurm cluster

    MI355X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml
    MI300X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml
MAD-integrated benchmarking

The following run command is tailored to Mixtral 8x7B. See Supported models to switch to another available model.

  1. Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.

    git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
  2. Use this command to run the performance benchmark test on the Mixtral 8x7B model using one GPU with the bf16 data type on the host machine.

    export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_mixtral-8x7b \ --keep-model-dir \ --live-output \ --timeout 28800

MAD launches a Docker container with the name container_ci-jax_maxtext_train_mixtral-8x7b. The latency and throughput reports of the model are collected in the following path: ~/MAD/perf.csv/.

Standalone benchmarking

The following commands are optimized for Mixtral 8x7B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Single node training

  1. Set up environment variables.

    export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>

    MAD_SECRETS_HFTOKEN is your Hugging Face access token to access models, tokenizers, and data. See User access tokens.

    HF_HOME is where huggingface_hub will store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to ~/.cache/huggingface.

  2. Launch the Docker container.

    docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3
  3. In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at MAD/scripts/jax-maxtext.

    git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
  4. Run the setup scripts to install libraries and datasets needed for benchmarking.

    ./jax-maxtext_benchmark_setup.sh -m Mixtral-8x7B
  5. To run the training benchmark without quantization, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Mixtral-8x7B

    For quantized training, run the script with the appropriate option for your Instinct GPU.

    MI355X and MI350X

    For fp8 quantized training on MI355X and MI350X GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Mixtral-8x7B -q fp8
    MI325X and MI300X

    For nanoo_fp8 quantized training on MI300X series GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Mixtral-8x7B -q nanoo_fp8

Multi-node training

The following SLURM scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container. The unified multi-node benchmark script accepts a configuration file that specifies the model and training parameters.

sbatch -N <NUM_NODES> jax_maxtext_multinode_benchmark.sh <config_file.yml> [docker_image]
<NUM_NODES>

The number of nodes to use for training (for example, 2, 4, 8).

<config_file.yml>

Path to the YAML configuration file containing model and training parameters. Configuration files are available in the scripts/jax-maxtext/env_scripts/ directory for different models and GPU architectures.

[docker_image] (optional)

The Docker image to use. If not specified, it defaults to rocm/jax-training:maxtext-v26.3.

For example, to run a multi-node training benchmark on Mixtral 8x7B:

MI355X and MI350X (gfx950)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/gfx950_mixtral_8x7b.yml
MI325X and MI300X (gfx942)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/llama3_8x7b.yml
Primus benchmarking

The following run commands are tailored to Qwen 14B. See Supported models to switch to another available model.

Download the Docker image and required packages

  1. Pull the rocm/jax-training:maxtext-v26.3 Docker image from Docker Hub.

    docker pull rocm/jax-training:maxtext-v26.3
  2. Run the Docker container.

    docker run -it \ --device /dev/dri \ --device /dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3

    Use these commands if you exit the training_env container and need to return to it.

    docker start training_env docker exec -it training_env bash
  3. Clone the Primus repository.

    git clone https://github.com/AMD-AIG-AIMA/Primus.git cd Primus git checkout main git submodule update --init third_party/maxtext/

Run the training job with primus-cli

For detailed usage instructions for primus-cli, see the Primus CLI User Guide.

Use the following examples to run training with primus-cli:

  • Direct mode: run directly on the current host or within an existing Docker container

    MI355X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI355X/qwen3_14B-pretrain.yaml
    MI300X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI300X/qwen3_14B-pretrain.yaml
  • Container mode: run in Docker containers

    MI355X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/qwen3_14B-pretrain.yaml
    MI300X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/qwen3_14B-pretrain.yaml
  • Slurm mode: run distributed training on a Slurm cluster

    MI355X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/qwen3_14B-pretrain.yaml
    MI300X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/qwen3_14B-pretrain.yaml
MAD-integrated benchmarking

The following run command is tailored to Qwen 14B. See Supported models to switch to another available model.

  1. Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.

    git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
  2. Use this command to run the performance benchmark test on the Qwen 14B model using one GPU with the bf16 data type on the host machine.

    export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_qwen3-14b \ --keep-model-dir \ --live-output \ --timeout 28800

MAD launches a Docker container with the name container_ci-jax_maxtext_train_qwen3-14b. The latency and throughput reports of the model are collected in the following path: ~/MAD/perf.csv/.

Standalone benchmarking

The following commands are optimized for Qwen 14B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Single node training

  1. Set up environment variables.

    export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>

    MAD_SECRETS_HFTOKEN is your Hugging Face access token to access models, tokenizers, and data. See User access tokens.

    HF_HOME is where huggingface_hub will store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to ~/.cache/huggingface.

  2. Launch the Docker container.

    docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3
  3. In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at MAD/scripts/jax-maxtext.

    git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
  4. Run the setup scripts to install libraries and datasets needed for benchmarking.

    ./jax-maxtext_benchmark_setup.sh -m Qwen3-14B
  5. To run the training benchmark without quantization, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Qwen3-14B

    For quantized training, run the script with the appropriate option for your Instinct GPU.

    MI355X and MI350X

    For fp8 quantized training on MI355X and MI350X GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Qwen3-14B -q fp8
    MI325X and MI300X

    For nanoo_fp8 quantized training on MI300X series GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Qwen3-14B -q nanoo_fp8

Multi-node training

The following SLURM scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container. The unified multi-node benchmark script accepts a configuration file that specifies the model and training parameters.

sbatch -N <NUM_NODES> jax_maxtext_multinode_benchmark.sh <config_file.yml> [docker_image]
<NUM_NODES>

The number of nodes to use for training (for example, 2, 4, 8).

<config_file.yml>

Path to the YAML configuration file containing model and training parameters. Configuration files are available in the scripts/jax-maxtext/env_scripts/ directory for different models and GPU architectures.

[docker_image] (optional)

The Docker image to use. If not specified, it defaults to rocm/jax-training:maxtext-v26.3.

For example, to run a multi-node training benchmark on Qwen 14B:

MI355X and MI350X (gfx950)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/gfx950_qwen3_14b.yml
MI325X and MI300X (gfx942)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/qwen3_14b.yml
Primus benchmarking

The following run commands are tailored to Qwen 30B A3B. See Supported models to switch to another available model.

Download the Docker image and required packages

  1. Pull the rocm/jax-training:maxtext-v26.3 Docker image from Docker Hub.

    docker pull rocm/jax-training:maxtext-v26.3
  2. Run the Docker container.

    docker run -it \ --device /dev/dri \ --device /dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3

    Use these commands if you exit the training_env container and need to return to it.

    docker start training_env docker exec -it training_env bash
  3. Clone the Primus repository.

    git clone https://github.com/AMD-AIG-AIMA/Primus.git cd Primus git checkout main git submodule update --init third_party/maxtext/

Run the training job with primus-cli

For detailed usage instructions for primus-cli, see the Primus CLI User Guide.

Use the following examples to run training with primus-cli:

  • Direct mode: run directly on the current host or within an existing Docker container

    MI355X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI355X/qwen3_30B_a3b-pretrain.yaml
    MI300X
    ./primus-cli direct \ -- train pretrain \ --config examples/maxtext/configs/MI300X/qwen3_30B_a3b-pretrain.yaml
  • Container mode: run in Docker containers

    MI355X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/qwen3_30B_a3b-pretrain.yaml
    MI300X
    ./primus-cli container --image rocm/jax-training:maxtext-v26.3 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/qwen3_30B_a3b-pretrain.yaml
  • Slurm mode: run distributed training on a Slurm cluster

    MI355X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI355X/qwen3_30B_a3b-pretrain.yaml
    MI300X
    # Use a custom config file, where you can specify # the Docker image and set environment variables. ./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \ -- train pretrain \ --config examples/maxtext/configs/MI300X/qwen3_30B_a3b-pretrain.yaml
MAD-integrated benchmarking

The following run command is tailored to Qwen 30B A3B. See Supported models to switch to another available model.

  1. Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.

    git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
  2. Use this command to run the performance benchmark test on the Qwen 30B A3B model using one GPU with the bf16 data type on the host machine.

    export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_qwen3-30b-a3b \ --keep-model-dir \ --live-output \ --timeout 28800

MAD launches a Docker container with the name container_ci-jax_maxtext_train_qwen3-30b-a3b. The latency and throughput reports of the model are collected in the following path: ~/MAD/perf.csv/.

Standalone benchmarking

The following commands are optimized for Qwen 30B A3B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.

Download the Docker image and required scripts

Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.

docker pull rocm/jax-training:maxtext-v26.3

Single node training

  1. Set up environment variables.

    export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>

    MAD_SECRETS_HFTOKEN is your Hugging Face access token to access models, tokenizers, and data. See User access tokens.

    HF_HOME is where huggingface_hub will store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to ~/.cache/huggingface.

  2. Launch the Docker container.

    docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v26.3
  3. In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at MAD/scripts/jax-maxtext.

    git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
  4. Run the setup scripts to install libraries and datasets needed for benchmarking.

    ./jax-maxtext_benchmark_setup.sh -m Qwen3-30B-A3B
  5. To run the training benchmark without quantization, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Qwen3-30B-A3B

    For quantized training, run the script with the appropriate option for your Instinct GPU.

    MI355X and MI350X

    For fp8 quantized training on MI355X and MI350X GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Qwen3-30B-A3B -q fp8
    MI325X and MI300X

    For nanoo_fp8 quantized training on MI300X series GPUs, use the following command:

    ./jax-maxtext_benchmark_report.sh -m Qwen3-30B-A3B -q nanoo_fp8

Multi-node training

The following SLURM scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container. The unified multi-node benchmark script accepts a configuration file that specifies the model and training parameters.

sbatch -N <NUM_NODES> jax_maxtext_multinode_benchmark.sh <config_file.yml> [docker_image]
<NUM_NODES>

The number of nodes to use for training (for example, 2, 4, 8).

<config_file.yml>

Path to the YAML configuration file containing model and training parameters. Configuration files are available in the scripts/jax-maxtext/env_scripts/ directory for different models and GPU architectures.

[docker_image] (optional)

The Docker image to use. If not specified, it defaults to rocm/jax-training:maxtext-v26.3.

For example, to run a multi-node training benchmark on Qwen 30B A3B:

MI355X and MI350X (gfx950)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/gfx950_qwen3_30b_a3b.yml
MI325X and MI300X (gfx942)
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/qwen3_30b_a3b.yml

Profiling with JAX XPlane Profiler#

MaxText has built-in XPlane profiling support via JAX’s profiler. Traces capture GPU kernel timelines, RCCL collectives, HLO graphs, and more. The output can be viewed in TensorBoard’s Trace Viewer or analyzed with TraceLens.

Key MaxText profiler flags#

The following MaxText config keys control profiling:

profiler=xplane # Use xplane format (produces .xplane.pb files) skip_first_n_steps_for_profiler=2 # Skip compilation/warmup steps profiler_steps=5 # Number of steps to profile upload_all_profiler_results=True # Save all GPU profiles (not just GPU0)

steps should be greater than skip_first_n_steps_for_profiler + profiler_steps (for example, steps=12 with skip=2 and profile=5 gives 5 warmup + 5 profiled + 2 cooldown). skip_first_n_steps_for_profiler=2 skips step 0 (compilation) and step 1 (warmup). profiler_steps=5 is typically sufficient; more steps produce larger .xplane.pb files.

Profiling with MAD or madengine#

The model YAML configs under scripts/jax-maxtext/env_scripts/ include a profiler key (set to "" by default). To enable profiling when running through MAD or madengine, edit the YAML config for your model and set the profiler fields:

profiler: "xplane" skip_first_n_steps_for_profiler: 2 profiler_steps: 5 upload_all_profiler_results: True steps: 12

Then run the benchmark as usual:

export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_llama-3.1-8b \ --keep-model-dir \ --live-output \ --timeout 28800

Use --keep-model-dir so the container’s output directory is preserved after the run. Profile output is written under the base_output_directory specified in the YAML.

Example: Profile a model standalone in Docker#

#!/bin/bash set -e IMAGE="$1" # Docker image, e.g. rocm/jax-training:maxtext-v26.3 TAG="$2" # Short tag for output folder, e.g. v26.3_llama2_7b PROFILE_DIR="/path/to/profiles/${TAG}" mkdir -p "${PROFILE_DIR}" docker run --rm --privileged --network=host \ --device=/dev/dri --device=/dev/kfd --ipc=host \ -v "${PROFILE_DIR}:/mnt/profile" \ "${IMAGE}" bash -c ' export XLA_PYTHON_CLIENT_MEM_FRACTION=.97 export LD_LIBRARY_PATH=/usr/local/lib/:/opt/rocm/lib:$LD_LIBRARY_PATH export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_enable_command_buffer= <your other XLA flags>" export GPU_MAX_HW_QUEUES=2 cd /workspace/maxtext python3 -m MaxText.train src/MaxText/configs/base.yml \ run_name=profile \ base_output_directory=/mnt/profile \ hardware=gpu \ steps=12 \ model_name=<your-model> \ dataset_type=synthetic \ enable_checkpointing=False \ enable_goodput_recording=False \ monitor_goodput=False \ <your model-specific flags> \ profiler=xplane \ skip_first_n_steps_for_profiler=2 \ profiler_steps=5 \ upload_all_profiler_results=True ' 2>&1 | tee "${PROFILE_DIR}/run.log" echo "Profile files:" find "${PROFILE_DIR}" -name "*.xplane.pb" -o -name "*.trace.json.gz" 2>/dev/null

Output structure#

MaxText writes profiles in TensorBoard format:

<base_output_directory>/ └── profile/ └── tensorboard/ └── plugins/ └── profile/ └── <YYYY_MM_DD_HH_MM_SS>/ ├── <hostname>.xplane.pb # Raw XPlane proto (GPU timelines) ├── <hostname>.trace.json.gz # Trace viewer data └── *.hlo_proto.pb # HLO graphs for each compiled module

Viewing traces in TensorBoard#

pip install tensorboard tensorboard-plugin-profile # Point --logdir at the directory containing the tensorboard/ folder tensorboard --logdir /path/to/profiles/<TAG>/profile --port 6006

Navigate to Profile > Trace Viewer in the TensorBoard UI. Zoom into a single training step (skip the first profiled step as it may have residual warmup) and look at individual GPU streams to see compute/RCCL overlap.

To keep profile files small, use profiler_steps=5 to keep .xplane.pb files under approximately 100 MB. Too many steps can produce files over 500 MB that TensorBoard struggles to load. Use enable_checkpointing=False to avoid checkpoint I/O noise in the trace, and dataset_type=synthetic to eliminate data loading variability.

Profiling with rocprofv3#

If you need to collect a trace without the JAX profiler, use rocprofv3:

rocprofv3 --hip-trace --kernel-trace --memory-copy-trace --rccl-trace \ --output-format pftrace -d ./v3_traces -- <command>

Replace <command> with the command you want to profile, such as ./jax-maxtext_benchmark_report.sh -m Llama-2-7B. Use -d <TRACE_DIRECTORY> to specify where the .json traces are saved. The resulting traces can be opened in Perfetto.

Known issues#

  • You might see NaNs in the losses while using real data (not synthetic data) when setting packing=True and NVTE_CK_IS_V3_ATOMIC_FP32=0. Set NVTE_CK_IS_V3_ATOMIC_FP32=1 for production training when using real data and input sequence packing (packing=True).

  • There is a known slight performance regression for DeepSeek-V2-lite (16B) in v26.3. This is being tracked and will be addressed in a future release.

  • JAX 0.9.1 Early Access known issues:

    • There is a known performance regression for MoE models (DeepSeek-V2-lite and Mixtral-8x7B).

    • The trace viewer in profiling may be missing some information in the flame graph.

  • Shardy is a new config in JAX 0.6.0. You might get related errors if it’s not configured correctly. To disable it, set shardy=False during the training run. See the Shardy migration guide to enable it.

Further reading#

Previous versions#

See JAX MaxText training performance testing version history to find documentation for previous releases of the ROCm/jax-training Docker image.

© 2026 Advanced Micro Devices, Inc