Pytorch

下記の記事を試してみた。

 

 

2つめのBoston house price datasetで以下のエラーが出た。

Pytorchのバージョンは、0.4.1

RuntimeError: input and target shapes do not match: input [339 x 1], target [339] at c:\programdata\miniconda3\conda-bld\pytorch-cpu_1532498166916\work\aten\src\thnn\generic/MSECriterion.c:12

 

意味は良くわからないが、reshapeで強引に合わせてみたら、

# loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

 

     def train(X_train, y_train):
     inputs = Variable(torch.from_numpy(X_train).float())
     targets = Variable(torch.from_numpy(y_train).float())
     targets = targets.reshape(339,1) #2018.9.21

 

     optimizer.zero_grad()
     outputs = model(inputs)
     loss = criterion(outputs, targets)
     loss.backward()
     optimizer.step()

     return loss.data[0]

 

     def valid(X_test, y_test):
     inputs = Variable(torch.from_numpy(X_test).float())
     targets = Variable(torch.from_numpy(y_test).float())
     targets = targets.reshape(167,1) #2018.9.21

     outputs = model(inputs)
     val_loss = criterion(outputs, targets)

     return val_loss.data[0]

 

epoch 0, loss: 591.5585 val_loss: 605.6591
epoch 200, loss: 458.1043 val_loss: 483.0244
epoch 400, loss: 376.4112 val_loss: 404.5938
epoch 600, loss: 310.6950 val_loss: 339.3514
epoch 800, loss: 255.1574 val_loss: 282.6092
epoch 1000, loss: 207.9300 val_loss: 233.6992
epoch 1200, loss: 168.0744 val_loss: 192.1623
epoch 1400, loss: 134.8108 val_loss: 157.2863
epoch 1600, loss: 107.4093 val_loss: 128.3291
epoch 1800, loss: 85.1776 val_loss: 104.5960
epoch 2000, loss: 67.4564 val_loss: 85.4385
epoch 2200, loss: 53.6185 val_loss: 70.2452
epoch 2400, loss: 43.0692 val_loss: 58.4374
epoch 2600, loss: 35.2481 val_loss: 49.4687
epoch 2800, loss: 29.6342 val_loss: 42.8279
epoch 3000, loss: 25.7527 val_loss: 38.0465
epoch 3200, loss: 23.1822 val_loss: 34.7050
epoch 3400, loss: 21.5619 val_loss: 32.4404
epoch 3600, loss: 20.5965 val_loss: 30.9517
epoch 3800, loss: 20.0568 val_loss: 30.0003
epoch 4000, loss: 19.7761 val_loss: 29.4071
epoch 4200, loss: 19.6413 val_loss: 29.0446
epoch 4400, loss: 19.5821 val_loss: 28.8268
epoch 4600, loss: 19.5586 val_loss: 28.6982
epoch 4800, loss: 19.5503 val_loss: 28.6241 

 

f:id:kazuzo88:20180921141844p:plain

 それらしい結果になった。