-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathGIC.py
More file actions
76 lines (66 loc) · 2.12 KB
/
GIC.py
File metadata and controls
76 lines (66 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import os
import pickle
import copy
import json
import warnings
warnings.filterwarnings("ignore")
os.environ["CUDA_VISIBLE_DEVICES"]="6,7"
from load_img import Load_from_Folder, Load_Images
from evaluate import Time, MSE, PSNR
from MBMBVQ import MBMBVQ
from EntropyCoding import EntropyCoding
class GIC():
def __init__(self, par):
self.MBMBVQ = MBMBVQ(par)
self.EC = EntropyCoding(par)
def change_n_img(self, n_img):
for i in range(1, self.EC.par['n_hop']+1):
for j in range(len(self.EC.par['shape']['hop'+str(i)])):
self.EC.par['shape']['hop'+str(i)][j][0] = n_img
@Time
def fit(self, Y):
self.change_n_img(Y.shape[0])
self.MBMBVQ.fit(copy.deepcopy(Y))
save = self.MBMBVQ.encode(copy.deepcopy(Y))
self.EC.fit(save)
return self
@Time
def refit(self, Y, par):
self.change_n_img(Y.shape[0])
self.MBMBVQ.refit(copy.deepcopy(Y), par)
save = self.MBMBVQ.encode(copy.deepcopy(Y))
self.EC.refit(save, par)
return self
@Time
def encode(self, Y):
self.change_n_img(Y.shape[0])
save = self.MBMBVQ.encode(Y)
stream = self.EC.encode(save, S=Y.shape[1])
return stream, save['DC'], save
@Time
def decode(self, stream, DC):
save = self.EC.decode(stream)
save['DC'] = DC
iY = self.MBMBVQ.decode(save)
return iY
# return pickleable obj
def save(self):
for k in self.MBMBVQ.km.keys():
km = self.MBMBVQ.km[k]
for i in km.KM:
i.KM.KM = None
i.KM.saveObj=False
return self
if __name__ == "__main__":
with open('./test_data/test_par1.json', 'r') as f:
par = json.load(f)
gic = GIC_Y(par)
Y_list = Load_from_Folder(folder='./test_data/', color='YUV', ct=-1)
Y = np.array(Y_list)[:,:,:,:1]
gic.fit(Y)
stream, dc = gic.encode(Y)
iY = gic.decode(stream, dc)
print('MSE=%5.3f, PSNR=%3.5f'%(MSE(Y, iY), PSNR(Y, iY)))
print('------------------')
print(" * Ref result: "+'MSE=129.342, PSNR=27.01340')