1 Star 15 Fork 6

Liereyy / Satori

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
test.py 2.47 KB
一键复制 编辑 原始数据 按行查看 历史
Liereyy 提交于 2024-05-17 01:51 . fix bugs
import argparse
import torch
import torch.nn as nn
import torchsummary
class ResBlock(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(256, 256, (1, 3), padding=(0, 1))
self.conv2 = nn.Conv2d(256, 256, (1, 3), padding=(0, 1))
def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
x += y
return x
class ResNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(838, 256, (1, 3), padding=(0, 1))
self.blocks = nn.Sequential(*[ResBlock() for i in range(50)])
self.conv2 = nn.Conv2d(256, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.blocks(x)
x = self.conv2(x)
return x
if __name__ == '__main__':
# policy = ResNet()
# torchsummary.summary(policy, (838, 1, 34), device='cpu')
"""
================================================================
Total params: 20,330,497
Trainable params: 20,330,497
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.11
Forward/backward pass size (MB): 10.03
Params size (MB): 77.55
Estimated Total Size (MB): 87.69
----------------------------------------------------------------
"""
parser = argparse.ArgumentParser()
parser.add_argument('--max_epochs', type=int, default=12)
parser.add_argument('--tot_epochs', type=int, default=36)
parser.add_argument('--warm_up_steps', default=0, type=int)
parser.add_argument('--init_lr', default=1e-6, type=float)
parser.add_argument('--max_lr', default=3e-4, type=float)
parser.add_argument('--final_lr', default=1e-6, type=float)
args = parser.parse_args()
import numpy as np
def lr_step(epoch):
scale = pow(5, epoch // args.max_epochs)
epoch %= args.max_epochs
if args.warm_up_steps > 0 and epoch < args.warm_up_steps:
return args.init_lr + (args.max_lr - args.init_lr) / args.warm_up_steps * epoch / scale
if epoch < args.max_epochs:
cos_steps = epoch - args.warm_up_steps
cos_max_steps = args.max_epochs - args.warm_up_steps - 1
return (args.final_lr
+ 0.5 * (args.max_lr - args.final_lr) * (1 + np.cos(cos_steps / cos_max_steps * np.pi)) / scale)
return args.final_lr
for epoch in range(args.tot_epochs):
print(f'{epoch} - {1 + epoch // args.max_epochs}: {lr_step(epoch)}')
Python
1
https://gitee.com/princesslaffey/satori.git
git@gitee.com:princesslaffey/satori.git
princesslaffey
satori
Satori
master

搜索帮助