导图社区 模型训练代码
CV图像分类思路加代码思维导图,包括数据、模型、优化、训练指标、迭代和结果、保存等内容。
编辑于2021-11-08 20:21:16模型训练代码
数据
数据预处理
工具
pandas
cv2
数据集划分
sklearn.model_selection.train_test_split
数据读取
工具包
torch.utils.data.Dataset
os
文件读取
scipy
.mat数据读取
数据分析
numpy
pandas
子主题
模块
__init__
初始化参数
__getitem__
Dataloader循环时得到一组图像、标签数据
__len__
数据集长度
__get_info__
获得全部图像、数据组
数据增强
transforms.Compose([a,b])
组合处理
transforms.CenterCrop(size)
中心切割
transforms.RandomCrop(size, padding=0)
随即切割
transforms.RandomHorizontalFlip
水平反转0.5
transforms.RandomSizedCrop(size, interpolation=2)
先随机切割后在变换成size大小
torchvision.transforms.Normalize(mean, std)
正则化要放在transforms.toTensor()之后
Mixup
将两张图片融合到一起
标签平滑
torch
dataset
dataloader
keras
模型
模型搭建
迁移
代码
num_ftrs = model.fc.in_features
fc表示最后全连接层,in_features最后一层输入
model.fc = nn.Linear(num_ftrs, train_data.cls_num) # 102
自己建立最后一层赋值
pretrained_state_dict = torch.load(path_state_dict, map_location="cpu") model.load_state_dict(pretrained_state_dict)
path_state_dict为与训练模型参数路径
mode.load_state_dict读取该参数
从头搭建
图像分类
各类优秀的backbonehttps://rwightman.github.io/pytorch-image-models/models/
优化
损失函数
交叉熵损失函数
nn.CrossEntropyLoss()
softmax
用exp函数使得输出归一化在【0,1】之间
优化器
随机梯度下降,带动量
torch.optim.SGD
keras.optimizers.SGD
训练指标
混淆矩阵
召回率
该样本分类正确的个数
准确率
所有该分类中正确的个数
泛化能力好
低方差
低偏差
模型选择
方差
刻画了数据扰动所造成的影响
偏差
度量了学习算法的期望预测与真是结果的偏差
噪声
表示任何算法所能达到期望泛化误差的下线
模型评估
迭代和结果保存
log
logger类
定义路径
初始化logger
子主题
绘制loss和acc曲线
绘制混淆矩阵
设置随机种子
检查路径是否存在
python通用方法
argparse
1.parser = argparse.ArgumentParser(description='Training')
2.parser.add_argument
3.args = parser.parse_args()
pickle
h5py
读取h5py文件的数据 with h5py.File(path_to_digit_struct_mat_file, 'r') as digit_struct_mat_file: attrs = get_attrs(digit_struct_mat_file, index) length = len(attrs['label']) attrs_left, attrs_top, attrs_width, attrs_height = map(lambda x: [int(i) for i in x], [attrs['left'], attrs['top'], attrs['width'], attrs['height']]) min_left, min_top, max_right, max_bottom = (min(attrs_left), min(attrs_top), max(map(lambda x, y: x + y, attrs_left, attrs_width)), max(map(lambda x, y: x + y, attrs_top, attrs_height))) def get_attrs(digit_struct_mat_file, index): """ Returns a dictionary which contains keys: label, left, top, width and height, each key has multiple values. """ attrs = {} f = digit_struct_mat_file item = f['digitStruct']['bbox'][index].item() for key in ['label', 'left', 'top', 'width', 'height']: attr = f[item][key] values = [f[attr[()][i].item()][()][0][0] # 此处[()]是因为h5py的要求 不然有warning for i in range(len(attr))] if len(attr) > 1 else [attr.value[0][0]] attrs[key] = values return attrs
通用工具
setup_seed
确定随机化种子
根据时间戳创建结果文件
pandas
数据读取
pd.read_csv()
数据处理
DataFrame.groupby(by='根据什么分组').方法(可以是sum,count,mean等)
head(n)显示前n行数据
columns。显示列名称
sort_values(['根据什么'],ascending=False)是否是升序
lambda结合.apply(lambda方法)对某列\数据使用
子主题
glob
通用技巧
f(*变量)
表示f函数可以接收多个变量俗称---动态变量