[解決済み] RuntimeError: 入力型(torch.FloatTensor)とウェイト型(torch.cuda.FloatTensor)は同じであるべきです。
2023-01-08 07:18:56
質問
以下のようなCNNを学習させようとしているのですが、.cuda()に関して同じエラーが出続けていて、どのように修正したらいいのかわかりません。以下は、これまでの私のコードの塊です。
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms, models
from torch.utils.data.sampler import SubsetRandomSampler
data_dir = "/home/ubuntu/ML2/ExamII/train2/"
valid_size = .2
# Normalize the test and train sets with torchvision
train_transforms = transforms.Compose([transforms.Resize(224),
transforms.ToTensor(),
])
test_transforms = transforms.Compose([transforms.Resize(224),
transforms.ToTensor(),
])
# ImageFolder class to load the train and test images
train_data = datasets.ImageFolder(data_dir, transform=train_transforms)
test_data = datasets.ImageFolder(data_dir, transform=test_transforms)
# Number of train images
num_train = len(train_data)
indices = list(range(num_train))
# Split = 20% of train images
split = int(np.floor(valid_size * num_train))
# Shuffle indices of train images
np.random.shuffle(indices)
# Subset indices for test and train
train_idx, test_idx = indices[split:], indices[:split]
# Samples elements randomly from a given list of indices
train_sampler = SubsetRandomSampler(train_idx)
test_sampler = SubsetRandomSampler(test_idx)
# Batch and load the images
trainloader = torch.utils.data.DataLoader(train_data, sampler=train_sampler, batch_size=1)
testloader = torch.utils.data.DataLoader(test_data, sampler=test_sampler, batch_size=1)
#print(trainloader.dataset.classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(pretrained=True)
model.fc = nn.Sequential(nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 10),
nn.LogSigmoid())
# nn.LogSoftmax(dim=1))
# criterion = nn.NLLLoss()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.003)
model.to(device)
#Train the network
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
しかし、コンソールにこのようなエラーが表示され続けます。
RuntimeErrorです。入力タイプ (torch.FloatTensor) とウェイトタイプ (torch.cuda.FloatTensor) は同じであるべきです`。
それを修正する方法について何か考えがありますか?私は、モデルが私のGPUにプッシュされていないかもしれないと読みましたが、それを修正する方法はよくわかりません。ありがとうございます。
どのように解決するのですか?
モデルがGPUにあり、データがCPUにあるため、このエラーが発生します。したがって、入力テンソルをGPUに送る必要があります。
inputs, labels = data # this is what you had
inputs, labels = inputs.cuda(), labels.cuda() # add this line
あるいは、残りのコードとの一貫性を保つために、このようにします。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
inputs, labels = inputs.to(device), labels.to(device)
この 同じエラー は、入力テンソルが GPU 上にあり、モデル重みが GPU 上にない場合に発生します。この場合、モデルの重みをGPUに送る必要があります。
model = MyModel()
if torch.cuda.is_available():
model.cuda()
関連
-
[解決済み】__str__と__repr__の違いは何ですか?
-
[解決済み】type()とisinstance()の違いは何ですか?)
-
[解決済み] Pythonのキャッシュライブラリはありますか?
-
[解決済み] django.db.migrations.exceptions.InconsistentMigrationHistory
-
[解決済み] Pythonのインスタンス変数とクラス変数
-
[解決済み] Python 2.7サポート終了?
-
[解決済み] Jupyter (IPython)ノートブックのセッションをpickleして保存する方法
-
[解決済み] Python Logging でログメッセージが2回表示される件
-
[解決済み] Celeryタスクのユニットテストはどのように行うのですか?
-
[解決済み] Pythonでランダムなファイル名を生成する最良の方法
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
[解決済み] Pythonのキャッシュライブラリはありますか?
-
[解決済み] PILからopenCVフォーマットへの変換
-
[解決済み] PythonでSVGからPNGに変換する
-
[解決済み] バブルソートの宿題
-
[解決済み] Cythonのコードを含むPythonパッケージはどのように構成すればよいのでしょうか?
-
[解決済み] 異なる順序で同じ要素を持つ2つのJSONオブジェクトを等しく比較するには?
-
[解決済み] Python Logging でログメッセージが2回表示される件
-
[解決済み] Flask でグローバル変数はスレッドセーフか?リクエスト間でデータを共有するには?
-
[解決済み] matplotlib でプロットの軸、目盛、ラベルの色を変更する方法
-
[解決済み] Flaskで非同期タスクを作る