Pytorch
下記の記事を試してみた。
2つめのBoston house price datasetで以下のエラーが出た。
Pytorchのバージョンは、0.4.1
RuntimeError: input and target shapes do not match: input [339 x 1], target [339] at c:\programdata\miniconda3\conda-bld\pytorch-cpu_1532498166916\work\aten\src\thnn\generic/MSECriterion.c:12
意味は良くわからないが、reshapeで強引に合わせてみたら、
# loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
def train(X_train, y_train):
inputs = Variable(torch.from_numpy(X_train).float())
targets = Variable(torch.from_numpy(y_train).float())
targets = targets.reshape(339,1) #2018.9.21
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()return loss.data[0]
def valid(X_test, y_test):
inputs = Variable(torch.from_numpy(X_test).float())
targets = Variable(torch.from_numpy(y_test).float())
targets = targets.reshape(167,1) #2018.9.21outputs = model(inputs)
val_loss = criterion(outputs, targets)return val_loss.data[0]
epoch 0, loss: 591.5585 val_loss: 605.6591
epoch 200, loss: 458.1043 val_loss: 483.0244
epoch 400, loss: 376.4112 val_loss: 404.5938
epoch 600, loss: 310.6950 val_loss: 339.3514
epoch 800, loss: 255.1574 val_loss: 282.6092
epoch 1000, loss: 207.9300 val_loss: 233.6992
epoch 1200, loss: 168.0744 val_loss: 192.1623
epoch 1400, loss: 134.8108 val_loss: 157.2863
epoch 1600, loss: 107.4093 val_loss: 128.3291
epoch 1800, loss: 85.1776 val_loss: 104.5960
epoch 2000, loss: 67.4564 val_loss: 85.4385
epoch 2200, loss: 53.6185 val_loss: 70.2452
epoch 2400, loss: 43.0692 val_loss: 58.4374
epoch 2600, loss: 35.2481 val_loss: 49.4687
epoch 2800, loss: 29.6342 val_loss: 42.8279
epoch 3000, loss: 25.7527 val_loss: 38.0465
epoch 3200, loss: 23.1822 val_loss: 34.7050
epoch 3400, loss: 21.5619 val_loss: 32.4404
epoch 3600, loss: 20.5965 val_loss: 30.9517
epoch 3800, loss: 20.0568 val_loss: 30.0003
epoch 4000, loss: 19.7761 val_loss: 29.4071
epoch 4200, loss: 19.6413 val_loss: 29.0446
epoch 4400, loss: 19.5821 val_loss: 28.8268
epoch 4600, loss: 19.5586 val_loss: 28.6982
epoch 4800, loss: 19.5503 val_loss: 28.6241
それらしい結果になった。