PyTorchで行列を用いて連立方程式を解くtorch.linalg.solve

行列を用いて連立方程式を解く際、torch.linalg.solveを用いて解を求めることができます。公式ドキュメントのサンプルコードをベースに、使い方を見ていきます。

torch.linalg.solve — PyTorch 2.1 documentation

torch.linalg.solveの基本的な使い方

この関数は、以下のような行列による式の解\(X\)を求めることができます。

$$AX = B$$

ただし、\(A\in\mathbb{R}^{n\times n}, B\in\mathbb{R}^{n\times k}, X\in\mathbb{R}^{n\times k}\) とする。

torch.linalg.solveの第一引数に\(A\)を、第二引数に\(B\)を入力します。また、上のような式を解く場合はleft=Trueを指定します。

import torch

A = torch.randn(3, 3)
b = torch.randn(3)

# 解を求める
x = torch.linalg.solve(A, b, left=True)

# AXがBに一致しているか確認
torch.allclose(A @ x, b)
# True

次に、left=Falseとして、以下のような式を解いてみます。

$$XA = B$$

ただし、\(A\in\mathbb{R}^{k\times k}, B\in\mathbb{R}^{n\times k}, X\in\mathbb{R}^{n\times k}\) です。

A = torch.randn(3, 3)
b = torch.randn(1,3)

# 解を求める
x = torch.linalg.solve(A, b, left=False)

# XAがBに一致しているか確認
torch.allclose(x @ A, b)
# True

バッチごと計算する

\(A\)は、正方行列が複数集まったバッチでも解くことができます。以下の例では、\(A, B\)の形状がそれぞれ(*, n, n), (*, n, k)の場合です。ただし、*がバッチサイズを表します。

A = torch.randn(2, 3, 3)
B = torch.randn(2, 3, 4)
X = torch.linalg.solve(A, B)

torch.allclose(A @ X, B)
# True

\(B\)が(*, n, 1)の場合は、以下のようにブロードキャストされます。

A = torch.randn(2, 3, 3)
b = torch.randn(3, 1)
x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3, 1)

torch.allclose(A @ x, b)
# True

複素数を用いる

また、torch.linalg.solveは複素数にも対応しています。

A = torch.randn(2, 3, 3, dtype=torch.complex64)
B = torch.randn(2, 3, 4, dtype=torch.complex64)
X = torch.linalg.solve(A, B)

torch.allclose(A @ X, B)
# True

\(A, B\)のdtypeが一致していないと、エラーとなってしまいます。

A = torch.randn(2, 3, 3, dtype=torch.complex64)
B = torch.randn(2, 3, 4, dtype=torch.float32)
X = torch.linalg.solve(A, B)
# RuntimeError: linalg.solve: Expected A and B to have the same dtype, but found A of type ComplexFloat and B of type Float instead

コメント

タイトルとURLをコピーしました