
--- model/Discriminator.py
+++ model/Discriminator.py
... | ... | @@ -17,22 +17,31 @@ |
17 | 17 |
self.fc1 = nn.Linear(32, 1) # You need to adjust the input dimension here depending on your input size |
18 | 18 |
self.fc2 = nn.Linear(1, 1) |
19 | 19 |
|
20 |
- def loss(self, input_tensor, label_tensor, attention_map, name): |
|
21 |
- batch_size, image_h, image_w, _ = input_tensor.size() |
|
20 |
+ def loss(self, real_clean, label_tensor, attention_map): |
|
21 |
+ """ |
|
22 |
+ :param real_clean: |
|
23 |
+ :param label_tensor: |
|
24 |
+ :param attention_map: This is the final attention map from the generator. |
|
25 |
+ :return: |
|
26 |
+ """ |
|
27 |
+ with torch.no_grad(): |
|
28 |
+ batch_size, image_h, image_w, _ = real_clean.size() |
|
22 | 29 |
|
23 |
- zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32) |
|
30 |
+ zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32) |
|
24 | 31 |
|
25 |
- # Inference function |
|
26 |
- fc_out_o, attention_mask_o, fc2_o = self.forward(input_tensor) |
|
27 |
- fc_out_r, attention_mask_r, fc2_r = self.forward(label_tensor) |
|
32 |
+ # Inference function |
|
33 |
+ ret = self.forward(real_clean) |
|
34 |
+ fc_out_o, attention_mask_o, fc2_o = ret["fc_out"], ret["attention_map"] , ret["fc_raw"] |
|
35 |
+ ret = self.forward(label_tensor) |
|
36 |
+ fc_out_r, attention_mask_r, fc2_r = ret["fc_out"], ret["attention_map"] , ret["fc_raw"] |
|
28 | 37 |
|
29 |
- l_map = F.mse_loss(attention_map, attention_mask_o) + \ |
|
30 |
- F.mse_loss(attention_mask_r, zeros_mask) |
|
38 |
+ l_map = F.mse_loss(attention_map, attention_mask_o) + \ |
|
39 |
+ F.mse_loss(attention_mask_r, zeros_mask) |
|
31 | 40 |
|
32 |
- entropy_loss = -torch.log(fc_out_r) - torch.log(-torch.sub(fc_out_o, 1.0)) |
|
33 |
- entropy_loss = torch.mean(entropy_loss) |
|
41 |
+ entropy_loss = -torch.log(fc_out_r) - torch.log(-torch.sub(fc_out_o, 1.0)) |
|
42 |
+ entropy_loss = torch.mean(entropy_loss) |
|
34 | 43 |
|
35 |
- loss = entropy_loss + 0.05 * l_map |
|
44 |
+ loss = entropy_loss + 0.05 * l_map |
|
36 | 45 |
|
37 | 46 |
return fc_out_o, loss |
38 | 47 |
|
... | ... | @@ -60,7 +69,7 @@ |
60 | 69 |
"attention_map": attention_map, |
61 | 70 |
"fc_raw" : fc_raw |
62 | 71 |
} |
63 |
- return fc_out, attention_map, fc_raw |
|
72 |
+ return ret |
|
64 | 73 |
|
65 | 74 |
if __name__ == "__main__": |
66 | 75 |
import torch |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?