jjzjj

【深度学习前沿应用】图像分类Fine-Tuning

灵彧universe 2023-03-28 原文

【深度学习前沿应用】图像分类Fine-Tuning


作者简介:在校大学生一枚,华为云享专家,阿里云星级博主,腾云先锋(TDP)成员,云曦智划项目总负责人,全国高等学校计算机教学与产业实践资源建设专家委员会(TIPCC)志愿者,以及编程爱好者,期待和大家一起学习,一起进步~ . 博客主页ぃ灵彧が的学习日志 . 本文专栏机器学习 . 专栏寄语:若你决定灿烂,山无遮,海无拦 .

(文章目录)


前言

1. 什么是预训练-微调模式?

在计算机视觉领域,预训练-微调模式已经沿用了多年,即在大规模图片数据集预训练模型参数,然后将训练好的参数在新的小数据集任务上进行微调,从而产生泛化性能更好的模型。


2. 什么是ResNet?

ResNet为常用的预训练模型之一,其核心操作为卷积与残差连接。卷积层为3×3的滤波器,并遵循两个简单的设计规则:①对于相同的输出特征图尺寸,每层具有相同数量的滤波器;②如果特征图尺寸减半,则滤波器数量加倍,以保持每层的时间复杂度。直接用步长为2的卷积层进行下采样,网络以全局平均池化层和伴随softmax的1000维全连接层结束,其中,卷积层数为34,因此也称为ResNet34(如下图1所示)。

本小节将使用ResNet34预训练-微调框架,实现猫脸12分类。对于给定的猫脸,判断其所属类型。


一、数据加载及预处理

