jjzjj

人工智能(Pytorch)搭建LSTM网络实现简单案例

微学AI 2023-06-10 原文

 本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052

目录

一、Pytorch搭建神经网络的简单步骤

二、LSTM网络

三、Pytorch搭建LSTM网络的代码实战

 一、Pytorch搭建神经网络

PyTorch是一个基于Python的深度学习框架,提供了自动求导机制、强大的GPU支持和动态图等特性。PyTorch搭建神经网络的一般步骤

1.导入必要的库和数据

import torch
import torch.nn as nn
import torch.optim as optim

# 加载数据并进行预处理
train_data = ...
test_data = ...

2.定义神经网络模型

# 建立Net网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)  # 输入层到隐藏层
        self.fc2 = nn.Linear(256, 10)   # 隐藏层到输出层

    def forward(self, x):
        x = x.view(-1, 784)        # 将样本拉平成一维向量
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        
net = Net()

3.定义损失函数和优化器

 # 使用交叉熵作为损失函数,适合分类问题
criterion = nn.CrossEntropyLoss()   

# 使用随机梯度下降进行参数优化
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)    

4.训练模型

num_epochs =10 # 训练次数

for epoch in range(num_epochs):
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

5.测试模型

correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total)

PyTorch搭建神经网络的具体步骤都是按照以上模板进行搭建的,大家可以记住以上步骤。

二、LSTM网络

LSTM网络是一种特殊的循环神经网络,它能够学习处理序列中的长期依赖性,而不会受到梯度消失或梯度爆炸的影响。LSTM中的关键组成部分是门控机制,它允许网络选择性地丢弃或保留信息。每个门控单元都包括一个sigmoid激活函数和一个逐元素乘积。

LSTM的内部状态由三个向量组成:记忆细胞状态、隐藏状态和输入。在每个时间步,LSTM接收输入​和前一个时刻的记忆状态和隐藏状态​。然后,它使用门控机制来计算以下内容:

 1.遗忘门:决定从前一个时间步的记忆状态中删除哪些信息。

其中是遗忘门的权重和偏差,是sigmoid激活函数。

 2.输入门:决定将哪些新信息添加到记忆状态中。

 3.候选记忆状态:包括新的候选信息。

 4.更新记忆状态:根据遗忘门、输入门和候选记忆状态更新记忆状态

其中⊙表示逐元素乘积。

 5.输出门:决定从记忆状态中输出哪些信息。

 6.隐藏状态:根据输出门和记忆状态计算隐藏状态。

LSTM通过门控机制实现了选择性地保留或丢弃信息,从而学习处理长序列的能力。在训练过程中,LSTM网络通过反向传播算法自动调整门控单元的参数,使其能够更好地适应数据。

三、Pytorch搭建LSTM网络的代码实战

import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(1), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(1), self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[-1, :, :])
        return out

# 准备数据
input_size = 10   # 输入特征数
hidden_size = 20  # 隐藏层特征数
num_layers = 2    # LSTM层数
output_size = 2   # 输出类别数
batch_size = 3    # 批大小
sequence_length = 5  # 序列长度

# 随机生成一些数据
x = torch.randn(sequence_length, batch_size, input_size)
y = torch.randint(output_size, (batch_size,))

# 定义优化器和损失函数
model = LSTM(input_size, hidden_size, num_layers, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 开始训练
num_epochs = 100
for epoch in range(num_epochs):
    outputs = model(x)
    loss = criterion(outputs, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 10 == 0:
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))


# 预测新数据
with torch.no_grad():
    test_x = torch.randn(sequence_length, batch_size, input_size)
    outputs = model(test_x)
    _, predicted = torch.max(outputs.data, 1)
    print(predicted)

运行结果:

Epoch [10/100], Loss: 0.0209
Epoch [20/100], Loss: 0.0007
Epoch [30/100], Loss: 0.0002
Epoch [40/100], Loss: 0.0002
Epoch [50/100], Loss: 0.0001
Epoch [60/100], Loss: 0.0001
Epoch [70/100], Loss: 0.0001
Epoch [80/100], Loss: 0.0001
Epoch [90/100], Loss: 0.0001
Epoch [100/100], Loss: 0.0001
tensor([0, 0, 0])

Process finished with exit code 0

