jjzjj

对 ChatGLM-6B 做 LoRA Fine-tuning

AI探险家 2023-11-17 原文

对 ChatGLM-6B 做 LoRA Fine-tuning

ChatGLM-6B 是一个支持中英双语的对话语言模型,基于 GLM (General Language Model)。它只有 62 亿个参数,量化后最低 (INT4 量化) 只需要 6GB 的显存,完全可以部署到消费级显卡上。在实际使用这个模型一段时间以后,我们发现模型的对话表现能力确实非常不错。那么,基于这个模型做 Fine-tuning 就非常有价值了。

声明:

本文提供的所有技术信息,都基于 THUDM/chatglm-6b 的历史版本:
096f3de6b4959ce38bef7bb05f3129c931a3084e

源码地址:

搭建依赖环境

安装 PyTorch 环境:

pip install torch torchvision torchaudio

按照 ChatGLM-6B 的官方指导,安装软件依赖环境:

pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels  

为了做 LoRA,还要安装 peft

pip install peft

加载模型和 Tokenizer

from transformers import AutoTokenizer, AutoModel

checkpoint = "THUDM/chatglm-6b"
revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"
model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)

分析模型结构

模型加载完后,我们可以打印这个 modeltokenizer,建立对模型的基本认知。

首先打印model

print(model)

得到如下结果:

ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (word_embeddings): Embedding(150528, 4096)
    (layers): ModuleList(
      (0-27): 28 x GLMBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attention): SelfAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
          (dense): Linear(in_features=4096, out_features=4096, bias=True)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): GLU(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
        )
      )
    )
    (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=4096, out_features=150528, bias=False)
)

简单分析这个模型结构,至少可以得到如下一些信息:

  • 模型使用了 Transformer 结构,因此可以使用 LoRA 进行 Fine-tuning
  • 从 Word Embedding 层可以看出,词汇表大小是 150528
  • LoRA 可以操作的目标是:query_key_value

再打印tokenizer:

print(tokenizer)

得到如下结果(为了便于阅读,已对结果做了分行处理):

ChatGLMTokenizer(
	name_or_path='THUDM/chatglm-6b', 
	vocab_size=150344, 
	model_max_length=2048, 
	is_fast=False, 
	padding_side='left', 
	truncation_side='right', 
	special_tokens={
		'bos_token': '<sop>', 
		'eos_token': '</s>', 
		'unk_token': '<unk>', 
		'pad_token': '<pad>', 
		'mask_token': '[MASK]'
	}
)

