ChatGLM3-6B-128K模型微调实战:使用自定义数据集训练专属AI助手

张开发
2026/4/11 14:21:08 15 分钟阅读

分享文章

ChatGLM3-6B-128K模型微调实战:使用自定义数据集训练专属AI助手
ChatGLM3-6B-128K模型微调实战使用自定义数据集训练专属AI助手想让AI助手真正理解你的业务场景通过微调ChatGLM3-6B-128K你可以打造一个真正懂你的专属智能助手。本文将手把手教你如何从数据准备到模型训练完整掌握微调全流程。1. 准备工作与环境搭建在开始微调之前我们需要先准备好训练环境和基础模型。ChatGLM3-6B-128K是专门为处理长文本设计的模型支持高达128K的上下文长度非常适合需要处理大量文本数据的场景。1.1 环境要求与依赖安装首先确保你的环境满足以下要求GPU内存建议至少24GB显存如RTX 4090或A100系统内存建议32GB以上Python版本3.8或更高版本CUDA版本11.7或更高安装必要的依赖包pip install torch2.0.1 transformers4.33.0 datasets2.14.0 pip install peft0.4.0 accelerate0.23.0 trl0.5.0 pip install sentencepiece protobuf tqdm1.2 获取基础模型你可以从Hugging Face下载ChatGLM3-6B-128K基础模型from transformers import AutoTokenizer, AutoModel model_name THUDM/chatglm3-6b-128k tokenizer AutoTokenizer.from_pretrained(model_name, trust_remote_codeTrue) model AutoModel.from_pretrained(model_name, trust_remote_codeTrue)2. 数据准备与格式化数据质量直接决定微调效果特别是对于长文本模型数据格式的正确性至关重要。2.1 数据格式要求ChatGLM3-6B-128K使用特定的对话格式每个样本应该包含多轮对话# 正确的数据格式示例 conversation [ { role: user, content: 请帮我分析这篇技术文档的主要观点 }, { role: assistant, content: 这篇文档主要讨论了人工智能在医疗领域的应用包括... } ]2.2 创建训练数据集假设我们正在创建一个技术文档分析助手下面是一个数据准备的例子import json from datasets import Dataset def create_training_data(): samples [] # 示例训练数据 training_examples [ { instruction: 分析这篇AI论文的核心贡献, input: 论文标题基于Transformer的医疗影像分析..., output: 该论文的主要贡献是提出了一个新型的注意力机制... }, { instruction: 总结这段技术文档的要点, input: 文档内容深度学习模型在自然语言处理中的应用..., output: 本文档重点介绍了BERT、GPT等模型在NLP任务中的表现... } ] for example in training_examples: # 转换为ChatGLM3的对话格式 conversation [ {role: user, content: f{example[instruction]}\n{example[input]}}, {role: assistant, content: example[output]} ] samples.append({conversations: conversation}) return Dataset.from_list(samples) # 创建数据集 dataset create_training_data() dataset.save_to_disk(./training_data)2.3 处理长文本数据ChatGLM3-6B-128K的优势在于处理长文本这里有一些处理技巧def process_long_text(text, max_length120000): 处理长文本数据确保不超过模型限制 # 计算token数量 tokens tokenizer.encode(text) if len(tokens) max_length: # 对于超长文本可以采取分段处理 chunks [tokens[i:imax_length] for i in range(0, len(tokens), max_length)] processed_texts [tokenizer.decode(chunk) for chunk in chunks] return processed_texts else: return [text] # 示例处理长文档 long_document 你的长文本内容... # 这里替换为实际的长文本 processed_chunks process_long_text(long_document)3. 模型微调配置3.1 训练参数设置from transformers import TrainingArguments training_args TrainingArguments( output_dir./chatglm3-finetuned, per_device_train_batch_size1, # 根据GPU内存调整 gradient_accumulation_steps8, num_train_epochs3, learning_rate2e-5, fp16True, logging_steps10, save_steps500, save_total_limit2, prediction_loss_onlyTrue, remove_unused_columnsFalse, )3.2 使用LoRA进行高效微调为了节省显存并提高训练效率我们使用LoRALow-Rank Adaptation技术from peft import LoraConfig, get_peft_model lora_config LoraConfig( r8, lora_alpha32, target_modules[query_key_value, dense, dense_h_to_4h, dense_4h_to_h], lora_dropout0.1, biasnone, task_typeCAUSAL_LM, ) # 应用LoRA到模型 model get_peft_model(model, lora_config) model.print_trainable_parameters()4. 开始训练4.1 准备训练器from transformers import DataCollatorForSeq2Seq # 数据整理器 data_collator DataCollatorForSeq2Seq( tokenizer, modelmodel, label_pad_token_id-100, pad_to_multiple_of8 ) # 创建训练器 from transformers import Trainer trainer Trainer( modelmodel, argstraining_args, train_datasetdataset, data_collatordata_collator, tokenizertokenizer, )4.2 启动训练# 开始训练 trainer.train() # 保存最终模型 trainer.save_model() tokenizer.save_pretrained(./chatglm3-finetuned)5. 模型评估与测试5.1 加载微调后的模型from peft import PeftModel # 加载微调后的模型 model AutoModel.from_pretrained( THUDM/chatglm3-6b-128k, trust_remote_codeTrue ) model PeftModel.from_pretrained(model, ./chatglm3-finetuned)5.2 测试模型效果def test_model(query): # 将模型设置为评估模式 model.eval() # 生成回复 with torch.no_grad(): response, history model.chat( tokenizer, query, history[], max_length4096, temperature0.7 ) return response # 测试示例 test_query 请分析这篇技术文档的创新点... response test_model(test_query) print(模型回复:, response)5.3 批量测试与评估def evaluate_model(test_dataset): results [] for example in test_dataset: prompt example[instruction] \n example[input] expected example[output] actual test_model(prompt) results.append({ prompt: prompt, expected: expected, actual: actual, match: expected.strip() actual.strip() }) accuracy sum(1 for r in results if r[match]) / len(results) print(f测试准确率: {accuracy:.2%}) return results # 运行评估 test_results evaluate_model(test_dataset)6. 实际应用与部署6.1 创建推理APIfrom fastapi import FastAPI from pydantic import BaseModel app FastAPI() class ChatRequest(BaseModel): message: str history: list [] app.post(/chat) async def chat_endpoint(request: ChatRequest): response, history model.chat( tokenizer, request.message, historyrequest.history ) return {response: response, history: history} # 启动服务 if __name__ __main__: import uvicorn uvicorn.run(app, host0.0.0.0, port8000)6.2 集成到现有系统class CustomChatGLM: def __init__(self, model_path): self.tokenizer AutoTokenizer.from_pretrained( model_path, trust_remote_codeTrue ) self.model AutoModel.from_pretrained( model_path, trust_remote_codeTrue ).half().cuda() # 使用半精度减少显存占用 def process_long_document(self, document, question): 处理长文档问答 # 结合文档内容和问题 prompt f根据以下文档内容回答问题\n{document}\n\n问题{question} response, _ self.model.chat( self.tokenizer, prompt, history[], max_length8192 # 根据需要调整 ) return response # 使用示例 chat_bot CustomChatGLM(./chatglm3-finetuned) answer chat_bot.process_long_document(long_document, 文档的主要观点是什么)7. 优化建议与注意事项7.1 性能优化技巧梯度累积在小批量大小下使用梯度累积来模拟大批量训练混合精度训练使用fp16减少显存占用并加速训练梯度检查点对于极长序列启用梯度检查点节省显存# 启用梯度检查点 model.gradient_checkpointing_enable()7.2 常见问题解决问题1显存不足解决方案减少批次大小增加梯度累积步数使用LoRA问题2过拟合解决方案增加数据集大小使用早停添加正则化问题3长文本处理效果不佳解决方案确保数据格式正确调整位置编码参数8. 总结通过本教程我们完整走过了ChatGLM3-6B-128K模型微调的全流程。从数据准备、模型配置到训练评估每个步骤都提供了具体的代码示例和实践建议。微调后的模型在特定领域表现会有显著提升特别是在处理长文本任务时ChatGLM3-6B-128K的128K上下文长度优势明显。在实际应用中你可以根据具体需求调整训练参数和数据格式获得更好的效果。记得在实际部署前充分测试模型性能特别是在处理边界情况和长文本时的表现。微调是一个迭代的过程可能需要多次尝试和调整才能获得最优结果。获取更多AI镜像想探索更多AI镜像和应用场景访问 CSDN星图镜像广场提供丰富的预置镜像覆盖大模型推理、图像生成、视频生成、模型微调等多个领域支持一键部署。

更多文章