Skip to content

Commit

Permalink
Bug fixes and tweaks for a stronger baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
Karan Desai committed Feb 11, 2019
2 parents feba6de + 711689c commit de30951
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ model:

# Optimization related arguments
solver:
batch_size: 32
batch_size: 128 # 32 x num_gpus is a good rule of thumb
num_epochs: 20
initial_lr: 0.001
lr_gamma: 0.9997592083
minimum_lr: 0.00005
initial_lr: 0.01
training_splits: "train" # "trainval"

lr_gamma: 0.1
lr_milestones: # epochs when lr —> lr * lr_gamma
- 4
- 7
- 10
warmup_factor: 0.2
warmup_epochs: 1
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

parser = argparse.ArgumentParser("Evaluate and/or generate EvalAI submission file.")
parser.add_argument(
"--config-yml", default="configs/lf_disc_faster_rcnn_x101_bs32.yml",
"--config-yml", default="configs/lf_disc_faster_rcnn_x101.yml",
help="Path to a config file listing reader, model and optimization parameters."
)
parser.add_argument(
Expand Down
56 changes: 42 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
import yaml
from bisect import bisect

from visdialch.data.dataset import VisDialDataset
from visdialch.encoders import Encoder
Expand All @@ -19,7 +20,7 @@

parser = argparse.ArgumentParser()
parser.add_argument(
"--config-yml", default="configs/lf_disc_faster_rcnn_x101_bs32.yml",
"--config-yml", default="configs/lf_disc_faster_rcnn_x101.yml",
help="Path to a config file listing reader, model and solver parameters."
)
parser.add_argument(
Expand Down Expand Up @@ -76,6 +77,7 @@
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


# ================================================================================================
# INPUT ARGUMENTS AND CONFIG
# ================================================================================================
Expand All @@ -95,14 +97,14 @@


# ================================================================================================
# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER
# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER
# ================================================================================================

train_dataset = VisDialDataset(
config["dataset"], args.train_json, overfit=args.overfit, in_memory=args.in_memory
)
train_dataloader = DataLoader(
train_dataset, batch_size=config["solver"]["batch_size"], num_workers=args.cpu_workers
train_dataset, batch_size=config["solver"]["batch_size"], num_workers=args.cpu_workers, shuffle=True
)

val_dataset = VisDialDataset(
Expand All @@ -126,9 +128,31 @@
if -1 not in args.gpu_ids:
model = nn.DataParallel(model, args.gpu_ids)

# Loss function.
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config["solver"]["initial_lr"])
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=config["solver"]["lr_gamma"])

if config["solver"]["training_splits"] == "trainval":
iterations = (len(train_dataset) + len(val_dataset)) // config["solver"]["batch_size"] + 1
else:
iterations = len(train_dataset) // config["solver"]["batch_size"] + 1


def lr_lambda_fun(current_iteration: int) -> float:
"""Returns a learning rate multiplier.
Till `warmup_epochs`, learning rate linearly increases to `initial_lr`,
and then gets multiplied by `lr_gamma` every time a milestone is crossed.
"""
current_epoch = float(current_iteration) / iterations
if current_epoch <= config["solver"]["warmup_epochs"]:
alpha = current_epoch / float(config["solver"]["warmup_epochs"])
return config["solver"]["warmup_factor"] * (1. - alpha) + alpha
else:
idx = bisect(config["solver"]["lr_milestones"], current_epoch)
return pow(config["solver"]["lr_gamma"], idx)

optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"])
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun)


# ================================================================================================
Expand Down Expand Up @@ -159,14 +183,10 @@
# TRAINING LOOP
# ================================================================================================

# Forever increasing counter keeping track of iterations completed.
if config["solver"]["training_splits"] == "trainval":
iterations = (len(train_dataset) + len(val_dataset)) // config["solver"]["batch_size"] + 1
else:
iterations = len(train_dataset) // config["solver"]["batch_size"] + 1

# Forever increasing counter keeping track of iterations completed (for tensorboard logging).
global_iteration_step = start_epoch * iterations
for epoch in range(start_epoch, config["solver"]["num_epochs"] + 1):

for epoch in range(start_epoch, config["solver"]["num_epochs"]):

# --------------------------------------------------------------------------------------------
# ON EPOCH START (combine dataloaders if training on train + val)
Expand All @@ -189,9 +209,10 @@

summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step)
summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step)
if optimizer.param_groups[0]["lr"] > config["solver"]["minimum_lr"]:
scheduler.step()

scheduler.step(global_iteration_step)
global_iteration_step += 1
torch.cuda.empty_cache()

# --------------------------------------------------------------------------------------------
# ON EPOCH END (checkpointing and validation)
Expand All @@ -200,6 +221,10 @@

# Validate and report automatic metrics.
if args.validate:

# Switch dropout, batchnorm etc to the correct mode.
model.eval()

print(f"\nValidation after epoch {epoch}:")
for i, batch in enumerate(tqdm(val_dataloader)):
for key in batch:
Expand All @@ -217,3 +242,6 @@
for metric_name, metric_value in all_metrics.items():
print(f"{metric_name}: {metric_value}")
summary_writer.add_scalars("metrics", all_metrics, global_iteration_step)

model.train()
torch.cuda.empty_cache()
4 changes: 3 additions & 1 deletion visdialch/decoders/disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def __init__(self, config, vocabulary):
padding_idx=vocabulary.PAD_INDEX)
self.option_rnn = nn.LSTM(config["word_embedding_size"],
config["lstm_hidden_size"],
batch_first=True)
config["lstm_num_layers"],
batch_first=True,
dropout=config["dropout"])

# Options are variable length padded sequences, use DynamicRNN.
self.option_rnn = DynamicRNN(self.option_rnn)
Expand Down
15 changes: 11 additions & 4 deletions visdialch/encoders/lf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __init__(self, config, vocabulary):
config["img_feature_size"], config["lstm_hidden_size"]
)

# fc layer for image * question to attention weights
self.attention_proj = nn.Linear(config["lstm_hidden_size"], 1)

# fusion layer (attended_image_features + question + history)
fusion_size = config["img_feature_size"] + config["lstm_hidden_size"] * 2
self.fusion = nn.Linear(fusion_size, config["lstm_hidden_size"])
Expand Down Expand Up @@ -78,10 +81,14 @@ def forward(self, batch):
batch_size * num_rounds, -1, self.config["lstm_hidden_size"]
)

# attend the features using question
# computing attention weights
# shape: (batch_size * num_rounds, num_proposals)
image_attention_weights = projected_image_features.bmm(
ques_embed.unsqueeze(-1)).squeeze()
projected_ques_features = ques_embed.unsqueeze(1).repeat(
1, img.shape[1], 1)
projected_ques_image = projected_ques_features * projected_image_features
projected_ques_image = self.dropout(projected_ques_image)
image_attention_weights = self.attention_proj(
projected_ques_image).squeeze()
image_attention_weights = F.softmax(image_attention_weights, dim=-1)

# shape: (batch_size * num_rounds, num_proposals, img_features_size)
Expand All @@ -105,7 +112,7 @@ def forward(self, batch):
hist_embed = self.word_embed(hist)

# shape: (batch_size * num_rounds, lstm_hidden_size)
_ , (hist_embed, _) = self.hist_rnn(hist_embed, batch["hist_len"])
_, (hist_embed, _) = self.hist_rnn(hist_embed, batch["hist_len"])

fused_vector = torch.cat((img, ques_embed, hist_embed), 1)
fused_vector = self.dropout(fused_vector)
Expand Down

0 comments on commit de30951

Please sign in to comment.