jjzjj

sentence-transformers(SBert)中文文本相似度预测(附代码)

我先润了 2023-10-02 原文

前言

训练模型

  1. 创建网络:使用Sbert官方给出的预训练模型sentence_hfl_chinese-roberta-wwm-ext,先载入embedding层进行分词,再载入池化层并传入嵌入后的维度,对模型进行降维压缩,最后载入密集层,选择Than激活函数,输出维度大小为256维。
  2. 获取训练数据:构建出新模型后使用InputExample类存储训练数据,它接受文本对字符串列表和用于指示语义相似性的标签,用标准的Pytorch Dataloader包装train_examples,作用是打乱数据并生成特定大小的批次。
  3. 计算损失函数:对于每个句子对,通过网络传递句子A和句子B,从而产生嵌入u和v,使用余弦相似度计算相似性,并将结果与标准相似度得分进行比较。这样网络就能够进行微调,更好地识别句子的相似性。
  4. 模型调优:通过调用model.fit()来调优模型。向model.fit()中传递train_objective列表(由元组(dataloader, loss_function))组成。也可以传递多个元组,以便在具有不同损失函数的多个数据集上执行多任务学习。在训练过程需要使用sentence_transformers.evaluation评估表现是否有所改善,它包含各种可以传递给fit方法的evaluators。Evaluators会在训练期间定期运行,并且会返回分数,只有得分最高的模型才会存储在磁盘上。

首先运行preprocess.py获取数据,并划分训练集和测试集,之后运行train_sentence_bert.py,使用预训练模型, sbert将数据集用sbert训练相似度任务,得到训练好的模型,最后运行evaluate.py评估训练好的模型,将结果保存在predict.txt中,并输出预测结果。

这部分在详细代码里注释得很全。

后端部分

使用flask编写post接口,接收的数据格式为application/json,将前端传来的两个句子使用训练好的模型对其进行相似度预测,将得到的相似度类型从无法序列化存入json的tensor转成list,并将状态码,信息,数据返回给前端。

from sentence_transformers import SentenceTransformer, util
# 后端接口
from flask import Flask, jsonify, request
import re
# 用当前脚本名称实例化Flask对象,方便flask从该脚本文件中获取需要的内容
app = Flask(__name__)
# 使通过jsonify返回的中文显示正常,否则显示为ASCII码
app.config["JSON_AS_ASCII"] = False
model_path = 'D:/xxx模型路径/'
model = SentenceTransformer(model_path)
@app.route("/evaluate",methods=['POST'])
def evalute_sentence():
    s1 = request.json.get("s1")
    s2 = request.json.get("s2")
    if s1 and s2:
        embedding1 = model.encode(s1, convert_to_tensor=True)
        embedding2 = model.encode(s2, convert_to_tensor=True)
        similarity = util.cos_sim(embedding1, embedding2).tolist()
        return jsonify({"code": 200, "msg": "预测成功", "data": similarity})
    else:
        return jsonify({"code": 400, "msg": "缺少字段"})
if __name__ == '__main__':
    app.run(debug=True)

前端部分

框架使用Vue2,UI框架使用elementui。组件校验用户输入的表单(内容为中文,字数限制32个字,两个句子不为空),只有符合规则的字段才能提交表单。将数据通过Axios调用接口传递给后端,再根据后端接口响应状态进行相应的处理,如果返回状态码200,说明接口调用成功,展示返回的预测值,否则调用失败,页面弹出失败消息提示。

