import os import torch import argparse from tqdm.auto import tqdm from safetensors.torch import load_file, save_file from torch import dtype def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1): _bits = torch.tensor(bits) _mantissa_bit = torch.tensor(mantissa_bit) _sign_bits = torch.tensor(sign_bits) M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits) E = _bits - _sign_bits - M bias = 2 ** (E - 1) - 1 mantissa = 1 for i in range(mantissa_bit - 1): mantissa += 1 / (2 ** (i+1)) maxval = mantissa * 2 ** (2**E - 1 - bias) return maxval def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1): """ Default is E4M3. """ bits = torch.tensor(bits) mantissa_bit = torch.tensor(mantissa_bit) sign_bits = torch.tensor(sign_bits) M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits) E = bits - sign_bits - M bias = 2 ** (E - 1) - 1 mantissa = 1 for i in range(mantissa_bit - 1): mantissa += 1 / (2 ** (i+1)) maxval = mantissa * 2 ** (2**E - 1 - bias) minval = - maxval minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval) input_clamp = torch.min(torch.max(x, minval), maxval) log_scales = torch.clamp( (torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0) log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype)) # dequant qdq_out = torch.round(input_clamp / log_scales) * log_scales return qdq_out, log_scales def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1): for i in range(len(x.shape) - 1): scale = scale.unsqueeze(-1) new_x = x / scale quant_dequant_x, log_scales = quantize_to_fp8( new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits) return quant_dequant_x, scale, log_scales def parse_args(): parser = argparse.ArgumentParser( description="Convert safetensors to fp8 scaled", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--file", type=str, required=True, help="Input .safetensors file to convert", ) parser.add_argument( "--base_dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"], help="dtype to use for anything that can't be converted to fp8", ) # parser.add_argument( # "--ban_list", # nargs="*", # default=[], # help="List of banned keys to keep in base dtype instead of converting to fp8 (zero or more strings)" # ) args = parser.parse_args() return args def main(args): input_path = os.path.normpath(args.file) output_path = os.path.splitext(input_path)[0] + "_fp8_scaled.safetensors" orig_state_dict = load_file(input_path) new_state_dict = {} model_dtype: dtype = None if args.base_dtype == "fp16": model_dtype = torch.float16 elif args.base_dtype == "bf16": model_dtype = torch.bfloat16 elif args.base_dtype == "fp32": model_dtype = torch.float32 else: raise Exception(f"unknown dtype: {args.base_dtype}") # ban_list = ["text", "time", "head"] # ban_list = args.ban_list # ban_list = ["norm", "embedder", "pad_token", "modulation", "final_layer"] # ban_list = ["norm", "embedder", "pad_token", "modulation", "final_layer", "to_q", "to_k", "to_v"] # for transformer, output will be ~6gb(bf16) <- ex. z_image_turbo_bf16_fp8_scaled_1.safetensors # ban_list = ["norm", "embedder", "pad_token", "modulation", "final_layer", "attention"] # for transformer, output will be ~8gb(bf16) <- ex. z_image_turbo_bf16_fp8_scaled_2.safetensors ban_list = ["norm", "embed_tokens"] # for text encoder maxval = get_fp_maxval() for key in tqdm(orig_state_dict.keys()): # decide whether to convert based on shape and banned keys convert = False if orig_state_dict[key].dim() == 2: convert = True for ban in ban_list: if ban in key: convert = False scale_key = key.rsplit(".", 1)[0] + ".scale_weight" if convert: weight = orig_state_dict[key] scale = torch.max(torch.abs(weight.flatten())) / maxval linear_weight, scale, log_scales = fp8_tensor_quant(weight, scale) linear_weight = linear_weight.to(dtype=torch.float8_e4m3fn) new_state_dict[scale_key] = scale new_state_dict[key] = linear_weight else: if orig_state_dict[key].dim() == 2: new_state_dict[scale_key] = torch.ones(1) new_state_dict[key] = orig_state_dict[key].to(dtype=model_dtype) new_state_dict["scaled_fp8"] = torch.zeros(2).to(dtype=torch.float8_e4m3fn) save_file(new_state_dict, output_path) if __name__ == "__main__": args = parse_args() print(args.base_dtype) main(args)