jjzjj

论文复现丨基于ModelArts实现Text2SQL

华为云开发者社区 2023-03-28 原文
摘要:该论文提出了一种基于预训练 BERT 的新神经网络架构,称为 M-SQL。基于列的值提取分为值提取和值列匹配两个模块。

本文分享自华为云社区《基于ModelArts实现Text2SQL》,作者:HWCloudAI。

M-SQL: Multi-Task Representation Learning for Single-Table Text2sql Generation

虽然之前对 Text2SQL 的研究提供了一些可行的解决方案,但大多数都是基于列表示提取值。如果查询中有多个值,并且这些值属于不同的列,则以前基于列表示的方法无法准确提取值。该论文提出了一种基于预训练 BERT 的新神经网络架构,称为 M-SQL。基于列的值提取分为值提取和值列匹配两个模块。

论文地址:https://ieeexplore.ieee.org/document/9020099

具体算法介绍:https://marketplace.huaweicloud.com/markets/aihub/modelhub/detail/?id=4d1c0887-8ef0-4133-bd02-c4d69255377a

注意事项:

1.本案例使用框架:PyTorch1.4.0
2.本案例使用硬件:GPU: 1*NVIDIA-V100NV32(32GB) | CPU: 8 核 64GB
3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码
4.JupyterLab的详细用法: 请参考《ModelAtrs JupyterLab使用指导》
5.碰到问题的解决办法: 请参考《ModelAtrs JupyterLab常见问题解决办法》

1.下载代码和数据集

运行下面代码,进行数据和代码的下载和解压缩

使用TableQA数据集,数据位于m-sql/TableQA/中

import os
# 数据代码下载
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/m-sql.zip
# 解压缩
os.system('unzip m-sql.zip -d ./')
os.chdir('./m-sql')

2.训练

2.1安装依赖库

!pip install -r pip-requirements.txt

2.2训练所需参数和函数

import os
import argparse
import shutil
import sqlite3
import time
import tqdm
import torch
import random as python_random
from transformers import BertTokenizer, BertModel
import logging
import numpy as np
from model import Loss_sw_se, Seq2SQL_v1
# import moxing as mox
from sql_utils.utils_tableqa import load_tableqa, get_loader, get_fields, get_g, get_g_wvi, get_wemb_bert, \
    pred_sw_se, convert_pr_wvi_to_string, generate_sql_i, extract_val, normalize_sql, get_acc, get_acc_x, \
    save_for_evaluation, load_tableqa_data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def construct_hyper_param(parser):
    parser.add_argument("--eval", default='False', type=str)
    parser.add_argument("--no_save", default='False', type=str)
    parser.add_argument("--toy_model", default='False', type=str)
    parser.add_argument("--toy_size", default=16, type=int)
    parser.add_argument('--tepoch', default=1, type=int)
    parser.add_argument('--print_per_step', default=50, type=int)
    parser.add_argument("--bS", default=32, type=int,
 help="Batch size")
    parser.add_argument("--accumulate_gradients", default=1, type=int,
 help="The number of accumulation of backpropagation to effectivly increase the batch size.")
    parser.add_argument('--fine_tune',
                        default='False', type=str,
 help="If present, BERT is trained.")
    parser.add_argument("--data_url", default='./TableQA', type=str,
 help="Saving path of model file, logfile and result file.")
    parser.add_argument("--train_url", default='./data_and_model/', type=str,
 help="Saving path of model file, logfile and result file.")
    parser.add_argument("--vocab_file",
                        default='vocab.txt', type=str,
 help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument("--max_seq_length",
                        default=512, type=int,
 help="The maximum total input sequence length after WordPiece tokenization. Sequences ")
    parser.add_argument("--num_target_layers",
                        default=1, type=int,
 help="The Number of final layers of BERT to be used in downstream task.")
    parser.add_argument('--lr_bert', default=1e-5, type=float, help='BERT model learning rate.')
    parser.add_argument('--seed',
 type=int,
                        default=1,
 help="random seed for initialization")
    parser.add_argument('--do_lower_case', default='False', type=str, help='whether to use lower case.')
    parser.add_argument("--bert_url", default='./pre-trained_weights/chinese_wwm_ext_pytorch/', type=str,
 help="Path or model name of BERT")
    parser.add_argument("--load_weight", default='./trained_model/model/best_model.pth', type=str,
 help="model path to load")
    parser.add_argument('--dr', default=0, type=float, help="Dropout rate.")
    parser.add_argument('--lr', default=1e-3, type=float, help="Learning rate.")
    parser.add_argument('--num_warmup_steps', default=-1, type=int, help="num_warmup_steps")
    parser.add_argument("--split", default='val', type=str, help='prefix of jsonl and db files')
    args, _ = parser.parse_known_args()
    python_random.seed(args.seed)
 np.random.seed(args.seed)
 torch.manual_seed(args.seed)
 if torch.cuda.is_available():
 torch.cuda.manual_seed_all(args.seed)
    args.do_lower_case = args.do_lower_case == 'True'
 args.fine_tune = args.fine_tune == 'True'
    args.no_save = args.no_save == 'True'
 args.eval = args.eval == 'True'
    args.toy_model = args.toy_model == 'True'
 return args
def get_bert(bert_path):
    tokenizer = BertTokenizer.from_pretrained(bert_path)
    model_bert = BertModel.from_pretrained(bert_path)
    bert_config = model_bert.config
    model_bert.to(device)
 return model_bert, tokenizer, bert_config
def update_lr(param_groups, current_step, num_warmup_steps, start_lr):
 if current_step <= num_warmup_steps:
        warmup_frac_done = current_step / num_warmup_steps
        new_lr = start_lr * warmup_frac_done
 for param_group in param_groups:
            param_group['lr'] = new_lr
def get_opt(model, model_bert, fine_tune):
 if fine_tune:
        opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=args.lr, weight_decay=0)
        opt_bert = torch.optim.Adam(filter(lambda p: p.requires_grad, model_bert.parameters()),
                                    lr=args.lr_bert, weight_decay=0)
 else:
        opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=args.lr, weight_decay=0)
        opt_bert = None
 return opt, opt_bert
