1 Star 2 Fork 3

phoneProject / Background-Matting

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
networks.py 8.79 KB
一键复制 编辑 原始数据 按行查看 历史
Soumyadip Sengupta 提交于 2020-03-25 17:40 . Add files via upload
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
class ResnetConditionHR(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, nf_part=64,norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks1=7, n_blocks2=3, padding_type='reflect'):
assert(n_blocks1 >= 0); assert(n_blocks2 >= 0)
super(ResnetConditionHR, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
use_bias=True
#main encoder output 256xW/4xH/4
model_enc1 = [nn.ReflectionPad2d(3),nn.Conv2d(input_nc[0], ngf, kernel_size=7, padding=0,bias=use_bias),norm_layer(ngf),nn.ReLU(True)]
model_enc1 += [nn.Conv2d(ngf , ngf * 2, kernel_size=3,stride=2, padding=1, bias=use_bias),norm_layer(ngf * 2),nn.ReLU(True)]
model_enc2 = [nn.Conv2d(ngf*2 , ngf * 4, kernel_size=3,stride=2, padding=1, bias=use_bias),norm_layer(ngf * 4),nn.ReLU(True)]
#back encoder output 256xW/4xH/4
model_enc_back = [nn.ReflectionPad2d(3),nn.Conv2d(input_nc[1], ngf, kernel_size=7, padding=0,bias=use_bias),norm_layer(ngf),nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model_enc_back += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,stride=2, padding=1, bias=use_bias),norm_layer(ngf * mult * 2),nn.ReLU(True)]
#seg encoder output 256xW/4xH/4
model_enc_seg = [nn.ReflectionPad2d(3),nn.Conv2d(input_nc[2], ngf, kernel_size=7, padding=0,bias=use_bias),norm_layer(ngf),nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model_enc_seg += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,stride=2, padding=1, bias=use_bias),norm_layer(ngf * mult * 2),nn.ReLU(True)]
mult = 2**n_downsampling
# #motion encoder output 256xW/4xH/4
model_enc_multi = [nn.ReflectionPad2d(3),nn.Conv2d(input_nc[3], ngf, kernel_size=7, padding=0,bias=use_bias),norm_layer(ngf),nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model_enc_multi += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,stride=2, padding=1, bias=use_bias),norm_layer(ngf * mult * 2),nn.ReLU(True)]
self.model_enc1 = nn.Sequential(*model_enc1)
self.model_enc2 = nn.Sequential(*model_enc2)
self.model_enc_back = nn.Sequential(*model_enc_back)
self.model_enc_seg = nn.Sequential(*model_enc_seg)
self.model_enc_multi = nn.Sequential(*model_enc_multi)
mult = 2**n_downsampling
self.comb_back=nn.Sequential(nn.Conv2d(ngf * mult*2,nf_part,kernel_size=1,stride=1,padding=0,bias=False),norm_layer(ngf),nn.ReLU(True))
self.comb_seg=nn.Sequential(nn.Conv2d(ngf * mult*2,nf_part,kernel_size=1,stride=1,padding=0,bias=False),norm_layer(ngf),nn.ReLU(True))
self.comb_multi=nn.Sequential(nn.Conv2d(ngf * mult*2,nf_part,kernel_size=1,stride=1,padding=0,bias=False),norm_layer(ngf),nn.ReLU(True))
#decoder
model_res_dec=[nn.Conv2d(ngf * mult +3*nf_part,ngf*mult,kernel_size=1,stride=1,padding=0,bias=False),norm_layer(ngf*mult),nn.ReLU(True)]
for i in range(n_blocks1):
model_res_dec += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
model_res_dec_al=[]
for i in range(n_blocks2):
model_res_dec_al += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
model_res_dec_fg=[]
for i in range(n_blocks2):
model_res_dec_fg += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
model_dec_al=[]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
#model_dec_al += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=2,padding=1, output_padding=1,bias=use_bias),norm_layer(int(ngf * mult / 2)),nn.ReLU(True)]
model_dec_al += [nn.Upsample(scale_factor=2,mode='bilinear',align_corners = True),nn.Conv2d(ngf * mult, int(ngf * mult / 2), 3, stride=1,padding=1),norm_layer(int(ngf * mult / 2)),nn.ReLU(True)]
model_dec_al += [nn.ReflectionPad2d(3),nn.Conv2d(ngf, 1, kernel_size=7, padding=0),nn.Tanh()]
model_dec_fg1=[nn.Upsample(scale_factor=2,mode='bilinear',align_corners = True),nn.Conv2d(ngf * 4, int(ngf * 2), 3, stride=1,padding=1),norm_layer(int(ngf * 2)),nn.ReLU(True)]
model_dec_fg2=[nn.Upsample(scale_factor=2,mode='bilinear',align_corners = True),nn.Conv2d(ngf * 4, ngf, 3, stride=1,padding=1),norm_layer(ngf),nn.ReLU(True),nn.ReflectionPad2d(3),nn.Conv2d(ngf, output_nc-1, kernel_size=7, padding=0)]
self.model_res_dec = nn.Sequential(*model_res_dec)
self.model_res_dec_al=nn.Sequential(*model_res_dec_al)
self.model_res_dec_fg=nn.Sequential(*model_res_dec_fg)
self.model_al_out=nn.Sequential(*model_dec_al)
self.model_dec_fg1=nn.Sequential(*model_dec_fg1)
self.model_fg_out = nn.Sequential(*model_dec_fg2)
def forward(self, image,back,seg,multi):
img_feat1=self.model_enc1(image)
img_feat=self.model_enc2(img_feat1)
back_feat=self.model_enc_back(back)
seg_feat=self.model_enc_seg(seg)
multi_feat=self.model_enc_multi(multi)
oth_feat=torch.cat([self.comb_back(torch.cat([img_feat,back_feat],dim=1)),self.comb_seg(torch.cat([img_feat,seg_feat],dim=1)),self.comb_multi(torch.cat([img_feat,back_feat],dim=1))],dim=1)
out_dec=self.model_res_dec(torch.cat([img_feat,oth_feat],dim=1))
out_dec_al=self.model_res_dec_al(out_dec)
al_out=self.model_al_out(out_dec_al)
out_dec_fg=self.model_res_dec_fg(out_dec)
out_dec_fg1=self.model_dec_fg1(out_dec_fg)
fg_out=self.model_fg_out(torch.cat([out_dec_fg1,img_feat1],dim=1))
return al_out, fg_out
############################## part ##################################
def conv_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.xavier_uniform(m.weight, gain=np.sqrt(2))
#init.normal(m.weight)
if m.bias is not None:
init.constant(m.bias, 0)
if classname.find('Linear') != -1:
init.normal(m.weight)
init.constant(m.bias,1)
if classname.find('BatchNorm2d') != -1:
init.normal(m.weight.data, 1.0, 0.2)
init.constant(m.bias.data, 0.0)
class conv3x3(nn.Module):
'''(conv => BN => ReLU)'''
def __init__(self, in_ch, out_ch):
super(conv3x3, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, stride=2,padding=1),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(0.2,inplace=True),
)
def forward(self, x):
x = self.conv(x)
return x
class conv3x3s1(nn.Module):
'''(conv => BN => ReLU)'''
def __init__(self, in_ch, out_ch):
super(conv3x3s1, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, stride=1,padding=1),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(0.2,inplace=True),
)
def forward(self, x):
x = self.conv(x)
return x
class conv1x1(nn.Module):
'''(conv => BN => ReLU)'''
def __init__(self, in_ch, out_ch):
super(conv1x1, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1, stride=1,padding=0),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(0.2,inplace=True),
)
def forward(self, x):
x = self.conv(x)
return x
class upconv3x3(nn.Module):
def __init__(self, in_ch, out_ch):
super(upconv3x3, self).__init__()
self.conv = nn.Sequential(
nn.Upsample(scale_factor=2,mode='bilinear'),
nn.Conv2d(in_ch, out_ch, 3, stride=1,padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
x=self.conv(x)
return x
class fc(nn.Module):
def __init__(self,in_ch,out_ch):
super(fc,self).__init__()
self.fullc = nn.Sequential(
nn.Linear(in_ch,out_ch),
nn.ReLU(inplace=True),
)
def forward(self,x):
x=self.fullc(x)
return x
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
Python
1
https://gitee.com/giteebytsl/Background-Matting.git
git@gitee.com:giteebytsl/Background-Matting.git
giteebytsl
Background-Matting
Background-Matting
master

搜索帮助