class ResNet(nn.Module): def __init__(self, block, layers, image_channels): super(ResNet, self).__init__() self.history_loss = [] self.history_eval = [] self.classific_accuracy_training = [] self.current_epoch = 0 self.in_channels = 12 self.conv1 = nn.Conv1d(12, 12, kernel_size=7, stride=2, padding=3) self.bn1 = nn.BatchNorm1d(12) self.relu = nn.ReLU() self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, layers[0], out_channels=12, stride=1) self.layer2 = self._make_layer(block, layers[1], out_channels=24, stride=2) self.layer3 = self._make_layer(block, layers[2], out_channels=24*2, stride=2) self.layer4 = self._make_layer(block, layers[3], out_channels=24*4, stride=2) self.avgpool = nn.AdaptiveAvgPool1d(25) self.lin1 = nn.Linear(4800, 800) self.lin2 = nn.Linear(800, 40) self.lin3 = nn.Linear(40, 10) self.lin4 = nn.Linear(10, 2) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.reshape(x.shape[0], -1) # eventuell auf x.view Funktion umsteigen x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.relu(self.lin3(x)) x = self.lin4(x) return x def _make_layer(self, block, num_residual_blocks, out_channels, stride): identity_downsample = None layers = [] if stride != 1 or self.in_channels != out_channels*2: identity_downsample = nn.Sequential(nn.Conv1d(self.in_channels, out_channels*2, kernel_size=1, stride=stride), nn.BatchNorm1d(out_channels*2)) layers.append(block(self.in_channels, out_channels, identity_downsample, stride=stride)) self.in_channels = out_channels*2 for i in range(num_residual_blocks -1): layers.append(block(self.in_channels, out_channels)) return nn.Sequential(*layers) def ResNet50(img_channels=12): return ResNet(block, [3, 4, 6, 3], img_channels) def test(): net = ResNet50() x = torch.randn(30, 12, 1000) y = net(x) return y net = ResNet50() learning_rate = 0.15 batch_size = 50 optimizer = optim.SGD(net.parameters(), lr=learning_rate)