Files
gemma3-finetuning/main.py
2025-09-03 00:20:29 +08:00

49 lines
1.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dotenv import load_dotenv
load_dotenv()
from data.data_transform import get_chat_data2
from train.train import Gemma3ModelTrainer, Qwen3ModelTrainer
from config import MODEL_CONFIGS, TRAINING_CONFIG
def main():
"""主函数,演示如何使用训练器类"""
# 示例1: 使用Qwen3模型训练
print("开始训练Qwen3模型...")
qwen_trainer = Qwen3ModelTrainer(
model_name=MODEL_CONFIGS["qwen3-600m"]["uri"],
mode="lora"
)
dataset = get_chat_data2()
qwen_trainer.set_dataset(dataset)
qwen_trainer.train()
# 测试对话
response = qwen_trainer.chat("为什么使用了校验服务后提交到交易所系统还是报错?")
print(f"Qwen3模型回复: {response}")
# 保存模型
qwen_trainer.save("qwen3-600m-finetuned")
# 示例2: 使用Gemma3模型训练注释状态
"""
print("开始训练Gemma3模型...")
gemma_trainer = Gemma3ModelTrainer(
model_name=MODEL_CONFIGS["gemma-3-270m-it"]["model_name"],
mode="lora"
)
gemma_trainer.set_dataset(dataset)
gemma_trainer.train()
# 测试对话
response = gemma_trainer.chat("为什么使用了校验服务后提交到交易所系统还是报错?")
print(f"Gemma3模型回复: {response}")
# 保存模型
gemma_trainer.save("gemma-3-270m-finetuned")
"""
if __name__ == '__main__':
main()