[解決済み] RuntimeError "Expected object of scalar type Float but got scalar type Double for argument" を修正する方法は?
2022-03-04 15:33:18
質問
PyTorchを使って分類器を学習させようとしています。しかし、モデルに学習データを与えると、学習で問題が発生します。
以下のようなエラーが発生します。
y_pred = model(X_trainTensor)
:
RuntimeError: スカラータイプの Float オブジェクトを期待したが、引数 #4 'mat1' にスカラータイプの Double が渡された。
以下は、私のコードの主要部分です。
# Hyper-parameters
D_in = 47 # there are 47 parameters I investigate
H = 33
D_out = 2 # output should be either 1 or 0
# Format and load the data
y = np.array( df['target'] )
X = np.array( df.drop(columns = ['target'], axis = 1) )
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size = 0.8) # split training/test data
X_trainTensor = torch.from_numpy(X_train) # convert to tensors
y_trainTensor = torch.from_numpy(y_train)
X_testTensor = torch.from_numpy(X_test)
y_testTensor = torch.from_numpy(y_test)
# Define the model
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
nn.LogSoftmax(dim = 1)
)
# Define the loss function
loss_fn = torch.nn.NLLLoss()
for i in range(50):
y_pred = model(X_trainTensor)
loss = loss_fn(y_pred, y_trainTensor)
model.zero_grad()
loss.backward()
with torch.no_grad():
for param in model.parameters():
param -= learning_rate * param.grad
解決方法は?
参照元は このgithubの問題 .
エラーの場合
RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1'
を使用する必要があります。
.float()
と書かれているので、関数
Expected object of scalar type Float
.
したがって、解答を変更する
y_pred = model(X_trainTensor)
から
y_pred = model(X_trainTensor.float())
.
同様に、別のエラーが発生した場合
loss = loss_fn(y_pred, y_trainTensor)
が必要です。
y_trainTensor.long()
というエラーメッセージが表示されるので
Expected object of scalar type Long
.
また、次のようにすることもできます。
model.double()
は、@Paddy が提案したように
.
関連
-
Pythonコンテナのための組み込み汎用関数操作
-
Python関数の高度な応用を解説
-
任意波形を生成してtxtで保存するためのPython実装
-
[解決済み】RuntimeWarning: 割り算で無効な値が発生しました。
-
[解決済み】なぜ「LinAlgError: Grangercausalitytestsから「Singular matrix」と表示されるのはなぜですか?
-
[解決済み] データ型が理解できない
-
[解決済み】Pythonスクリプトで「Expected 2D array, got 1D array instead: 」というエラーが発生?
-
[解決済み】csv.Error:イテレータはバイトではなく文字列を返すべき
-
[解決済み】ImportError: bs4という名前のモジュールがない(BeautifulSoup)
-
[解決済み】「OverflowError: Python int too large to convert to C long" on windows but not mac
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
Python jiabaライブラリの使用方法について説明
-
Python interpreted model libraryによる機械学習モデル出力の可視化 Shap
-
Python LeNetネットワークの説明とpytorchでの実装
-
Python Pillow Image.save jpg画像圧縮問題
-
FacebookオープンソースワンストップサービスpythonのタイミングツールKats詳細
-
[解決済み】なぜ「LinAlgError: Grangercausalitytestsから「Singular matrix」と表示されるのはなぜですか?
-
[解決済み】numpyの配列連結。"ValueError:すべての入力配列は同じ次元数でなければならない"
-
[解決済み】"No JSON object could be decoded "よりも良いエラーメッセージを表示する。
-
[解決済み】IndexError: invalid index to scalar variableを修正する方法
-
[解決済み】Flask ImportError: Flask という名前のモジュールがない