Files
gemma3-finetuning/test/test_trainers.py
2025-09-03 00:20:29 +08:00

81 lines
2.5 KiB
Python

#!/usr/bin/env python3
"""
测试Gemma3和Qwen3模型训练器
"""
from train import Gemma3ModelTrainer, Qwen3ModelTrainer
from datasets import Dataset
def test_gemma3_trainer():
"""测试Gemma3模型训练器"""
print("测试Gemma3ModelTrainer...")
# 创建一个简单的测试数据集
test_data = {
"conversations": [
[
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好!有什么可以帮助你的吗?"}
]
]
}
dataset = Dataset.from_dict(test_data)
try:
trainer = Gemma3ModelTrainer("google/gemma-3-4b-it", "lora")
trainer.set_dataset(dataset)
print("✓ Gemma3ModelTrainer 初始化成功")
# 测试聊天模板应用
result = trainer.apply_chat_template({"conversations": test_data["conversations"]})
print("✓ Gemma3聊天模板应用成功")
return True
except Exception as e:
print(f"✗ Gemma3ModelTrainer 测试失败: {e}")
return False
def test_qwen3_trainer():
"""测试Qwen3模型训练器"""
print("测试Qwen3ModelTrainer...")
# 创建一个简单的测试数据集
test_data = {
"conversations": [
[
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好!有什么可以帮助你的吗?"}
]
]
}
dataset = Dataset.from_dict(test_data)
try:
trainer = Qwen3ModelTrainer("Qwen/Qwen3-600m", "lora")
trainer.set_dataset(dataset)
print("✓ Qwen3ModelTrainer 初始化成功")
# 测试聊天模板应用
result = trainer.apply_chat_template({"conversations": test_data["conversations"]})
print("✓ Qwen3聊天模板应用成功")
return True
except Exception as e:
print(f"✗ Qwen3ModelTrainer 测试失败: {e}")
return False
if __name__ == "__main__":
print("开始测试模型训练器...\n")
gemma_success = test_gemma3_trainer()
print()
qwen_success = test_qwen3_trainer()
print(f"\n测试结果:")
print(f"Gemma3ModelTrainer: {'✓ 通过' if gemma_success else '✗ 失败'}")
print(f"Qwen3ModelTrainer: {'✓ 通过' if qwen_success else '✗ 失败'}")
if gemma_success and qwen_success:
print("\n🎉 所有测试通过!模型训练器类已正确实现。")
else:
print("\n❌ 部分测试失败,请检查代码。")