윤영준 윤영준 2023-06-27
Loss function for Autoencoder
@dd1c901957f5e408d58311acf5dada79c93a7ae2
model/Autoencoder.py
--- model/Autoencoder.py
+++ model/Autoencoder.py
@@ -3,10 +3,11 @@
 from torch.nn import functional as F
 from torchvision.models import vgg16
 
+
 class AutoEncoder(nn.Module):
     def __init__(self):
         super(AutoEncoder, self).__init__()
-
+        #layers
         self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2, bias=False)
         self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False)
 
@@ -37,7 +38,13 @@
         self.skip_output2 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
         self.skip_output3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
 
-    # maybe change it into concat Networks? this seems way to cumbersome.
+        # Loss specific definitions
+        self.vgg = vgg16(pretrained=True).features
+        self.vgg.eval()
+        for param in self.vgg.parameters():
+            param.requires_grad = False
+        self.lambda_i = [0.6, 0.8, 1.0]
+
     def forward(self, input_tensor):
         # Feed the input through each layer
         x = torch.relu(self.conv1(input_tensor))
@@ -80,20 +87,7 @@
 
         return ret
 
-
-class LossFunction(nn.Module):
-    def __init__(self):
-        super(LossFunction, self).__init__()
-
-        # Load pre-trained VGG model for feature extraction
-        self.vgg = vgg16(pretrained=True).features
-        self.vgg.eval()
-        for param in self.vgg.parameters():
-            param.requires_grad = False
-
-        self.lambda_i = [0.6, 0.8, 1.0]
-
-    def forward(self, input_tensor, label_tensor):
+    def loss(self, input_tensor, label_tensor):
         ori_height, ori_width = label_tensor.shape[2:]
 
         # Rescale labels to match the scales of the outputs
@@ -101,9 +95,7 @@
         label_tensor_resize_4 = F.interpolate(label_tensor, size=(ori_height // 4, ori_width // 4))
         label_list = [label_tensor_resize_4, label_tensor_resize_2, label_tensor]
 
-        # Initialize autoencoder model
-        autoencoder = AutoEncoder()
-        inference_ret = autoencoder(input_tensor)
+        inference_ret = self.forward(input_tensor)
 
         output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']]
 
@@ -134,12 +126,4 @@
             x = layer(x)
             if layer_num in {3, 8, 15, 22, 29}:
                 feats.append(x)
-        return feats
-
-
-if __name__ == "__main__":
-    from torchinfo import summary
-    torch.set_default_tensor_type(torch.FloatTensor)
-    generator = AutoEncoder()
-    batch_size = 2
-    summary(generator, input_size=(batch_size, 3, 960,540))
+        return feats
(파일 끝에 줄바꿈 문자 없음)
Add a comment
List