代码拉取完成,页面将自动刷新
该教程以图像分类模型MobileNetV2为例,说明如何在cifar10数据集上快速使用网络结构搜索接口。 该示例包含以下步骤:
以下章节依次介绍每个步骤的内容。
请确认已正确安装Paddle,导入需要的依赖包。
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.static as static
import paddleslim as slim
import numpy as np
port = np.random.randint(8337, 8773)
sanas = slim.nas.SANAS(configs=[('MobileNetV2Space')], server_addr=("", port), save_checkpoint=None)
根据传入的网络结构构造训练program和测试program。
paddle.enable_static()
def build_program(archs):
train_program = static.Program()
startup_program = static.Program()
with static.program_guard(train_program, startup_program):
data = static.data(name='data', shape=[None, 3, 32, 32], dtype='float32')
label = static.data(name='label', shape=[None, 1], dtype='int64')
gt = paddle.reshape(label, [-1, 1])
output = archs(data)
output = static.nn.fc(output, size=10)
softmax_out = F.softmax(output)
cost = F.cross_entropy(softmax_out, label=gt)
avg_cost = paddle.mean(cost)
acc_top1 = paddle.metric.accuracy(input=softmax_out, label=gt, k=1)
acc_top5 = paddle.metric.accuracy(input=softmax_out, label=gt, k=5)
test_program = static.default_main_program().clone(for_test=True)
optimizer = paddle.optimizer.Adam(learning_rate=0.1)
optimizer.minimize(avg_cost)
place = paddle.CPUPlace()
exe = static.Executor(place)
exe.run(startup_program)
return exe, train_program, test_program, (data, label), avg_cost, acc_top1, acc_top5
为了快速执行该示例,我们使用的数据集为CIFAR10,Paddle框架的paddle.vision.datasets.Cifar10
包定义了CIFAR10数据的下载和读取。 代码如下:
import paddle.vision.transforms as T
def input_data(image, label):
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(mode="train", transform=transform, backend='cv2')
train_loader = paddle.io.DataLoader(train_dataset,
places=paddle.CPUPlace(),
feed_list=[image, label],
drop_last=True,
batch_size=64,
return_list=False,
shuffle=True)
eval_dataset = paddle.vision.datasets.Cifar10(mode="test", transform=transform, backend='cv2')
eval_loader = paddle.io.DataLoader(eval_dataset,
places=paddle.CPUPlace(),
feed_list=[image, label],
drop_last=False,
batch_size=64,
return_list=False,
shuffle=False)
return train_loader, eval_loader
根据训练program和训练数据进行训练。
def start_train(program, data_loader):
outputs = [avg_cost.name, acc_top1.name, acc_top5.name]
for data in data_loader():
batch_reward = exe.run(program, feed=data, fetch_list = outputs)
print("TRAIN: loss: {}, acc1: {}, acc5:{}".format(batch_reward[0], batch_reward[1], batch_reward[2]))
根据评估program和评估数据进行评估。
def start_eval(program, data_loader):
reward = []
outputs = [avg_cost.name, acc_top1.name, acc_top5.name]
for data in data_loader():
batch_reward = exe.run(program, feed=data, fetch_list = outputs)
reward_avg = np.mean(np.array(batch_reward), axis=1)
reward.append(reward_avg)
print("TEST: loss: {}, acc1: {}, acc5:{}".format(batch_reward[0], batch_reward[1], batch_reward[2]))
finally_reward = np.mean(np.array(reward), axis=0)
print("FINAL TEST: avg_cost: {}, acc1: {}, acc5: {}".format(finally_reward[0], finally_reward[1], finally_reward[2]))
return finally_reward
以下步骤拆解说明了如何获得当前模型结构以及获得当前模型结构之后应该有的步骤,如果想要看如何启动搜索实验的完整示例可以看步骤9。
调用next_archs()
函数获取到下一个模型结构。
archs = sanas.next_archs()[0]
调用步骤3中的函数,根据4.1中的模型结构构造相应的program。
exe, train_program, eval_program, (image, label), avg_cost, acc_top1, acc_top5 = build_program(archs)
train_loader, eval_loader = input_data(image, label)
根据上面得到的训练program和评估数据启动训练。
start_train(train_program, train_loader)
根据上面得到的评估program和评估数据启动评估。
finally_reward = start_eval(eval_program, eval_loader)
sanas.reward(float(finally_reward[1]))
以下是一个完整的搜索实验示例,示例中使用FLOPs作为约束条件,搜索实验一共搜索3个step,表示搜索到3个满足条件的模型结构进行训练,每搜索到一个网络结构训练7个epoch。
for step in range(3):
archs = sanas.next_archs()[0]
exe, train_program, eval_program, (images,label), avg_cost, acc_top1, acc_top5 = build_program(archs)
train_loader, eval_loader = input_data(images, label)
current_flops = slim.analysis.flops(train_program)
if current_flops > 6555276:
continue
for epoch in range(7):
start_train(train_program, train_loader)
finally_reward = start_eval(eval_program, eval_loader)
sanas.reward(float(finally_reward[1]))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。