윤영준 윤영준 2023-06-21
workaround for nn.Sequential can not have multiple input, and I need LSTM for this project
@127379ef99d8fdaa089a70f7bec82ae6c4eaf35b
model/generator.py
--- model/generator.py
+++ model/generator.py
@@ -2,6 +2,13 @@
 from torch import nn
 from torch.nn import functional as F
 
+# nn.Sequential does not handel multiple input by design
+# https://github.com/pytorch/pytorch/issues/19808#
+class mySequential(nn.Sequential):
+    def forward(self, *input):
+        for module in self._modules.values():
+            input = module(*input)
+        return input
 
 def conv3x3(in_ch, out_ch, stride=1, padding=1, groups=1, dilation=1):
     return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=padding, groups=groups, dilation=dilation)
@@ -50,13 +57,13 @@
         self.layers = layers
 
     def forward(self, x):
+        x = self.conv1(x)
         shortcut = x
-        x = self.conv1(shortcut)
         for i, hidden_layer in enumerate(self.conv_hidden):
             x = hidden_layer(x)
             if (i % self.layers == 0) & (i != 0):
-                x = self.relu(x)
-                x += shortcut
+                x = F.relu(x)
+                x = x + shortcut
         return x
 
 
@@ -136,7 +143,7 @@
                 dilation=self.dilation
             )
         )
-        self.LSTM = nn.Sequential(
+        self.LSTM = mySequential(
             ConvLSTM(
                 ch=out_ch, kernel_size=kernel_size,
             )
@@ -168,7 +175,7 @@
         self.groups = groups
         self.dilation = dilation
         self.repetition = repetition
-        self.generator_block = nn.Sequential(
+        self.generator_block = mySequential(
             GeneratorBlock(blocks=blocks,
                            layers=layers,
                            input_ch=input_ch,
@@ -252,5 +259,5 @@
 if __name__ == "__main__":
     from torchinfo import summary
     generator = Generator(3)
-    batch_size = 10
-    summary(generator, input_size=(batch_size, 3, 1920,1080))
+    batch_size = 1
+    summary(generator, input_size=(batch_size, 3, 960,540))
Add a comment
List