def get_models(args, logger, bert_model, trained=False, path_model=None, eval=False):
 # some constants
 if not eval:
 logger.info(f"Batch_size = {args.bS * args.accumulate_gradients}")
 logger.info(f"BERT parameters:")
 logger.info(f"learning rate: {args.lr_bert}")
 logger.info(f"Fine-tune BERT: {args.fine_tune}")
 # Get BERT
    model_bert, tokenizer, bert_config = get_bert(bert_model)
    iS = bert_config.hidden_size * args.num_target_layers
    logger.info(bert_config.to_json_string())
 # Get Seq-to-SQL
 if not eval:
 logger.info(f"Seq-to-SQL: the number of final BERT layers to be used: {args.num_target_layers}")
 logger.info(f"Seq-to-SQL: learning rate = {args.lr}")
    model = Seq2SQL_v1(iS, args.dr)
    model = model.to(device)
 if trained:
 assert path_model != None
 if torch.cuda.is_available():
            res = torch.load(path_model)
 else:
            res = torch.load(path_model, map_location='cpu')
        model_bert.load_state_dict(res['model_bert'])
        model_bert.to(device)
 model.load_state_dict(res['model'])
        model.to(device)
 return model, model_bert, tokenizer, bert_config
def get_data(path_wikisql, args):
    train_data, train_table, dev_data, dev_table = load_tableqa(path_wikisql, args.toy_model, args.toy_size,
                                                                no_hs_tok=True)
    train_loader, dev_loader = get_loader(train_data, dev_data, args.bS, shuffle_train=True)
 return train_data, train_table, dev_data, dev_table, train_loader, dev_loader
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer,
          max_seq_length, num_target_layers, accumulate_gradients, print_per_step, logger,
          current_step, st_pos=0, opt_bert=None):
 model.train()
    model_bert.train()
    torch.autograd.set_detect_anomaly(True)
    ave_loss = 0
    cnt = 0
 for iB, t in enumerate(train_loader):
        cnt += len(t)
 if cnt < st_pos:
 continue
 # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, train_table, no_hs_t=True, no_sql_t=True)
 # nlu  : natural language utterance
 # nlu_t: tokenized nlu
 # sql_i: canonical form of SQL query
 # sql_q: full SQL query text. Not used.
 # sql_t: tokenized SQL query
 # tb   : table
 # hs_t : tokenized headers. Not used.
        g_sn, g_sc, g_sa, g_wnop, g_wc, g_wo, g_wv = get_g(sql_i)
        g_wvi, g_tags, g_value_match = get_g_wvi(t, g_wc)
        wemb_cls, wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
 = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        l_n_t = []
 for t in t_to_tt_idx:
            l_n_t.append(len(t))
 # wemb_n: natural language embedding
 # wemb_h: header embedding
 # l_n: token lengths of each question
 # l_hpu: header token lengths
 # l_hs: the number of columns (headers) of the tables.
 # score
        s_sn, s_sc, s_sa, s_wnop, s_wc, \
        s_wo, s_tags, s_match = model(wemb_cls, wemb_n, l_n_t, wemb_h, l_hpu, l_hs,
                                      t_to_tt_idx=t_to_tt_idx,
                                      g_sn=g_sn, g_sc=g_sc, g_sa=g_sa, g_wo=g_wo,
                                      g_wnop=g_wnop, g_wc=g_wc, g_wvi=g_wvi,
                                      g_tags=g_tags, g_vm=g_value_match)
 # Calculate loss & step
        loss = Loss_sw_se(s_sn, s_sc, s_sa, s_wnop, s_wc, s_wo, s_tags, s_match,
                          g_sn, g_sc, g_sa, g_wnop, g_wc, g_wo, g_tags, g_value_match)
 if iB % accumulate_gradients == 0:
 opt.zero_grad()
 if opt_bert:
                opt_bert.zero_grad()
 loss.backward()
 if accumulate_gradients == 1:
                update_lr(opt.param_groups, current_step, args.num_warmup_steps, args.lr)
 opt.step()
 if opt_bert:
                    update_lr(opt_bert.param_groups, current_step, args.num_warmup_steps, args.lr_bert)
                    opt_bert.step()
                current_step += 1
 elif iB % accumulate_gradients == (accumulate_gradients - 1):
 loss.backward()
            update_lr(opt.param_groups, current_step, args.num_warmup_steps, args.lr)
 opt.step()
 if opt_bert:
                update_lr(opt_bert.param_groups, current_step, args.num_warmup_steps, args.lr_bert)
                opt_bert.step()
            current_step += 1
 else:
 loss.backward()
 # statistics
        ave_loss += loss.item()
 if iB % print_per_step == 0:
            log = f'[Train Batch {iB}] '
            logs = []
 logs.append(f'average loss: {"%.4f" % (ave_loss / cnt,)}')
 logger.info(log + ', '.join(logs))
 if iB == 150:
            logger.info('暂停训练,如需完整训练删除这个IF分支即可')
 break
    ave_loss /= cnt
 return ave_loss, current_step