这里有几个可以关注的点:

  • 词汇表大小vocab_size150344
  • 不是一个 fast Tokenizer(is_fast 的值是 False
  • 特殊 token 包括:bos eos padmask

为什么 model 中的词汇表大小是 150528,而 tokenizer 中定义的词汇表大小却是 150344 呢?读者可以带着这个疑问去读一读模型项目的源码,看看能不能找到答案。

配置 LoRA

借助 peft 库,我们可以很方便地对模型注入 LoRA。

from peft import LoraConfig, get_peft_model, TaskType

def load_lora_config(model):
	config = LoraConfig(
	    task_type=TaskType.CAUSAL_LM, 
	    inference_mode=False,
	    r=8, 
	    lora_alpha=32, 
	    lora_dropout=0.1,
	    target_modules=["query_key_value"]
	)
	return get_peft_model(model, config)

model = load_lora_config(model)

打印可训练的参数量:

model.print_trainable_parameters()

得到如下结果:

trainable params: 3670016 || all params: 6258876416 || trainable%: 0.05863697820615348

可以看到,总的参数量是 6,258,876,416,可训练的参数量是 3,670,016,占比 0.0586% 左右。训练参数量只是百万级别的,可谓相当友好了!另外需要注意的一点是,ChatGLM-6B 是一个因果语言模型 (Causal Language Model),因此我们这里选择的任务类型是 CAUSAL_LM

构建数据集

定义常量

构建之前,我们先定义几个特殊 Token 常量:

bos = tokenizer.bos_token_id
eop = tokenizer.eop_token_id
pad = tokenizer.pad_token_id
mask = tokenizer.mask_token_id
gmask = tokenizer.sp_tokenizer[tokenizer.gMASK_token]

将这几个值打印出来:

print("bos = ", bos)
print("eop = ", eop)
print("pad = ", pad)
print("mask = ", mask)
print("gmask = ", gmask)

得到如下结果:

bos =  150004
eop =  150005
pad =  20003
mask =  150000
gmask =  150001

我们也可以直接用这个常量结果替换动态计算的部分。常量修改后的结果变成:

bos = 150004
eop = 150005
pad = 20003
mask = 150000
gmask = 150001

除了上面定义的 Token 常量,我们还需要定义模型训练绑定的设备名,以及最大输入长度和最大输出长度等,如下:

device = "cuda"
max_src_length = 200
max_dst_length = 500

开发者可以结合自己的显卡性能和要处理的数据集特点来确定这些最大长度。

测试 Tokenizer 的编解码

我们可以先做个简单的测试:

text = "AI探险家"
print(tokenizer.encode(text, add_special_tokens = True))
print(tokenizer.encode(text, add_special_tokens = False))

输出结果是:

[26738, 98715, 83920, 150001, 150004]
[26738, 98715, 83920]

从这个结果可以看出,“AI探险家”这几个字的裸编码是 [26738, 98715, 83920]。为什么是这样呢?我们可以对每一个数值再解码,看看输出结果:

print(tokenizer.decode([26738]))
print(tokenizer.decode([98715]))
print(tokenizer.decode([83920]))

输出结果是:

AI
探险
家

观察这个结果,读者应该能对词汇表建立基本的认知了。读者如果有兴趣,还可以分别针对 “A” “I” “探” “险” 这几个字分别编码,看看编码结果是什么。

另外,当 add_special_tokens = True 时,编码结果会在末尾添加 150001150004,也就是 gmaskbos。请注意,我们的训练数据,要按照如下编码要求进行构造:

[token, ..., token, gmask, bos, token, ... token, eop]

因此,前半部分文本的编码可以直接让 add_special_tokens = True,后半部分文本的编码则让 add_special_tokens = False,最后再拼接一个 eop

定义 Prompt

我们 Fine-tuning 的任务是问答任务(简称 QA),因此一个简单的 Prompt 是这样的:

PROMPT_PATTERN = "问:{}\n答: "

{}里填入 QA 训练集的问题文本。在显存有限的情况下,如果不对长文本做限制处理,很容易出现类似 CUDA out of memory 这样的报错。处理长文本,在给定编码后的数组上限时,可能存在这么几种方式:

  • 截断末尾超出部分的编码
  • 截断前面超出部分的编码
  • 丢掉训练样本

每一种方式都有各自的优劣,开发者可以根据自身数据的特点自行选择一种处理方式。当然,如果你的显存够大,也可以不处理。本文以上述第一种方式进行处理。
为了不把 PROMPT_PATTERN 中的 \n答: 这几个字截断掉,我们将整个 PROMPT_PATTERN 拆成两部分:

PROMPT_PATTERN = "问:{}"
SEP_PATTERN = "\n答: "

基于这份 Prompt 模板,我们定义下面三个辅助方法:

def create_prompt(question):
    return PROMPT_PATTERN.format(question), SEP_PATTERN


def create_prompt_ids(tokenizer, question, max_src_length):
    prompt, sep = create_prompt(question)
    sep_ids = tokenizer.encode(
        sep, 
        add_special_tokens = True
    )
    sep_len = len(sep_ids)
    special_tokens_num = 2
    prompt_ids = tokenizer.encode(
        prompt, 
        max_length = max_src_length - (sep_len - special_tokens_num),
        truncation = True,
        add_special_tokens = False
    )

    return prompt_ids + sep_ids


def create_inputs_and_labels(tokenizer, question, answer, device):
    prompt = create_prompt_ids(tokenizer, question, max_src_length)
    completion = tokenizer.encode(
        answer, 
        max_length = max_dst_length,
        truncation = True,
        add_special_tokens = False
    )

    inputs = prompt + completion + [eop]
    labels = [-100] * len(prompt) + completion + [eop] 
    
    inputs = torch.tensor(inputs, dtype=torch.long, device=device)
    labels = torch.tensor(labels, dtype=torch.long, device=device)
    return inputs, labels

值得注意的两点:

  • create_prompt_ids 这个函数实现可以看出,我们编码分隔符 SEP_PATTERN 时自动添加了前面所述的 2 个特殊 Token。
  • create_inputs_and_labels 的函数实现中,我们将 labels 无需处理的部分用数值 -100 来表示。因为 ChatGLMForConditionalGeneration 内部在计算损失函数的时候,用的是 torch.nn.CrossEntropyLoss。该函数的参数之一 ignore_index 默认值是 -100。这就让我们在计算损失函数时,无需考虑非标识部分的数值。

构建 Attention Mask 和 Position IDs

def get_attention_mask(tokenizer, input_ids, device):
    seq = input_ids.tolist()
    context_len = seq.index(bos)
    seq_len = len(seq)
    attention_mask = torch.ones((seq_len, seq_len), device=device)
    attention_mask.tril_()
    attention_mask[..., :context_len] = 1
    attention_mask.unsqueeze_(0)
    attention_mask = (attention_mask < 0.5).bool()
    return attention_mask


def get_position_ids(tokenizer, input_ids, device, position_encoding_2d=True):
    seq = input_ids.tolist()
    context_len = seq.index(bos)
    seq_len = len(seq)

    mask_token = mask if mask in seq else gmask
    use_gmask = False if mask in seq else gmask

    mask_position = seq.index(mask_token)

    if position_encoding_2d:
        position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
        if not use_gmask:
            position_ids[context_len:] = mask_position
        block_position_ids = torch.cat((
            torch.zeros(context_len, dtype=torch.long, device=device),
            torch.arange(seq_len - context_len, dtype=torch.long, device=device) + 1
        ))
        position_ids = torch.stack((position_ids, block_position_ids), dim=0)
    else:
        position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
        if not use_gmask:
            position_ids[context_len:] = mask_position
    
    return position_ids

在这个通用实现中,我们针对 maskgmask 两种情况做了区分,同时也对是否执行 position_encoding_2d 分情况处理。本文的 QA 任务采用的是 gmask,并且使用 position_encoding_2d = True

我们可以构建下面的问答,来验证下这几个函数的输出:

test_data = {
	"question": "AI探险家帅不帅?",
	"answer": "非常帅!"
}

inputs, labels = create_inputs_and_labels(tokenizer, **test_data, device=device)
attention_mask = get_attention_mask(tokenizer, inputs, device=device)
position_ids = get_position_ids(tokenizer, inputs, device=device)

print("inputs: \n", inputs.tolist())
print("\nlabels: \n", labels.tolist())
print("\nposition_ids: \n", position_ids.tolist())
print("\nattention_mask: \n", attention_mask.tolist())

输出结果(为了便于阅读,已对输出进行格式化操作):

inputs: 
 [20005, 84286, 20012, 31943, 98715, 83920, 87359, 83848, 87359, 20031, 20005, 20004, 87342, 20012, 150001, 150004, 20005, 84122, 87359, 20035, 150005]

labels: 
 [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 20005, 84122, 87359, 20035, 150005]

position_ids: 
 [
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  1,  2,  3,  4,  5]
 ]

