Files
gemma3-finetuning/train/train.py
2025-09-03 00:20:29 +08:00

252 lines
8.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from abc import abstractmethod, ABCMeta
from dotenv import load_dotenv
load_dotenv()
import unsloth
from unsloth import FastModel, get_chat_template
from unsloth_zoo.dataset_utils import train_on_responses_only
from trl import SFTTrainer, SFTConfig
from transformers import TextStreamer
from datasets import Dataset
class BaseModelTrainer(metaclass=ABCMeta):
"""
基础模型微调训练器支持全量微调和LoRA微调两种模式
"""
# 训练配置参数
TRAINING_CONFIG = {
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 4,
"warmup_steps": 5,
"max_steps": 300,
"learning_rate": 2e-4,
"logging_steps": 10,
"optim": "adamw_8bit",
"lr_scheduler_type": "linear",
"seed": 3407,
"weight_decay": 0.01,
"report_to": "none"
}
def __init__(self, model_name, mode="lora"):
"""
初始化模型训练器
:param model_name: 预训练模型名称或路径
:param mode: 训练模式,可选"full"(全量微调)或"lora"(LoRA微调)
"""
self.model_name = model_name
self.model = None
self.tokenizer = None
self.trainer = None
self._full_finetuning = (mode == "full")
self._get_model()
def _get_model(self):
"""加载预训练模型并配置微调参数"""
model, tokenizer = FastModel.from_pretrained(
model_name=self.model_name,
max_seq_length=2048,
dtype=None,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=self._full_finetuning
)
if not self._full_finetuning:
self.model = self._init_lora_model(model)
else:
self.model = model
# 聊天模板由子类具体实现
self.tokenizer = tokenizer
@staticmethod
def _init_lora_model(model):
"""初始化LoRA微调模型"""
return FastModel.get_peft_model(
model,
r=128,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
use_gradient_checkpointing="unsloth",
lora_alpha=128,
lora_dropout=0,
bias="none",
random_state=3407,
use_rslor=False,
loftq_config=None
)
@abstractmethod
def apply_chat_template(self, examples):
"""将对话示例应用聊天模板并编码"""
return {"text": ""}
@abstractmethod
def set_dataset(self, data: Dataset):
pass
def _get_sft_config(self):
"""获取SFT训练配置"""
return SFTConfig(
dataset_text_field="text",
**self.TRAINING_CONFIG
)
def train(self):
"""开始模型训练"""
if not self.trainer:
raise ValueError("训练器未初始化请先调用set_dataset方法")
self.trainer.train()
def save(self, name: str = "gemma-3-finetuned", save_gguf=False):
"""
保存训练好的模型
:param name: 模型保存名称
:param save_gguf: 是否保存为GGUF格式
"""
if save_gguf:
self._save_gguf_format(name)
else:
if self._full_finetuning:
self.model.save_pretrained(name)
else:
self._save_lora_model(name)
def _save_gguf_format(self, name):
"""保存为GGUF格式"""
self.model.save_pretrained_gguf(f"{name}-gguf", self.tokenizer, quantization_method="q4_k_m")
def _save_lora_model(self, name):
"""保存LoRA模型"""
lora_name = f"{name}-lora"
self.model.save_pretrained(lora_name)
self.tokenizer.save_pretrained(lora_name)
self.model.save_pretrained_merged(name, self.tokenizer, save_method="merged_16bit")
def chat(self, message: str, max_new_tokens: int = 1024, temperature: float = 1.0, top_p: float = 0.95,
top_k: int = 3) -> str:
"""
与模型进行对话
:param message: 用户输入消息
:param max_new_tokens: 最大生成token数
:param temperature: 温度参数
:param top_p: top-p采样参数
:param top_k: top-k采样参数
:return: 模型生成的回复
"""
if not self.model or not self.tokenizer:
raise ValueError("模型或分词器未初始化")
# 由子类提供具体的对话模板
prompt = self._format_chat_prompt(message)
inputs = self.tokenizer([prompt], return_tensors="pt").to("cuda")
streamer = TextStreamer(self.tokenizer, skip_prompt=True)
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
streamer=streamer,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
return self.tokenizer.batch_decode(outputs)[0]
@abstractmethod
def _format_chat_prompt(self, message: str) -> str:
"""格式化聊天提示模板,由子类实现"""
pass
class Gemma3ModelTrainer(BaseModelTrainer):
def __init__(self, model_name, mode="lora"):
"""
初始化Gemma3模型训练器
:param model_name: 预训练模型名称或路径
:param mode: 训练模式,可选"full"(全量微调)或"lora"(LoRA微调)
"""
super().__init__(model_name, mode)
def _get_model(self):
"""加载预训练模型并配置聊天模板"""
super()._get_model()
self.tokenizer = get_chat_template(
self.tokenizer,
chat_template="gemma-3",
)
def apply_chat_template(self, examples):
"""将对话示例应用聊天模板并编码"""
conversations = examples["conversations"]
texts = self.tokenizer.apply_chat_template(conversations)
return {"text": self.tokenizer.decode(texts)}
def set_dataset(self, data: Dataset):
"""设置训练数据集并初始化训练器"""
_dataset = data.map(self.apply_chat_template)
sft_config = self._get_sft_config()
_trainer = SFTTrainer(
model=self.model,
train_dataset=_dataset,
args=sft_config
)
self.trainer = train_on_responses_only(
_trainer,
instruction_part="<start_of_turn>user\n",
response_part="<start_of_turn>model\n",
num_proc=1
)
def _format_chat_prompt(self, message: str) -> str:
"""格式化Gemma3聊天提示模板"""
return f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
class Qwen3ModelTrainer(BaseModelTrainer):
def __init__(self, model_name, mode="lora"):
"""
初始化Qwen3模型训练器
:param model_name: 预训练模型名称或路径
:param mode: 训练模式,可选"full"(全量微调)或"lora"(LoRA微调)
"""
super().__init__(model_name, mode)
def _get_model(self):
"""加载预训练模型并配置聊天模板"""
super()._get_model()
self.tokenizer = get_chat_template(
self.tokenizer,
chat_template="qwen3",
)
def apply_chat_template(self, examples):
"""将对话示例应用聊天模板并编码"""
conversations = examples["conversations"]
texts = self.tokenizer.apply_chat_template(conversations)
return {"text": self.tokenizer.decode(texts)}
def set_dataset(self, data: Dataset):
"""设置训练数据集并初始化训练器"""
_dataset = data.map(self.apply_chat_template)
sft_config = self._get_sft_config()
_trainer = SFTTrainer(
model=self.model,
train_dataset=_dataset,
args=sft_config
)
self.trainer = train_on_responses_only(
_trainer,
instruction_part="<|im_start|>user\n",
response_part="<|im_start|>assistant\n",
num_proc=1
)
def _format_chat_prompt(self, message: str) -> str:
"""格式化Qwen3聊天提示模板"""
return f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"