def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length,
         num_target_layers, print_per_step, logger, path_db, st_pos=0):
 model.eval()
    model_bert.eval()
    cnt = 0
    cnt_sn = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wnop = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_lx = 0
    cnt_x = 0
    db_conn = sqlite3.connect(path_db)
    cursor = db_conn.cursor()
    results = []
 for iB, t in enumerate(data_loader):
        cnt += len(t)
 if cnt < st_pos:
 continue
 # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)
        wemb_cls, wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
 = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        l_n_t = []
 for t in t_to_tt_idx:
            l_n_t.append(len(t))
 # score
        s_sn, s_sc, s_sa, s_wnop, s_wc, \
        s_wo, s_tags, s_match = model(wemb_cls, wemb_n, l_n_t, wemb_h, l_hpu, l_hs, t_to_tt_idx)
 # prediction
        pr_sn, pr_sc, pr_sa, pr_wn, pr_conn_op, \
        pr_wc, pr_wo, pr_tags, pr_wvi = pred_sw_se(s_sn, s_sc, s_sa, s_wnop, s_wc, s_wo, s_tags, s_match, l_n_t)
        pr_wv_str = convert_pr_wvi_to_string(pr_wvi, nlu_t)
        pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_conn_op, pr_wc, pr_wo, pr_wv_str, nlu)
        value_indexes, value_nums = extract_val(pr_tags, l_n_t)
 # Saving for the official evaluation later.
 for b, pr_sql_i1 in enumerate(pr_sql_i):
            normalize_sql(pr_sql_i1, tb[b])
            results1 = {}
            results1["sql"] = pr_sql_i1
            results1["gold_sql"] = sql_i[b]
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results1['value_indexes'] = value_indexes[b]
            results1['value_nums'] = value_nums[b]
            results1['pr_wc'] = pr_wc[b]
            sn, sc, sa, co, wn, wc, wo, wv, cond, sql = \
                get_acc(sql_i[b], pr_sql_i1, pr_wc[b], pr_wo[b], tb[b], normalized=True)
            cnt_sn += sn
            cnt_sc += sc
            cnt_sa += sa
            cnt_wnop += (co and wn)
            cnt_wc += wc
            cnt_wo += wo
            cnt_wv += wv
            cnt_lx += sql
            results1['correct'] = sql
            execution, res = get_acc_x(sql_i[b], pr_sql_i1, tb[b], cursor)
            cnt_x += execution
            results1['ex_correct'] = execution
            results1['result'] = res
 results.append(results1)
 # print acc
        cnts = [cnt_sn, cnt_sc, cnt_sa, cnt_wnop, cnt_wc,
                cnt_wo, cnt_wv, cnt_lx, cnt_x, (cnt_lx + cnt_x) / 2]
        cnt_desc = [
 's-num', 's-col', 's-col-agg', 'w-num-op', 'w-col',
 'w-col-op', 'w-col-value', 'acc_lx', 'acc_x', 'acc_mx'
 ]
 if iB % print_per_step == 0:
            log = f'[Test Batch {iB}] '
            logs = []
 for k, metric in enumerate(cnts):
 logs.append(cnt_desc[k] + ': ' + '%.4f' % (metric / cnt,))
 logger.info(log + ', '.join(logs))
 if iB == 150:
            logger.info('暂停训练,如需完整训练删除这个IF分支即可')
 break
    acc_sn = cnt_sn / cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wnop = cnt_wnop / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt
    acc_mx = (acc_lx + acc_x) / 2
    acc = [acc_sn, acc_sc, acc_sa, acc_wnop, acc_wc,
           acc_wo, acc_wv, acc_lx, acc_x, acc_mx]
 return acc, results, acc_lx
