Files
gemma3-finetuning/train.py
lychang 560097cb27 init
2025-08-26 09:21:48 +08:00

99 lines
3.5 KiB
Python

from dotenv import load_dotenv
load_dotenv()
import unsloth
from transformers import TextStreamer
from datasets import Dataset
from unsloth import FastModel, get_chat_template
from trl import SFTTrainer, SFTConfig
class Gemma3ModelTrainer:
def __init__(self, model_name):
self.model_name = model_name
self.model = None
self.tokenizer = None
self.trainer = None
self._get_model()
def _get_model(self):
model, tokenizer = FastModel.from_pretrained(
model_name=self.model_name,
max_seq_length=2048,
dtype=None,
load_in_4bit=False,
full_finetuning=False)
self.model = FastModel.get_peft_model(
model,
r=128,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
use_gradient_checkpointing="unsloth",
lora_alpha=128,
lora_dropout=0,
bias="none",
random_state=3407,
use_rslor=False,
loftq_config=None
)
self.tokenizer = get_chat_template(
tokenizer,
chat_template="gemma-3",
)
def apply_chat_template(self, examples):
conversations = examples["conversations"]
texts = self.tokenizer.apply_chat_template(conversations)
return {"text": self.tokenizer.decode(texts)}
def set_dataset(self, data: Dataset):
dataset = data.map(self.apply_chat_template)
self.trainer = SFTTrainer(model=self.model,
train_dataset=dataset,
args=SFTConfig(
dataset_text_field="text",
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=100,
learning_rate=5e-5,
logging_steps=10,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
weight_decay=0.01,
report_to="none"
)
)
def train(self):
self.trainer.train()
def save(self, name: str="gemma-3-finetuned", only_one=False):
if not only_one:
self.model.save_pretrained(f"{name}-lora") # 保存适配器权重
self.tokenizer.save_pretrained(f"{name}-lora") # 保存分词器
self.model.save_pretrained_merged(name, self.tokenizer, save_method="merged_16bit")
def chat(self, message):
outputs = self.model.generate(
**self.tokenizer(["<start_of_turn>user\n"+message+"<end_of_turn>\n<start_of_turn>model\n"], return_tensors="pt").to("cuda"),
max_new_tokens=1024,
streamer=TextStreamer(self.tokenizer, skip_prompt=True),
temperature=1.0, top_p=0.95, top_k=64,
)
return self.tokenizer.batch_decode(outputs)[0]
if __name__ == '__main__':
trainer = Gemma3ModelTrainer("unsloth/gemma-3-270m-it")
from data_transform import get_chat_data
dataset = get_chat_data()
trainer.set_dataset(dataset)
trainer.train()
trainer.chat("介绍智能校验的离线特点")
trainer.save("gemma-3-270m-finetuned")