5 Star 26 Fork 6

科大讯飞 / VLE

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
README.md 3.47 KB
一键复制 编辑 原始数据 按行查看 历史
ycpan 提交于 2023-03-09 17:50 . update examples

Fine-tuning on VQA

Requirements

We use Pytorch-Lightning to fine-tuning the pre-trained VLEModel on VQA. To speedup training, we use DeepSpeed. The main packages are as follows:

pytorch_lightning==1.5.10
transformers==4.26.0
deepspeed==0.7.7
Pillow==8.1.0
tqdm==4.64.1
ipdb==0.13.4
numpy==1.21.6
einops==0.3.0
pyarrow==2.0.0
sacred==0.8.2
pandas==1.1.5
timm==0.4.12
ftfy
torchvision~=0.8.2
torch~=1.7.1

Dataset Preparation for VQAv2

Download the VQAv2 dataset from VQA official site, including COCO 2014 train images, 2014 val images, 2015 test images, annotations (train, val), and questions (train, val, test).

Please unzip and organize the dataset as follows:

root
├── train2014            
│   ├── COCO_train2014_000000000009.jpg                
|   └── ...
├── val2014              
|   ├── COCO_val2014_000000000042.jpg
|   └── ...  
├── test2015              
|   ├── COCO_test2015_000000000001.jpg
|   └── ...         
├── v2_OpenEnded_mscoco_train2014_questions.json
├── v2_OpenEnded_mscoco_val2014_questions.json
├── v2_OpenEnded_mscoco_test2015_questions.json
├── v2_OpenEnded_mscoco_test-dev2015_questions.json
├── v2_mscoco_train2014_annotations.json
└── v2_mscoco_val2014_annotations.json

We use pyarrow to serialize the datasets, the conversion script is write_vqa.py. Please replace the value of id2label in VLE Model's config with the generated mapping file label2answer.json (example: config.json of hfl/vle-base-for-vqa).

Fine-tuning VLE on VQAv2

Hyperparameters for training are set in vqa_train_config.json.

Move the training related files to the same level of the directory as models, as follows:

root
├── models
│   └── VLE 
|       └── ...
├── run_vqav2_ft.py
├── vqav2_datamodule.py
└── vqav2_train_module.py

Specify the config file through --train_config_file and run the train script run_vqav2_ft.py. Here is an example:

export MASTER_ADDR=$DIST_0_IP
export MASTER_PORT=$DIST_0_PORT
export NODE_RANK=$DIST_RANK
python run_vqav2_ft.py --train_config_file=vqa_train_config.json

Postprocess the checkpoint

After training, we convert the saved checkpoint, so that it can be loaded by VLEModel.

We first convert the deepspeed saved checkpoint to a pytorch checkpoint. The convert script is zero_to_fp32.py. If you didn't use DeepSpeed when training the model, this step could be skipped.

python zero_to_fp32.py <ckpt_dir> <output_file> <tag>
# for example:
python zero_to_fp32.py ./logs/VQAv2_seed0_from_vle-base-ft-vqa/version_0/checkpoints/epoch\=0-step\=0.ckpt step\=0.ckpt global_step0

Then, we convert the parameters' names to the same format as VLEModel. The convert script is convert_checkpoint_after_ft.py.

1
https://gitee.com/iflytek/VLE.git
git@gitee.com:iflytek/VLE.git
iflytek
VLE
VLE
main

搜索帮助