-
Notifications
You must be signed in to change notification settings - Fork 30
Expand file tree
/
Copy pathtrain.py
More file actions
87 lines (71 loc) · 3 KB
/
train.py
File metadata and controls
87 lines (71 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
import torch
import utils
import modules
parser = argparse.ArgumentParser()
parser.add_argument('--iterations', type=int, default=100,
help='Number of training iterations.')
parser.add_argument('--learning-rate', type=float, default=1e-2,
help='Learning rate.')
parser.add_argument('--hidden-dim', type=int, default=64,
help='Number of hidden units.')
parser.add_argument('--latent-dim', type=int, default=32,
help='Dimensionality of latent variables.')
parser.add_argument('--latent-dist', type=str, default='gaussian',
help='Choose: "gaussian" or "concrete" latent variables.')
parser.add_argument('--batch-size', type=int, default=512,
help='Mini-batch size (for averaging gradients).')
parser.add_argument('--num-symbols', type=int, default=5,
help='Number of distinct symbols in data generation.')
parser.add_argument('--num-segments', type=int, default=3,
help='Number of segments in data generation.')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='Disable CUDA training.')
parser.add_argument('--log-interval', type=int, default=5,
help='Logging interval.')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device('cuda' if args.cuda else 'cpu')
model = modules.CompILE(
input_dim=args.num_symbols + 1, # +1 for EOS/Padding symbol.
hidden_dim=args.hidden_dim,
latent_dim=args.latent_dim,
max_num_segments=args.num_segments,
latent_dist=args.latent_dist).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
# Train model.
print('Training model...')
for step in range(args.iterations):
data = None
rec = None
batch_loss = 0
batch_acc = 0
optimizer.zero_grad()
# Generate data.
data = []
for _ in range(args.batch_size):
data.append(utils.generate_toy_data(
num_symbols=args.num_symbols,
num_segments=args.num_segments))
lengths = torch.tensor(list(map(len, data)))
lengths = lengths.to(device)
inputs = torch.nn.utils.rnn.pad_sequence(data, batch_first=True)
inputs = inputs.to(device)
# Run forward pass.
model.train()
outputs = model.forward(inputs, lengths)
loss, nll, kl_z, kl_b = utils.get_losses(inputs, outputs, args)
loss.backward()
optimizer.step()
if step % args.log_interval == 0:
# Run evaluation.
model.eval()
outputs = model.forward(inputs, lengths)
acc, rec = utils.get_reconstruction_accuracy(inputs, outputs, args)
# Accumulate metrics.
batch_acc += acc.item()
batch_loss += nll.item()
print('step: {}, nll_train: {:.6f}, rec_acc_eval: {:.3f}'.format(
step, batch_loss, batch_acc))
print('input sample: {}'.format(inputs[-1, :lengths[-1] - 1]))
print('reconstruction: {}'.format(rec[-1]))