<template>
  <div class="recommend">
    <el-card class="box">
      <h2 class="title">中文文本相似度预测</h2>
      <el-form :model="evaluateForm" :rules="evaluateRules" ref="evaluateForm" class="form">
        <el-form-item prop="s1">
          <el-input
            placeholder="请输入句子一"
            maxlength="32"
            show-word-limit
            v-model="evaluateForm.s1"
            autocomplete="false"
            prefix-icon="el-icon-edit-outline"
          ></el-input>
        </el-form-item>
        <el-form-item prop="s2">
          <el-input
            maxlength="32"
            placeholder="请输入句子二"
            v-model="evaluateForm.s2"
            show-word-limit
            autocomplete="false"
            prefix-icon="el-icon-edit-outline"
          ></el-input>
        </el-form-item>
        <el-form-item class="btn-container">
          <el-button
            type="primary"
            @click="submitForm('evaluateForm')"
            class="btn"
            id="queryButton"
          >开始预测</el-button>
        </el-form-item>
      </el-form>
      <div v-show="result" style="margin-top: 20px">
        <el-progress
          :text-inside="true"
          :stroke-width="26"
          :percentage="result*100 ? result*100 : 0"
          class="el-bg-inner-running"
        ></el-progress>
        <p>预测结果:{{result}}</p>
      </div>
    </el-card>
  </div>
</template>

<script>
import api from "@/api/index"
export default {
  data () {
    return {
      evaluateForm: {
        s1: "",
        s2: ""
      },
      evaluateRules: { // 评估表单校验规则
        s1: [
          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },
        ],
        s2: [
          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },
        ],
      },
      result: undefined,
    }
  },
  methods: {
    postEvaluate () { // 调用接口
      api.postEvaluate(this.evaluateForm)
        .then((res) => {
          if (!res) {
            return
          }
          console.log("res", res)
          if (res.data.code !== 200) {
            this.$message({
              message: "请求失败",
              type: "error"
            })
            return
          }
          let data = res.data.data[0]
          this.result = data[0]
          console.log("this.result", this.result)
          this.$message({
            message: "预测成功!",
            type: "success"
          })

        })
        .catch((error) => {
          this.$message.error('资源获取错误!')
        })
    },
    submitForm (formName) { // 提交表单
      this.$refs[formName].validate((valid) => {
        if (valid) {
          this.postEvaluate()
        } else {
          this.$message({
            message: "请按要求填写",
            type: "warning"
          })
          console.log('error in submit form')
          return false
        }
      })
      document.getElementById("queryButton").blur()
    },
  }

}
</script>

<style lang="scss" scoped>
.recommend {
  width: 100%;
  height: 100%;
  text-align: center;
  display: flex;
  text-align: center;
  flex-direction: column;
  align-items: center;
  justify-content: center;
  overflow: hidden;
  background: #00416a 0 / cover fixed; /* fallback for old browsers */
  background: -webkit-linear-gradient(
    to right,
    #00416a,
    #e4e5e6
  ); /* Chrome 10-25, Safari 5.1-6 */
  background: linear-gradient(
    to right,
    #00416a,
    #e4e5e6
  ); /* W3C, IE 10+/ Edge, Firefox 16+, Chrome 26+, Opera 12+, Safari 7+ */
  .box {
    width: 48%;
    height: 60%;
    position: relative;
    background: hsla(0, 0%, 100%, 0.3);
    z-index: 5;
    padding: 10px 20px;
    // display: flex;
    // flex-direction: column;
    // justify-content: center;
    box-sizing: border-box;
    &::before {
      content: '';
      position: absolute;
      top: 0;
      right: 0;
      bottom: 0;
      left: 0;
      filter: blur(20px);
    }
    .title {
      color: #143b54;
    }
    .btn-container {
      margin: 10px auto;
      .btn {
        width: 100%;
        border-radius: 20px;
      }
    }
  }
}
::v-deep .el-card {
  border: 0;
  box-shadow: 0 5px 16px 0 rgb(0 0 0 / 30%);
}
::v-deep .el-progress-bar__outer {
  border: 0;
  background-color: transparent;
  // background-color: #abcbe0;
}
::v-deep .el-bg-inner-running .el-progress-bar__inner {
  background: #9cecfb; /* fallback for old browsers */
  background: -webkit-linear-gradient(
    to left,
    #0052d4,
    #65c7f7,
    #9cecfb
  ); /* Chrome 10-25, Safari 5.1-6 */
  background: linear-gradient(
    to left,
    #0052d4,
    #65c7f7,
    #9cecfb
  ); /* W3C, IE 10+/ Edge, Firefox 16+, Chrome 26+, Opera 12+, Safari 7+ */
}
</style>

预训练模型比较

