We use Pytorch-Lightning
to fine-tuning the pre-trained VLEModel on VQA. To speedup training, we use DeepSpeed
. The main packages are as follows:
pytorch_lightning==1.5.10
transformers==4.26.0
deepspeed==0.7.7
Pillow==8.1.0
tqdm==4.64.1
ipdb==0.13.4
numpy==1.21.6
einops==0.3.0
pyarrow==2.0.0
sacred==0.8.2
pandas==1.1.5
timm==0.4.12
ftfy
torchvision~=0.8.2
torch~=1.7.1
Download the VQAv2 dataset from VQA official site, including COCO 2014 train images, 2014 val images, 2015 test images, annotations (train, val), and questions (train, val, test).
Please unzip and organize the dataset as follows:
root
├── train2014
│ ├── COCO_train2014_000000000009.jpg
| └── ...
├── val2014
| ├── COCO_val2014_000000000042.jpg
| └── ...
├── test2015
| ├── COCO_test2015_000000000001.jpg
| └── ...
├── v2_OpenEnded_mscoco_train2014_questions.json
├── v2_OpenEnded_mscoco_val2014_questions.json
├── v2_OpenEnded_mscoco_test2015_questions.json
├── v2_OpenEnded_mscoco_test-dev2015_questions.json
├── v2_mscoco_train2014_annotations.json
└── v2_mscoco_val2014_annotations.json
We use pyarrow
to serialize the datasets, the conversion script is write_vqa.py
.
Please replace the value of id2label
in VLE Model's config with the generated mapping file label2answer.json
(example: config.json
of hfl/vle-base-for-vqa
).
Hyperparameters for training are set in vqa_train_config.json
.
Move the training related files to the same level of the directory as models
, as follows:
root
├── models
│ └── VLE
| └── ...
├── run_vqav2_ft.py
├── vqav2_datamodule.py
└── vqav2_train_module.py
Specify the config file through --train_config_file
and run the train script run_vqav2_ft.py
. Here is an example:
export MASTER_ADDR=$DIST_0_IP
export MASTER_PORT=$DIST_0_PORT
export NODE_RANK=$DIST_RANK
python run_vqav2_ft.py --train_config_file=vqa_train_config.json
After training, we convert the saved checkpoint, so that it can be loaded by VLEModel
.
We first convert the deepspeed saved checkpoint to a pytorch checkpoint. The convert script is zero_to_fp32.py
. If you didn't use DeepSpeed
when training the model, this step could be skipped.
python zero_to_fp32.py <ckpt_dir> <output_file> <tag>
# for example:
python zero_to_fp32.py ./logs/VQAv2_seed0_from_vle-base-ft-vqa/version_0/checkpoints/epoch\=0-step\=0.ckpt step\=0.ckpt global_step0
Then, we convert the parameters' names to the same format as VLEModel
. The convert script is convert_checkpoint_after_ft.py
.
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。