252 lines
8.1 KiB
Python
252 lines
8.1 KiB
Python
from abc import abstractmethod, ABCMeta
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
|
||
import unsloth
|
||
from unsloth import FastModel, get_chat_template
|
||
from unsloth_zoo.dataset_utils import train_on_responses_only
|
||
from trl import SFTTrainer, SFTConfig
|
||
from transformers import TextStreamer
|
||
from datasets import Dataset
|
||
|
||
|
||
class BaseModelTrainer(metaclass=ABCMeta):
|
||
"""
|
||
基础模型微调训练器,支持全量微调和LoRA微调两种模式
|
||
"""
|
||
# 训练配置参数
|
||
TRAINING_CONFIG = {
|
||
"per_device_train_batch_size": 2,
|
||
"gradient_accumulation_steps": 4,
|
||
"warmup_steps": 5,
|
||
"max_steps": 300,
|
||
"learning_rate": 2e-4,
|
||
"logging_steps": 10,
|
||
"optim": "adamw_8bit",
|
||
"lr_scheduler_type": "linear",
|
||
"seed": 3407,
|
||
"weight_decay": 0.01,
|
||
"report_to": "none"
|
||
}
|
||
|
||
def __init__(self, model_name, mode="lora"):
|
||
"""
|
||
初始化模型训练器
|
||
:param model_name: 预训练模型名称或路径
|
||
:param mode: 训练模式,可选"full"(全量微调)或"lora"(LoRA微调)
|
||
"""
|
||
self.model_name = model_name
|
||
self.model = None
|
||
self.tokenizer = None
|
||
self.trainer = None
|
||
self._full_finetuning = (mode == "full")
|
||
self._get_model()
|
||
|
||
def _get_model(self):
|
||
"""加载预训练模型并配置微调参数"""
|
||
model, tokenizer = FastModel.from_pretrained(
|
||
model_name=self.model_name,
|
||
max_seq_length=2048,
|
||
dtype=None,
|
||
load_in_4bit=False,
|
||
load_in_8bit=False,
|
||
full_finetuning=self._full_finetuning
|
||
)
|
||
|
||
if not self._full_finetuning:
|
||
self.model = self._init_lora_model(model)
|
||
else:
|
||
self.model = model
|
||
|
||
# 聊天模板由子类具体实现
|
||
self.tokenizer = tokenizer
|
||
|
||
@staticmethod
|
||
def _init_lora_model(model):
|
||
"""初始化LoRA微调模型"""
|
||
return FastModel.get_peft_model(
|
||
model,
|
||
r=128,
|
||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||
use_gradient_checkpointing="unsloth",
|
||
lora_alpha=128,
|
||
lora_dropout=0,
|
||
bias="none",
|
||
random_state=3407,
|
||
use_rslor=False,
|
||
loftq_config=None
|
||
)
|
||
|
||
@abstractmethod
|
||
def apply_chat_template(self, examples):
|
||
"""将对话示例应用聊天模板并编码"""
|
||
return {"text": ""}
|
||
|
||
@abstractmethod
|
||
def set_dataset(self, data: Dataset):
|
||
pass
|
||
|
||
def _get_sft_config(self):
|
||
"""获取SFT训练配置"""
|
||
return SFTConfig(
|
||
dataset_text_field="text",
|
||
**self.TRAINING_CONFIG
|
||
)
|
||
|
||
def train(self):
|
||
"""开始模型训练"""
|
||
if not self.trainer:
|
||
raise ValueError("训练器未初始化,请先调用set_dataset方法")
|
||
self.trainer.train()
|
||
|
||
def save(self, name: str = "gemma-3-finetuned", save_gguf=False):
|
||
"""
|
||
保存训练好的模型
|
||
:param name: 模型保存名称
|
||
:param save_gguf: 是否保存为GGUF格式
|
||
"""
|
||
if save_gguf:
|
||
self._save_gguf_format(name)
|
||
else:
|
||
if self._full_finetuning:
|
||
self.model.save_pretrained(name)
|
||
else:
|
||
self._save_lora_model(name)
|
||
|
||
def _save_gguf_format(self, name):
|
||
"""保存为GGUF格式"""
|
||
self.model.save_pretrained_gguf(f"{name}-gguf", self.tokenizer, quantization_method="q4_k_m")
|
||
|
||
def _save_lora_model(self, name):
|
||
"""保存LoRA模型"""
|
||
lora_name = f"{name}-lora"
|
||
self.model.save_pretrained(lora_name)
|
||
self.tokenizer.save_pretrained(lora_name)
|
||
self.model.save_pretrained_merged(name, self.tokenizer, save_method="merged_16bit")
|
||
|
||
def chat(self, message: str, max_new_tokens: int = 1024, temperature: float = 1.0, top_p: float = 0.95,
|
||
top_k: int = 3) -> str:
|
||
"""
|
||
与模型进行对话
|
||
:param message: 用户输入消息
|
||
:param max_new_tokens: 最大生成token数
|
||
:param temperature: 温度参数
|
||
:param top_p: top-p采样参数
|
||
:param top_k: top-k采样参数
|
||
:return: 模型生成的回复
|
||
"""
|
||
if not self.model or not self.tokenizer:
|
||
raise ValueError("模型或分词器未初始化")
|
||
|
||
# 由子类提供具体的对话模板
|
||
prompt = self._format_chat_prompt(message)
|
||
|
||
inputs = self.tokenizer([prompt], return_tensors="pt").to("cuda")
|
||
streamer = TextStreamer(self.tokenizer, skip_prompt=True)
|
||
|
||
outputs = self.model.generate(
|
||
**inputs,
|
||
max_new_tokens=max_new_tokens,
|
||
streamer=streamer,
|
||
temperature=temperature,
|
||
top_p=top_p,
|
||
top_k=top_k,
|
||
)
|
||
return self.tokenizer.batch_decode(outputs)[0]
|
||
|
||
@abstractmethod
|
||
def _format_chat_prompt(self, message: str) -> str:
|
||
"""格式化聊天提示模板,由子类实现"""
|
||
pass
|
||
|
||
|
||
class Gemma3ModelTrainer(BaseModelTrainer):
|
||
|
||
def __init__(self, model_name, mode="lora"):
|
||
"""
|
||
初始化Gemma3模型训练器
|
||
:param model_name: 预训练模型名称或路径
|
||
:param mode: 训练模式,可选"full"(全量微调)或"lora"(LoRA微调)
|
||
"""
|
||
super().__init__(model_name, mode)
|
||
|
||
def _get_model(self):
|
||
"""加载预训练模型并配置聊天模板"""
|
||
super()._get_model()
|
||
self.tokenizer = get_chat_template(
|
||
self.tokenizer,
|
||
chat_template="gemma-3",
|
||
)
|
||
|
||
def apply_chat_template(self, examples):
|
||
"""将对话示例应用聊天模板并编码"""
|
||
conversations = examples["conversations"]
|
||
texts = self.tokenizer.apply_chat_template(conversations)
|
||
return {"text": self.tokenizer.decode(texts)}
|
||
|
||
def set_dataset(self, data: Dataset):
|
||
"""设置训练数据集并初始化训练器"""
|
||
_dataset = data.map(self.apply_chat_template)
|
||
sft_config = self._get_sft_config()
|
||
_trainer = SFTTrainer(
|
||
model=self.model,
|
||
train_dataset=_dataset,
|
||
args=sft_config
|
||
)
|
||
self.trainer = train_on_responses_only(
|
||
_trainer,
|
||
instruction_part="<start_of_turn>user\n",
|
||
response_part="<start_of_turn>model\n",
|
||
num_proc=1
|
||
)
|
||
|
||
def _format_chat_prompt(self, message: str) -> str:
|
||
"""格式化Gemma3聊天提示模板"""
|
||
return f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
|
||
|
||
|
||
class Qwen3ModelTrainer(BaseModelTrainer):
|
||
|
||
def __init__(self, model_name, mode="lora"):
|
||
"""
|
||
初始化Qwen3模型训练器
|
||
:param model_name: 预训练模型名称或路径
|
||
:param mode: 训练模式,可选"full"(全量微调)或"lora"(LoRA微调)
|
||
"""
|
||
super().__init__(model_name, mode)
|
||
|
||
def _get_model(self):
|
||
"""加载预训练模型并配置聊天模板"""
|
||
super()._get_model()
|
||
self.tokenizer = get_chat_template(
|
||
self.tokenizer,
|
||
chat_template="qwen3",
|
||
)
|
||
|
||
def apply_chat_template(self, examples):
|
||
"""将对话示例应用聊天模板并编码"""
|
||
conversations = examples["conversations"]
|
||
texts = self.tokenizer.apply_chat_template(conversations)
|
||
return {"text": self.tokenizer.decode(texts)}
|
||
|
||
def set_dataset(self, data: Dataset):
|
||
"""设置训练数据集并初始化训练器"""
|
||
_dataset = data.map(self.apply_chat_template)
|
||
sft_config = self._get_sft_config()
|
||
_trainer = SFTTrainer(
|
||
model=self.model,
|
||
train_dataset=_dataset,
|
||
args=sft_config
|
||
)
|
||
self.trainer = train_on_responses_only(
|
||
_trainer,
|
||
instruction_part="<|im_start|>user\n",
|
||
response_part="<|im_start|>assistant\n",
|
||
num_proc=1
|
||
)
|
||
|
||
def _format_chat_prompt(self, message: str) -> str:
|
||
"""格式化Qwen3聊天提示模板"""
|
||
return f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
|