4.3 KiB
4.3 KiB
大语言模型微调框架
一个基于Unsloth和Transformers的大语言模型微调框架,支持多种模型的全量微调和LoRA微调。
功能特性
- ✅ 支持Gemma3、Qwen3等多种大语言模型
- ✅ 支持全量微调(Full Fine-tuning)和LoRA微调两种模式
- ✅ 统一的训练器接口,易于扩展新模型
- ✅ 配置集中管理,便于调整训练参数
- ✅ 自动应用模型特定的聊天模板
- ✅ 支持模型保存和对话测试
安装依赖
pip install -r requirements.txt
项目结构
fine-tuning/
├── config.py # 配置文件,集中管理模型和训练参数
├── data_transform.py # 数据预处理和转换工具
├── main.py # 主程序入口,演示如何使用训练器
├── train.py # 核心训练器类定义
├── test.py # 测试脚本
├── test_trainers.py # 训练器测试
├── train_data.json # 训练数据
├── README.md # 项目说明文档
└── requirements.txt # 依赖包列表
使用方法
1. 基础使用
from train import Gemma3ModelTrainer, Qwen3ModelTrainer
from data_transform import get_chat_data2
# 使用Gemma3模型训练器
gemma_trainer = Gemma3ModelTrainer("google/gemma-3-4b-it", "lora")
dataset = get_chat_data2()
gemma_trainer.set_dataset(dataset)
gemma_trainer.train()
gemma_trainer.chat("你好,介绍一下你自己")
gemma_trainer.save("gemma-3-4b-finetuned")
# 使用Qwen3模型训练器
qwen_trainer = Qwen3ModelTrainer("Qwen/Qwen3-600m", "lora")
qwen_trainer.set_dataset(dataset)
qwen_trainer.train()
qwen_trainer.chat("为什么使用了校验服务后提交到交易所系统还是报错?")
qwen_trainer.save("qwen3-600m-finetuned")
2. 训练模式选择
支持两种训练模式:
"lora"
(默认): LoRA微调,参数高效,训练速度快"full"
: 全量微调,所有参数都会更新
# LoRA微调
trainer = Gemma3ModelTrainer("google/gemma-3-4b-it", "lora")
# 全量微调
trainer = Gemma3ModelTrainer("google/gemma-3-4b-it", "full")
3. 使用配置文件
所有配置现在集中在config.py
文件中管理:
# 模型配置
MODEL_CONFIGS = {
"gemma-3-270m-it": {
"model_name": "unsloth/gemma-3-270m-it",
"chat_template": "gemma-3",
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
},
"qwen3-600m": {
"model_name": "Qwen/Qwen3-600m",
"chat_template": "qwen3",
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
}
}
# 训练配置
TRAINING_CONFIG = {
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 4,
"warmup_steps": 5,
"max_steps": 300,
"learning_rate": 2e-4,
"logging_steps": 10,
"optim": "adamw_8bit",
"lr_scheduler_type": "linear",
"seed": 3407,
"weight_decay": 0.01,
"report_to": "none"
}
类结构
BaseModelTrainer (基类)
- 抽象基类,定义通用接口
- 包含训练配置和通用方法
- 支持LoRA和全量微调
Gemma3ModelTrainer (Gemma3专用)
- 使用Gemma3特定的聊天模板
- 适配Gemma3的对话格式
- 支持Gemma3模型的所有功能
Qwen3ModelTrainer (Qwen3专用)
- 使用Qwen3特定的聊天模板
- 适配Qwen3的对话格式
- 支持Qwen3模型的所有功能
快速开始
1. 准备数据
确保train_data.json
文件包含正确的训练数据格式:
[
{
"instruction": "用户问题",
"output": "模型回答"
}
]
2. 运行训练
python main.py
3. 使用自定义配置
在config.py
中可以修改以下配置:
MODEL_CONFIGS
: 模型相关配置TRAINING_CONFIG
: 训练超参数LORA_CONFIG
: LoRA微调参数GENERATION_CONFIG
: 生成参数
扩展新模型
- 在
train.py
中创建新的训练器类,继承BaseModelTrainer
- 实现必要的抽象方法
- 在
config.py
中添加模型配置 - 在
main.py
中使用新的训练器
测试
运行测试脚本验证功能:
python test_trainers.py
注意事项
- 确保有足够的GPU内存进行训练
- 全量微调需要更多显存,建议使用LoRA模式
- 模型路径需要正确设置
- 训练数据需要符合对话格式要求