def print_result(epoch, acc, dname, logger=None):
 if logger:
 logger.info(f'------------ {dname} results ------------')
 if dname == 'dev':
            acc_sn, acc_sc, acc_sa, acc_wnop, acc_wc, \
            acc_wo, acc_wv, acc_lx, acc_x, acc_mx = acc
 logger.info(
 f" Epoch: {epoch}, s-num: {acc_sn:.4f}, s-col: {acc_sc:.4f},"
 f" s-col-agg: {acc_sa:.4f}, w-num-op: {acc_wnop:.4f},"
 f" w-col: {acc_wc:.4f}, w-col-op: {acc_wo:.4f}, w-col-value: {acc_wv:.4f},"
 f" acc_lx: {acc_lx:.4f}, acc_x: {acc_x:.4f}, acc_mx: {acc_mx:.4f}"
 )
 else:
 logger.info(f" Epoch: {epoch}, average loss: {acc}")
def get_logger(log_fp=None):
 logging.basicConfig(level=logging.INFO,
 format='[%(asctime)s] %(message)s')
    logger = logging.getLogger(__name__)
 if log_fp:
        handler = logging.FileHandler(log_fp)
 handler.setLevel(logging.INFO)
        formatter = logging.Formatter('[%(asctime)s] %(message)s')
 handler.setFormatter(formatter)
 logger.addHandler(handler)
 return logger