attention_mask: 
 [[
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]]

结合论文观察数据,基本符合预期。

创建数据集

我们先定义具有如下格式的训练数据:

train_data = [
	{"question": "问题1", "answer": "答案1"},
	{"question": "问题2", "answer": "答案2"},
]

定义好格式后,我们先创建一个 QADataset 类,如下:

from torch.utils.data import Dataset

class QADataset(Dataset):
    def __init__(self, data, tokenizer) -> None:
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
 

    def __getitem__(self, index):
        item_data = self.data[index]
        tokenizer = self.tokenizer
        input_ids, labels = create_inputs_and_labels(
            tokenizer, 
            device=device,
            **item_data
        )
        
        attention_mask = get_attention_mask(tokenizer, input_ids, device)
        position_ids = get_position_ids(tokenizer, input_ids, device)

        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
            "position_ids": position_ids
        }
        

    def __len__(self):
        return len(self.data)

然后创建一个 Data Collator:

def collate_fn(batch):
    input_ids = []
    attention_mask = []
    labels = []
    position_ids = []
    
    for obj in batch:
        input_ids.append(obj['input_ids'])
        labels.append(obj['labels'])
        attention_mask.append(obj['attention_mask'])
        position_ids.append(obj['position_ids'])
        
    return {
        'input_ids': torch.stack(input_ids),
        'attention_mask': torch.stack(attention_mask), 
        'labels': torch.stack(labels),
        'position_ids':torch.stack(position_ids)
    }

开始训练

