Files
gemma3-finetuning/README.md
2025-09-03 00:20:29 +08:00

4.3 KiB
Raw Permalink Blame History

大语言模型微调框架

一个基于Unsloth和Transformers的大语言模型微调框架支持多种模型的全量微调和LoRA微调。

功能特性

  • 支持Gemma3、Qwen3等多种大语言模型
  • 支持全量微调(Full Fine-tuning)和LoRA微调两种模式
  • 统一的训练器接口,易于扩展新模型
  • 配置集中管理,便于调整训练参数
  • 自动应用模型特定的聊天模板
  • 支持模型保存和对话测试

安装依赖

pip install -r requirements.txt

项目结构

fine-tuning/
├── config.py          # 配置文件,集中管理模型和训练参数
├── data_transform.py   # 数据预处理和转换工具
├── main.py            # 主程序入口,演示如何使用训练器
├── train.py           # 核心训练器类定义
├── test.py            # 测试脚本
├── test_trainers.py   # 训练器测试
├── train_data.json    # 训练数据
├── README.md          # 项目说明文档
└── requirements.txt   # 依赖包列表

使用方法

1. 基础使用

from train import Gemma3ModelTrainer, Qwen3ModelTrainer
from data_transform import get_chat_data2

# 使用Gemma3模型训练器
gemma_trainer = Gemma3ModelTrainer("google/gemma-3-4b-it", "lora")
dataset = get_chat_data2()
gemma_trainer.set_dataset(dataset)
gemma_trainer.train()
gemma_trainer.chat("你好,介绍一下你自己")
gemma_trainer.save("gemma-3-4b-finetuned")

# 使用Qwen3模型训练器
qwen_trainer = Qwen3ModelTrainer("Qwen/Qwen3-600m", "lora")
qwen_trainer.set_dataset(dataset)
qwen_trainer.train()
qwen_trainer.chat("为什么使用了校验服务后提交到交易所系统还是报错?")
qwen_trainer.save("qwen3-600m-finetuned")

2. 训练模式选择

支持两种训练模式:

  • "lora" (默认): LoRA微调参数高效训练速度快
  • "full": 全量微调,所有参数都会更新
# LoRA微调
trainer = Gemma3ModelTrainer("google/gemma-3-4b-it", "lora")

# 全量微调  
trainer = Gemma3ModelTrainer("google/gemma-3-4b-it", "full")

3. 使用配置文件

所有配置现在集中在config.py文件中管理:

# 模型配置
MODEL_CONFIGS = {
    "gemma-3-270m-it": {
        "model_name": "unsloth/gemma-3-270m-it",
        "chat_template": "gemma-3",
        "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    },
    "qwen3-600m": {
        "model_name": "Qwen/Qwen3-600m",
        "chat_template": "qwen3",
        "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    }
}

# 训练配置
TRAINING_CONFIG = {
    "per_device_train_batch_size": 2,
    "gradient_accumulation_steps": 4,
    "warmup_steps": 5,
    "max_steps": 300,
    "learning_rate": 2e-4,
    "logging_steps": 10,
    "optim": "adamw_8bit",
    "lr_scheduler_type": "linear",
    "seed": 3407,
    "weight_decay": 0.01,
    "report_to": "none"
}

类结构

BaseModelTrainer (基类)

  • 抽象基类,定义通用接口
  • 包含训练配置和通用方法
  • 支持LoRA和全量微调

Gemma3ModelTrainer (Gemma3专用)

  • 使用Gemma3特定的聊天模板
  • 适配Gemma3的对话格式
  • 支持Gemma3模型的所有功能

Qwen3ModelTrainer (Qwen3专用)

  • 使用Qwen3特定的聊天模板
  • 适配Qwen3的对话格式
  • 支持Qwen3模型的所有功能

快速开始

1. 准备数据

确保train_data.json文件包含正确的训练数据格式:

[
  {
    "instruction": "用户问题",
    "output": "模型回答"
  }
]

2. 运行训练

python main.py

3. 使用自定义配置

config.py中可以修改以下配置:

  • MODEL_CONFIGS: 模型相关配置
  • TRAINING_CONFIG: 训练超参数
  • LORA_CONFIG: LoRA微调参数
  • GENERATION_CONFIG: 生成参数

扩展新模型

  1. train.py中创建新的训练器类,继承BaseModelTrainer
  2. 实现必要的抽象方法
  3. config.py中添加模型配置
  4. main.py中使用新的训练器

测试

运行测试脚本验证功能:

python test_trainers.py

注意事项

  1. 确保有足够的GPU内存进行训练
  2. 全量微调需要更多显存建议使用LoRA模式
  3. 模型路径需要正确设置
  4. 训练数据需要符合对话格式要求