forked from ry/tensorflow-resnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathresnet_train.py
More file actions
133 lines (103 loc) · 5.16 KB
/
resnet_train.py
File metadata and controls
133 lines (103 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from resnet import *
import tensorflow as tf
from pdb import set_trace
MOMENTUM = 0.9
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('train_dir', '/tmp/resnet_train_ckpt',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_float('learning_rate', 0.1, "learning rate.")
tf.app.flags.DEFINE_integer('batch_size', 128, "batch size")
tf.app.flags.DEFINE_integer('max_steps', 500000, "max steps")
tf.app.flags.DEFINE_boolean('resume', False,
'resume from latest saved state')
tf.app.flags.DEFINE_boolean('minimal_summaries', True,
'produce fewer summaries to save HD space')
tf.app.flags.DEFINE_boolean('is_use_ckpt', False,
'Whether to load a checkpoint and continue training')
tf.app.flags.DEFINE_string('ckpt_path', '/tmp/resnet_train_ckpt/model.ckpt-11001', 'Checkpoint directory to restore')
def top_k_error(predictions, labels, k):
batch_size = float(FLAGS.batch_size) #tf.shape(predictions)[0]
in_top1 = tf.to_float(tf.nn.in_top_k(predictions, labels, k=1))
num_correct = tf.reduce_sum(in_top1)
return (batch_size - num_correct) / batch_size
def train(is_training, logits, images, labels):
global_step = tf.get_variable('global_step', [],
initializer=tf.constant_initializer(0),
trainable=False)
val_step = tf.get_variable('val_step', [],
initializer=tf.constant_initializer(0),
trainable=False)
loss_ = loss(logits, labels)
predictions = tf.nn.softmax(logits)
top1_error = top_k_error(predictions, labels, 1)
# loss_avg
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss_]))
tf.summary.scalar('loss_avg', ema.average(loss_))
# validation stats
ema = tf.train.ExponentialMovingAverage(0.9, val_step)
val_op = tf.group(val_step.assign_add(1), ema.apply([top1_error]))
top1_error_avg = ema.average(top1_error)
tf.summary.scalar('val_top1_error_avg', top1_error_avg)
tf.summary.scalar('learning_rate', FLAGS.learning_rate)
opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM)
grads = opt.compute_gradients(loss_)
for grad, var in grads:
if grad is not None and not FLAGS.minimal_summaries:
tf.summary.histogram(var.op.name + '/gradients', grad)
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
if not FLAGS.minimal_summaries:
# Display the training images in the visualizer.
tf.summary.image('images', images)
for var in tf.trainable_variables():
tf.summary.histogram(var.op.name, var)
batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
batchnorm_updates_op = tf.group(*batchnorm_updates)
train_op = tf.group(apply_gradient_op, batchnorm_updates_op)
saver = tf.train.Saver(tf.all_variables())
summary_op = tf.summary.merge_all()
init = tf.initialize_all_variables()
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
if FLAGS.is_use_ckpt is True:
saver.restore(sess, FLAGS.ckpt_path)
print('Restored from checkpoint...')
else:
sess.run(init)
tf.train.start_queue_runners(sess=sess)
summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
if FLAGS.resume:
latest = tf.train.latest_checkpoint(FLAGS.train_dir)
if not latest:
print("No checkpoint to continue from in", FLAGS.train_dir)
sys.exit(1)
print("resume", latest)
saver.restore(sess, latest)
for x in range(FLAGS.max_steps + 1):
start_time = time.time()
step = sess.run(global_step)
i = [train_op, loss_, top1_error_avg]
write_summary = step % 1000 and step > 1
if write_summary:
i.append(summary_op)
o = sess.run(i, { is_training: True })
loss_value = o[1]
top1_error_avg_value = o[2]
duration = time.time() - start_time
assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
if step % 100 == 0:
examples_per_sec = FLAGS.batch_size / float(duration)
format_str = ('step %d, loss = %.2f, top1_error_avg = %.5f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (step, loss_value, top1_error_avg_value, examples_per_sec, duration))
if write_summary:
summary_str = o[3]
summary_writer.add_summary(summary_str, step)
# Save the model checkpoint periodically.
if step > 1 and step % 1000 == 0:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=global_step)
# Run validation periodically
if step > 1 and step % 1000== 0:
_, loss_value, top1_error_value = sess.run([val_op, loss_, top1_error], { is_training: False })
print('Validation top1 error %.2f, loss %.5f' % top1_error_value, loss_value)