윤영준 윤영준 2023-06-21
testing update
@66fa69f192cd7b7466e2b6b160a5ae9503b1bc41
 
model/discriminator.py (deleted)
--- model/discriminator.py
@@ -1,1 +0,0 @@
-from torch import nn(파일 끝에 줄바꿈 문자 없음)
model/generator.py
--- model/generator.py
+++ model/generator.py
@@ -63,7 +63,7 @@
 class ConvLSTM(nn.Module):
     def __init__(self, ch, kernel_size=3):
         super(ConvLSTM, self).__init__()
-        self.padding = (kernel_size-1)/2
+        self.padding = (len(kernel_size)-1)/2
         self.conv_i = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1,
                                 bias=False)
         self.conv_f = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1,
@@ -181,7 +181,7 @@
         )
         self.generator_blocks = nn.ModuleList()
         for repetition in range(repetition):
-            self.conv_hidden.append(
+            self.generator_blocks.append(
                 self.generator_block
             )
 
@@ -248,3 +248,9 @@
         fc_out = torch.clamp(fc_out, min=1e-7, max=1 - 1e-7)
 
         return fc_out, attention_map, fc2
+
+if __name__ == "__main__":
+    from torchinfo import summary
+    generator = Generator(3)
+    batch_size = 10
+    summary(generator, input_size=(batch_size, 3, 1920,1080))
Add a comment
List