1 Star 0 Fork 11

TYLove516 / PaddleVideo

forked from PaddlePaddle / PaddleVideo 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
attention_lstm.md 3.84 KB
一键复制 编辑 原始数据 按行查看 历史
huangjun12 提交于 2022-05-25 07:18 . init commit

简体中文 | English

AttentionLSTM

内容

模型简介

循环神经网络(RNN)常用于序列数据的处理,可建模视频连续多帧的时序信息,在视频分类领域为基础常用方法。 该模型采用了双向长短时记忆网络(LSTM),将视频的所有帧特征依次编码。与传统方法直接采用LSTM最后一个时刻的输出不同,该模型增加了一个Attention层,每个时刻的隐状态输出都有一个自适应权重,然后线性加权得到最终特征向量。参考论文中实现的是两层LSTM结构,而本模型实现的是带Attention的双向LSTM

Attention层可参考论文AttentionCluster

数据准备

PaddleVide提供了在Youtube-8M数据集上训练和测试脚本。Youtube-8M数据下载及准备请参考YouTube-8M数据准备

模型训练

Youtube-8M数据集训练

开始训练

  • Youtube-8M数据集使用8卡训练,feature格式下会使用视频和音频特征作为输入,数据的训练启动命令如下

    python3.7 -B -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" --log_dir=log_attetion_lstm  main.py  --validate -c configs/recognition/attention_lstm/attention_lstm_youtube8m.yaml

模型测试

命令如下:

python3.7 -B -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" --log_dir=log_attetion_lstm  main.py  --test -c configs/recognition/attention_lstm/attention_lstm_youtube8m.yaml -w "output/AttentionLSTM/AttentionLSTM_best.pdparams"

当测试配置采用如下参数时,在Youtube-8M的validation数据集上的测试指标如下:

Hit@1 PERR GAP checkpoints
89.05 80.49 86.30 AttentionLSTM_yt8.pdparams

模型推理

导出inference模型

python3.7 tools/export_model.py -c configs/recognition/attention_lstm/attention_lstm_youtube8m.yaml \
                                -p data/AttentionLSTM_yt8.pdparams \
                                -o inference/AttentionLSTM

上述命令将生成预测所需的模型结构文件AttentionLSTM.pdmodel和模型权重文件AttentionLSTM.pdiparams

各参数含义可参考模型推理方法

使用预测引擎推理

python3.7 tools/predict.py --input_file data/example.pkl \
                           --config configs/recognition/attention_lstm/attention_lstm_youtube8m.yaml \
                           --model_file inference/AttentionLSTM/AttentionLSTM.pdmodel \
                           --params_file inference/AttentionLSTM/AttentionLSTM.pdiparams \
                           --use_gpu=True \
                           --use_tensorrt=False

输出示例如下:

Current video file: data/example.pkl
        top-1 class: 11
        top-1 score: 0.9841002225875854

可以看到,使用在Youtube-8M上训练好的AttentionLSTM模型对data/example.pkl进行预测,输出的top1类别id为11,置信度为0.98。

参考论文

Python
1
https://gitee.com/xhd0115/PaddleVideo.git
git@gitee.com:xhd0115/PaddleVideo.git
xhd0115
PaddleVideo
PaddleVideo
master

搜索帮助