jjzjj

Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

落花雨时 2023-04-18 原文

文章目录

一、概述

🔥本项目使用Pytroch,并基于ResNet50模型,实现了对天气图片的识别,过程详细,十分适合基础阶段的同学阅读。

项目目录结构

核心步骤

  • 数据处理
  • 准备配置文件
  • 构建自定义DataSetDataloader
  • 构建模型
  • 训练模型
  • 编写预测模块
  • 效果展示

二、代码编写

1. 数据处理

本项目数据来源:
https://www.heywhale.com/mw/dataset/60d9bd7c056f570017c305ee/file
http://vcc.szu.edu.cn/research/2017/RSCM.html

由于数据是直接下载,且目录分的很规整,本项目的数据处理部分较为简单,直接手动复制,合并两个数据集即可。

数据概览

总数据量约7万张

2. 准备配置文件

配置文件的主要存储一些各个模块通用的一些全局变量,如各种文件的存放位置等等(本人Java程序员出身,一些Python的代码规范不太熟悉,望见谅)。

config.py

import time

import torch
# 项目配置文件

class Common:
    '''
    通用配置
    '''
    basePath = "D:/Data/weather/source/all/"  # 图片文件基本路径
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 设备配置
    imageSize = (224,224) # 图片大小
    labels = ["cloudy","haze","rainy","shine","snow","sunny","sunrise","thunder"] # 标签名称/文件夹名称


class Train:
    '''
    训练相关配置
    '''
    batch_size = 128
    num_workers = 0  # 对于Windows用户,这里应设置为0,否则会出现多线程错误
    lr = 0.001
    epochs = 40
    logDir = "./log/" + time.strftime('%Y-%m-%d-%H-%M-%S',time.gmtime()) # 日志存放位置
    modelDir = "./model/" # 模型存放位置

3. 自定义DataSet和DataLoader

dada_loader.py

# 自定义数据加载器
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from config import Common
from config import Train
import os
from PIL import Image
import torch.utils.data as Data
import numpy

# 定义数据处理transform
transform = transforms.Compose([
    transforms.Resize(Common.imageSize),
    transforms.ToTensor()
])



def loadDataFromDir():
    '''
    从文件夹中获取数据
    '''
    images = []
    labels = []
    # 1. 获取根文件夹下所有分类文件夹
    for d in os.listdir(Common.basePath):
        for imagePath in os.listdir(Common.basePath + d):  # 2. 获取某一类型下所有的图片名称
            # 3. 读取文件
            image = Image.open(Common.basePath + d + "/" + imagePath).convert('RGB')
            print("加载数据" + str(len(images)) + "条")

            # 4. 添加到图片列表中
            images.append(transform(image))
            # 5. 构造label
            categoryIndex = Common.labels.index(d)  # 获取分类下标
            label = [0] * 8  # 初始化label
            label[categoryIndex] = 1  # 根据下标确定目标值
            label = torch.tensor(label,dtype=torch.float)  # 转为tensor张量
            # 6. 添加到目标值列表
            labels.append(label)
            # 7. 关闭资源
            image.close()
    # 返回图片列表和目标值列表
    return images, labels


class WeatherDataSet(Dataset):
    '''
    自定义DataSet
    '''

    def __init__(self):
        '''
        初始化DataSet
        :param transform: 自定义转换器
        '''
        images, labels = loadDataFromDir()  # 在文件夹中加载图片
        self.images = images
        self.labels = labels

    def __len__(self):
        '''
        返回数据总长度
        :return:
        '''
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label


def splitData(dataset):
    '''
    分割数据集
    :param dataset:
    :return:
    '''
    # 求解一下数据的总量
    total_length = len(dataset)

    # 确认一下将80%的数据作为训练集, 剩下的20%的数据作为测试集
    train_length = int(total_length * 0.8)
    validation_length = total_length - train_length

    # 利用Data.random_split()直接切分数据集, 按照80%, 20%的比例进行切分
    train_dataset,validation_dataset = Data.random_split(dataset=dataset, lengths=[train_length, validation_length])
    return train_dataset, validation_dataset