def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer,
            max_seq_length, num_target_layers, path_db):
 model.eval()
    model_bert.eval()
    results = []
    cnt = 0
    cnt_sn = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wnop = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_lx = 0
    cnt_x = 0
    db_conn = sqlite3.connect(path_db)
    cursor = db_conn.cursor()
 for iB, t in tqdm.tqdm(enumerate(data_loader)):
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)
        wemb_cls, wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
 = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        l_n_t = []
 for t in t_to_tt_idx:
            l_n_t.append(len(t))
        s_sn, s_sc, s_sa, s_wnop, s_wc, \
        s_wo, s_tags, s_match = model(wemb_cls, wemb_n, l_n_t, wemb_h, l_hpu, l_hs, t_to_tt_idx)
 # prediction
        pr_sn, pr_sc, pr_sa, pr_wn, pr_conn_op, \
        pr_wc, pr_wo, pr_tags, pr_wvi = pred_sw_se(s_sn, s_sc, s_sa, s_wnop, s_wc, s_wo, s_tags, s_match, l_n_t)
        pr_wv_str = convert_pr_wvi_to_string(pr_wvi, nlu_t)
        pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_conn_op, pr_wc, pr_wo, pr_wv_str, nlu)
        value_indexes, value_nums = extract_val(pr_tags, l_n_t)
 for b, pr_sql_i1 in enumerate(pr_sql_i):
            cnt += 1
            results1 = {}
            normalize_sql(pr_sql_i1, tb[b])
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results1["sql"] = pr_sql_i1
 if sql_i[b]:
                results1["gold_sql"] = sql_i[b]
            results1['value_indexes'] = value_indexes[b]
            results1['value_nums'] = value_nums[b]
            results1['pr_wc'] = pr_wc[b]
 if sql_i[b]:
                sn, sc, sa, co, wn, wc, wo, wv, cond, sql =\
                    get_acc(sql_i[b], pr_sql_i1, pr_wc[b], pr_wo[b], tb[b], normalized=True)
                cnt_sn += sn
                cnt_sc += sc
                cnt_sa += sa
                cnt_wnop += (wn and co)
                cnt_wc += wc
                cnt_wo += wo
                cnt_wv += wv
                cnt_lx += sql
                results1['correct'] = sql
                execution, res = get_acc_x(sql_i[b], pr_sql_i1, tb[b], cursor)
                cnt_x += execution
                results1['ex_correct'] = execution
                results1['result'] = res
 results.append(results1)
    cnts = [cnt_sn, cnt_sc, cnt_sa, cnt_wnop, cnt_wc,
            cnt_wo, cnt_wv, cnt_lx, cnt_x, (cnt_x + cnt_lx) / 2]
 if sum(cnts) > 0:
        cnt_desc = [
 's-num', 's-col', 's-col-agg', 'w-num-op', 'w-col',
 'w-col-op', 'w-col-value', 'acc_lx', 'acc_x', 'acc_mx'
 ]
 logger.info('--------- eval result ---------')
 for k, metric in enumerate(cnts):
            logger.info(cnt_desc[k] + ': ' + '%.4f' % (metric / cnt,))
 else:
        cnts = None
        cnt_desc = None
 return results, cnt, cnts, cnt_desc

2.3开始训练

