代码拉取完成,页面将自动刷新
1、PyTorch 实现 CNN,以及一些使用CNN图片分类项目
1、CNN / 猫狗大战
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
machine-learn/ CNN / 猫狗大战 使用说明:
1. AverageMeter.py - 保存和更新准确率
2. CNNet.py - CNN网络Pytorch实现
3. dataset.py - (训练、验证、测试)数据集处理
4. detect.py - 使用模型进行图片分类
5. imageResize.py - 图片尺寸统一化(有助于优化训练模型精度)
6. main.py - 运行主函数及基本配置(需要自行配置自己计算机的路径)
7. model2.pth - 我的猫狗分类训练模型(仅供参考)
8. test.py - 模型验证代码并模型计算准确率
9. train.py - 模型训练函数
每个文件都有相应的注释,根据自己的需要进行更改
注:main.py 中需要对基本的数据集文件路径进行更改,替换为自己的数据集文件路径
--yourselves config--
basePath = "D:\\myInterestTest\\objectDetect\\data\\catVSdog" # 数据集文件父路径
dogFolderPath = basePath + os.sep + "train" + os.sep + "dog" # dog train images path
catFolderPath = basePath + os.sep + "train" + os.sep + "cat" # cat train images path
testImgPath = basePath + os.sep + "test" # 验证图片数据集
testPath = basePath + os.sep + "test.csv" # 验证图片类别csv文件
model_cp = './model2.pth' # 模型保存路径
tensorboard_path = 'D:\\myInterestTest\\objectDetect\\tensorBoard' # tensorboard 文件夹路径
dogAct = 0 # 狗的类别数字
catAct = 1 # 猫的类别数字
EPOCH = 20 # 训练轮数
workers = 10 # PyTorch读取数据线程数量
batch_size = 16 # 训练所抓取的数据样本数量
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。