# 1. 分割数据集
train_dataset, validation_dataset = splitData(WeatherDataSet())
# 2. 训练数据集加载器
trainLoader = DataLoader(train_dataset, batch_size=Train.batch_size, shuffle=True, num_workers=Train.num_workers)
# 3. 验证集数据加载器
valLoader = DataLoader(validation_dataset, batch_size=Train.batch_size, shuffle=False,
                       num_workers=Train.num_workers)

主要步骤:

  1. 读取图片使用的是Python自带的PIL

PIL教程:https://blog.csdn.net/weixin_43790276/article/details/108478270

  1. 由于使用的是残差网络,其图片尺寸必须是3*224*224,故需要使用Pytroch的transforms工具进行处理

transforms教程:https://blog.csdn.net/qq_38410428/article/details/94719553

  1. 自定义DataSet(继承DataSet类,并实现重写三个核心方法)
  2. 分割数据
  3. 创建验证集和训练集各自的加载器

4. 构建模型

model.py

import torch
from torch import nn
import torchvision.models as models
from config import Common, Train

# 引入rest50模型
net = models.resnet50()
net.load_state_dict(torch.load("./model/resnet50-11ad3fa6.pth"))


class WeatherModel(nn.Module):
    def __init__(self, net):
        super(WeatherModel, self).__init__()
        # resnet50
        self.net = net
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(1000, 8)
        self.output = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.net(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc(x)
        x = self.output(x)
        return x


model = WeatherModel(net)

主要步骤:

  1. 引入Pytorch官方的残差网络预训练模型

关于新版本的引入方法:https://blog.csdn.net/Sihang_Xie/article/details/125646287

  1. 添加自己的全连接输出层
  2. 创建模型

5. 训练模型

train.py

# 训练部分
import time
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from config import Common, Train
from model import model as weatherModel
from data_loader import trainLoader, valLoader
from torch import optim

# 1. 获取模型
model = weatherModel
model.to(Common.device)
# 2. 定义损失函数
criterion = nn.CrossEntropyLoss()
# 3. 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 4. 创建writer
writer = SummaryWriter(log_dir=Train.logDir, flush_secs=500)


def train(epoch):
    '''
    训练函数
    '''
    # 1. 获取dataLoader
    loader = trainLoader
    # 2. 调整为训练状态
    model.train()
    print()
    print('========== Train Epoch:{} Start =========='.format(epoch))
    epochLoss = 0  # 每个epoch的损失
    epochAcc = 0  # 每个epoch的准确率
    correctNum = 0  # 正确预测的数量
    for data, label in loader:
        data, label = data.to(Common.device), label.to(Common.device)  # 加载到对应设备
        batchAcc = 0  # 单批次正确率
        batchCorrectNum = 0  # 单批次正确个数
        optimizer.zero_grad()  # 清空梯度
        output = model(data)  # 获取模型输出
        loss = criterion(output, label)  # 计算损失
        loss.backward()  # 反向传播梯度
        optimizer.step()  # 更新参数
        epochLoss += loss.item() * data.size(0)  # 计算损失之和
        # 计算正确预测的个数
        labels = torch.argmax(label, dim=1)
        outputs = torch.argmax(output, dim=1)
        for i in range(0, len(labels)):
            if labels[i] == outputs[i]:
                correctNum += 1
                batchCorrectNum += 1
        batchAcc = batchCorrectNum / data.size(0)
        print("Epoch:{}\t TrainBatchAcc:{}".format(epoch, batchAcc))

    epochLoss = epochLoss / len(trainLoader.dataset)  # 平均损失
    epochAcc = correctNum / len(trainLoader.dataset)  # 正确率
    print("Epoch:{}\t Loss:{} \t Acc:{}".format(epoch, epochLoss, epochAcc))
    writer.add_scalar("train_loss", epochLoss, epoch)  # 写入日志
    writer.add_scalar("train_acc", epochAcc, epoch)  # 写入日志
    return epochAcc

def val(epoch):
    '''
    验证函数
    :param epoch: 轮次
    :return:
    '''
    # 1. 获取dataLoader
    loader = valLoader
    # 2. 初始化损失、准确率列表
    valLoss = []
    valAcc = []
    # 3. 调整为验证状态
    model.eval()
    print()
    print('========== Val Epoch:{} Start =========='.format(epoch))
    epochLoss = 0  # 每个epoch的损失
    epochAcc = 0  # 每个epoch的准确率
    correctNum = 0  # 正确预测的数量
    with torch.no_grad():
        for data, label in loader:
            data, label = data.to(Common.device), label.to(Common.device)  # 加载到对应设备
            batchAcc = 0  # 单批次正确率
            batchCorrectNum = 0  # 单批次正确个数
            output = model(data)  # 获取模型输出
            loss = criterion(output, label)  # 计算损失
            epochLoss += loss.item() * data.size(0)  # 计算损失之和
            # 计算正确预测的个数
            labels = torch.argmax(label, dim=1)
            outputs = torch.argmax(output, dim=1)
            for i in range(0, len(labels)):
                if labels[i] == outputs[i]:
                    correctNum += 1
                    batchCorrectNum += 1
            batchAcc = batchCorrectNum / data.size(0)
            print("Epoch:{}\t ValBatchAcc:{}".format(epoch, batchAcc))

        epochLoss = epochLoss / len(valLoader.dataset)  # 平均损失
        epochAcc = correctNum / len(valLoader.dataset)  # 正确率
        print("Epoch:{}\t Loss:{} \t Acc:{}".format(epoch, epochLoss, epochAcc))
        writer.add_scalar("val_loss", epochLoss, epoch)  # 写入日志
        writer.add_scalar("val_acc", epochAcc, epoch)  # 写入日志
    return epochAcc

if __name__ == '__main__':
    maxAcc = 0.75
    for epoch in range(1,Train.epochs + 1):
        trainAcc = train(epoch)
        valAcc = val(epoch)
        if valAcc > maxAcc:
            maxAcc = valAcc
            # 保存最大模型
            torch.save(model, Train.modelDir + "weather-" + time.strftime('%Y-%m-%d-%H-%M-%S', time.gmtime()) + ".pth")
    # 保存模型
    torch.save(model,Train.modelDir+"weather-"+time.strftime('%Y-%m-%d-%H-%M-%S',time.gmtime())+".pth")

主要步骤

  1. 加载模型
  2. 准备损失函数及优化器
  3. 创建tensorboard的writer

关于tensorboard的使用:https://blog.csdn.net/weixin_43637851/article/details/116003280

  1. 编写训练函数及验证函数,同时记录损失和正确率

验证函数和训练函数的区别就是是否需要更新参数

  1. 循环训练epochs次,不断保存正确率最大的模型,以及最后一次的训练模型
  2. 开始训练
  3. 不断调参(我就只训了3次),知道有一个比较满意的效果

训练过程中电脑的状态:

查看训练日志(tensorboard)


保存的模型

6. 编写预测模块

pridect.py

import torch
import torchvision.transforms as transforms
from PIL import Image
from config import Common
def pridect(imagePath, modelPath):
    '''
    预测函数
    :param imagePath: 图片路径
    :param modelPath: 模型路径
    :return:
    '''
    # 1. 读取图片
    image = Image.open(imagePath)
    # 2. 进行缩放
    image = image.resize(Common.imageSize)
    image.show()
    # 3. 加载模型
    model = torch.load(modelPath)
    model = model.to(Common.device)
    # 4. 转为tensor张量
    transform = transforms.ToTensor()
    x = transform(image)
    x = torch.unsqueeze(x, 0)  # 升维
    x = x.to(Common.device)
    # 5. 传入模型
    output = model(x)
    # 6. 使用argmax选出最有可能的结果
    output = torch.argmax(output)
    print("预测结果:",Common.labels[output.item()])

if __name__ == '__main__':
    pridect("D:/Download/76ee4c5e833499949eac41561dcb487d.jpeg","./model/weather-2022-10-14-07-36-57.pth")

三、效果展示

去网上随便找的图片:


四、源码地址

https://github.com/mengxianglong123/weather-recognition

欢迎交流学习🥰🥰🥰

有关Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)的更多相关文章

  1. postman接口测试工具-基础使用教程 - 2

    1.postman介绍Postman一款非常流行的API调试工具。其实,开发人员用的更多。因为测试人员做接口测试会有更多选择,例如Jmeter、soapUI等。不过,对于开发过程中去调试接口,Postman确实足够的简单方便,而且功能强大。2.下载安装官网地址:https://www.postman.com/下载完成后双击安装吧,安装过程极其简单,无需任何操作3.使用教程这里以百度为例,工具使用简单,填写URL地址即可发送请求,在下方查看响应结果和响应状态码常用方法都有支持请求方法:getpostputdeleteGet、Post、Put与Delete的作用get:请求方法一般是用于数据查询,

  2. 软件测试基础 - 2

    Ⅰ软件测试基础一、软件测试基础理论1、软件测试的必要性所有的产品或者服务上线都需要测试2、测试的发展过程3、什么是软件测试找bug,发现缺陷4、测试的定义使用人工或自动的手段来运行或者测试某个系统的过程。目的在于检测它是否满足规定的需求。弄清预期结果和实际结果的差别。5、测试的目的以最小的人力、物力和时间找出软件中潜在的错误和缺陷6、测试的原则28原则:20%的主要功能要重点测(eg:支付宝的支付功能,其他功能都是次要的)80%的错误存在于20%的代码中7、测试标准8、测试的基本要求功能测试性能测试安全性测试兼容性测试易用性测试外观界面测试可靠性测试二、质量模型衡量一个优秀软件的维度①功能性功

  3. LC滤波器设计学习笔记(一)滤波电路入门 - 2

    目录前言滤波电路科普主要分类实际情况单位的概念常用评价参数函数型滤波器简单分析滤波电路构成低通滤波器RC低通滤波器RL低通滤波器高通滤波器RC高通滤波器RL高通滤波器部分摘自《LC滤波器设计与制作》,侵权删。前言最近需要学习放大电路和滤波电路,但是由于只在之前做音乐频谱分析仪的时候简单了解过一点点运放,所以也是相当从零开始学习了。滤波电路科普主要分类滤波器:主要是从不同频率的成分中提取出特定频率的信号。有源滤波器:由RC元件与运算放大器组成的滤波器。可滤除某一次或多次谐波,最普通易于采用的无源滤波器结构是将电感与电容串联,可对主要次谐波(3、5、7)构成低阻抗旁路。无源滤波器:无源滤波器,又称

  4. CAN协议的学习与理解 - 2

    最近在学习CAN,记录一下,也供大家参考交流。推荐几个我觉得很好的CAN学习,本文也是在看了他们的好文之后做的笔记首先是瑞萨的CAN入门,真的通透;秀!靠这篇我竟然2天理解了CAN协议!实战STM32F4CAN!原文链接:https://blog.csdn.net/XiaoXiaoPengBo/article/details/116206252CAN详解(小白教程)原文链接:https://blog.csdn.net/xwwwj/article/details/105372234一篇易懂的CAN通讯协议指南1一篇易懂的CAN通讯协议指南1-知乎(zhihu.com)视频推荐CAN总线个人知识总

  5. 深度学习部署:Windows安装pycocotools报错解决方法 - 2

    深度学习部署:Windows安装pycocotools报错解决方法1.pycocotools库的简介2.pycocotools安装的坑3.解决办法更多Ai资讯:公主号AiCharm本系列是作者在跑一些深度学习实例时,遇到的各种各样的问题及解决办法,希望能够帮助到大家。ERROR:Commanderroredoutwithexitstatus1:'D:\Anaconda3\python.exe'-u-c'importsys,setuptools,tokenize;sys.argv[0]='"'"'C:\\Users\\46653\\AppData\\Local\\Temp\\pip-instal

  6. ES基础入门 - 2

    ES一、简介1、ElasticStackES技术栈:ElasticSearch:存数据+搜索;QL;Kibana:Web可视化平台,分析。LogStash:日志收集,Log4j:产生日志;log.info(xxx)。。。。使用场景:metrics:指标监控…2、基本概念Index(索引)动词:保存(插入)名词:类似MySQL数据库,给数据Type(类型)已废弃,以前类似MySQL的表现在用索引对数据分类Document(文档)真正要保存的一个JSON数据{name:"tcx"}二、入门实战{"name":"DESKTOP-1TSVGKG","cluster_name":"elasticsear

  7. ruby - 我正在学习编程并选择了 Ruby。我应该升级到 Ruby 1.9 吗? - 2

    我完全不是程序员,正在学习使用Ruby和Rails框架进行编程。我目前正在使用Ruby1.8.7和Rails3.0.3,但我想知道我是否应该升级到Ruby1.9,因为我真的没有任何升级的“遗留”成本。缺点是什么?我是否会遇到与普通gem的兼容性问题,或者甚至其他我不太了解甚至无法预料的问题? 最佳答案 你应该升级。不要坚持从1.8.7开始。如果您发现不支持1.9.2的gem,请避免使用它们(因为它们很可能不被维护)。如果您对gem是否兼容1.9.2有任何疑问,您可以在以下位置查看:http://www.railsplugins.or

  8. ruby - 我如何学习 ruby​​ 的正则表达式? - 2

    如何学习ruby​​的正则表达式?(对于假人) 最佳答案 http://www.rubular.com/在Ruby中使用正则表达式时是一个很棒的工具,因为它可以立即将结果可视化。 关于ruby-我如何学习ruby​​的正则表达式?,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.com/questions/1881231/

  9. 深度学习12. CNN经典网络 VGG16 - 2

    深度学习12.CNN经典网络VGG16一、简介1.VGG来源2.VGG分类3.不同模型的参数数量4.3x3卷积核的好处5.关于学习率调度6.批归一化二、VGG16层分析1.层划分2.参数展开过程图解3.参数传递示例4.VGG16各层参数数量三、代码分析1.VGG16模型定义2.训练3.测试一、简介1.VGG来源VGG(VisualGeometryGroup)是一个视觉几何组在2014年提出的深度卷积神经网络架构。VGG在2014年ImageNet图像分类竞赛亚军,定位竞赛冠军;VGG网络采用连续的小卷积核(3x3)和池化层构建深度神经网络,网络深度可以达到16层或19层,其中VGG16和VGG

  10. 【网络】-- 网络基础 - 2

    (本文是网络的宏观的概念铺垫)目录计算机网络背景网络发展认识"协议"网络协议初识协议分层OSI七层模型TCP/IP五层(或四层)模型报头以太网碰撞路由器IP地址和MAC地址IP地址与MAC地址总结IP地址MAC地址计算机网络背景网络发展        是最开始先有的计算机,计算机后来因为多项技术的水平升高,逐渐的计算机变的小型化、高效化。后来因为计算机其本身的计算能力比较的快速:独立模式:计算机之间相互独立。    如:有三个人,每个人做的不同的事物,但是是需要协作的完成。    而这三个人所做的事是需要进行协作的,然而刚开始因为每一台计算机之间都是互相独立的。所以前面的人处理完了就需要将数据

随机推荐