batch = tree_map(lambda x: x.to(device), next(data_generator))# 获取一个批次的数据
为什么这里batch.Size([256,64,24])?
config里面这些参数是怎么设计的?有什么用?它和缩放参数很像,能用在缩放上面吗?另外缩放之后Y_bound怎么设计?
1 | Y_bound=[242.382], # eta |
32*32的图像
tokens=64
block_sz=2
图像被分成多个大小为2*2的小块
self.index 是干什么的?
图像处理流程
32
1 | R = img[:, :, 0] |
- 32*32的图
- 分别提取R, G, B, 尺寸都是32*32
- 使用公式转到YCbCr颜色空间
- cb,cr降采样,从32X32降到16X16
1 | y_blocks = split_into_blocks(img_y, self.block_sz) # Y component, (64, 64) --> (256, 4, 4) |
- 将图像分成若干个小块(Y256, cb和cr64),每个小块大小为2*2
函数调用
train.py 使用datasets.py中的get_dataset函数拿到数据集,然后通过DataKiader初始化数据集迭代器
分析get_dataset:
- 根据不同数据集的名称返回不同的类,类中有个self.train的属性,存储了数据集实例
分析程序流程
仔细!别急!
- 加载配置文件
- 初始化设置
- 多进程
- 种子
- 混合精读
- log
- 加载数据
- get_dataset加载数据集
- get_split根据是训练集还是测试集、有标签还是无标签切割数据集
- DataLoader生成一个迭代器
- 初始化训练状态、accelerator
- 将数据加载器和模型与加速器结合,以进行分布式训练。
- 损失重加权变量的设置
- train.py 74行