软件调优(六):Checkpoint 文件优化

对 Checkpoint 进行脚本处理

  • torch-checkpoint-convert-to-bf16:这个脚本会将两类 checkpoint 文件: torch 的 ".bin" 文件和 safetensor 的 ".safetensor" 文件中的权重转换为 bf16 ,并在名为 bf16 的子目录下创建一个新的 checkpoint。

    注意事项:该脚本假设所有权重均为浮点张量,因此仅适用于标准的 fp32/fp16 浮点 checkpoint

  • torch-checkpoint-shrink.py:这个脚本用于修复某些 ".pt" 文件后缀的 checkpoint:由于某些原因,这些 checkpoint 在保存时,张量对应的底层 storage 比当时实际使用的 view 更大。它会克隆当前 view,并重新保存张量,使其只保留当前 view 所需的 storage

    注意事项:这个脚本会直接覆盖原 checkpoint 文件,因此使用前最好先备份

脚本

torch-checkpoint-convert-to-bf16

torch-checkpoint-convert-to-bf16

#!/bin/bash

# 这个脚本会将两类 checkpoint 文件: torch 的 *.bin 文件和 safetensor 的 *.safetensor 文件
# 中的权重转换为 bf16,并在名为 bf16 的子目录下创建一个新的 checkpoint
#
# 注意事项:该脚本假设所有权重均为浮点张量,因此仅适用于标准的 fp32/fp16 浮点 checkpoint
# 
# 使用方法:
# cd checkpoint
# bash torch-checkpoint-convert-to-bf16

# 设置目标目录
target_dir=bf16

echo "creating a new checkpoint under dir $target_dir"
mkdir -p $target_dir

# 复制 config 和其他文件,可按需调整;也可以执行 `cp * $target_dir`
cp *json *model $target_dir

# 转换 *bin 文件
echo "converting *bin torch files"
python -c "import torch, sys; [torch.save({k:v.to(torch.bfloat16) for k,v in torch.load(f).items()}, f'{sys.argv[1]}/{f}') for f in sys.argv[2:]]" $target_dir *bin

# 转换 *safetensors 文件,来源是原始的 *bin 文件
if compgen -G "*.safetensors" > /dev/null; then
    echo "converting *safetensors files"
    cd $target_dir
    python -c "import re, sys, torch; from safetensors.torch import save_file; [save_file(torch.load(f), re.sub(r'.*?(model.*?)\.bin',r'\1.safetensors',f), metadata={'format': 'pt'}) for f in sys.argv[1:]]" *bin
    if test -e "pytorch_model.bin.index.json"; then
        cp pytorch_model.bin.index.json model.safetensors.index.json
        perl -pi -e 's|pytorch_||; s|\.bin|.safetensors|' model.safetensors.index.json
    fi
    cd - > /dev/null
fi

echo "the dir $target_dir now contains a copy of the original checkpoint with bf16 weights"

torch-checkpoint-shrink

torch-checkpoint-shrink.py

#!/usr/bin/env python

# 这个脚本用于修复某些 ".pt" 文件后缀的 checkpoint:由于某些原因,这些 checkpoint 在保存时,
# 张量对应的底层 storage 比当时实际使用的 view 更大。
# 它会克隆当前 view,并重新保存张量,使其只保留当前 view 所需的 storage。
#
# 注意事项:这个脚本会直接覆盖原 checkpoint 文件,因此使用前最好先备份
#
#
# 示例:
#
# 1. 处理 checkpoint 中的所有文件
# ./torch-checkpoint-shrink.py --checkpoint_dir ./checkpoints/global_step10
#
# 2. 只处理 checkpoint 中匹配多个模式的指定文件
# ./torch-checkpoint-shrink.py --checkpoint_dir ./checkpoints/global_step10 --patterns 'layer*pt' 'zero*pt'

import argparse
import torch
import glob
import os
import collections.abc
from fnmatch import fnmatch

debug = 0

# 加载到 CPU
device = torch.device('cpu')

def get_pt_files(checkpoint_dir, patterns):

    if not os.path.isdir(checkpoint_dir):
        raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")

    pt_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*.pt")))

    if len(pt_files) == 0:
        raise FileNotFoundError(
            f"can't find '*.pt' files in directory '{checkpoint_dir}'")

    # 按模式过滤文件,只匹配文件名部分,不包含任何父目录
    pt_files = [f for f in pt_files for p in patterns if fnmatch(os.path.basename(f), p)];

    return pt_files


# 当从检查点(checkpoint)加载模型时,张量可能共享底层存储(storage)。例如:
# """
#   weight = torch.randn(1000, 1000)
#   bias = weight[0]  # 与 weight 共享存储
# """
# 通过 .clone() 为每个张量创建独立的内存副本,丢弃未使用的 storage 部分
def shrink_dict_values(d, prefix=""):
    for k, v in d.items():
        k_full = f"{prefix}.{k}" if len(prefix) else k
        if isinstance(v, collections.abc.Mapping):
            shrink_dict_values(v, k_full)
        else:
            if debug:
                print(f"{k_full}")
            if v is not None and torch.is_tensor(v):
                d[k] = v.clone() # 丢弃任何未使用的 storage

def shrink_pt_file(f):
    print(f"-> {f}")
    size_before = os.path.getsize(f)
    sd = torch.load(f, map_location=device)     # 加载 .pt 文件
    shrink_dict_values(sd)
    torch.save(sd, f)                           # 覆盖原先的 .pt 文件
    size_after = os.path.getsize(f)
    size_delta = size_before - size_after       # 统计节省的内存空间    
    if debug:
        print(f"before {size_before / 2**20:.2f}MB, after {size_after / 2**20:.2f}MB, saved {size_delta / 2**20:.2f}MB")
    return size_before, size_after, size_delta

def checkpoint_shrink(checkpoint_dir, patterns):
    """
    参数:
        - ``ds_checkpoint_dir``:deepspeed checkpoint 文件夹路径,也就是 optimizer 文件所在的位置
    """
    print(f"Processing zero checkpoint '{checkpoint_dir}'")
    pt_files = get_pt_files(checkpoint_dir, patterns)
    before, after, delta = 0, 0, 0
    for f in pt_files:
        size_before, size_after, size_delta = shrink_pt_file(f)
        before += size_before
        after  += size_after
        delta  += size_delta
    print(f"Done. Before {before / 2**20:.2f}MB, after {after / 2**20:.2f}MB, saved {delta / 2**20:.2f}MB")

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoints/global_step10")
    parser.add_argument("--patterns", nargs='+', default="*.pt", required=False, type=str, help="one or more patterns of checkpoint files - make sure to quote those! by default all *.pt files")
    parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
    args = parser.parse_args()

    debug = args.debug

    checkpoint_shrink(args.checkpoint_dir, args.patterns)

评论