Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions docs/tutorials/posttraining/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs:
### Option 2: Converting a Hugging Face checkpoint
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.

1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved.
1. **Set the Output Path:** First, define where the converted MaxText checkpoint will be saved. For example:

```sh
export PRE_TRAINED_MODEL_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint/0/items
export PRE_TRAINED_MODEL_CKPT_DIRECTORY=${BASE_OUTPUT_DIRECTORY}/maxtext-checkpoint
```

2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).
Expand All @@ -89,10 +89,16 @@ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu #
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
model_name=${PRE_TRAINED_MODEL} \
hf_access_token=${HF_TOKEN} \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint \
base_output_directory=${PRE_TRAINED_MODEL_CKPT_DIRECTORY} \
scan_layers=True skip_jax_distributed_system=True
```

3. **Use the Converted Checkpoint:** Set the following environment variable to use the converted checkpoint:

```sh
export PRE_TRAINED_MODEL_CKPT_PATH=${PRE_TRAINED_MODEL_CKPT_DIRECTORY}/0/items
```

## Run SFT on Hugging Face Dataset
Now you are ready to run SFT using the following command:

Expand Down
18 changes: 15 additions & 3 deletions docs/tutorials/posttraining/sft_on_multi_host.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-b
### Option 2: Converting a Hugging Face checkpoint
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.

1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved.
1. **Set the Output Path:** First, define where the converted MaxText checkpoint will be saved. For example:

```bash
export MODEL_CHECKPOINT_PATH=${OUTPUT_PATH}/${WORKLOAD_NAME}/maxtext-checkpoint/0/items
export MODEL_CHECKPOINT_DIRECTORY=${OUTPUT_PATH}/maxtext-checkpoint
```

2. **Run the Conversion Script:** Execute the following commands on a CPU machine that downloads the specified HuggingFace model and converts its weights into the MaxText format. This command will download the HuggingFace model and convert it to the MaxText format, saving it to the specified GCS bucket. The conversion script only supports official versions of models from HuggingFace. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).
Expand All @@ -122,7 +122,19 @@ USE_OCDBT=<Flag to use ocdbt> # True to run SFT with McJAX, False to run SFT wit
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# For large models, it is recommended to set `--lazy_load_tensors` flag to reduce memory usage during conversion
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml model_name=$MODEL_NAME hf_access_token=$HF_TOKEN base_output_directory=$OUTPUT_PATH/$WORKLOAD_NAME/maxtext-checkpoint scan_layers=True checkpoint_storage_use_zarr3=$USE_ZARR3 checkpoint_storage_use_ocdbt=$USE_OCDBT skip_jax_distributed_system=True --lazy_load_tensors=True
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
model_name=$MODEL_NAME \
hf_access_token=$HF_TOKEN \
base_output_directory=$MODEL_CHECKPOINT_DIRECTORY \
scan_layers=True \
checkpoint_storage_use_zarr3=$USE_ZARR3 checkpoint_storage_use_ocdbt=$USE_OCDBT \
skip_jax_distributed_system=True --lazy_load_tensors=True
```

3. **Use the Converted Checkpoint:** Set the following environment variable to use the converted checkpoint:

```bash
export MODEL_CHECKPOINT_PATH=${MODEL_CHECKPOINT_DIRECTORY}/0/items
```

## 6. Submit workload on GKE cluster
Expand Down
Loading