if __name__ == '__main__':
 # Hyper parameters
    parser = argparse.ArgumentParser()
    args = construct_hyper_param(parser)
    save_path = args.train_url
 if not os.path.exists(save_path):
 os.makedirs(save_path)
 if not args.eval:
        _model_path = './trained_model/model/'
 shutil.copytree(_model_path, os.path.join(save_path, 'model'))
    t = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    log_fp = os.path.join(save_path, f'{t}.log')
    logger = get_logger(log_fp)
 logger.info(f"BERT-Model: {args.bert_url}")
    trained = args.load_weight is not None and args.load_weight != 'None'
    load_path = None
 if trained:
        load_path = '/home/work/modelarts/inputs/best_model.pt'
 if args.load_weight and args.load_weight.startswith('obs://'):
 if not os.path.exists(load_path):
 mox.file.copy_parallel(args.load_weight, load_path)
 print('copy %s to %s' % (args.load_weight, load_path))
 else:
 print(load_path, 'already exists')
 else:
            load_path = args.load_weight
    train_input_dir = args.data_url
    bert_model = args.bert_url
 # Paths
    path_wikisql = train_input_dir
    path_val_db = os.path.join(train_input_dir, 'val.db')
    path_save_for_evaluation = save_path
 # Build & Load models
 if args.eval and not trained:
 print('in eval mode, "--load_weight" must be provided!')
 exit(-1)
 if not trained:
        model, model_bert, tokenizer, bert_config = get_models(args, logger, bert_model, eval=args.eval)
 else:
        path_model = load_path
        model, model_bert, tokenizer, bert_config = get_models(args, logger, bert_model,
                                                               trained=True, path_model=path_model,
 eval=args.eval)
 if not args.eval:
        train_data, train_table, dev_data, dev_table, train_loader, dev_loader = get_data(path_wikisql, args)
        opt, opt_bert = get_opt(model, model_bert, args.fine_tune)
        acc_lx_t_best = -1
        epoch_best = -1
        current_step = 1
 for epoch in range(args.tepoch):
 # train
 logger.info(f'Training Epoch {epoch}')
            ave_loss_train, current_step = train(train_loader,
                                                 train_table,
                                                 model,
                                                 model_bert,
                                                 opt,
                                                 bert_config,
                                                 tokenizer,
                                                 args.max_seq_length,
                                                 args.num_target_layers,
 args.accumulate_gradients,
                                                 args.print_per_step,
                                                 logger=logger,
                                                 current_step=current_step,
                                                 opt_bert=opt_bert,
                                                 st_pos=0)
 # check DEV
 with torch.no_grad():
 logger.info(f'Testing on dev Epoch {epoch}:')
                acc_dev, results_dev, \
                    dev_acc_lx = test(dev_loader,
                                      dev_table,
                                      model,
                                      model_bert,
                                      bert_config,
                                      tokenizer,
                                      args.max_seq_length,
                                      args.num_target_layers,
                                      args.print_per_step,
                                      logger=logger,
                                      path_db=path_val_db,
                                      st_pos=0)
            print_result(epoch, ave_loss_train, 'train', logger=logger)
            print_result(epoch, acc_dev, 'dev', logger=logger)
 # save results for the official evaluation
            path_save_file = save_for_evaluation(path_save_for_evaluation,
                                                 results_dev, 'dev', epoch=epoch)
 # mox.file.copy_parallel(path_save_file,
 #                        args.train_url + f'results_dev-{epoch}.jsonl')
 # save best model
 # Based on Dev Set logical accuracy lx
 if dev_acc_lx > acc_lx_t_best:
                acc_lx_t_best = dev_acc_lx
                epoch_best = epoch
 # save model
 if not args.no_save:
                    state = {'model': model.state_dict(),
 'model_bert': model_bert.state_dict()}
 torch.save(state, os.path.join(save_path, 'model', f'best_model.pth'))
 logger.info(f" Best Dev lx acc: {acc_lx_t_best} at epoch: {epoch_best}")
 else:
 try:
            dev_data, dev_table = load_tableqa_data(path_wikisql, mode=args.split, no_hs_tok=True)
 except Exception:
            logger.error('未找到输入文件!')
 exit(-1)
        dev_loader = torch.utils.data.DataLoader(
            batch_size=args.bS,
            dataset=dev_data,
            shuffle=False,
            num_workers=1,
            collate_fn=lambda x: x
 )
 with torch.no_grad():
            results, cnt, cnts, cnt_desc \
 = predict(dev_loader,
                              dev_table,
                              model,
                              model_bert,
                              bert_config,
                              tokenizer,
                              args.max_seq_length,
                              args.num_target_layers,
 os.path.join(train_input_dir, args.split + '.db'))
        save_for_evaluation(os.path.join(save_path, 'pred_results.jsonl'),
                            results, args.split, 'pred', use_filename=True)
 if cnts:
 with open(os.path.join(save_path, 'eval_result.txt'), 'w') as f_eval:
                f_eval.write('--------- eval result ---------\n')
 for k, metric in enumerate(cnts):
                    f_eval.write(cnt_desc[k] + ': ' + '%.4f' % (metric / cnt,) + '\n')

