| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Convert pretrained Pixtral vision model weights to checkpoint and verify the checkpoint loading. |
| |
| Usage: |
| |
| PYTHONPATH=$(pwd) python cosmos1/scripts/convert_pixtral_ckpt.py |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import shutil |
| from glob import glob |
|
|
| import torch |
| from huggingface_hub import snapshot_download |
| from safetensors.torch import load_file |
|
|
|
|
| def convert_pixtral_checkpoint(checkpoint_dir: str, checkpoint_name: str, vit_type: str): |
| """ |
| Main function to convert Pixtral vision model weights to checkpoint and optionally verify and save the converted checkpoint. |
| |
| Args: |
| checkpoint_dir (str): Path to the checkpoint directory |
| checkpoint_name (str): Name of the checkpoint |
| vit_type (str): Type of ViT used in the Pixtral model |
| |
| This function performs the following steps: |
| 0. Download the checkpoint from Hugging Face |
| 1. Loads the original Pixtral checkpoint |
| 2. Splits the checkpoint into vision encoder, projector, and LLM weights |
| 3. Reorganizes the weights to match the expected format |
| 4. Extracts and verifies the vision encoder configuration |
| 5. Optionally verifies the converted checkpoint by loading it into a VisionTransformer |
| 6. Optionally saves the converted checkpoint and configuration |
| """ |
|
|
| save_dir = os.path.join(checkpoint_dir, checkpoint_name) |
| os.makedirs(save_dir, exist_ok=True) |
| |
| save_path = os.path.join(save_dir, "model.pt") |
| if os.path.exists(save_path) and os.path.getsize(save_path) > 0: |
| print(f"Checkpoint {save_path} already exists and is not empty") |
| return |
|
|
| pixtral_ckpt_dir = os.path.join(checkpoint_dir, "Pixtral-12B-2409") |
| os.makedirs(pixtral_ckpt_dir, exist_ok=True) |
| repo_id = "mistralai/Pixtral-12B-2409" |
| print(f"Downloading {repo_id} to {pixtral_ckpt_dir}...") |
| snapshot_download( |
| repo_id=repo_id, |
| allow_patterns=["params.json", "consolidated.safetensors"], |
| local_dir=pixtral_ckpt_dir, |
| local_dir_use_symlinks=False, |
| ) |
| orig_dtype = torch.get_default_dtype() |
| dtype = torch.bfloat16 |
| torch.set_default_dtype(dtype) |
|
|
| |
| ckpt_files = glob(os.path.join(pixtral_ckpt_dir, "*.safetensors")) |
| assert len(ckpt_files) == 1, "ckpt_dir should contain only one file" |
| ckpt_path = ckpt_files[0] |
| ckpt = load_file(ckpt_path) |
|
|
| |
| vit_key_prefix = "vision_encoder." |
| vit_ckpt = {} |
| for key, value in ckpt.items(): |
| if key.startswith(vit_key_prefix): |
| vit_ckpt[key.lstrip(vit_key_prefix)] = value |
|
|
| projector_key_prefix = "vision_language_adapter." |
| projector_ckpt = {} |
| substring_replacement_map = { |
| "w_in.": "projector.0.", |
| "w_out.": "projector.2.", |
| } |
| for key, value in ckpt.items(): |
| if key.startswith(projector_key_prefix): |
| key = key.lstrip(projector_key_prefix) |
| for old, new in substring_replacement_map.items(): |
| key = key.replace(old, new) |
| projector_ckpt[key] = value |
|
|
| llm_ckpt = {} |
| for key, value in ckpt.items(): |
| if key.startswith(vit_key_prefix) or key.startswith(projector_key_prefix): |
| continue |
| llm_ckpt[key] = value |
|
|
| vlm_ckpt = {} |
| for key, value in llm_ckpt.items(): |
| vlm_ckpt["model." + key] = value |
| for key, value in projector_ckpt.items(): |
| vlm_ckpt["mm_projector." + key] = value |
| for key, value in vit_ckpt.items(): |
| vlm_ckpt["vision_encoder." + key] = value |
|
|
| |
| config_path = os.path.join(pixtral_ckpt_dir, "params.json") |
| with open(config_path, "r") as f: |
| pixtral_config = json.load(f) |
|
|
| |
| vision_encoder_config = { |
| "dim": pixtral_config["vision_encoder"]["hidden_size"], |
| "num_channels": pixtral_config["vision_encoder"]["num_channels"], |
| "image_size": pixtral_config["vision_encoder"]["image_size"], |
| "patch_size": pixtral_config["vision_encoder"]["patch_size"], |
| "rope_theta": pixtral_config["vision_encoder"]["rope_theta"], |
| "ffn_hidden_size": pixtral_config["vision_encoder"]["intermediate_size"], |
| "n_layers": pixtral_config["vision_encoder"]["num_hidden_layers"], |
| "n_heads": pixtral_config["vision_encoder"]["num_attention_heads"], |
| "n_kv_heads": pixtral_config["vision_encoder"]["num_attention_heads"], |
| "norm_type": "rmsnorm", |
| "norm_eps": pixtral_config["norm_eps"], |
| "image_token_id": pixtral_config["vision_encoder"]["image_token_id"], |
| } |
| |
| vit_config = dict( |
| dim=1024, |
| num_channels=3, |
| image_size=1024, |
| patch_size=16, |
| rope_theta=10000, |
| ffn_hidden_size=4096, |
| n_layers=24, |
| n_heads=16, |
| n_kv_heads=16, |
| norm_type="rmsnorm", |
| norm_eps=1e-5, |
| image_token_id=10, |
| ) |
| |
| for key, value in vit_config.items(): |
| assert vision_encoder_config[key] == value, f"Mismatch in {key}: {vision_encoder_config[key]} != {value}" |
|
|
| llm_config_keys = [ |
| "dim", |
| "n_layers", |
| "head_dim", |
| "hidden_dim", |
| "n_heads", |
| "n_kv_heads", |
| "rope_theta", |
| "norm_eps", |
| "vocab_size", |
| ] |
| assert set(list(pixtral_config.keys())) == set(llm_config_keys + ["vision_encoder"]), "Config keys mismatch" |
| replace_map = { |
| "hidden_dim": "ffn_hidden_size", |
| } |
| llm_config = {} |
| for k, v in pixtral_config.items(): |
| if k in llm_config_keys: |
| llm_config[replace_map.get(k, k)] = v |
| elif k == "vision_encoder": |
| llm_config["vision_encoder"] = vit_type |
| else: |
| raise ValueError(f"Unknown key: {k}") |
|
|
| ckpt_to_save = {"model": vlm_ckpt, "mm_projector": projector_ckpt, "vision_encoder": vit_ckpt} |
| torch.save(ckpt_to_save, save_path) |
| print(f"Model saved to {save_path}") |
|
|
| |
| config_path = os.path.join(save_dir, "config.json") |
| with open(config_path, "w") as f: |
| json.dump(llm_config, f) |
|
|
| torch.set_default_dtype(orig_dtype) |
|
|
| |
| shutil.rmtree(pixtral_ckpt_dir, ignore_errors=True) |
| print(f"Removed {pixtral_ckpt_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Convert pretrained Pixtral vision model weights to checkpoint and verify accuracy" |
| ) |
| parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Path to the checkpoint directory") |
| parser.add_argument( |
| "--checkpoint_name", |
| type=str, |
| default="Pixtral-12B", |
| help="Name of the checkpoint", |
| ) |
| parser.add_argument("--vit_type", default="pixtral-12b-vit", help="Type of ViT used in the Pixtral model") |
| args = parser.parse_args() |
| convert_pixtral_checkpoint( |
| checkpoint_dir=args.checkpoint_dir, checkpoint_name=args.checkpoint_name, vit_type=args.vit_type |
| ) |
|
|