代码拉取完成,页面将自动刷新
import argparse
import torch
import torch.nn as nn
import torchsummary
class ResBlock(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(256, 256, (1, 3), padding=(0, 1))
self.conv2 = nn.Conv2d(256, 256, (1, 3), padding=(0, 1))
def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
x += y
return x
class ResNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(838, 256, (1, 3), padding=(0, 1))
self.blocks = nn.Sequential(*[ResBlock() for i in range(50)])
self.conv2 = nn.Conv2d(256, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.blocks(x)
x = self.conv2(x)
return x
if __name__ == '__main__':
# policy = ResNet()
# torchsummary.summary(policy, (838, 1, 34), device='cpu')
"""
================================================================
Total params: 20,330,497
Trainable params: 20,330,497
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.11
Forward/backward pass size (MB): 10.03
Params size (MB): 77.55
Estimated Total Size (MB): 87.69
----------------------------------------------------------------
"""
parser = argparse.ArgumentParser()
parser.add_argument('--max_epochs', type=int, default=12)
parser.add_argument('--tot_epochs', type=int, default=36)
parser.add_argument('--warm_up_steps', default=0, type=int)
parser.add_argument('--init_lr', default=1e-6, type=float)
parser.add_argument('--max_lr', default=3e-4, type=float)
parser.add_argument('--final_lr', default=1e-6, type=float)
args = parser.parse_args()
import numpy as np
def lr_step(epoch):
scale = pow(5, epoch // args.max_epochs)
epoch %= args.max_epochs
if args.warm_up_steps > 0 and epoch < args.warm_up_steps:
return args.init_lr + (args.max_lr - args.init_lr) / args.warm_up_steps * epoch / scale
if epoch < args.max_epochs:
cos_steps = epoch - args.warm_up_steps
cos_max_steps = args.max_epochs - args.warm_up_steps - 1
return (args.final_lr
+ 0.5 * (args.max_lr - args.final_lr) * (1 + np.cos(cos_steps / cos_max_steps * np.pi)) / scale)
return args.final_lr
for epoch in range(args.tot_epochs):
print(f'{epoch} - {1 + epoch // args.max_epochs}: {lr_step(epoch)}')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。