本实验数据集来源于网络开源数据集(https://aistudio.baidu.com/aistudio/datasetdetail/10954),该数据集中包含12类猫图片,总计数据量为2160,部分图片展示如下图1所示。


(一)、数据加载及预处理

首先将该数据集挂载到当前项目中,然后读取数据文件,将数据按照8:2划分为训练集与验证集

  1. 导入相关包
import os import time import os.path as osp import zipfile import numpy as np import paddle import paddle.nn as nn import pandas as pd import paddle.nn.functional as F from PIL import Image from paddle.io import Dataset, DataLoader from paddle.optimizer import Adam from paddle.vision import Compose, ToTensor, Resize from paddle.vision.models import resnet34 from paddle.metric import Accuracy from sklearn.model_selection import StratifiedShuffleSplit
  1. 将train划分为训练集和验证集
info = pd.read_csv(osp.join('./data', 'train_list.txt'), sep='\t', header=None) images, labels = info.iloc[:, 0], info.iloc[:, 1] split = StratifiedShuffleSplit(test_size=0.2) train_idx, valid_idx = next(split.split(images, labels)) info_tr = info.iloc[train_idx, :] info_va = info.iloc[valid_idx, :] info_tr.to_csv('data/train.csv', header=False, index=False) info_va.to_csv('data/valid.csv', header=False, index=False)

(二)、数据集封装

class CatDataset(Dataset): train_file = 'cat_12_train.zip' test_file = 'cat_12_test.zip' train_label = 'train_list.txt' def __init__(self, root, mode, transform=None): super(CatDataset, self).__init__() self.root = root self.mode = mode self.transform = transform if not osp.isfile(osp.join(root, self.train_file)) or \ not osp.isfile(osp.join(root, self.train_label)) or \ not osp.isfile(osp.join(root, self.test_file)): raise ValueError('wrong data path') if not osp.isdir(osp.join(self.root, 'cat_12_train')): with zipfile.ZipFile(osp.join(root, self.train_file)) as f: f.extractall(root) with zipfile.ZipFile(osp.join(root, self.test_file)) as f: f.extractall(root) if mode == 'train': info = pd.read_csv(osp.join(root, 'train_list.txt'), sep='\t', header=None) self.images = info.iloc[:, 0].to_list() self.labels = paddle.to_tensor( info.iloc[:, 1].to_list() ) elif mode == 'train_': info = pd.read_csv(osp.join(root, 'train.csv'), header=None) self.images = info.iloc[:, 0].to_list() self.labels = paddle.to_tensor( info.iloc[:, 1].to_list() ) pass elif mode == 'valid_': info = pd.read_csv(osp.join(root, 'valid.csv'), header=None) self.images = info.iloc[:, 0].to_list() self.labels = paddle.to_tensor( info.iloc[:, 1].to_list() ) else: images = os.listdir(os.path.join(root, 'cat_12_test')) self.images = ['cat_12_test/'+image for image in images] self.labels = None def __getitem__(self, idx): image = Image.open(osp.join(self.root, self.images[idx])) if image.mode != 'RGB': image = image.convert('RGB') if self.transform is not None: image = self.transform(image) if self.mode == 'test': return image, else: label = self.labels[idx] return image, label def __len__(self): return len(self.images)

(三)、样本分类与统计

paddle.set_device('gpu' if paddle.is_compiled_with_cuda() else 'cpu') transform = Compose([ Resize([224, 224]), ToTensor() ]) train_ds = CatDataset('./data', 'train_', transform) valid_ds = CatDataset('./data', 'valid_', transform) train_dl = DataLoader(train_ds, batch_size=64, shuffle=True) valid_dl = DataLoader(valid_ds, batch_size=64, shuffle=False) print('训练集样本数:',train_ds.__len__()) print('验证集样本数:',valid_ds.__len__())

二、预训练模型加载

paddle,vision是飞桨在视觉领域的高层API,内部封装了常用的数据集以及常用预测训练模型,如LeNet、VGG系列、ResNet系列及MobileNet系列等。本实验使用resnet34为例,演示如何进行图像分类的微调。

准备好数据集之后,加载预训练模型,调用net=resnet34(pretrained=True),设置参数pretrained为True,便可使用预训练好的参数,否则,需要从头开始训练参数(首次加载预训练参数时需要从相关专业网络中下载):


加载预训练模型,并设置类别数目为12(猫的分类)

net = resnet34(pretrained=True, num_classes=12)

三、模型微调

加载好预训练的模型之后,定义模型的优化器、评价指标等,输入领域数据,执行微调:

(一)、定义优化器

optimizer = Adam( parameters=net.parameters(), learning_rate=1e-5 )

(二)、定义损失函数

loss_fn = nn.CrossEntropyLoss()

(三)、定义准确率评价指标

metric_fn = Accuracy()

(四)、微调20轮

for epoch in range(20): t0 = time.time() net.train() for data, label in train_dl: logit = net(data) loss = loss_fn(logit, label.astype('int64')) optimizer.clear_grad() loss.backward() optimizer.step() # 验证 net.eval() loss_tr = 0. for data, label in train_dl: logit = net(data) label = label.astype('int64') loss_tr += loss_fn(logit, label).cpu().numpy()[0] loss_tr /= len(train_dl) loss_va = 0. for data, label in valid_dl: label = label.astype('int64') logit = net(data) loss_va += loss_fn(logit, label).cpu().numpy()[0] metric_fn.update( metric_fn.compute(logit, label) ) loss_va /= len(valid_dl) acc_va = metric_fn.accumulate() metric_fn.reset() t = time.time() - t0 print('[Epoch {:3d} {:.2f}s] train loss({:.4f}); valid loss({:.4f}), acc({:.2f})' .format(epoch, t, loss_tr, loss_va, acc_va)) 训练过程部分输出如下图2所示:


四、模型预测

import matplotlib.image as mpimg import matplotlib.pyplot as plt def show_image(file_name): img = mpimg.imread('data/'+file_name) plt.figure(figsize=(10,10)) plt.imshow(img) plt.show() test_ds = CatDataset('./data', mode='test', transform=transform) test_dl = DataLoader(test_ds, batch_size=32, shuffle=False) test_pred = [] with paddle.no_grad(): for data, in test_dl: logit = net(data) pred = paddle.argmax( F.softmax(logit, axis=-1), axis=-1 ) test_pred.append(pred.cpu().numpy()) test_pred = np.concatenate(test_pred, axis=0) for image, pred in zip(test_ds.images, test_pred.astype(np.int)): img = mpimg.imread('data/'+image) plt.figure(figsize=(10,10)) plt.imshow(img) plt.show() print('图片路径:%s, 图片预测类型:%d\n' % (image.split('/')[1], pred)) 预测结果部分输出如下图3、4、5、6所示


总结

本系列文章内容为根据清华社出版的《机器学习实践》所作的相关笔记和感悟,其中代码均为基于百度飞桨开发,若有任何侵权和不妥之处,请私信于我,定积极配合处理,看到必回!!!

最后,引用本次活动的一句话,来作为文章的结语~( ̄▽ ̄~)~:

【**学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。**】

有关【深度学习前沿应用】图像分类Fine-Tuning的更多相关文章

  1. ruby-on-rails - 添加回形针新样式不影响旧上传的图像 - 2

    我有带有Logo图像的公司模型has_attached_file:logo我用他们的Logo创建了许多公司。现在,我需要添加新样式has_attached_file:logo,:styles=>{:small=>"30x15>",:medium=>"155x85>"}我是否应该重新上传所有旧数据以重新生成新样式?我不这么认为……或者有什么rake任务可以重新生成样式吗? 最佳答案 参见Thumbnail-Generation.如果rake任务不适合你,你应该能够在控制台中使用一个片段来调用重新处理!关于相关公司

  2. 世界前沿3D开发引擎HOOPS全面讲解——集3D数据读取、3D图形渲染、3D数据发布于一体的全新3D应用开发工具 - 2

    无论您是想搭建桌面端、WEB端或者移动端APP应用,HOOPSPlatform组件都可以为您提供弹性的3D集成架构,同时,由工业领域3D技术专家组成的HOOPS技术团队也能为您提供技术支持服务。如果您的客户期望有一种在多个平台(桌面/WEB/APP,而且某些客户端是“瘦”客户端)快速、方便地将数据接入到3D应用系统的解决方案,并且当访问数据时,在各个平台上的性能和用户体验保持一致,HOOPSPlatform将帮助您完成。利用HOOPSPlatform,您可以开发在任何环境下的3D基础应用架构。HOOPSPlatform可以帮您打造3D创新型产品,HOOPSSDK包含的技术有:快速且准确的CAD

  3. LC滤波器设计学习笔记(一)滤波电路入门 - 2

    目录前言滤波电路科普主要分类实际情况单位的概念常用评价参数函数型滤波器简单分析滤波电路构成低通滤波器RC低通滤波器RL低通滤波器高通滤波器RC高通滤波器RL高通滤波器部分摘自《LC滤波器设计与制作》,侵权删。前言最近需要学习放大电路和滤波电路,但是由于只在之前做音乐频谱分析仪的时候简单了解过一点点运放,所以也是相当从零开始学习了。滤波电路科普主要分类滤波器:主要是从不同频率的成分中提取出特定频率的信号。有源滤波器:由RC元件与运算放大器组成的滤波器。可滤除某一次或多次谐波,最普通易于采用的无源滤波器结构是将电感与电容串联,可对主要次谐波(3、5、7)构成低阻抗旁路。无源滤波器:无源滤波器,又称

  4. CAN协议的学习与理解 - 2

    最近在学习CAN,记录一下,也供大家参考交流。推荐几个我觉得很好的CAN学习,本文也是在看了他们的好文之后做的笔记首先是瑞萨的CAN入门,真的通透;秀!靠这篇我竟然2天理解了CAN协议!实战STM32F4CAN!原文链接:https://blog.csdn.net/XiaoXiaoPengBo/article/details/116206252CAN详解(小白教程)原文链接:https://blog.csdn.net/xwwwj/article/details/105372234一篇易懂的CAN通讯协议指南1一篇易懂的CAN通讯协议指南1-知乎(zhihu.com)视频推荐CAN总线个人知识总

  5. 深度学习部署:Windows安装pycocotools报错解决方法 - 2

    深度学习部署:Windows安装pycocotools报错解决方法1.pycocotools库的简介2.pycocotools安装的坑3.解决办法更多Ai资讯:公主号AiCharm本系列是作者在跑一些深度学习实例时,遇到的各种各样的问题及解决办法,希望能够帮助到大家。ERROR:Commanderroredoutwithexitstatus1:'D:\Anaconda3\python.exe'-u-c'importsys,setuptools,tokenize;sys.argv[0]='"'"'C:\\Users\\46653\\AppData\\Local\\Temp\\pip-instal

  6. ruby-on-rails - 在 Ruby (on Rails) 中使用 imgur API 获取图像 - 2

    我正在尝试使用Ruby2.0.0和Rails4.0.0提供的API从imgur中提取图像。我已尝试按照Ruby2.0.0文档中列出的各种方式构建http请求,但均无济于事。代码如下:require'net/http'require'net/https'defimgurheaders={"Authorization"=>"Client-ID"+my_client_id}path="/3/gallery/image/#{img_id}.json"uri=URI("https://api.imgur.com"+path)request,data=Net::HTTP::Get.new(path

  7. python ffmpeg 使用 pyav 转换 一组图像 到 视频 - 2

    2022/8/4更新支持加入水印水印必须包含透明图像,并且水印图像大小要等于原图像的大小pythonconvert_image_to_video.py-f30-mwatermark.pngim_dirout.mkv2022/6/21更新让命令行参数更加易用新的命令行使用方法pythonconvert_image_to_video.py-f30im_dirout.mkvFFMPEG命令行转换一组JPG图像到视频时,是将这组图像视为MJPG流。我需要转换一组PNG图像到视频,FFMPEG就不认了。pyav内置了ffmpeg库,不需要系统带有ffmpeg工具因此我使用ffmpeg的python包装p

  8. ruby - 我正在学习编程并选择了 Ruby。我应该升级到 Ruby 1.9 吗? - 2

    我完全不是程序员,正在学习使用Ruby和Rails框架进行编程。我目前正在使用Ruby1.8.7和Rails3.0.3,但我想知道我是否应该升级到Ruby1.9,因为我真的没有任何升级的“遗留”成本。缺点是什么?我是否会遇到与普通gem的兼容性问题,或者甚至其他我不太了解甚至无法预料的问题? 最佳答案 你应该升级。不要坚持从1.8.7开始。如果您发现不支持1.9.2的gem,请避免使用它们(因为它们很可能不被维护)。如果您对gem是否兼容1.9.2有任何疑问,您可以在以下位置查看:http://www.railsplugins.or

  9. ruby - 是否有将图像文件转换为 ASCII 艺术的命令行程序或库? - 2

    有这样的事吗?我想在Ruby程序中使用它。 最佳答案 试试这个http://csl.sublevel3.org/jp2a/此外,Imagemagick可能还有一些东西 关于ruby-是否有将图像文件转换为ASCII艺术的命令行程序或库?,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.com/questions/6510445/

  10. ruby-on-rails - 使用 Dragonfly 从 URL 分配图像 - 2

    我正在使用Dragonfly在Rails3.1应用程序上处理图像。我正在努力通过url将图像分配给模型。我有一个很好的表格:{:multipart=>true}do|f|%>RemovePicture?Dragonfly的文档指出:Dragonfly提供了一个直接从url分配的访问器:@album.cover_image_url='http://some.url/file.jpg'但是当我在控制台中尝试时:=>#ruby-1.9.2-p290>picture.image_url="http://i.imgur.com/QQiMz.jpg"=>"http://i.imgur.com/QQ

随机推荐