有关人工智能(Pytorch)搭建LSTM网络实现简单案例的更多相关文章

  1. ruby - 如何根据特征实现 FactoryGirl 的条件行为 - 2

    我有一个用户工厂。我希望默认情况下确认用户。但是鉴于unconfirmed特征,我不希望它们被确认。虽然我有一个基于实现细节而不是抽象的工作实现,但我想知道如何正确地做到这一点。factory:userdoafter(:create)do|user,evaluator|#unwantedimplementationdetailshereunlessFactoryGirl.factories[:user].defined_traits.map(&:name).include?(:unconfirmed)user.confirm!endendtrait:unconfirmeddoenden

  2. ruby - 简单获取法拉第超时 - 2

    有没有办法在这个简单的get方法中添加超时选项?我正在使用法拉第3.3。Faraday.get(url)四处寻找,我只能先发起连接后应用超时选项,然后应用超时选项。或者有什么简单的方法?这就是我现在正在做的:conn=Faraday.newresponse=conn.getdo|req|req.urlurlreq.options.timeout=2#2secondsend 最佳答案 试试这个:conn=Faraday.newdo|conn|conn.options.timeout=20endresponse=conn.get(url

  3. ruby - 用 Ruby 编写一个简单的网络服务器 - 2

    我想在Ruby中创建一个用于开发目的的极其简单的Web服务器(不,不想使用现成的解决方案)。代码如下:#!/usr/bin/rubyrequire'socket'server=TCPServer.new('127.0.0.1',8080)whileconnection=server.acceptheaders=[]length=0whileline=connection.getsheaders想法是从命令行运行这个脚本,提供另一个脚本,它将在其标准输入上获取请求,并在其标准输出上返回完整的响应。到目前为止一切顺利,但事实证明这真的很脆弱,因为它在第二个请求上中断并出现错误:/usr/b

  4. ruby-on-rails - 简单的 Ruby on Rails 问题——如何将评论附加到用户和文章? - 2

    我意识到这可能是一个非常基本的问题,但我现在已经花了几天时间回过头来解决这个问题,但出于某种原因,Google就是没有帮助我。(我认为部分问题在于我是一个初学者,我不知道该问什么......)我也看过O'Reilly的RubyCookbook和RailsAPI,但我仍然停留在这个问题上.我找到了一些关于多态关系的信息,但它似乎不是我需要的(尽管如果我错了请告诉我)。我正在尝试调整MichaelHartl'stutorial创建一个包含用户、文章和评论的博客应用程序(不使用脚手架)。我希望评论既属于用户又属于文章。我的主要问题是:我不知道如何将当前文章的ID放入评论Controller。

  5. ruby - 使用 Ruby 通过 Outlook 发送消息的最简单方法是什么? - 2

    我的工作要求我为某些测试自动生成电子邮件。我一直在四处寻找,但未能找到可以快速实现的合理解决方案。它需要在outlook而不是其他邮件服务器中,因为我们有一些奇怪的身份验证规则,我们需要保存草稿而不是仅仅发送邮件的选项。显然win32ole可以做到这一点,但我找不到任何相当简单的例子。 最佳答案 假设存储了Outlook凭据并且您设置为自动登录到Outlook,WIN32OLE可以很好地完成此操作:require'win32ole'outlook=WIN32OLE.new('Outlook.Application')message=

  6. 华为OD机试用Python实现 -【明明的随机数】 2023Q1A - 2

    华为OD机试题本篇题目:明明的随机数题目输入描述输出描述:示例1输入输出说明代码编写思路最近更新的博客华为od2023|什么是华为od,od薪资待遇,od机试题清单华为OD机试真题大全,用Python解华为机试题|机试宝典【华为OD机试】全流程解析+经验分享,题型分享,防作弊指南华为o

  7. 基于C#实现简易绘图工具【100010177】 - 2

    C#实现简易绘图工具一.引言实验目的:通过制作窗体应用程序(C#画图软件),熟悉基本的窗体设计过程以及控件设计,事件处理等,熟悉使用C#的winform窗体进行绘图的基本步骤,对于面向对象编程有更加深刻的体会.Tutorial任务设计一个具有基本功能的画图软件**·包括简单的新建文件,保存,重新绘图等功能**·实现一些基本图形的绘制,包括铅笔和基本形状等,学习橡皮工具的创建**·设计一个合理舒适的UI界面**注明:你可能需要先了解一些关于winform窗体应用程序绘图的基本知识,以及关于GDI+类和结构的知识二.实验环境Windows系统下的visualstudio2017C#窗体应用程序三.

  8. 网络编程套接字 - 2

    网络编程套接字网络编程基础知识理解源`IP`地址和目的`IP`地址理解源MAC地址和目的MAC地址认识端口号理解端口号和进程ID理解源端口号和目的端口号认识`TCP`协议认识`UDP`协议网络字节序socket编程接口`sockaddr``UDP`网络程序服务器端代码逻辑:需要用到的接口服务器端代码`udp`客户端代码逻辑`udp`客户端代码`TCP`网络程序服务器代码逻辑多个版本服务器单进程版本多进程版本多线程版本线程池版本服务器端代码客户端代码逻辑客户端代码TCP协议通讯流程TCP协议的客户端/服务器程序流程三次握手(建立连接)数据传输四次挥手(断开连接)TCP和UDP对比网络编程基础知识

  9. 「Python|Selenium|场景案例」如何定位iframe中的元素? - 2

    本文主要介绍在使用Selenium进行自动化测试或者任务时,对于使用了iframe的页面,如何定位iframe中的元素文章目录场景描述解决方案具体代码场景描述当我们在使用Selenium进行自动化测试的时候,可能会遇到一些界面或者窗体是使用HTML的iframe标签进行承载的。对于iframe中的标签,如果直接查找是无法找到的,会抛出没有找到元素的异常。比如近在咫尺的例子就是,CSDN的登录窗体就是使用的iframe,大家可以尝试通过F12开发者模式查看到的tag_name,class_name,id或者xpath来定位中的页面元素,会抛出NoSuchElementException异常。解决

  10. postman——集合——执行集合——测试脚本——pm对象简单示例02 - 2

    //1.验证返回状态码是否是200pm.test("Statuscodeis200",function(){pm.response.to.have.status(200);});//2.验证返回body内是否含有某个值pm.test("Bodymatchesstring",function(){pm.expect(pm.response.text()).to.include("string_you_want_to_search");});//3.验证某个返回值是否是100pm.test("Yourtestname",function(){varjsonData=pm.response.json

随机推荐