代码拉取完成,页面将自动刷新
import traceback
import numpy as np
import easynn as nn
import easynn_golden as golden
import easynn_cpp as cpp
def random_kwargs(kwargs):
return {k: np.random.random(shape) if shape != None else np.random.random() for k, shape in kwargs.items()}
def is_same(p, n, **kwargs):
e0 = p.compile(golden.Builder())
e1 = p.compile(cpp.Builder())
nkwargs = [random_kwargs(kwargs) for i in range(n)]
return all([np.allclose(e0(**nkwargs[i]), e1(**nkwargs[i])) for i in range(n)])
def grade_Q1():
pool = nn.MaxPool2d(3, 3)
x = pool(nn.Input2d("x", 12, 15, 3))
return is_same(x, 1, x = (10, 12, 15, 3))
def grade_Q2():
c = nn.Conv2d("c", 1, 1, 3)
x = c(nn.Input2d("x", 12, 15, 1))
x.resolve({
"c.weight": np.random.random((1, 1, 3, 3)),
"c.bias": np.random.random((1,))
})
return is_same(x, 1, x = (10, 12, 15, 1))
def grade_Q3():
c = nn.Conv2d("c", 3, 16, 5)
x = c(nn.Input2d("x", 15, 20, 3))
x.resolve({
"c.weight": np.random.random((16, 3, 5, 5)),
"c.bias": np.random.random((16,))
})
return is_same(x, 1, x = (10, 15, 20, 3))
def grade_Q4():
pool = nn.MaxPool2d(2, 2)
relu = nn.ReLU()
flatten = nn.Flatten()
x = nn.Input2d("images", 28, 28, 1)
c1 = nn.Conv2d("c1", 1, 8, 3) # 28->26
c2 = nn.Conv2d("c2", 8, 8, 3) # 26->24
x = pool(relu(c2(relu(c1(x))))) # 24->12
c3 = nn.Conv2d("c3", 8, 16, 3) # 12->10
c4 = nn.Conv2d("c4", 16, 16, 3) # 10->8
x = pool(relu(c4(relu(c3(x))))) # 8->4
f = nn.Linear("f", 16*4*4, 10)
x = f(flatten(x))
x.resolve(np.load("mnist_params.npz"))
mnist_test = np.load("mnist_test.npz")
images = mnist_test["images"][:20]
infer0 = x.compile(golden.Builder())
infer1 = x.compile(cpp.Builder())
logit0 = infer0(images = images)
logit1 = infer1(images = images)
return np.allclose(logit0, logit1)
def grade_Q5():
pool = nn.MaxPool2d(2, 2)
relu = nn.ReLU()
flatten = nn.Flatten()
x = nn.Input2d("images", 28, 28, 1)
c1 = nn.Conv2d("c1", 1, 8, 3) # 28->26
c2 = nn.Conv2d("c2", 8, 8, 3) # 26->24
x = pool(relu(c2(relu(c1(x))))) # 24->12
c3 = nn.Conv2d("c3", 8, 16, 3) # 12->10
c4 = nn.Conv2d("c4", 16, 16, 3) # 10->8
x = pool(relu(c4(relu(c3(x))))) # 8->4
f = nn.Linear("f", 16*4*4, 10)
x = f(flatten(x))
x.resolve(np.load("mnist_params.npz"))
mnist_test = np.load("mnist_test.npz")
images = mnist_test["images"][:1000]
infer0 = x.compile(golden.Builder())
infer1 = x.compile(cpp.Builder())
label0 = infer0(images = images).argmax(axis = 1)
label1 = infer1(images = images).argmax(axis = 1)
return np.allclose(label0, label1)
grade = 0
for q in range(1, 2):
func = globals()["grade_Q%d" % q]
try:
if func():
grade += 1
else:
print("============Q%d failed!============\n" % q)
except Exception as e:
print("============Q%d failed!============" % q)
print(traceback.format_exc())
print("Total questions passed: %d" % grade)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。