@@ -104,55 +104,142 @@ print(ssim_score_0.shape, ssim_score_1.shape)
104104
105105## As A Loss
106106
107- ![ prediction] ( https://user-images.githubusercontent.com/26847524/174814849-f80ec67c-5397-4ce6-bf4e-8b0aa568ed6f.png )
107+ As you can see from the respective thresholds of the two cases below, it is easier to optimize towards MSSIM=1 than MSSIM=-1.
108+
109+ ### Optimize towards MSSIM=1
110+
111+ ![ prediction] ( https://user-images.githubusercontent.com/26847524/174930091-9d7f7505-1752-423a-b7c3-d4dbfeb8d336.png )
108112
109113``` python
110114import matplotlib.pyplot as plt
111115import torch
112116from pytorch_ssim import SSIM
113117from skimage import data
114- from torch.optim import Adam
118+ from torch import optim
115119
116-
117- original_image = data.camera() / 255
120+ original_image = data.moon() / 255
118121target_image = torch.from_numpy(original_image).unsqueeze(0 ).unsqueeze(0 ).float().cuda()
119- predicted_image = torch.rand_like (
122+ predicted_image = torch.zeros_like (
120123 target_image, device = target_image.device, dtype = target_image.dtype, requires_grad = True
121124)
122125initial_image = predicted_image.clone()
123126
124127ssim = SSIM().cuda()
125128initial_ssim_value = ssim(predicted_image, target_image)
126- print (f " Initial ssim: { initial_ssim_value.item():.4f } " )
127- ssim_value = initial_ssim_value
128129
129- optimizer = Adam([predicted_image], lr = 0.01 )
130+ ssim_value = initial_ssim_value
131+ optimizer = optim.Adam([predicted_image], lr = 0.01 )
130132loss_curves = []
131- while ssim_value < 0.95 :
133+ while ssim_value < 0.999 :
132134 ssim_out = 1 - ssim(predicted_image, target_image)
133135 loss_curves.append(ssim_out.item())
134136 ssim_value = 1 - ssim_out.item()
137+ print (ssim_value)
135138 ssim_out.backward()
136139 optimizer.step()
137140 optimizer.zero_grad()
138141
139- fig, axes = plt.subplots(nrows = 1 , ncols = 4 , figsize = (8 , 2 ))
142+ fig, axes = plt.subplots(nrows = 2 , ncols = 4 , figsize = (8 , 4 ))
140143ax = axes.ravel()
141144
142145ax[0 ].imshow(original_image, cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
143146ax[0 ].set_title(" Original Image" )
144147
145148ax[1 ].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
146- ax[1 ].set_xlabel(f " SSIM: { initial_ssim_value:.4f } " )
149+ ax[1 ].set_xlabel(f " SSIM: { initial_ssim_value:.5f } " )
147150ax[1 ].set_title(" Initial Image" )
148151
149152ax[2 ].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
150- ax[2 ].set_xlabel(f " SSIM: { ssim_value:.4f } " )
153+ ax[2 ].set_xlabel(f " SSIM: { ssim_value:.5f } " )
151154ax[2 ].set_title(" Predicted Image" )
152155
153156ax[3 ].plot(loss_curves)
154157ax[3 ].set_title(" SSIM Loss Curve" )
155158
159+ ax[4 ].set_title(" Original Image" )
160+ ax[4 ].hist(original_image.ravel(), bins = 256 )
161+ ax[4 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
162+ ax[4 ].set_xlabel(" Pixel Intensity" )
163+
164+ ax[5 ].set_title(" Initial Image" )
165+ ax[5 ].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins = 256 )
166+ ax[5 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
167+ ax[5 ].set_xlabel(" Pixel Intensity" )
168+
169+ ax[6 ].set_title(" Predicted Image" )
170+ ax[6 ].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins = 256 )
171+ ax[6 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
172+ ax[6 ].set_xlabel(" Pixel Intensity" )
173+
174+ plt.tight_layout()
175+ plt.savefig(" prediction.png" )
176+ ```
177+
178+ ### Optimize towards MSSIM=-1
179+
180+ ![ prediction] ( https://user-images.githubusercontent.com/26847524/174929574-5332cab2-104f-4aab-a4e5-35e7635a793f.png )
181+
182+ ``` python
183+ import matplotlib.pyplot as plt
184+ import torch
185+ from pytorch_ssim import SSIM
186+ from skimage import data
187+ from torch import optim
188+
189+ original_image = data.moon() / 255
190+ target_image = torch.from_numpy(original_image).unsqueeze(0 ).unsqueeze(0 ).float().cuda()
191+ predicted_image = torch.zeros_like(
192+ target_image, device = target_image.device, dtype = target_image.dtype, requires_grad = True
193+ )
194+ initial_image = predicted_image.clone()
195+
196+ ssim = SSIM(L = original_image.max() - original_image.min()).cuda()
197+ initial_ssim_value = ssim(predicted_image, target_image)
198+
199+ ssim_value = initial_ssim_value
200+ optimizer = optim.Adam([predicted_image], lr = 0.01 )
201+ loss_curves = []
202+ while ssim_value > - 0.94 :
203+ ssim_out = ssim(predicted_image, target_image)
204+ loss_curves.append(ssim_out.item())
205+ ssim_value = ssim_out.item()
206+ print (ssim_value)
207+ ssim_out.backward()
208+ optimizer.step()
209+ optimizer.zero_grad()
210+
211+ fig, axes = plt.subplots(nrows = 2 , ncols = 4 , figsize = (8 , 4 ))
212+ ax = axes.ravel()
213+
214+ ax[0 ].imshow(original_image, cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
215+ ax[0 ].set_title(" Original Image" )
216+
217+ ax[1 ].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
218+ ax[1 ].set_xlabel(f " SSIM: { initial_ssim_value:.5f } " )
219+ ax[1 ].set_title(" Initial Image" )
220+
221+ ax[2 ].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
222+ ax[2 ].set_xlabel(f " SSIM: { ssim_value:.5f } " )
223+ ax[2 ].set_title(" Predicted Image" )
224+
225+ ax[3 ].plot(loss_curves)
226+ ax[3 ].set_title(" SSIM Loss Curve" )
227+
228+ ax[4 ].set_title(" Original Image" )
229+ ax[4 ].hist(original_image.ravel(), bins = 256 )
230+ ax[4 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
231+ ax[4 ].set_xlabel(" Pixel Intensity" )
232+
233+ ax[5 ].set_title(" Initial Image" )
234+ ax[5 ].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins = 256 )
235+ ax[5 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
236+ ax[5 ].set_xlabel(" Pixel Intensity" )
237+
238+ ax[6 ].set_title(" Predicted Image" )
239+ ax[6 ].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins = 256 )
240+ ax[6 ].ticklabel_format(axis = " y" , style = " scientific" , scilimits = (0 , 0 ))
241+ ax[6 ].set_xlabel(" Pixel Intensity" )
242+
156243plt.tight_layout()
157244plt.savefig(" prediction.png" )
158245```
0 commit comments