
workaround for nn.Sequential can not have multiple input, and I need LSTM for this project
@127379ef99d8fdaa089a70f7bec82ae6c4eaf35b
--- model/generator.py
+++ model/generator.py
... | ... | @@ -2,6 +2,13 @@ |
2 | 2 |
from torch import nn |
3 | 3 |
from torch.nn import functional as F |
4 | 4 |
|
5 |
+# nn.Sequential does not handel multiple input by design |
|
6 |
+# https://github.com/pytorch/pytorch/issues/19808# |
|
7 |
+class mySequential(nn.Sequential): |
|
8 |
+ def forward(self, *input): |
|
9 |
+ for module in self._modules.values(): |
|
10 |
+ input = module(*input) |
|
11 |
+ return input |
|
5 | 12 |
|
6 | 13 |
def conv3x3(in_ch, out_ch, stride=1, padding=1, groups=1, dilation=1): |
7 | 14 |
return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=padding, groups=groups, dilation=dilation) |
... | ... | @@ -50,13 +57,13 @@ |
50 | 57 |
self.layers = layers |
51 | 58 |
|
52 | 59 |
def forward(self, x): |
60 |
+ x = self.conv1(x) |
|
53 | 61 |
shortcut = x |
54 |
- x = self.conv1(shortcut) |
|
55 | 62 |
for i, hidden_layer in enumerate(self.conv_hidden): |
56 | 63 |
x = hidden_layer(x) |
57 | 64 |
if (i % self.layers == 0) & (i != 0): |
58 |
- x = self.relu(x) |
|
59 |
- x += shortcut |
|
65 |
+ x = F.relu(x) |
|
66 |
+ x = x + shortcut |
|
60 | 67 |
return x |
61 | 68 |
|
62 | 69 |
|
... | ... | @@ -136,7 +143,7 @@ |
136 | 143 |
dilation=self.dilation |
137 | 144 |
) |
138 | 145 |
) |
139 |
- self.LSTM = nn.Sequential( |
|
146 |
+ self.LSTM = mySequential( |
|
140 | 147 |
ConvLSTM( |
141 | 148 |
ch=out_ch, kernel_size=kernel_size, |
142 | 149 |
) |
... | ... | @@ -168,7 +175,7 @@ |
168 | 175 |
self.groups = groups |
169 | 176 |
self.dilation = dilation |
170 | 177 |
self.repetition = repetition |
171 |
- self.generator_block = nn.Sequential( |
|
178 |
+ self.generator_block = mySequential( |
|
172 | 179 |
GeneratorBlock(blocks=blocks, |
173 | 180 |
layers=layers, |
174 | 181 |
input_ch=input_ch, |
... | ... | @@ -252,5 +259,5 @@ |
252 | 259 |
if __name__ == "__main__": |
253 | 260 |
from torchinfo import summary |
254 | 261 |
generator = Generator(3) |
255 |
- batch_size = 10 |
|
256 |
- summary(generator, input_size=(batch_size, 3, 1920,1080)) |
|
262 |
+ batch_size = 1 |
|
263 |
+ summary(generator, input_size=(batch_size, 3, 960,540)) |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?