-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathTrainModelAndPredict.py
More file actions
87 lines (65 loc) · 3.63 KB
/
TrainModelAndPredict.py
File metadata and controls
87 lines (65 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
data_file_path = 'market_data_hist_2y.csv'
mkt_data = pd.read_csv(data_file_path)
field_names = ['pctChg1', 'pctChg2', 'pctChg3', 'pctChg4', 'pctChg5', 'pctChg6', 'pctChg7', 'pctChg8', 'pctChg9', 'pctChg10', 'pctChg11', 'pctChg12', 'pctChg13', 'pctChg14', 'y']
# Create feature columns
f_col_pct1 = tf.feature_column.numeric_column('pctChg1')
f_col_pct2 = tf.feature_column.numeric_column('pctChg2')
f_col_pct3 = tf.feature_column.numeric_column('pctChg3')
f_col_pct4 = tf.feature_column.numeric_column('pctChg4')
f_col_pct5 = tf.feature_column.numeric_column('pctChg5')
f_col_pct6 = tf.feature_column.numeric_column('pctChg6')
f_col_pct7 = tf.feature_column.numeric_column('pctChg7')
f_col_pct8 = tf.feature_column.numeric_column('pctChg8')
f_col_pct9 = tf.feature_column.numeric_column('pctChg9')
f_col_pct10 = tf.feature_column.numeric_column('pctChg10')
f_col_pct11 = tf.feature_column.numeric_column('pctChg11')
f_col_pct12 = tf.feature_column.numeric_column('pctChg12')
f_col_pct13 = tf.feature_column.numeric_column('pctChg13')
f_col_pct14 = tf.feature_column.numeric_column('pctChg14')
feature_cols = [f_col_pct1, f_col_pct2, f_col_pct3, f_col_pct4, f_col_pct5, f_col_pct6, f_col_pct7, f_col_pct8, f_col_pct9, f_col_pct10, f_col_pct11, f_col_pct12, f_col_pct13, f_col_pct14]
x_data = mkt_data.drop('y', axis=1)
y_data = mkt_data['y']
# Create train/test split
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=.2, random_state=99)
# Create and train model
input_func = tf.estimator.inputs.pandas_input_fn(x_train, y_train, batch_size=10, num_epochs=100, shuffle=True)
nn_model = tf.estimator.DNNClassifier(hidden_units=[15, 15, 15], feature_columns=feature_cols, n_classes=2)
nn_model.train(input_fn=input_func, steps=100)
# Evaluate model
eval_input_func = tf.estimator.inputs.pandas_input_fn(x=x_test, y=y_test, batch_size=10, num_epochs=1, shuffle=False)
results = nn_model.evaluate(eval_input_func)
print('Model Results:')
print(results)
#####################
# Make Prediction
#####################
data_file_path_pred = 'market_data_daily.csv'
mkt_data_pred = pd.read_csv(data_file_path_pred)
# Create feature columns
f_col_pct1 = tf.feature_column.numeric_column('pctChg1')
f_col_pct2 = tf.feature_column.numeric_column('pctChg2')
f_col_pct3 = tf.feature_column.numeric_column('pctChg3')
f_col_pct4 = tf.feature_column.numeric_column('pctChg4')
f_col_pct5 = tf.feature_column.numeric_column('pctChg5')
f_col_pct6 = tf.feature_column.numeric_column('pctChg6')
f_col_pct7 = tf.feature_column.numeric_column('pctChg7')
f_col_pct8 = tf.feature_column.numeric_column('pctChg8')
f_col_pct9 = tf.feature_column.numeric_column('pctChg9')
f_col_pct10 = tf.feature_column.numeric_column('pctChg10')
f_col_pct11 = tf.feature_column.numeric_column('pctChg11')
f_col_pct12 = tf.feature_column.numeric_column('pctChg12')
f_col_pct13 = tf.feature_column.numeric_column('pctChg13')
f_col_pct14 = tf.feature_column.numeric_column('pctChg14')
feature_cols_pred = [f_col_pct1, f_col_pct2, f_col_pct3, f_col_pct4, f_col_pct5, f_col_pct6, f_col_pct7, f_col_pct8, f_col_pct9, f_col_pct10, f_col_pct11, f_col_pct12, f_col_pct13, f_col_pct14]
x_data = mkt_data_pred
pred_input_func = tf.estimator.inputs.pandas_input_fn(x=x_data, batch_size=10, num_epochs=1, shuffle=False)
predictions = nn_model.predict(pred_input_func)
my_pred = list(predictions)
print('Prediction Results:')
print(my_pred)
pred_val = int(my_pred[0]['classes'][0])
print('Prediction Value:' + str(pred_val))