from transformers import TrainingArguments, Trainer
model.to(device)

training_args = TrainingArguments(
    "output",
    fp16 =True,
    save_steps = 500,
    save_total_limit = 3,
    gradient_accumulation_steps=1,
    per_device_train_batch_size = 1,
    learning_rate = 1e-4,
    max_steps=1500,
    logging_steps=50,
    remove_unused_columns=False,
    seed=0,
    data_seed=0,
    group_by_length=False,
    dataloader_pin_memory=False
)

class ModifiedTrainer(Trainer):

    def compute_loss(self, model, inputs, return_outputs=False):
        return model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            position_ids=inputs["position_ids"],
            labels=inputs["labels"],
        ).loss


train_dataset = QADataset(train_data, tokenizer=tokenizer)
trainer = ModifiedTrainer(
    model=model,
    train_dataset=train_dataset,
    args=training_args,
    data_collator=collate_fn,
    tokenizer=tokenizer
)

trainer.train()

预测

response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])
print(response)

保存训练模型

import os

def save_tuned_parameters(model, path):
    saved_params = {
        k: v.to(device)
        for k, v in model.named_parameters()
        if v.requires_grad
    }
    torch.save(saved_params, path)

save_tuned_parameters(model, os.path.join("/path/to/output", "chatglm-6b-lora.pt"))

重载训练后的模型

checkpoint = "THUDM/chatglm-6b"
revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"
model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)

model = load_lora_config(model)
model.load_state_dict(torch.load(f"/path/to/output/chatglm-6b-lora.pt"), strict=False)

model.half().cuda().eval()
response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])
print(response)

