윤영준 윤영준 2023-06-27
Discriminator Loss
@3f5673b9231bdca670a02a1219ce06e2d13f3f06
model/Discriminator.py
--- model/Discriminator.py
+++ model/Discriminator.py
@@ -17,22 +17,31 @@
         self.fc1 = nn.Linear(32, 1)  # You need to adjust the input dimension here depending on your input size
         self.fc2 = nn.Linear(1, 1)
 
-    def loss(self, input_tensor, label_tensor, attention_map, name):
-        batch_size, image_h, image_w, _ = input_tensor.size()
+    def loss(self, real_clean, label_tensor, attention_map):
+        """
+        :param real_clean:
+        :param label_tensor:
+        :param attention_map: This is the final attention map from the generator.
+        :return:
+        """
+        with torch.no_grad():
+            batch_size, image_h, image_w, _ = real_clean.size()
 
-        zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32)
+            zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32)
 
-        # Inference function
-        fc_out_o, attention_mask_o, fc2_o = self.forward(input_tensor)
-        fc_out_r, attention_mask_r, fc2_r = self.forward(label_tensor)
+            # Inference function
+            ret = self.forward(real_clean)
+            fc_out_o, attention_mask_o, fc2_o = ret["fc_out"], ret["attention_map"] , ret["fc_raw"]
+            ret = self.forward(label_tensor)
+            fc_out_r, attention_mask_r, fc2_r = ret["fc_out"], ret["attention_map"] , ret["fc_raw"]
 
-        l_map = F.mse_loss(attention_map, attention_mask_o) + \
-                F.mse_loss(attention_mask_r, zeros_mask)
+            l_map = F.mse_loss(attention_map, attention_mask_o) + \
+                    F.mse_loss(attention_mask_r, zeros_mask)
 
-        entropy_loss = -torch.log(fc_out_r) - torch.log(-torch.sub(fc_out_o, 1.0))
-        entropy_loss = torch.mean(entropy_loss)
+            entropy_loss = -torch.log(fc_out_r) - torch.log(-torch.sub(fc_out_o, 1.0))
+            entropy_loss = torch.mean(entropy_loss)
 
-        loss = entropy_loss + 0.05 * l_map
+            loss = entropy_loss + 0.05 * l_map
 
         return fc_out_o, loss
 
@@ -60,7 +69,7 @@
             "attention_map": attention_map,
             "fc_raw" : fc_raw
         }
-        return fc_out, attention_map, fc_raw
+        return ret
 
 if __name__ == "__main__":
     import torch
Add a comment
List