from dotenv import load_dotenv load_dotenv() from data.data_transform import get_chat_data2 from train.train import Gemma3ModelTrainer, Qwen3ModelTrainer from config import MODEL_CONFIGS, TRAINING_CONFIG def main(): """主函数,演示如何使用训练器类""" # 示例1: 使用Qwen3模型训练 print("开始训练Qwen3模型...") qwen_trainer = Qwen3ModelTrainer( model_name=MODEL_CONFIGS["qwen3-600m"]["uri"], mode="lora" ) dataset = get_chat_data2() qwen_trainer.set_dataset(dataset) qwen_trainer.train() # 测试对话 response = qwen_trainer.chat("为什么使用了校验服务后提交到交易所系统还是报错?") print(f"Qwen3模型回复: {response}") # 保存模型 qwen_trainer.save("qwen3-600m-finetuned") # 示例2: 使用Gemma3模型训练(注释状态) """ print("开始训练Gemma3模型...") gemma_trainer = Gemma3ModelTrainer( model_name=MODEL_CONFIGS["gemma-3-270m-it"]["model_name"], mode="lora" ) gemma_trainer.set_dataset(dataset) gemma_trainer.train() # 测试对话 response = gemma_trainer.chat("为什么使用了校验服务后提交到交易所系统还是报错?") print(f"Gemma3模型回复: {response}") # 保存模型 gemma_trainer.save("gemma-3-270m-finetuned") """ if __name__ == '__main__': main()