diff --git a/libs/configs/config_v1.py b/libs/configs/config_v1.py index 6f7a996..5a2ddcb 100644 --- a/libs/configs/config_v1.py +++ b/libs/configs/config_v1.py @@ -11,6 +11,10 @@ 'train_dir', './output/mask_rcnn/', 'Directory where checkpoints and event logs are written to.') +tf.app.flags.DEFINE_integer( + 'train_checkpoint_interval', 3000, + 'The number of steps per saved checkpoint.') + tf.app.flags.DEFINE_string( 'pretrained_model', './data/pretrained_models/resnet_v1_50.ckpt', 'Path to pretrained model') @@ -42,7 +46,7 @@ 'The name of the train/test/val split.') tf.app.flags.DEFINE_string( - 'dataset_dir', 'data/coco/', + 'dataset_dir', './data/coco/', 'The directory where the dataset files are stored.') tf.app.flags.DEFINE_integer( @@ -130,7 +134,7 @@ 'Specifies how the learning rate is decayed. One of "fixed", "exponential",' ' or "polynomial"') -tf.app.flags.DEFINE_float('learning_rate', 0.002, +tf.app.flags.DEFINE_float('learning_rate', 2e-3 'Initial learning rate.') tf.app.flags.DEFINE_float( @@ -143,8 +147,8 @@ tf.app.flags.DEFINE_float( 'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.') -tf.app.flags.DEFINE_float( - 'num_epochs_per_decay', 2.0, +tf.app.flags.DEFINE_integer( + 'num_epochs_per_decay', 2, 'Number of epochs after which learning rate decays.') tf.app.flags.DEFINE_bool( diff --git a/train/train.py b/train/train.py index 839c178..f327a7d 100644 --- a/train/train.py +++ b/train/train.py @@ -63,21 +63,23 @@ def solve(global_step): def restore(sess): """choose which param to restore""" + is_restored = False if FLAGS.restore_previous_if_exists: try: checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir) restorer = tf.train.Saver() restorer.restore(sess, checkpoint_path) + is_restored = True print ('restored previous model %s from %s'\ %(checkpoint_path, FLAGS.train_dir)) time.sleep(2) return except: - print ('--restore_previous_if_exists is set, but failed to restore in %s %s'\ + print ('--restore_previous_if_exists is set, but FAILED TO RESTORE in %s %s'\ % (FLAGS.train_dir, checkpoint_path)) time.sleep(2) - if FLAGS.pretrained_model: + if FLAGS.pretrained_model and not is_restored: if tf.gfile.IsDirectory(FLAGS.pretrained_model): checkpoint_path = tf.train.latest_checkpoint(FLAGS.pretrained_model) else: @@ -105,6 +107,8 @@ def restore(sess): def train(): """The main function that runs training""" + print("Starting learning rate %.7f"%(FLAGS.learning_rate)) + ## data image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \ datasets.get_dataset(FLAGS.dataset_name, @@ -143,6 +147,7 @@ def train(): ## solvers global_step = slim.create_global_step() + # global_step = tf.Variable(0, name='global_step', trainable=False) update_op = solve(global_step) cropped_rois = tf.get_collection('__CROPPED__')[0] @@ -208,10 +213,10 @@ def train(): summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) - if (step % 10000 == 0 or step + 1 == FLAGS.max_iters) and step != 0: + if (step % FLAGS.train_checkpoint_interval == 0 or step + 1 == FLAGS.max_iters) and step != 0: checkpoint_path = os.path.join(FLAGS.train_dir, FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt') - saver.save(sess, checkpoint_path, global_step=step) + saver.save(sess, checkpoint_path, global_step=global_step) if coord.should_stop(): coord.request_stop()