
--- model/Autoencoder.py
+++ model/Autoencoder.py
... | ... | @@ -3,10 +3,11 @@ |
3 | 3 |
from torch.nn import functional as F |
4 | 4 |
from torchvision.models import vgg16 |
5 | 5 |
|
6 |
+ |
|
6 | 7 |
class AutoEncoder(nn.Module): |
7 | 8 |
def __init__(self): |
8 | 9 |
super(AutoEncoder, self).__init__() |
9 |
- |
|
10 |
+ #layers |
|
10 | 11 |
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2, bias=False) |
11 | 12 |
self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False) |
12 | 13 |
|
... | ... | @@ -37,7 +38,13 @@ |
37 | 38 |
self.skip_output2 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False) |
38 | 39 |
self.skip_output3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False) |
39 | 40 |
|
40 |
- # maybe change it into concat Networks? this seems way to cumbersome. |
|
41 |
+ # Loss specific definitions |
|
42 |
+ self.vgg = vgg16(pretrained=True).features |
|
43 |
+ self.vgg.eval() |
|
44 |
+ for param in self.vgg.parameters(): |
|
45 |
+ param.requires_grad = False |
|
46 |
+ self.lambda_i = [0.6, 0.8, 1.0] |
|
47 |
+ |
|
41 | 48 |
def forward(self, input_tensor): |
42 | 49 |
# Feed the input through each layer |
43 | 50 |
x = torch.relu(self.conv1(input_tensor)) |
... | ... | @@ -80,20 +87,7 @@ |
80 | 87 |
|
81 | 88 |
return ret |
82 | 89 |
|
83 |
- |
|
84 |
-class LossFunction(nn.Module): |
|
85 |
- def __init__(self): |
|
86 |
- super(LossFunction, self).__init__() |
|
87 |
- |
|
88 |
- # Load pre-trained VGG model for feature extraction |
|
89 |
- self.vgg = vgg16(pretrained=True).features |
|
90 |
- self.vgg.eval() |
|
91 |
- for param in self.vgg.parameters(): |
|
92 |
- param.requires_grad = False |
|
93 |
- |
|
94 |
- self.lambda_i = [0.6, 0.8, 1.0] |
|
95 |
- |
|
96 |
- def forward(self, input_tensor, label_tensor): |
|
90 |
+ def loss(self, input_tensor, label_tensor): |
|
97 | 91 |
ori_height, ori_width = label_tensor.shape[2:] |
98 | 92 |
|
99 | 93 |
# Rescale labels to match the scales of the outputs |
... | ... | @@ -101,9 +95,7 @@ |
101 | 95 |
label_tensor_resize_4 = F.interpolate(label_tensor, size=(ori_height // 4, ori_width // 4)) |
102 | 96 |
label_list = [label_tensor_resize_4, label_tensor_resize_2, label_tensor] |
103 | 97 |
|
104 |
- # Initialize autoencoder model |
|
105 |
- autoencoder = AutoEncoder() |
|
106 |
- inference_ret = autoencoder(input_tensor) |
|
98 |
+ inference_ret = self.forward(input_tensor) |
|
107 | 99 |
|
108 | 100 |
output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']] |
109 | 101 |
|
... | ... | @@ -134,12 +126,4 @@ |
134 | 126 |
x = layer(x) |
135 | 127 |
if layer_num in {3, 8, 15, 22, 29}: |
136 | 128 |
feats.append(x) |
137 |
- return feats |
|
138 |
- |
|
139 |
- |
|
140 |
-if __name__ == "__main__": |
|
141 |
- from torchinfo import summary |
|
142 |
- torch.set_default_tensor_type(torch.FloatTensor) |
|
143 |
- generator = AutoEncoder() |
|
144 |
- batch_size = 2 |
|
145 |
- summary(generator, input_size=(batch_size, 3, 960,540)) |
|
129 |
+ return feats(파일 끝에 줄바꿈 문자 없음) |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?