paraphrase-multilingual-MiniLM-L12-v2
参数设置:epochs=1,batch_size=16
特点:作为sbert官方多语言预训练模型,已带有BERT层和池化层,可直接用数据评估,但未经纯中文文本训练,准确率较低

chinese-electra-180g-small-discriminator
参数设置:epochs=1, batch_size=16
特点:运行时间快,准确率尚可

chinese-electra-180g-small-discriminator
参数设置:epochs=20, batch_size=16
特点:20次迭代比1次迭代有效果,但差别不大

chinese-electra-180g-small-discriminator
参数设置:epochs=1,batch_size=8
特点:比batch_size=16时效果更好

chinese-roberta-wwm-ext
参数设置:epochs=1,batch_size=8
特点:迭代1次和20次准确率无差别,稳定且效果在所有模型中最好,缺点是体积大运行速度慢

最后

代码已上传至sbert中文文本相似度预测,欢迎star!

有关sentence-transformers(SBert)中文文本相似度预测(附代码)的更多相关文章

  1. ruby - 使用 ruby​​ 将 HTML 转换为纯文本并维护结构/格式 - 2

    我想将html转换为纯文本。不过,我不想只删除标签,我想智能地保留尽可能多的格式。为插入换行符标签,检测段落并格式化它们等。输入非常简单,通常是格式良好的html(不是整个文档,只是一堆内容,通常没有anchor或图像)。我可以将几个正则表达式放在一起,让我达到80%,但我认为可能有一些现有的解决方案更智能。 最佳答案 首先,不要尝试为此使用正则表达式。很有可能你会想出一个脆弱/脆弱的解决方案,它会随着HTML的变化而崩溃,或者很难管理和维护。您可以使用Nokogiri快速解析HTML并提取文本:require'nokogiri'h

  2. ruby - 如何在 buildr 项目中使用 Ruby 代码? - 2

    如何在buildr项目中使用Ruby?我在很多不同的项目中使用过Ruby、JRuby、Java和Clojure。我目前正在使用我的标准Ruby开发一个模拟应用程序,我想尝试使用Clojure后端(我确实喜欢功能代码)以及JRubygui和测试套件。我还可以看到在未来的不同项目中使用Scala作为后端。我想我要为我的项目尝试一下buildr(http://buildr.apache.org/),但我注意到buildr似乎没有设置为在项目中使用JRuby代码本身!这看起来有点傻,因为该工具旨在统一通用的JVM语言并且是在ruby中构建的。除了将输出的jar包含在一个独特的、仅限ruby​​

  3. ruby-on-rails - Rails 源代码 : initialize hash in a weird way? - 2

    在rails源中:https://github.com/rails/rails/blob/master/activesupport/lib/active_support/lazy_load_hooks.rb可以看到以下内容@load_hooks=Hash.new{|h,k|h[k]=[]}在IRB中,它只是初始化一个空哈希。和做有什么区别@load_hooks=Hash.new 最佳答案 查看rubydocumentationforHashnew→new_hashclicktotogglesourcenew(obj)→new_has

  4. ruby-on-rails - 浏览 Ruby 源代码 - 2

    我的主要目标是能够完全理解我正在使用的库/gem。我尝试在Github上从头到尾阅读源代码,但这真的很难。我认为更有趣、更温和的踏脚石就是在使用时阅读每个库/gem方法的源代码。例如,我想知道RubyonRails中的redirect_to方法是如何工作的:如何查找redirect_to方法的源代码?我知道在pry中我可以执行类似show-methodmethod的操作,但我如何才能对Rails框架中的方法执行此操作?您对我如何更好地理解Gem及其API有什么建议吗?仅仅阅读源代码似乎真的很难,尤其是对于框架。谢谢! 最佳答案 Ru

  5. ruby - 模块嵌套代码风格偏好 - 2

    我的假设是moduleAmoduleBendend和moduleA::Bend是一样的。我能够从thisblog找到解决方案,thisSOthread和andthisSOthread.为什么以及什么时候应该更喜欢紧凑语法A::B而不是另一个,因为它显然有一个缺点?我有一种直觉,它可能与性能有关,因为在更多命名空间中查找常量需要更多计算。但是我无法通过对普通类进行基准测试来验证这一点。 最佳答案 这两种写作方法经常被混淆。首先要说的是,据我所知,没有可衡量的性能差异。(在下面的书面示例中不断查找)最明显的区别,可能也是最著名的,是你的

  6. ruby - 寻找通过阅读代码确定编程语言的ruby gem? - 2

    几个月前,我读了一篇关于ruby​​gem的博客文章,它可以通过阅读代码本身来确定编程语言。对于我的生活,我不记得博客或gem的名称。谷歌搜索“ruby编程语言猜测”及其变体也无济于事。有人碰巧知道相关gem的名称吗? 最佳答案 是这个吗:http://github.com/chrislo/sourceclassifier/tree/master 关于ruby-寻找通过阅读代码确定编程语言的rubygem?,我们在StackOverflow上找到一个类似的问题:

  7. ruby - Net::HTTP 获取源代码和状态 - 2

    我目前正在使用以下方法获取页面的源代码:Net::HTTP.get(URI.parse(page.url))我还想获取HTTP状态,而无需发出第二个请求。有没有办法用另一种方法做到这一点?我一直在查看文档,但似乎找不到我要找的东西。 最佳答案 在我看来,除非您需要一些真正的低级访问或控制,否则最好使用Ruby的内置Open::URI模块:require'open-uri'io=open('http://www.example.org/')#=>#body=io.read[0,50]#=>"["200","OK"]io.base_ur

  8. 程序员如何提高代码能力? - 2

    前言作为一名程序员,自己的本质工作就是做程序开发,那么程序开发的时候最直接的体现就是代码,检验一个程序员技术水平的一个核心环节就是开发时候的代码能力。众所周知,程序开发的水平提升是一个循序渐进的过程,每一位程序员都是从“菜鸟”变成“大神”的,所以程序员在程序开发过程中的代码能力也是根据平时开发中的业务实践来积累和提升的。提高代码能力核心要素程序员要想提高自身代码能力,尤其是新晋程序员的代码能力有很大的提升空间的时候,需要针对性的去提高自己的代码能力。提高代码能力其实有几个比较关键的点,只要把握住这些方面,就能很好的、快速的提高自己的一部分代码能力。1、多去阅读开源项目,如有机会可以亲自参与开源

  9. 亚特兰蒂斯的回声(中文版): chatGPT 的杰作 - 2

    英文版英文链接关注公众号在“亚特兰蒂斯的回声”中踏上一段难忘的冒险之旅,深入未知的海洋深处。足智多谋的考古学家AriaSeaborne偶然发现了一件古代神器,揭示了一张通往失落之城亚特兰蒂斯的隐藏地图。在她神秘的导师内森·兰登教授的指导和勇敢的冒险家亚历克斯·默瑟的帮助下,阿丽亚开始了一段危险的旅程,以揭开这座传说中城市的真相。他们的冒险之旅带领他们穿越险恶的大海、神秘的岛屿和充满陷阱和谜语的致命迷宫。随着Aria潜在的魔法能力的觉醒,她被睿智勇敢的QueenNeria的幻象所指引,她让她为即将到来的挑战做好准备。三人组揭开亚特兰蒂斯令人惊叹的隐藏文明,并了解到邪恶的巫师马拉卡勋爵试图利用其古

  10. 7个大一C语言必学的程序 / C语言经典代码大全 - 2

    嗨~大家好,这里是可莉!今天给大家带来的是7个C语言的经典基础代码~那一起往下看下去把【程序一】打印100到200之间的素数#includeintmain(){ inti; for(i=100;i 【程序二】输出乘法口诀表#includeintmain(){inti;for(i=1;i 【程序三】判断1000年---2000年之间的闰年#includeintmain(){intyear;for(year=1000;year 【程序四】给定两个整形变量的值,将两个值的内容进行交换。这里提供两种方法来进行交换,第一种为创建临时变量来进行交换,第二种是不创建临时变量而直接进行交换。1.创建临时变量来

随机推荐