81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
测试Gemma3和Qwen3模型训练器
|
|
"""
|
|
|
|
from train import Gemma3ModelTrainer, Qwen3ModelTrainer
|
|
from datasets import Dataset
|
|
|
|
def test_gemma3_trainer():
|
|
"""测试Gemma3模型训练器"""
|
|
print("测试Gemma3ModelTrainer...")
|
|
|
|
# 创建一个简单的测试数据集
|
|
test_data = {
|
|
"conversations": [
|
|
[
|
|
{"role": "user", "content": "你好"},
|
|
{"role": "assistant", "content": "你好!有什么可以帮助你的吗?"}
|
|
]
|
|
]
|
|
}
|
|
dataset = Dataset.from_dict(test_data)
|
|
|
|
try:
|
|
trainer = Gemma3ModelTrainer("google/gemma-3-4b-it", "lora")
|
|
trainer.set_dataset(dataset)
|
|
print("✓ Gemma3ModelTrainer 初始化成功")
|
|
|
|
# 测试聊天模板应用
|
|
result = trainer.apply_chat_template({"conversations": test_data["conversations"]})
|
|
print("✓ Gemma3聊天模板应用成功")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"✗ Gemma3ModelTrainer 测试失败: {e}")
|
|
return False
|
|
|
|
def test_qwen3_trainer():
|
|
"""测试Qwen3模型训练器"""
|
|
print("测试Qwen3ModelTrainer...")
|
|
|
|
# 创建一个简单的测试数据集
|
|
test_data = {
|
|
"conversations": [
|
|
[
|
|
{"role": "user", "content": "你好"},
|
|
{"role": "assistant", "content": "你好!有什么可以帮助你的吗?"}
|
|
]
|
|
]
|
|
}
|
|
dataset = Dataset.from_dict(test_data)
|
|
|
|
try:
|
|
trainer = Qwen3ModelTrainer("Qwen/Qwen3-600m", "lora")
|
|
trainer.set_dataset(dataset)
|
|
print("✓ Qwen3ModelTrainer 初始化成功")
|
|
|
|
# 测试聊天模板应用
|
|
result = trainer.apply_chat_template({"conversations": test_data["conversations"]})
|
|
print("✓ Qwen3聊天模板应用成功")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"✗ Qwen3ModelTrainer 测试失败: {e}")
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
print("开始测试模型训练器...\n")
|
|
|
|
gemma_success = test_gemma3_trainer()
|
|
print()
|
|
qwen_success = test_qwen3_trainer()
|
|
|
|
print(f"\n测试结果:")
|
|
print(f"Gemma3ModelTrainer: {'✓ 通过' if gemma_success else '✗ 失败'}")
|
|
print(f"Qwen3ModelTrainer: {'✓ 通过' if qwen_success else '✗ 失败'}")
|
|
|
|
if gemma_success and qwen_success:
|
|
print("\n🎉 所有测试通过!模型训练器类已正确实现。")
|
|
else:
|
|
print("\n❌ 部分测试失败,请检查代码。") |