开发者

PyTorch 中mm和bmm函数的使用示例详解

开发者 https://www.devze.com 2025-06-24 10:00 出处:网络 作者: 点云SLAM
目录一、函数定义二、使用条件和注意事项三、示例代码示例 1:基本矩阵乘法示例 2:不匹配维度导致报错示例 3:推荐写法(推荐使用 @ 或 matmul)四、与其他乘法函数的比较五、典型应用场景六、总结:
目录
  • 一、函数定义
  • 二、使用条件和注意事项
  • 三、示例代码
    • 示例 1:基本矩阵乘法
    • 示例 2:不匹配维度导致报错
    • 示例 3:推荐写法(推荐使用 @ 或 matmul)
  • 四、与其他乘法函数的比较
    • 五、典型应用场景
      • 六、总结:什么时候用 mm?
        • 一、torch.bmm 语法
          • 二、示例演示
            • 示例 1:基础用法
            • 示例 2:手动循环 vs bmm 效率对比
          • 三、注意事项
            • 1. 维度必须是三维张量
            • 2. 维度必须满足矩阵乘法规则
            • 3. bmm 不支持广播(broadcasting)
          • 四、在实际应用中的例子
            • 在点云变换中:批量乘旋转矩阵
          • 五、总结

            torch.mm 是 PyTorch 中用于 二维矩阵乘法(matrix-matrix multiplication) 的http://www.devze.com函数,等价于数学中的 A × B 矩阵乘积。

            一、函数定义

            torch.mm(input, mat2) → Tensor

            执行的是两个 2D Tensor(矩阵)的标准矩阵乘法。

            • input: 第一个二维张量,形状为 (n × m)
            • mat2: 第二个二维张量http://www.devze.com,形状为 (m × p)
            • 返回:形状为 (n × p) 的张量

            二、使用条件和注意事项

            条件说明
            仅支持 2D 张量一维或三维以上使用 torch.matmul 或 @ 操作符
            维度要匹配即 input.shape[1] == mat2.shape[0]
            不支持广播两个矩阵维度不匹配会直接报错
            结果是普通矩阵乘积不是逐元素乘法(Hadamard),即不是 * 或 torch.mul()

            三、示例代码

            示例 1:基本矩阵乘法

            import torch
            A = torch.tensor([[1., 2.], [3., 4.]])   # 2x2
            B = torch.tensor([[5., 6.], [7., 8.]])   # 2x2
            C = torch.mm(A, B)
            print(C)

            输出:

            tensor([[19., 22.],

                    [43., 50.]])

            计算步骤:

            C[0][0] = 1*5 + 2*7 = 19
            C[0][1] = 1*6 + 2*8 = 22
            ...

            示例 2:不匹配维度导致报错

            A = torch.rand(2, 3)
            B = torch.rand(4, 2)
            C = torch.mm(A, B)  # ❌ 会报错

            报错:

            RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 4x2)

            示例 3:推荐写法(推荐使用 @ 或 matmul)

            A = torch.rand(3, 4)
            B = torch.rand(4, 5)
            C1 = torch.mm(A, B)
            C2 = A @ B                # 推荐用法
            C3 = torch.matmul(A, B)   # 推荐用法

            四、与其他乘法函数的比较

            函数名支持维度运算类型支持广播
            torch.mm仅限二维矩阵乘法❌ 不支持
            torch.matmul1D, 2D, ND自动判断点乘 / 矩阵乘✅ 支持
            torch.bmm批量二维乘法3D Tensor BATch × batch❌ 不支持
            torch.mul任意维度元素乘(Hadamard)✅ 支持
            * 运算符任意维度元素乘✅ 支持
            @ 运算符ND(推荐用)矩阵乘法(和 matmul 一样)

            五、典型应用场景

            • 神经网络权重乘法:output = torch.mm(W, x)
            • 点云 / 图像变换:x' = torch.mm(R, x) + t
            • 多层感知机中的矩阵计算
            • 注意力机制中 QK^T 乘积

            六、总结:什么时候用 mm?

            使用场景用什么
            仅二维矩阵乘法torch.mm
            高维或支持广播乘法torch.matmul / @
            批量矩阵乘法 (如 batch_size×3×3)torch.bmm
            元素乘torch.mul or *

            在 PyTorch 中,torch.bmm 是 批量矩阵乘法(batch matrix multiplication) 的操作,专用于处理三维张量(batch of matrices)。它的主要作用是对一组矩阵成对进行乘法,效率远高于手动循环计算。

            一、torch.bmm 语法

            torch.bmm(input, mat2, *, out=None) → Tensor
            • inputTensor,形状为 (B, N, M)
            • mat2Tensor,形状为 (B, M, P)
            • 返回结果形状为 (B, N, P)

            这表示对 B 对 N×M 和 M×P 的矩阵进行成对相乘。

            二、示例演示

            示例 1:基础用法

            import torch
            # 定义两个 batch 矩阵
            A = torch.randn(4, 2, 3)  # shape: (B=4, N=2, M=3)
            B = torch.randn(4, 3, 5)  # shape: (B=4, M=3, P=5)
            # 批量矩阵乘法
            C = torch.bmm(A, B)       # shape: (4, 2, 5)
            print(C.shape)  # 输出: torch.Size([4, 2, 5])

            示例 2:手动循环 vs bmm 效率对比

            # 慢速手动方式
            C_manual = torch.stack([A[i] @ B[i] for i in range(A.size(0))])
            # 等效于 bmm
            C_bmm = torch.bmm(A, B)
            print(torch.allclose(C_manual, C_bmm))  # True

            三、注意事项

            1. 维度必须是三维张量

            • 否则会报错:
            RuntimeError: batch1 must be a 3D tensor

            你可以通过 .unsqueeze() 手动调整维度:

            a = torch.randn(2, 3)
            b = torch.randn(3, 4)
            # 升维
            a_batch = a.unsqueeze(0)  # (1, 2, 3)
            b_batch = b.unsqueeze(0)  # (1, 3, 4)
            c = torch.bmm(a_batch, b_batch)  # (1, 2, 4)

            2. 维度必须满足矩阵乘法规则

            • (B, N, M) × (B, M, P) → (B, N, P)
            • 若 M 不一致会报错:
            RuntimeError: Expected size for the second dimension of batch2 tensor to match the first dimension oqUVCYCydgf batch1 tensor

            3. bmm 不支持广播(broadcasting)

            • 必须显式提供相同的 batch sijsze。
            • 如果只有一个矩阵固定,可以使用 .expand()
            A = torch.randn(1, 2, 3)  # 单个矩阵
            B = torch.randn(4, 3, 5)  # 4 个矩阵
            # 扩展 A 以进行 batch 乘法
            A_expand = A.expand(4, -1, -1)
            C = torch.bmm(A_expand, B)  # (4, 2, 5)

            四、在实际应用中的例子

            在点云变换中:批量乘旋转矩阵

            # 假设有 B 个旋转矩阵和点坐标
            R = torch.randn(B, 3, 3)       # 旋转矩阵
            points = torch.randn(B, 3, N)  # 点云
            # 先转置点坐标为 (B, N, 3)
            points_T = points.transpose(1, 2)  # (B, N, 3)
            # 用 bmm 做点变换:每组点乘旋转
            transformed = torch.bmm(points_T, R.transpose(1, 2))  # (B, N, 3)

            五、总结

            特性torch.bmm
            操作对象三维张量(batch of matrices)
            核心规则(B, N, M) x (B, M, P) = (B, N, P)
            是否支持广播❌ 不支持,需要手动 .expand()
            与 matmul 区别matmul 支持更多广播,bmm 更高效用于纯批量矩阵乘法
            应用场景批量线性变换、点云配准、神经网络前向传播等

            到此这篇关于PyTorch 中mm和bmm函数的使用详解的文章就介绍到这了,更多相关PyTorch mm和bmm函数内容请搜索编程客栈(androidwww.devze.com)以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程客栈(www.devze.com)!

            0

            精彩评论

            暂无评论...
            验证码 换一张
            取 消

            关注公众号