-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_image_classifier.py
More file actions
162 lines (140 loc) · 8.15 KB
/
train_image_classifier.py
File metadata and controls
162 lines (140 loc) · 8.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -*- encoding: utf-8 -*-
import tensorflow as tf
from easydict import EasyDict as edict
from base import env_utils
from config.global_configs import TrainBaseConfig, \
TrainConfig, TFRecordBaseConfig, TFRecordConfig, ProjectConfig
from hook_and_exporter import BetterExporter, EvalEarlyStoppingHook, TrainEarlyStoppingHook
from nets.neural_network import NeuralNetwork
def init_training_params():
# https://zhuanlan.zhihu.com/p/74857888
training_params = edict(
{
'drop_rate': TrainBaseConfig.DROP_RATE,
'learning_rate': TrainConfig.getDefault().initial_learning_rate,
'decay_steps': TrainConfig.getDefault().decay_steps,
'decay_rate': TrainConfig.getDefault().decay_rate,
'num_classes': TFRecordConfig.getDefault().num_classes,
'batch_size': TFRecordBaseConfig.BATCH_SIZE,
'evaluation_step': TrainBaseConfig.EVALUATION_STEP,
'max_steps': TrainConfig.getDefault().max_steps,
'early_stopping_patience': TrainBaseConfig.EARLY_STOPPING_PATIENCE,
'eval_throttle_secs': TrainBaseConfig.EVAL_THROTTLE_SECS,
'save_checkpoints_secs': TrainBaseConfig.SAVE_CHECKPOINTS_SECS,
'keep_checkpoint_max': TrainBaseConfig.KEEP_CHECKPOINT_MAX,
'eval_start_delay_secs': TrainBaseConfig.EVAL_START_DELAY_SECS,
'shuffle_buffer_size': TrainBaseConfig.SHUFFLE_BUFFER_SIZE,
'quant': TrainBaseConfig.QUANT,
'quant_delay': TrainBaseConfig.QUANT_DELAY,
'exports_to_keep': TrainBaseConfig.EXPORTS_TO_KEEP,
'net': ProjectConfig.getDefault().net,
'input_tensor_name': TrainBaseConfig.INPUT_TENSOR_NAME,
'output_tensor_name': TrainBaseConfig.OUTPUT_TENSOR_NAME,
'image': TFRecordBaseConfig.IMAGE,
'label': TFRecordBaseConfig.LABEL,
'shape': TFRecordConfig.getDefault().image_shape,
}
)
return training_params
def serving_input_receiver_fn():
# This is used to define inputs to serve the model.
reciever_tensors = {
# The size of input image is flexible.
TFRecordBaseConfig.IMAGE: tf.placeholder(dtype=tf.float32,
shape=[None, *TFRecordConfig.getDefault().image_shape],
name=TrainBaseConfig.INPUT_TENSOR_NAME)}
# # Convert give inputs to adjust to the model.
# features = {
# # Resize given images.
# TFRecordBaseConfig.IMAGE: tf.reshape(reciever_tensors[INPUT_FEATURE], [None, *TFRecordConfig.getDefault().image_shape])
# }
# return: ServingInputReciever
return tf.estimator.export.ServingInputReceiver(receiver_tensors=reciever_tensors, features=reciever_tensors.copy())
def running_train(train_dataset, valid_dataset, test_dataset, gpu='0'):
# init training params
training_params = init_training_params()
# init gpu device
env_utils.select_gpu(gpu)
# limit to num_cpu_core CPU usage
# tf.compat.v1.ConfigProto
session_config = tf.ConfigProto(device_count={"CPU": 4},
log_device_placement=True,
inter_op_parallelism_threads=2,
intra_op_parallelism_threads=5)
session_config.gpu_options.allow_growth = True
# session_config.gpu_options.per_process_gpu_memory_fraction = 0.9
estimator_config = tf.estimator.RunConfig(session_config=session_config,
save_checkpoints_secs=training_params.save_checkpoints_secs,
keep_checkpoint_max=training_params.keep_checkpoint_max)
# build estimator
if ProjectConfig.getDefault().keras:
# # TODO
# # 这种训练方式遇到一个问题,暂时不会保存PB文件,待解决
# neural_network = NeuralNetwork(
# network=ProjectConfig.getDefault().net,
# num_classes=TFRecordConfig.getDefault().num_classes,
# )
# keras_network = neural_network.init_keras_network(input_shape=TFRecordConfig.getDefault().image_shape,
# input_tensor_name=TrainBaseConfig.INPUT_TENSOR_NAME,
# output_tensor_name=TrainBaseConfig.OUTPUT_TENSOR_NAME,
# convert=True)
#
# # 如果用这种方式训练,返回的dataset是image, classifier
# # readTfrecord.get_keras_datasets()
# estimator_network = tf.keras.estimator.model_to_estimator(
# keras_model=keras_network,
# model_dir=TrainConfig.getDefault().train_dir,
# custom_objects=training_params,
# config=estimator_config,
# )
# 如果用这种方式训练,返回的dataset是{xxx, image}, {yyy, classifier}
# 所以如果用这种方式训练,则ReadTfrecord读dataset要使用get_dataset_from_tfrecord
# 这种训练方式output_tensor_name不是设置的Softmax
# TODO
# 1. 待研究维护不是 Softmax
# 2. 通过tf.get_default_graph().get_operations()找出output是什么
estimator_network = tf.estimator.Estimator(
model_fn=NeuralNetwork.build_keras_network,
model_dir=TrainConfig.getDefault().train_dir,
config=estimator_config,
params=training_params,
)
else:
# 如果用这种方式训练,返回的dataset是{xxx, image}, {yyy, classifier}
# estimator都是这种数据
estimator_network = tf.estimator.Estimator(
model_fn=NeuralNetwork.build_network,
model_dir=TrainConfig.getDefault().train_dir,
config=estimator_config,
params=training_params,
)
# train_early_stopping_hook = KerasTrainEarlyStoppingHook(monitor=TrainConfig.getDefault().monitor,
# min_delta=TrainConfig.getDefault().min_delta,
# patience=TrainConfig.getDefault().patience)
train_early_stopping_hook = TrainEarlyStoppingHook()
eval_train_early_stopping_hook = EvalEarlyStoppingHook(training_params.evaluation_step,
patience=training_params.early_stopping_patience,
total_eval_examples=TFRecordConfig.getDefault().val_numbers,
batch_size=training_params.batch_size)
# # https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/estimator/BestExporter
# better_exporter = tf.estimator.BestExporter(name=TrainBaseConfig.BEST_EXPORT,
# serving_input_receiver_fn=serving_input_receiver_fn, exports_to_keep=5)
better_exporter = BetterExporter(TrainBaseConfig.BEST_EXPORT,
serving_input_receiver_fn=serving_input_receiver_fn,
exports_to_keep=training_params.exports_to_keep)
final_exporter = tf.estimator.FinalExporter(TrainBaseConfig.FINAL_EXPORT,
serving_input_receiver_fn=serving_input_receiver_fn)
# train and evaluate
train_spec = tf.estimator.TrainSpec(train_dataset,
max_steps=training_params.max_steps,
hooks=[train_early_stopping_hook],
)
eval_spec = tf.estimator.EvalSpec(valid_dataset, steps=training_params.evaluation_step,
start_delay_secs=training_params.eval_start_delay_secs,
throttle_secs=training_params.eval_throttle_secs,
exporters=[final_exporter, better_exporter],
hooks=[eval_train_early_stopping_hook],
)
tf.estimator.train_and_evaluate(estimator_network, train_spec, eval_spec)