3.模型测试

from trained_model.model.customize_service import *
if __name__ == '__main__':
    model_path = r'./outputs/model/best_model.pth'
    my_model = ModelClass('', model_path)
    data = {
 "question": "近四周成交量小于3574套并且环比低于69.7%的城市有几个",
 "table_id": "252c7b6b302e11e995ee542696d6e445"
 }
    data = my_model._preprocess(data)
    result = my_model._inference(data)
 print(json.dumps(dict(result), ensure_ascii=False, indent=2))

 

点击关注,第一时间了解华为云新鲜技术~

有关论文复现丨基于ModelArts实现Text2SQL的更多相关文章

  1. ruby-on-rails - 使用 Sublime Text 3 突出显示 HTML 背景语法中的 ERB? - 2

    所以我在关注Railscast,我注意到在html.erb文件中,ruby代码有一个微弱的背景高亮效果,以区别于其他代码HTML文档。我知道Ryan使用TextMate。我正在使用SublimeText3。我怎样才能达到同样的效果?谢谢! 最佳答案 为SublimeText安装ERB包。假设您安装了SublimeText包管理器*,只需点击cmd+shift+P即可获得命令菜单,然后键入installpackage并选择PackageControl:InstallPackage获取包管理器菜单。在该菜单中,键入ERB并在看到包时选择

  2. ruby - 如何根据特征实现 FactoryGirl 的条件行为 - 2

    我有一个用户工厂。我希望默认情况下确认用户。但是鉴于unconfirmed特征,我不希望它们被确认。虽然我有一个基于实现细节而不是抽象的工作实现,但我想知道如何正确地做到这一点。factory:userdoafter(:create)do|user,evaluator|#unwantedimplementationdetailshereunlessFactoryGirl.factories[:user].defined_traits.map(&:name).include?(:unconfirmed)user.confirm!endendtrait:unconfirmeddoenden

  3. 叮咚买菜基于 Apache Doris 统一 OLAP 引擎的应用实践 - 2

    导读:随着叮咚买菜业务的发展,不同的业务场景对数据分析提出了不同的需求,他们希望引入一款实时OLAP数据库,构建一个灵活的多维实时查询和分析的平台,统一数据的接入和查询方案,解决各业务线对数据高效实时查询和精细化运营的需求。经过调研选型,最终引入ApacheDoris作为最终的OLAP分析引擎,Doris作为核心的OLAP引擎支持复杂地分析操作、提供多维的数据视图,在叮咚买菜数十个业务场景中广泛应用。作者|叮咚买菜资深数据工程师韩青叮咚买菜创立于2017年5月,是一家专注美好食物的创业公司。叮咚买菜专注吃的事业,为满足更多人“想吃什么”而努力,通过美好食材的供应、美好滋味的开发以及美食品牌的孵

  4. 华为OD机试用Python实现 -【明明的随机数】 2023Q1A - 2

    华为OD机试题本篇题目:明明的随机数题目输入描述输出描述:示例1输入输出说明代码编写思路最近更新的博客华为od2023|什么是华为od,od薪资待遇,od机试题清单华为OD机试真题大全,用Python解华为机试题|机试宝典【华为OD机试】全流程解析+经验分享,题型分享,防作弊指南华为o

  5. 基于C#实现简易绘图工具【100010177】 - 2

    C#实现简易绘图工具一.引言实验目的:通过制作窗体应用程序(C#画图软件),熟悉基本的窗体设计过程以及控件设计,事件处理等,熟悉使用C#的winform窗体进行绘图的基本步骤,对于面向对象编程有更加深刻的体会.Tutorial任务设计一个具有基本功能的画图软件**·包括简单的新建文件,保存,重新绘图等功能**·实现一些基本图形的绘制,包括铅笔和基本形状等,学习橡皮工具的创建**·设计一个合理舒适的UI界面**注明:你可能需要先了解一些关于winform窗体应用程序绘图的基本知识,以及关于GDI+类和结构的知识二.实验环境Windows系统下的visualstudio2017C#窗体应用程序三.

  6. MIMO-OFDM无线通信技术及MATLAB实现(1)无线信道:传播和衰落 - 2

     MIMO技术的优缺点优点通过下面三个增益来总体概括:阵列增益。阵列增益是指由于接收机通过对接收信号的相干合并而活得的平均SNR的提高。在发射机不知道信道信息的情况下,MIMO系统可以获得的阵列增益与接收天线数成正比复用增益。在采用空间复用方案的MIMO系统中,可以获得复用增益,即信道容量成倍增加。信道容量的增加与min(Nt,Nr)成正比分集增益。在采用空间分集方案的MIMO系统中,可以获得分集增益,即可靠性性能的改善。分集增益用独立衰落支路数来描述,即分集指数。在使用了空时编码的MIMO系统中,由于接收天线或发射天线之间的间距较远,可认为它们各自的大尺度衰落是相互独立的,因此分布式MIMO

  7. kvm虚拟机安装centos7基于ubuntu20.04系统 - 2

    需求:要创建虚拟机,就需要给他提供一个虚拟的磁盘,我们就在/opt目录下创建一个10G大小的raw格式的虚拟磁盘CentOS-7-x86_64.raw命令格式:qemu-imgcreate-f磁盘格式磁盘名称磁盘大小qemu-imgcreate-f磁盘格式-o?1.创建磁盘qemu-imgcreate-fraw/opt/CentOS-7-x86_64.raw10G执行效果#ls/opt/CentOS-7-x86_64.raw2.安装虚拟机使用virt-install命令,基于我们提供的系统镜像和虚拟磁盘来创建一个虚拟机,另外在创建虚拟机之前,提前打开vnc客户端,在创建虚拟机的时候,通过vnc

  8. 【Java入门】使用Java实现文件夹的遍历 - 2

    遍历文件夹我们通常是使用递归进行操作,这种方式比较简单,也比较容易理解。本文为大家介绍另一种不使用递归的方式,由于没有使用递归,只用到了循环和集合,所以效率更高一些!一、使用递归遍历文件夹整体思路1、使用File封装初始目录,2、打印这个目录3、获取这个目录下所有的子文件和子目录的数组。4、遍历这个数组,取出每个File对象4-1、如果File是否是一个文件,打印4-2、否则就是一个目录,递归调用代码实现publicclassSearchFile{publicstaticvoidmain(String[]args){//初始目录Filedir=newFile("d:/Dev");Datebeg

  9. Hive SQL 五大经典面试题 - 2

    目录第1题连续问题分析:解法:第2题分组问题分析:解法:第3题间隔连续问题分析:解法:第4题打折日期交叉问题分析:解法:第5题同时在线问题分析:解法:第1题连续问题如下数据为蚂蚁森林中用户领取的减少碳排放量iddtlowcarbon10012021-12-1212310022021-12-124510012021-12-134310012021-12-134510012021-12-132310022021-12-144510012021-12-1423010022021-12-154510012021-12-1523.......找出连续3天及以上减少碳排放量在100以上的用户分析:遇到这类

  10. sql - 查询忽略时间戳日期的时间范围 - 2

    我正在尝试查询我的Rails数据库(Postgres)中的购买表,我想查询时间范围。例如,我想知道在所有日期的下午2点到3点之间进行了多少次购买。此表中有一个created_at列,但我不知道如何在不搜索特定日期的情况下完成此操作。我试过:Purchases.where("created_atBETWEEN?and?",Time.now-1.hour,Time.now)但这最终只会搜索今天与那些时间的日期。 最佳答案 您需要使用PostgreSQL'sdate_part/extractfunction从created_at中提取小时

随机推荐