GQA.py is a streamlined port of the Annotated BUTD repository, to exist in a single-file (for those that prefer to see the whole story all at once).
It steps through each of the stages of training a Bottom-Up Top-Down (BUTD) model on the GQA Dataset, including:
Note that this file only includes streamlined code for the original Bottom-Up Top-Down Model, with the simple product-based fusion operation.
For the BUTD-FiLM Model, consult the Modular branch of the Annotated BUTD repository – the code is quite similar.
To run this executable file, run the following (append --gpus 1
if running on GPU):
python gqa.py --run_name GQA
from argparse import Namespace
from datetime import datetime
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
from tap import Tap
from torch.nn.utils.weight_norm import weight_norm
from torch.utils.data import Dataset, DataLoader
import base64
import csv
import numpy as np
import h5py
import json
import os
import pickle
import pytorch_lightning as pl
import random
import sys
import torch
import torch.nn as nn
Defines an argument parser with paths to the appropriate arguments – we use Tap for readability.
Note the arguments gqa_questions
and gqa_features
. These are paths to the
GQA Questions and extracted
Bottom-Up Features, and should be pre-downloaded using
this script.
We also define an argument
gqa_cache
that we use to store serialized/formatted data created during the
preprocessing step (e.g. HDF5 files).
Feel free to change this to a path that is convenient for you (and has enough storage – this
directory can grow
large!).
Similarly note the argument glove
which contains a path to pre-trained
GloVe Embeddings. These can be downloaded via
this script.
Other important arguments include gpus
(the number of gpus to run with :: default = 0),
and the random seed seed
.
The argument checkpoint
is a path to a checkpoint directory, to store model metrics and
checkpoints
(saved based on best validation accuracy). Feel free to change this as you see fit.
All other arguments contain sane defaults for initializing different parts of the BUTD model – these are not optimized, but seem to work well.
class ArgumentParser(Tap):
run_name: str # Run Name -- for informative logging
data: str = "data/" # Where downloaded data is located
checkpoint: str = "checkpoints/" # Where to save model checkpoints and serialized statistics
gqa_questions: str = 'data/GQA-Questions' # Path to GQA Balanced Training Set of Questions
gqa_features: str = 'data/GQA-Features' # Path to GQA Features
gqa_cache: str = 'data/GQA-Cache' # Path to GQA Cache Directory for storing serialized data
glove: str = 'data/GloVe/glove.6B.300d.txt' # Path to GloVe Embeddings File (300-dim)
gpus: int = 0 # Number of GPUs to run with (default :: 0)
model: str = 'butd' # Model Architecture to run with -- < butd | film >
dataset: str = 'gqa' # Dataset to run BUTD Model with -- < vqa2 | gqa | nlvr2 >
emb_dim: int = 300 # Word Embedding Dimension --> Should Match GloVe (300)
emb_dropout: float = 0.0 # Dropout to Apply to Word Embeddings
rnn: str = 'GRU' # RNN Type for Question Encoder --> one of < 'GRU' | 'LSTM' >
rnn_layers: int = 1 # Number of RNN Stacked Layers (for Statement Encoder)
bidirectional: bool = False # Whether or not RNN is Bidirectional
q_dropout: float = 0.0 # RNN Dropout (for Question Encoder)
attention_dropout: float = 0.2 # Dropout for Attention Operation (fusing Image + Question)
answer_dropout: float = 0.5 # Dropout to Apply to Answer Classifier
hidden: int = 1024 # Dimensionality of Hidden Layer (Question Encoder & Object Encoder)
weight_norm: bool = True # Boolean whether or not to use Weight Normalization
bsz: int = 256 # Batch Size --> the Bigger the Better
epochs: int = 15 # Number of Training Epochs
opt: str = 'adamax' # Optimizer for Performing Gradient Updates
gradient_clip: float = 0.25 # Value for Gradient Clipping
seed: int = 7 # Random Seed (for Reproducibility)
Parse arguments – Convert to and from Namespace because of weird PyTorch Lightning Bug, and set the name of the Run for meaningful logging.
args = Namespace(**ArgumentParser().parse_args().as_dict())
run_name = args.run_name + '-%s' % args.model + '-x%d' % args.seed + '+' + datetime.now().strftime('%m-%d-[%H:%M]')
print('[*] Starting Train Job in Mode %s with Run Name: %s' % (args.dataset.upper(), run_name))
Book-Keeping – Set the Random Seed for all relevant libraries.
print('[*] Setting Random Seed to %d!' % args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
There are 4 steps to the preprocessing pipeline:
torch.Dataset
wrapping the
VQA Data in an easy-to-batch format.
Assemble a Dictionary mapping question tokens to integer indices. Additionally, use the created dictionaries to index and load in GloVe vectors.
class Dictionary(object):
def __init__(self, word2idx=None, idx2word=None):
if word2idx is None:
word2idx = {}
if idx2word is None:
idx2word = []
self.word2idx, self.idx2word = word2idx, idx2word
@property
def ntoken(self):
return len(self.word2idx)
@property
def padding_idx(self):
return len(self.word2idx)
def tokenize(self, sentence, add_word):
sentence = sentence.lower().replace(',', '').replace('.', '').replace('?', '').replace('\'s', ' \'s')
words, tokens = sentence.split(), []
if add_word:
for w in words:
tokens.append(self.add_word(w))
else:
for w in words:
tokens.append(self.word2idx.get(w, self.padding_idx))
return tokens
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
Create Dictionary from GQA Question Files, and Initialize GloVe Embeddings from File
def gqa_create_dictionary_glove(gqa_q='data/GQA-Questions', glove='data/GloVe/glove.6B.300d.txt',
cache='data/GQA-Cache'):
Note: It’s worth talking about a common design pattern that you’ll see
throughout this codebase, around utilizing
the gqa_cache
directory to its fullest potential.
As we compute serialized/formatted versions of data (token dictionaries, embedding matrices, HDF5 files), we cache them for future runs to speed up the iteration time.
For a research codebase (where speedy iteration is the name of the game), we find this to be a useful practice.
dfile, gfile = os.path.join(cache, 'dictionary.pkl'), os.path.join(cache, 'glove.npy')
if os.path.exists(dfile) and os.path.exists(gfile):
with open(dfile, 'rb') as f:
dictionary = pickle.load(f)
weights = np.load(gfile)
return dictionary, weights
elif not os.path.exists(cache):
os.makedirs(cache)
dictionary = Dictionary()
questions = ['train_balanced_questions.json', 'val_balanced_questions.json', 'testdev_balanced_questions.json',
'test_balanced_questions.json']
Iterate through Question in Question Files and update Vocabulary
print('\t[*] Creating Dictionary from GQA Questions...')
for qfile in questions:
qpath = os.path.join(gqa_q, qfile)
with open(qpath, 'r') as f:
examples = json.load(f)
for ex_key in examples:
ex = examples[ex_key]
dictionary.tokenize(ex['question'], add_word=True)
Load GloVe Embeddings
print('\t[*] Loading GloVe Embeddings...')
with open(glove, 'r') as f:
entries = f.readlines()
Assert that we’re using the 300-Dimensional GloVe Embeddings
assert len(entries[0].split()) - 1 == 300, 'ERROR - Not using 300-dimensional GloVe Embeddings!'
Create Embedding Weights
weights = np.zeros((len(dictionary.idx2word), 300), dtype=np.float32)
Populate Embedding Weights
for entry in entries:
word_vec = entry.split()
word, vec = word_vec[0], list(map(float, word_vec[1:]))
if word in dictionary.word2idx:
weights[dictionary.word2idx[word]] = vec
Dump Dictionary and Weights to file
with open(dfile, 'wb') as f:
pickle.dump(dictionary, f)
np.save(gfile, weights)
Return Dictionary and Weights
return dictionary, weights
Assemble dictionaries mapping answer strings to indices and vice-versa, for priming the Softmax in the final layer of the BUTD model.
Create mapping from answers to labels
def gqa_create_answers(gqa_q='data/GQA-Questions', cache='data/GQA-Cache'):
Create File Paths and Load from Disk (if cached)
dfile = os.path.join(cache, 'answers.pkl')
if os.path.exists(dfile):
with open(dfile, 'rb') as f:
ans2label, label2ans = pickle.load(f)
return ans2label, label2ans
ans2label, label2ans = {}, []
questions = ['train_balanced_questions.json', 'val_balanced_questions.json', 'testdev_balanced_questions.json']
Iterate through Answer in Question Files and update Mapping
print('\t[*] Creating Answer Labels from GQA Question/Answers...')
for qfile in questions:
qpath = os.path.join(gqa_q, qfile)
with open(qpath, 'r') as f:
examples = json.load(f)
for ex_key in examples:
ex = examples[ex_key]
if not ex['answer'].lower() in ans2label:
ans2label[ex['answer'].lower()] = len(ans2label)
label2ans.append(ex['answer'])
Dump Dictionaries to File
with open(dfile, 'wb') as f:
pickle.dump((ans2label, label2ans), f)
return ans2label, label2ans
Reads in a tsv file with pre-trained bottom up attention features and writes them to hdf5 file. Additionally builds image ID –> Feature IDX Mapping.
Hierarchy of HDF5 file:
{
'image_features': num_images x num_boxes x 2048
'image_spatials': num_images x num_boxes x 6
'image_bb': num_images x num_boxes x 4
}
Set CSV Field Size Limit (Big TSV Files…)
csv.field_size_limit(sys.maxsize)
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", "attrs_id", "attrs_conf", "num_boxes", "boxes",
"features"]
NUM_FIXED_BOXES = 36
FEATURE_LENGTH = 2048
Iterate through BUTD TSV and Build HDF5 Files with Bounding Box Features, Image ID –> IDX Mappings
def gqa_create_image_features(gqa_f='data/GQA-Features', cache='data/GQA-Cache'):
print('\t[*] Setting up HDF5 Files for Image/Object Features...')
Create Trackers for Image IDX –> Index
trainval_indices, testdev_indices = {}, {}
tv_file = os.path.join(cache, 'trainval36.hdf5')
td_file = os.path.join(cache, 'testdev36.hdf5')
tv_idxfile = os.path.join(cache, 'trainval36_img2idx.pkl')
td_idxfile = os.path.join(cache, 'testdev36_img2idx.pkl')
if os.path.exists(tv_file) and os.path.exists(td_file) and os.path.exists(tv_idxfile) and \
os.path.exists(td_idxfile):
with open(tv_idxfile, 'rb') as f:
trainval_indices = pickle.load(f)
with open(td_idxfile, 'rb') as f:
testdev_indices = pickle.load(f)
return trainval_indices, testdev_indices
with h5py.File(tv_file, 'w') as h_trainval, h5py.File(td_file, 'w') as h_testdev:
Get Number of Images in each Split
with open(os.path.join(gqa_f, 'vg_gqa_obj36.tsv'), 'r') as f:
ntrainval = len(f.readlines())
with open(os.path.join(gqa_f, 'gqa_testdev_obj36.tsv'), 'r') as f:
ntestdev = len(f.readlines())
Setup HDF5 Files
trainval_img_features = h_trainval.create_dataset('image_features', (ntrainval, NUM_FIXED_BOXES,
FEATURE_LENGTH), 'f')
trainval_img_bb = h_trainval.create_dataset('image_bb', (ntrainval, NUM_FIXED_BOXES, 4), 'f')
trainval_spatial_features = h_trainval.create_dataset('spatial_features', (ntrainval, NUM_FIXED_BOXES, 6), 'f')
testdev_img_features = h_testdev.create_dataset('image_features', (ntestdev, NUM_FIXED_BOXES, FEATURE_LENGTH),
'f')
testdev_img_bb = h_testdev.create_dataset('image_bb', (ntestdev, NUM_FIXED_BOXES, 4), 'f')
testdev_spatial_features = h_testdev.create_dataset('spatial_features', (ntestdev, NUM_FIXED_BOXES, 6), 'f')
Start Iterating through TSV
print('\t[*] Reading Train-Val TSV File and Populating HDF5 File...')
trainval_counter, testdev_counter = 0, 0
with open(os.path.join(gqa_f, 'vg_gqa_obj36.tsv'), 'r') as tsv:
reader = csv.DictReader(tsv, delimiter='\t', fieldnames=FIELDNAMES)
for item in reader:
item['num_boxes'] = int(item['num_boxes'])
image_id = item['img_id']
image_w = float(item['img_w'])
image_h = float(item['img_h'])
bb = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape((item['num_boxes'], -1))
box_width = bb[:, 2] - bb[:, 0]
box_height = bb[:, 3] - bb[:, 1]
scaled_width = box_width / image_w
scaled_height = box_height / image_h
scaled_x = bb[:, 0] / image_w
scaled_y = bb[:, 1] / image_h
scaled_width = scaled_width[..., np.newaxis]
scaled_height = scaled_height[..., np.newaxis]
scaled_x = scaled_x[..., np.newaxis]
scaled_y = scaled_y[..., np.newaxis]
spatial_features = np.concatenate(
(scaled_x,
scaled_y,
scaled_x + scaled_width,
scaled_y + scaled_height,
scaled_width,
scaled_height),
axis=1)
trainval_indices[image_id] = trainval_counter
trainval_img_bb[trainval_counter, :, :] = bb
trainval_img_features[trainval_counter, :, :] = \
np.frombuffer(base64.b64decode(item['features']), dtype=np.float32).reshape((item['num_boxes'], -1))
trainval_spatial_features[trainval_counter, :, :] = spatial_features
trainval_counter += 1
print('\t[*] Reading Test-Dev TSV File and Populating HDF5 File...')
with open(os.path.join(gqa_f, 'gqa_testdev_obj36.tsv'), 'r') as tsv:
reader = csv.DictReader(tsv, delimiter='\t', fieldnames=FIELDNAMES)
for item in reader:
item['num_boxes'] = int(item['num_boxes'])
image_id = item['img_id']
image_w = float(item['img_w'])
image_h = float(item['img_h'])
bb = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape((item['num_boxes'], -1))
box_width = bb[:, 2] - bb[:, 0]
box_height = bb[:, 3] - bb[:, 1]
scaled_width = box_width / image_w
scaled_height = box_height / image_h
scaled_x = bb[:, 0] / image_w
scaled_y = bb[:, 1] / image_h
scaled_width = scaled_width[..., np.newaxis]
scaled_height = scaled_height[..., np.newaxis]
scaled_x = scaled_x[..., np.newaxis]
scaled_y = scaled_y[..., np.newaxis]
spatial_features = np.concatenate(
(scaled_x,
scaled_y,
scaled_x + scaled_width,
scaled_y + scaled_height,
scaled_width,
scaled_height),
axis=1)
testdev_indices[image_id] = testdev_counter
testdev_img_bb[testdev_counter, :, :] = bb
testdev_img_features[testdev_counter, :, :] = \
np.frombuffer(base64.b64decode(item['features']), dtype=np.float32).reshape((item['num_boxes'], -1))
testdev_spatial_features[testdev_counter, :, :] = spatial_features
testdev_counter += 1
Dump TrainVal and TestDev Indices to File
with open(tv_idxfile, 'wb') as f:
pickle.dump(trainval_indices, f)
with open(td_idxfile, 'wb') as f:
pickle.dump(testdev_indices, f)
return trainval_indices, testdev_indices
Define GQA Feature Dataset torch.Dataset
, with utilities for loading image features
from HDF5 files, and tensorizing
data.
class GQAFeatureDataset(Dataset):
def __init__(self, dictionary, ans2label, label2ans, img2idx, gqa_q='data/GQA-Questions', cache='data/GQA-Cache',
mode='train'):
super(GQAFeatureDataset, self).__init__()
self.dictionary, self.ans2label, self.label2ans, self.img2idx = dictionary, ans2label, label2ans, img2idx
Load HDF5 Image Features
print('\t[*] Loading HDF5 Features...')
if mode in ['train', 'val']:
prefix = 'trainval'
else:
prefix = 'testdev'
self.v_dim, self.s_dim = 2048, 6
self.hf = h5py.File(os.path.join(cache, '%s36.hdf5' % prefix), 'r')
self.features = self.hf.get('image_features')
self.spatials = self.hf.get('spatial_features')
Create the Dataset Entries by Iterating through the Data
self.entries = load_dataset(self.img2idx, ans2label, gqa_q=gqa_q, mode=mode)
self.tokenize()
self.tensorize()
Tokenize and Front-Pad the Questions in the Dataset
def tokenize(self, max_length=40):
for entry in self.entries:
tokens = self.dictionary.tokenize(entry['question'], False)
tokens = tokens[:max_length]
if len(tokens) < max_length:
Note that we pad in front of the sentence (GRU reads left-to-right)
padding = [self.dictionary.padding_idx] * (max_length - len(tokens))
tokens = padding + tokens
assert len(tokens) == max_length, "Tokenized & Padded Question != Max Length!"
entry['q_token'] = tokens
def tensorize(self):
for entry in self.entries:
question = torch.from_numpy(np.array(entry['q_token']))
entry['q_token'] = question
def __getitem__(self, index):
entry = self.entries[index]
Get Features
features = torch.from_numpy(np.array(self.features[entry['image']]))
spatials = torch.from_numpy(np.array(self.spatials[entry['image']]))
question = entry['q_token']
target = entry['answer']
return features, spatials, question, target
def __len__(self):
return len(self.entries)
Load Dataset Entries
def load_dataset(img2idx, ans2label, gqa_q='data/GQA-Questions', mode='train'):
question_path = os.path.join(gqa_q, '%s_balanced_questions.json' % mode)
with open(question_path, 'r') as f:
examples = json.load(f)
print('\t[*] Creating GQA %s Entries...' % mode)
entries = []
for ex_key in sorted(examples):
entry = create_entry(examples[ex_key], ex_key, img2idx, ans2label)
entries.append(entry)
return entries
def create_entry(example, qid, img2idx, ans2label):
img_id = example['imageId']
assert img_id in img2idx, 'Image ID not in Index!'
entry = {
'question_id': qid,
'image_id': img_id,
'image': img2idx[img_id],
'question': example['question'],
'answer': ans2label[example['answer'].lower()]
}
return entry
In this section, we formally define the Bottom-Up Top-Down Model with product-based multi-modal fusion.
This model is moderately different than that originally proposed and is instead inspired by the implementation by Hengyuan Hu et. al. with some minor tweaks around the handling of spatial features.
It’s also worth noting that this Model is built using the PyTorch-Lightning library – an excellent resource for quickly prototyping research-based models.
Simple utility class defining a fully connected network (multi-layer perceptron)
class MLP(nn.Module):
def __init__(self, dims, use_weight_norm=True):
super(MLP, self).__init__()
layers = []
for i in range(len(dims) - 1):
in_dim, out_dim = dims[i], dims[i + 1]
if use_weight_norm:
layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
else:
layers.append(nn.Linear(in_dim, out_dim))
layers.append(nn.ReLU())
self.mlp = nn.Sequential(*layers)
def forward(self, x):
output: [bsz, *, dims[0]] –> [bsz, *, dims[-1]]
return self.mlp(x)
Initialize an Embedding Matrix with the appropriate dimensions –> Defines padding as last token in dict
class WordEmbedding(nn.Module):
def __init__(self, ntoken, dim, dropout=0.0):
super(WordEmbedding, self).__init__()
self.ntoken, self.dim = ntoken, dim
self.emb = nn.Embedding(ntoken + 1, dim, padding_idx=ntoken)
self.dropout = nn.Dropout(dropout)
Set Embedding Weights from Numpy Array
def load_embeddings(self, weights):
assert weights.shape == (self.ntoken, self.dim)
self.emb.weight.data[:self.ntoken] = torch.from_numpy(weights)
def forward(self, x):
x : [bsz, seq_len] output: [bsz, seq_len, emb_dim]
return self.dropout(self.emb(x))
Initialize the RNN Question Encoder with the appropriate configuration
class QuestionEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, nlayers=1, bidirectional=False, dropout=0.0, rnn='GRU'):
super(QuestionEncoder, self).__init__()
self.in_dim, self.hidden, self.nlayers, self.bidirectional = in_dim, hidden_dim, nlayers, bidirectional
self.rnn_type, self.rnn_cls = rnn, nn.GRU if rnn == 'GRU' else nn.LSTM
Initialize RNN
self.rnn = self.rnn_cls(self.in_dim, self.hidden, self.nlayers, bidirectional=self.bidirectional,
dropout=dropout, batch_first=True)
x: [bsz, seq_len, emb_dim]
output[0]: [bsz, seq_len, ndirections * hidden]
output[1]: [bsz, nlayers * ndirections, hidden]
def forward(self, x):
output, hidden = self.rnn(x) # Note that Hidden Defaults to 0
If not Bidirectional –> Just return last output state
if not self.bidirectional:
output: [bsz, hidden]
return output[:, -1]
Otherwise, concat forward state for last element and backward state for first element
else:
output: [bsz, 2 * hidden]
f, b = output[:, -1, :self.hidden], output[:, 0, self.hidden:]
return torch.cat([f, b], dim=1)
Initialize the Attention Mechanism with the appropriate fusion operation
class Attention(nn.Module):
def __init__(self, image_dim, question_dim, hidden, dropout=0.2, use_weight_norm=True):
super(Attention, self).__init__()
Attention w/ Product Fusion
self.image_proj = MLP([image_dim, hidden], use_weight_norm=use_weight_norm)
self.question_proj = MLP([question_dim, hidden], use_weight_norm=use_weight_norm)
self.dropout = nn.Dropout(dropout)
self.linear = weight_norm(nn.Linear(hidden, 1), dim=None) if use_weight_norm else nn.Linear(hidden, 1)
def forward(self, image_features, question_emb):
image_features: [bsz, k, image_dim = 2048]
question_emb: [bsz, question_dim]
Project both image and question embedding to hidden and repeat question_emb
num_objs = image_features.size(1)
image_proj = self.image_proj(image_features)
question_proj = self.question_proj(question_emb).unsqueeze(1).repeat(1, num_objs, 1)
Key: Fuse w/ Product
image_question = image_proj * question_proj
Dropout Joint Representation
joint_representation = self.dropout(image_question)
Compute Logits – Softmax
logits = self.linear(joint_representation)
return nn.functional.softmax(logits, dim=1)
class BUTD(pl.LightningModule):
def __init__(self, hparams, train_dataset, val_dataset, ans2label=None, label2ans=None):
super(BUTD, self).__init__()
Save Hyper-Parameters and Dataset
self.hparams = hparams
self.train_dataset, self.val_dataset = train_dataset, val_dataset
self.ans2label, self.label2ans = ans2label, label2ans
Build Model
self.build_model()
def build_model(self):
Build Word Embeddings (for Questions)
self.w_emb = WordEmbedding(ntoken=self.train_dataset.dictionary.ntoken, dim=self.hparams.emb_dim,
dropout=self.hparams.emb_dropout)
Build Question Encoder
self.q_enc = QuestionEncoder(in_dim=self.hparams.emb_dim, hidden_dim=self.hparams.hidden,
nlayers=self.hparams.rnn_layers, bidirectional=self.hparams.bidirectional,
dropout=self.hparams.q_dropout, rnn=self.hparams.rnn)
Build Attention Mechanism
self.att = Attention(image_dim=self.train_dataset.v_dim + 6, question_dim=self.q_enc.hidden,
hidden=self.hparams.hidden, dropout=self.hparams.attention_dropout,
use_weight_norm=self.hparams.weight_norm)
Build Projection Networks
self.q_project = MLP([self.q_enc.hidden, self.hparams.hidden], use_weight_norm=self.hparams.weight_norm)
self.img_project = MLP([self.train_dataset.v_dim + 6, self.hparams.hidden],
use_weight_norm=self.hparams.weight_norm)
Build Answer Classifier
self.ans_classifier = nn.Sequential(*[
weight_norm(nn.Linear(self.hparams.hidden, 2 * self.hparams.hidden), dim=None)
if self.hparams.weight_norm else nn.Linear(self.hparams.hidden, 2 * self.hparams.hidden),
nn.ReLU(),
nn.Dropout(self.hparams.answer_dropout),
weight_norm(nn.Linear(2 * self.hparams.hidden, len(self.ans2label)), dim=None)
if self.hparams.weight_norm else nn.Linear(2 * self.hparams.hidden, len(self.ans2label))
])
def forward(self, image_features, spatial_features, question_features, indicator_features=None):
image_features: [bsz, K, image_dim] question_features: [bsz, seq_len]
Embed and Encode Question – [bsz, q_hidden]
w_emb = self.w_emb(question_features)
q_enc = self.q_enc(w_emb)
Create new Image Features
Key: Concatenate Spatial Features!
if indicator_features is not None:
image_features = torch.cat([image_features, spatial_features, indicator_features], dim=2)
else:
image_features = torch.cat([image_features, spatial_features], dim=2)
Attend over Image Features and Create Image Encoding
img_enc: [bsz, img_hidden]
att = self.att(image_features, q_enc)
img_enc = (image_features * att).sum(dim=1)
Project Image and Question Features –> [bsz, hidden]
q_repr = self.q_project(q_enc)
img_repr = self.img_project(img_enc)
Merge
joint_repr = q_repr * img_repr
Compute and Return Logits
return self.ans_classifier(joint_repr)
def configure_optimizers(self):
if self.hparams.opt == 'adamax':
return torch.optim.Adamax(self.parameters())
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.hparams.bsz, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.hparams.bsz)
def training_step(self, train_batch, batch_idx):
img, spatials, question, answer = train_batch
Run Forward Pass
logits = self.forward(img, spatials, question)
Compute Loss (Cross-Entropy)
loss = nn.functional.cross_entropy(logits, answer)
Compute Answer Accuracy
accuracy = torch.mean((logits.argmax(dim=1) == answer).float())
Set up Data to be Logged
log = {'train_loss': loss, 'train_acc': accuracy}
return {'loss': loss, 'train_loss': loss, 'train_acc': accuracy, 'progress_bar': log, 'log': log}
def training_epoch_end(self, outputs):
Outputs –> List of Individual Step Outputs
avg_loss = torch.stack([x['callback_metrics']['train_loss'] for x in outputs]).mean()
avg_acc = torch.stack([x['callback_metrics']['train_acc'] for x in outputs]).mean()
log = {'train_epoch_loss': avg_loss, 'train_epoch_acc': avg_acc}
return {'progress_bar': log, 'log': log}
def validation_step(self, val_batch, batch_idx):
img, spatials, question, answer = val_batch
Run Forward Pass
logits = self.forward(img, spatials, question)
Compute Loss (Cross-Entropy)
loss = nn.functional.cross_entropy(logits, answer)
Compute Answer Accuracy
accuracy = torch.mean((logits.argmax(dim=1) == answer).float())
return {'val_loss': loss, 'val_acc': accuracy}
def validation_epoch_end(self, outputs):
Outputs –> List of Individual Step Outputs
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
log = {'val_loss': avg_loss, 'val_acc': avg_acc}
return {'progress_bar': log, 'log': log}
We tap into PyTorch-Lightning’s extensive Logging Capabilities and define our own simple logger to log metrics like training loss, training accuracy, validation loss, and validation accuracy to straightforward JSON files.
class MetricLogger(LightningLoggerBase):
def __init__(self, name, save_dir):
super(MetricLogger, self).__init__()
self._name, self._save_dir = name, os.path.join(save_dir, 'metrics')
Create Massive Dictionary to JSONify
self.events = {}
@property
def name(self):
return self._name
@property
def experiment(self):
return None
@property
def version(self):
return 1.0
@rank_zero_only
def log_hyperparams(self, params):
Params is an argparse.Namespace
self.events['hyperparams'] = vars(params)
@rank_zero_only
def log_metrics(self, metrics, step):
Metrics is a dictionary of metric names and values
for metric in metrics:
if metric in self.events:
self.events[metric].append(metrics[metric])
self.events["%s_step" % metric].append(step)
else:
self.events[metric] = [metrics[metric]]
self.events["%s_step" % metric] = [step]
@rank_zero_only
def finalize(self, status):
Optional. Any code that needs to be run after training
self.events['status'] = status
if not os.path.exists(self._save_dir):
os.makedirs(self._save_dir)
with open(os.path.join(self._save_dir, '%s-metrics.json' % self._name), 'w') as f:
json.dump(self.events, f, indent=4)
Here, we bring all the pieces together, calling each of the 4 preprocessing steps, assembling the training and development datasets, and initializing and training the BUTD model.
Preprocess Question Data – Return Dictionary and GloVe-initialized Embeddings
print('\n[*] Pre-processing GQA Questions...')
dictionary, emb = gqa_create_dictionary_glove(gqa_q=args.gqa_questions, glove=args.glove, cache=args.gqa_cache)
Preprocess Answer Data
print('\n[*] Pre-processing GQA Answers...')
ans2label, label2ans = gqa_create_answers(gqa_q=args.gqa_questions, cache=args.gqa_cache)
Create Image Features
print('\n[*] Pre-processing GQA BUTD Image Features')
trainval_img2idx, testdev_img2idx = gqa_create_image_features(gqa_f=args.gqa_features, cache=args.gqa_cache)
Build Train and TestDev Datasets – Note here that we use the TestDev split of GQA instead of Val (as is common practice) because of Visual Genome data leakage in the Validation Set
print('\n[*] Building GQA Train and TestDev Datasets...')
train_dataset = GQAFeatureDataset(dictionary, ans2label, label2ans, trainval_img2idx, gqa_q=args.gqa_questions,
cache=args.gqa_cache, mode='train')
dev_dataset = GQAFeatureDataset(dictionary, ans2label, label2ans, testdev_img2idx, gqa_q=args.gqa_questions,
cache=args.gqa_cache, mode='testdev')
Create BUTD Module (and load Embeddings!)
print('\n[*] Initializing Bottom-Up Top-Down Model...')
nn = BUTD(args, train_dataset, dev_dataset, ans2label, label2ans)
nn.w_emb.load_embeddings(emb)
Setup Logger for PyTorch-Lightning
mt_logger = MetricLogger(name=run_name, save_dir=args.checkpoint)
Saves the top-3 Checkpoints based on Validation Accuracy – feel free to change this metric to suit your needs
checkpoint_callback = ModelCheckpoint(filepath=os.path.join(args.checkpoint, 'runs', run_name,
'butd-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}'),
monitor='val_acc', mode='max', save_top_k=3)
Create Pytorch-Lightning Trainer – run for the given number of epochs, with gradient clipping!
trainer = pl.Trainer(default_root_dir=args.checkpoint, max_epochs=args.epochs, gradient_clip_val=args.gradient_clip,
gpus=args.gpus, benchmark=True, logger=mt_logger, checkpoint_callback=checkpoint_callback)
Fit and Profit!
print('\n[*] Training...\n')
trainer.fit(nn)