49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
from data.data_transform import get_chat_data2
|
||
from train.train import Gemma3ModelTrainer, Qwen3ModelTrainer
|
||
from config import MODEL_CONFIGS, TRAINING_CONFIG
|
||
|
||
|
||
def main():
|
||
"""主函数,演示如何使用训练器类"""
|
||
|
||
# 示例1: 使用Qwen3模型训练
|
||
print("开始训练Qwen3模型...")
|
||
qwen_trainer = Qwen3ModelTrainer(
|
||
model_name=MODEL_CONFIGS["qwen3-600m"]["uri"],
|
||
mode="lora"
|
||
)
|
||
|
||
dataset = get_chat_data2()
|
||
qwen_trainer.set_dataset(dataset)
|
||
qwen_trainer.train()
|
||
|
||
# 测试对话
|
||
response = qwen_trainer.chat("为什么使用了校验服务后提交到交易所系统还是报错?")
|
||
print(f"Qwen3模型回复: {response}")
|
||
|
||
# 保存模型
|
||
qwen_trainer.save("qwen3-600m-finetuned")
|
||
|
||
# 示例2: 使用Gemma3模型训练(注释状态)
|
||
"""
|
||
print("开始训练Gemma3模型...")
|
||
gemma_trainer = Gemma3ModelTrainer(
|
||
model_name=MODEL_CONFIGS["gemma-3-270m-it"]["model_name"],
|
||
mode="lora"
|
||
)
|
||
|
||
gemma_trainer.set_dataset(dataset)
|
||
gemma_trainer.train()
|
||
|
||
# 测试对话
|
||
response = gemma_trainer.chat("为什么使用了校验服务后提交到交易所系统还是报错?")
|
||
print(f"Gemma3模型回复: {response}")
|
||
|
||
# 保存模型
|
||
gemma_trainer.save("gemma-3-270m-finetuned")
|
||
"""
|
||
if __name__ == '__main__':
|
||
main() |