99 lines
3.5 KiB
Python
99 lines
3.5 KiB
Python
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
import unsloth
|
|
from transformers import TextStreamer
|
|
from datasets import Dataset
|
|
from unsloth import FastModel, get_chat_template
|
|
from trl import SFTTrainer, SFTConfig
|
|
|
|
|
|
|
|
|
|
|
|
class Gemma3ModelTrainer:
|
|
def __init__(self, model_name):
|
|
self.model_name = model_name
|
|
self.model = None
|
|
self.tokenizer = None
|
|
self.trainer = None
|
|
self._get_model()
|
|
|
|
def _get_model(self):
|
|
model, tokenizer = FastModel.from_pretrained(
|
|
model_name=self.model_name,
|
|
max_seq_length=2048,
|
|
dtype=None,
|
|
load_in_4bit=False,
|
|
full_finetuning=False)
|
|
|
|
self.model = FastModel.get_peft_model(
|
|
model,
|
|
r=128,
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
|
use_gradient_checkpointing="unsloth",
|
|
lora_alpha=128,
|
|
lora_dropout=0,
|
|
bias="none",
|
|
random_state=3407,
|
|
use_rslor=False,
|
|
loftq_config=None
|
|
)
|
|
self.tokenizer = get_chat_template(
|
|
tokenizer,
|
|
chat_template="gemma-3",
|
|
)
|
|
|
|
def apply_chat_template(self, examples):
|
|
conversations = examples["conversations"]
|
|
texts = self.tokenizer.apply_chat_template(conversations)
|
|
return {"text": self.tokenizer.decode(texts)}
|
|
|
|
def set_dataset(self, data: Dataset):
|
|
dataset = data.map(self.apply_chat_template)
|
|
self.trainer = SFTTrainer(model=self.model,
|
|
train_dataset=dataset,
|
|
args=SFTConfig(
|
|
dataset_text_field="text",
|
|
per_device_train_batch_size=2,
|
|
gradient_accumulation_steps=4,
|
|
warmup_steps=5,
|
|
max_steps=100,
|
|
learning_rate=5e-5,
|
|
logging_steps=10,
|
|
optim="adamw_8bit",
|
|
lr_scheduler_type="linear",
|
|
seed=3407,
|
|
weight_decay=0.01,
|
|
report_to="none"
|
|
)
|
|
)
|
|
|
|
def train(self):
|
|
self.trainer.train()
|
|
|
|
def save(self, name: str="gemma-3-finetuned", only_one=False):
|
|
if not only_one:
|
|
self.model.save_pretrained(f"{name}-lora") # 保存适配器权重
|
|
self.tokenizer.save_pretrained(f"{name}-lora") # 保存分词器
|
|
self.model.save_pretrained_merged(name, self.tokenizer, save_method="merged_16bit")
|
|
|
|
def chat(self, message):
|
|
outputs = self.model.generate(
|
|
**self.tokenizer(["<start_of_turn>user\n"+message+"<end_of_turn>\n<start_of_turn>model\n"], return_tensors="pt").to("cuda"),
|
|
max_new_tokens=1024,
|
|
streamer=TextStreamer(self.tokenizer, skip_prompt=True),
|
|
temperature=1.0, top_p=0.95, top_k=64,
|
|
)
|
|
|
|
return self.tokenizer.batch_decode(outputs)[0]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
trainer = Gemma3ModelTrainer("unsloth/gemma-3-270m-it")
|
|
from data_transform import get_chat_data
|
|
dataset = get_chat_data()
|
|
trainer.set_dataset(dataset)
|
|
trainer.train()
|
|
trainer.chat("介绍智能校验的离线特点")
|
|
trainer.save("gemma-3-270m-finetuned") |