有关对 ChatGLM-6B 做 LoRA Fine-tuning的更多相关文章

  1. 类ChatGPT国产大模型ChatGLM-6B,单卡即可运行 - 2

    2023年3月14日GPT4又发布了,在ChatGPT发展如火如荼的当下,我们更应该关注国内的进展,今天将分享一个清华大学基于GLM-130B模型开发的类似ChatGPT的ChatGLM-6B模型,ChatGLM-6B是一个开源的、支持中英双语的对话语言模型,基于 GeneralLanguageModel(GLM) 架构,具有62亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4量化级别下最低只需6GB显存)。ChatGLM-6B使用了和ChatGPT相似的技术,针对中文问答和对话进行了优化。经过约1T标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的

  2. c++ - libjpeg 版本6b jpeg_stdio_src 与 jpeg_mem_src - 2

    我使用的是Libjpeg版本6b。在版本8中,他们有一个很好的函数可以从内存中读取数据,称为jpeg_mem_src(...),不幸的是。6b没有这个功能。我可以用什么直接从内存中读取压缩数据?我所看到的只是从硬盘读取的jpeg_stdio_src。 最佳答案 自己写.../*ReadJPEGimagefromamemorysegment*/staticvoidinit_source(j_decompress_ptrcinfo){}staticbooleanfill_input_buffer(j_decompress_ptrcinf

  3. 【ChatGLM 开发实战】ChatGLM 定制开发: CUDA 模型指定运行的 GPU 设备 - 2

      目录ChatGLM简介系统配置options.py设备获取 device.py模型初始化model.py运

  4. 论文学习——Tune-A-Video - 2

    Tune-A-Video:One-ShotTuningofImageDiffusionModelsforText-to-VideoGenerationAbstract本文提出了一种方法,站在巨人的肩膀上——在大规模图像数据集上pretrain并表现良好的texttoimage生成模型——加入新结构并进行微调,训练出一套oneshot的texttovideo生成器。这样做的优点在于利用已经非常成功、风格多样的图像扩散生成模型,在其基础上进行扩展,同时其训练时间很短,大大降低了训练开销。作为oneshot方法,tuneavideo还需要额外信息,一个文本-视频对儿作为demo。作者对于T2I(te

  5. 【ChatGPT】预训练模型微调及其应用(ChatGLM-6B、duckduckgo_search、GPT在科研的应用等) - 2

    noteinstructGPT(基于提示学习的系列模型)——>GPT3.5(大规模预训练语言模型)——>ChatGPT模型(高质量数据标注+反馈学习)。chatGPT三大技术:情景学习、思维链、自然指令学习。GPT4飞跃式提升:多模态、输入字符数量、推理能力、文本创造,如poem、解释图片含义、图表计算等,2022年8月完成训练。论文:https://cdn.openai.com/papers/gpt-4.pdfChatGPTPlus:集成GPT-4的ChatGPT升级版,https://chat.openai.com/chat可以利用chatGPT获取更高质量数据文章目录note一、预训练模

  6. 类ChatGPT逐行代码解读(1/2):从零起步实现Transformer、ChatGLM-6B - 2

    前言最近一直在做类ChatGPT项目的部署微调,关注比较多的是两个:一个LLaMA,一个ChatGLM,会发现有不少模型是基于这两个模型去做微调的,说到微调,那具体怎么微调呢,因此又详细了解了一下微调代码,发现微调LLM时一般都会用到Huggingface实现的Transformers库的Trainer类从而发现,如果大家想从零复现ChatGPT,便得从实现Transformer开始,因此便开启了本文:如何从零起步实现Transformer、ChatGLM(至于LLaMA已在之前的博客里解读过),主要分为两个大部分按照transformer的每一步的原理逐步逐行从零实现,先编码器后解码器,特别

  7. 使用 Docker 和 Alpaca LoRA 对 LLaMA 65B 大模型进行 Fine-Tune - 2

    这篇文章中,我们来聊聊如何使用两张显卡来进行LLaMA65B大模型的微调工作,以及如何在一张普通的4090家用显卡上,只花几个小时,就能够完成7B模型的微调。写在前面在之前的几篇文章里,我们介绍过三种方式运行Meta开源模型LLaMA的7B、13B版本:《模型杂谈:使用IN8量化推理运行Meta“开源泄露”的大模型(LLaMA)》《模型杂谈:快速上手元宇宙大厂Meta“开源泄露”的大模型(LLaMA)》不过,在之前的尝试中我们不难发现,如果没有我们“限定的数据”,模型效果其实不是特别好,尤其是相对小参数量的7B模型。同时,这也让我们对65B的模型更加充满了兴趣。当然,想要在极少量资源的显卡上完

  8. ChatGLM 本地部署的详细教程 - 2

    ChatGLM是一个基于GPT模型的开源聊天机器人框架,可以在本地部署和使用。以下是ChatGLM本地部署的详细教程:1.确认环境:ChatGLM需要在Linux系统上运行,需要安装Python3.6或更高版本、CUDA10.1或更高版本、cuDNN7或更高版本。确保系统已经安装了这些依赖项。2.下载代码:从ChatGLM的GitHub仓库(https://github.com/cooelf/ChatGLM)下载代码。3.安装依赖项:在代码目录下运行以下命令安装依赖项:```pipinstall-rrequirements.txt```4.下载预训练模型:ChatGLM使用预训练的GPT模型来

  9. 使用STM32驱动3WE6B61B/CG24N9Z5L电磁阀(一) - 2

    一.方案设计--光耦选型已经完成STM32单片机通过modbusrtu控制16路电磁阀,接下来进行电路设计。设计思路为单片机控制控制光耦,光耦进行放大和隔离后,驱动mos管输出24VPWM波形,进行驱动。1.1光耦简介光耦是隔离传输器件,原边给定信号,副边回路就会输出经过隔离的信号。对于光耦的隔离容易理解,此处不做讨论。以一个简单的图(图.1)说明光耦的工作:原边输入信号Vin,施加到原边的发光二极管和Ri上产生光耦的输入电流If,If驱动发光二极管,使得副边的光敏三极管导通,回路VCC、RL产生Ic,Ic经过RL产生Vout,达到传递信号的目的。原边副边直接的驱动关联是CTR(电流传输比),

  10. 对 ChatGLM-6B 做 LoRA Fine-tuning - 2

    对ChatGLM-6B做LoRAFine-tuning搭建依赖环境加载模型和Tokenizer分析模型结构配置LoRA构建数据集定义常量测试Tokenizer的编解码定义Prompt构建AttentionMask和PositionIDs创建数据集开始训练预测保存训练模型重载训练后的模型ChatGLM-6B是一个支持中英双语的对话语言模型,基于GLM(GeneralLanguageModel)。它只有62亿个参数,量化后最低(INT4量化)只需要6GB的显存,完全可以部署到消费级显卡上。在实际使用这个模型一段时间以后,我们发现模型的对话表现能力确实非常不错。那么,基于这个模型做Fine-tuni

随机推荐