52 Star 367 Fork 120

PaddlePaddle / PaddleClas

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
resnet34_distill_resnet18_afd.yaml 5.50 KB
一键复制 编辑 原始数据 按行查看 历史
littletomatodonkey 提交于 2022-08-19 12:15 . fix config yaml (#2214)
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
models:
- Teacher:
name: AttentionModel
pretrained_list:
freeze_params_list:
- True
- False
models:
- ResNet34:
name: ResNet34
pretrained: True
return_patterns: &t_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]",
"blocks[8]", "blocks[9]", "blocks[10]", "blocks[11]",
"blocks[12]", "blocks[13]", "blocks[14]", "blocks[15]"]
- LinearTransformTeacher:
name: LinearTransformTeacher
qk_dim: 128
keys: *t_keys
t_shapes: &t_shapes [[64, 56, 56], [64, 56, 56], [64, 56, 56], [128, 28, 28],
[128, 28, 28], [128, 28, 28], [128, 28, 28], [256, 14, 14],
[256, 14, 14], [256, 14, 14], [256, 14, 14], [256, 14, 14],
[256, 14, 14], [512, 7, 7], [512, 7, 7], [512, 7, 7]]
- Student:
name: AttentionModel
pretrained_list:
freeze_params_list:
- False
- False
models:
- ResNet18:
name: ResNet18
pretrained: False
return_patterns: &s_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]"]
- LinearTransformStudent:
name: LinearTransformStudent
qk_dim: 128
keys: *s_keys
s_shapes: &s_shapes [[64, 56, 56], [64, 56, 56], [128, 28, 28], [128, 28, 28],
[256, 14, 14], [256, 14, 14], [512, 7, 7], [512, 7, 7]]
t_shapes: *t_shapes
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
key: logits
- DistillationKLDivLoss:
weight: 0.9
model_name_pairs: [["Student", "Teacher"]]
temperature: 4
key: logits
- AFDLoss:
weight: 50.0
model_name_pair: ["Student", "Teacher"]
student_keys: ["bilinear_key", "value"]
teacher_keys: ["query", "value"]
s_shapes: *s_shapes
t_shapes: *t_shapes
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
weight_decay: 1e-4
lr:
name: MultiStepDecay
learning_rate: 0.1
milestones: [30, 60, 90]
step_each_epoch: 1
gamma: 0.1
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: "docs/images/inference_deployment/whl_demo.jpg"
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Python
1
https://gitee.com/paddlepaddle/PaddleClas.git
git@gitee.com:paddlepaddle/PaddleClas.git
paddlepaddle
PaddleClas
PaddleClas
release/2.5

搜索帮助

53164aa7 5694891 3bd8fe86 5694891