-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathlayers.py
More file actions
118 lines (98 loc) · 5.1 KB
/
layers.py
File metadata and controls
118 lines (98 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
from torch.autograd import Function, Variable
import torch.nn.functional as Func
from torch.nn import Module, Parameter
import math
from utils import hard_sigm, bound
import time
class HM_LSTMCell(Module):
def __init__(self, bottom_size, hidden_size, top_size, a, last_layer):
super(HM_LSTMCell, self).__init__()
self.bottom_size = bottom_size
self.hidden_size = hidden_size
self.top_size = top_size
self.a = a
self.last_layer = last_layer
'''
U_11 means the state transition parameters from layer l (current layer) to layer l
U_21 means the state transition parameters from layer l+1 (top layer) to layer l
W_01 means the state transition parameters from layer l-1 (bottom layer) to layer l
'''
self.U_11 = Parameter(torch.cuda.FloatTensor(4 * self.hidden_size + 1, self.hidden_size))
if not self.last_layer:
self.U_21 = Parameter(torch.cuda.FloatTensor(4 * self.hidden_size + 1, self.top_size))
self.W_01 = Parameter(torch.cuda.FloatTensor(4 * self.hidden_size + 1, self.bottom_size))
self.bias = Parameter(torch.cuda.FloatTensor(4 * self.hidden_size + 1))
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for par in self.parameters():
par.data.uniform_(-stdv, stdv)
def forward(self, c, h_bottom, h, h_top, z, z_bottom):
# h_bottom.size = bottom_size * batch_size
s_recur = torch.mm(self.W_01, h_bottom)
if not self.last_layer:
s_topdown_ = torch.mm(self.U_21, h_top)
s_topdown = z.expand_as(s_topdown_) * s_topdown_
else:
s_topdown = Variable(torch.zeros(s_recur.size()).cuda(), requires_grad=False).cuda()
s_bottomup_ = torch.mm(self.U_11, h)
s_bottomup = z_bottom.expand_as(s_bottomup_) * s_bottomup_
f_s = s_recur + s_topdown + s_bottomup + self.bias.unsqueeze(1).expand_as(s_recur)
# f_s.size = (4 * hidden_size + 1) * batch_size
f = Func.sigmoid(f_s[0:self.hidden_size, :]) # hidden_size * batch_size
i = Func.sigmoid(f_s[self.hidden_size:self.hidden_size*2, :])
o = Func.sigmoid(f_s[self.hidden_size*2:self.hidden_size*3, :])
g = Func.tanh(f_s[self.hidden_size*3:self.hidden_size*4, :])
z_hat = hard_sigm(self.a, f_s[self.hidden_size*4:self.hidden_size*4+1, :])
one = Variable(torch.ones(f.size()).cuda(), requires_grad=False)
z = z.expand_as(f)
z_bottom = z_bottom.expand_as(f)
c_new = z * (i * g) + (one - z) * (one - z_bottom) * c + (one - z) * z_bottom * (f * c + i * g)
h_new = z * o * Func.tanh(c_new) + (one - z) * (one - z_bottom) * h + (one - z) * z_bottom * o * Func.tanh(c_new)
# if z == 1: (FLUSH)
# c_new = i * g
# h_new = o * Func.tanh(c_new)
# elif z_bottom == 0: (COPY)
# c_new = c
# h_new = h
# else: (UPDATE)
# c_new = f * c + i * g
# h_new = o * Func.tanh(c_new)
z_new = bound()(z_hat)
return h_new, c_new, z_new
class HM_LSTM(Module):
def __init__(self, a, input_size, size_list):
super(HM_LSTM, self).__init__()
self.a = a
self.input_size = input_size
self.size_list = size_list
self.cell_1 = HM_LSTMCell(self.input_size, self.size_list[0], self.size_list[1], self.a, False)
self.cell_2 = HM_LSTMCell(self.size_list[0], self.size_list[1], None, self.a, True)
def forward(self, inputs, hidden):
# inputs.size = (batch_size, time steps, embed_size/input_size)
time_steps = inputs.size(1)
batch_size = inputs.size(0)
if hidden == None:
h_t1 = Variable(torch.zeros(self.size_list[0], batch_size).float().cuda(), requires_grad=False)
c_t1 = Variable(torch.zeros(self.size_list[0], batch_size).float().cuda(), requires_grad=False)
z_t1 = Variable(torch.zeros(1, batch_size).float().cuda(), requires_grad=False)
h_t2 = Variable(torch.zeros(self.size_list[1], batch_size).float().cuda(), requires_grad=False)
c_t2 = Variable(torch.zeros(self.size_list[1], batch_size).float().cuda(), requires_grad=False)
z_t2 = Variable(torch.zeros(1, batch_size).float().cuda(), requires_grad=False)
else:
(h_t1, c_t1, z_t1, h_t2, c_t2, z_t2) = hidden
z_one = Variable(torch.ones(1, batch_size).float().cuda(), requires_grad=False)
h_1 = []
h_2 = []
z_1 = []
z_2 = []
for t in range(time_steps):
h_t1, c_t1, z_t1 = self.cell_1(c=c_t1, h_bottom=inputs[:, t, :].t(), h=h_t1, h_top=h_t2, z=z_t1, z_bottom=z_one)
h_t2, c_t2, z_t2 = self.cell_2(c=c_t2, h_bottom=h_t1, h=h_t2, h_top=None, z=z_t2, z_bottom=z_t1) # 0.01s used
h_1 += [h_t1.t()]
h_2 += [h_t2.t()]
z_1 += [z_t1.t()]
z_2 += [z_t2.t()]
hidden = (h_t1, c_t1, z_t1, h_t2, c_t2, z_t2)
return torch.stack(h_1, dim=1), torch.stack(h_2, dim=1), torch.stack(z_1, dim=1), torch.stack(z_2, dim=1), hidden