开发者

PyTorch torch.unique() 基础与实战应用指南

开发者 https://www.devze.com 2025-10-30 09:26 出处:网络 作者: Geoking.
目录一、什么是torch.unique()?二、函数语法三、参数说明四、基本用法 示例 1:基础去重 示例 2:不排序五、返回索引与计数 示例 3:return_inverse 示例 4:return_counts 示例 5:同时返回多个结果六、按维度去重
目录
  • 一、什么是torch.unique()?
  • 二、函数语法
  • 三、参数说明
  • 四、基本用法
    • 示例 1:基础去重
    • 示例 2:不排序
  • 五、返回索引与计数
    • 示例 3:return_inverse
    • 示例 4:return_counts
    • 示例 5:同时返回多个结果
  • 六、按维度去重(dim 参数)
    • 示例 6:按行去重
    • 示例 7:按列去重
  • 七、torch.unique()与 NumPy 对比
    • 八、实际应用场景
      • 1. 分类问题中统计类别数量python
      • 2. 计算样本分布(类别频率)
      • 3. 在图像分割中统计像素类别
    • ⚠️ 九、注意事项
      • 参考资料

        在深度学习的数据处理中经常需要统计或筛选 张量(Tensor) 中的唯一值,比如去重、统计类别数量、计算唯一标签数等。

        PyTorch 提供了一个非常方便的函数 —— torch.unique(),可以轻松完成这些操作。

        本文将带你深入了解 torch.unique() 的用法、参数、返回值以及实际应用场景。

        一、什么是torch.unique()?

        torch.unique() 是 PyTorch 中的一个去重函数,用于返回张量中所有的唯一元素(unique elements)。

        它类似于 python 的 set() 或 NumPy 的 np.unique(),但专为 GPU 加速的张量操作 设计。

        二、函数语法

        torch.unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None)
        

        三、参数说明

        参数类型说明
        inputTensor输入张量
        sortedbool是否对结果排序(默认 True
        return_inversebool是否返回原张量中每个值在唯一值列表中的索引
        return_countsboivvbuol是否返回每个唯一值的出现次数
        dimintNone按指定维度去重,默认对整个张量去重

        四、基本用法

        示例 1:基础去重

        import torch
        x = torch.tensor([1, 2, 2, 3, 3, 3])
        unique_x = torch.unique(x)
        print(unique_x)

        输出:

        tensor([1, 2, 3])
        

        ✅ 结果去除了重复值,并自动排序。

        示例 2:不排序

        x = torch.tensor([3, 2, 1, 3, 2])
        unique_x = torch.unique(x, sorted=False)
        print(unique_x)
        

        输出:

        tensor([3, 2, 1])
        

        sorted=False 时,结果的顺序与首次出现的顺序一致。

        五、返回索引与计数

        示例 3:return_inverse

        return_inverse=True 会返回一个索引张量,表示原张量中每个元素在唯一值(即新张量)中的位置。

        x = torch.tensor([2, 1, 2, 3])
        u, inv = torch.unique(x, return_inverse=True)
        print(u)
        print(inv)ivvbu

        输出:

        tensor([1, 2, 3])
        tensor([1, 0, 1, 2])
        

        解释:

        • 唯一值为 [1, 2, 3]
        • 原数组 [2, 1, 2, 3] 中:
          • 第一个元素 2 → 索引 1
          • 第二个元素 1 → 索引 0
          • 第三个元素 2 → 索引 1
          • 第四个元素 3 → 索引 2

        示例 4:return_counts

        return_counts=True 会返回每个唯一值出现的次数。

        x = torch.tensor([1, 2, 2, 3, 3javascript, 3])
        u, counts = torch.unique(x, return_counts=True)
        print(u)
        print(counts)

        输出:

        tensor([1, 2, 3])
        tensor([1, 2, 3])
        

        表示:

        • 值 1 出现 1 次
        • 值 2 出现 2 次
        • 值 3 出现 3 次

        示例 5:同时返回多个结果

        你可以同时返回 unique 值、inverse 索引和计数

        x = torch.tensor([1, 2, 2, 3, 3, 3])
        u, inv, counts = torch.unique(x, return_inverse=True, return_counts=True)
        print(u)
        print(inv)
        print(counts)
        

        输出:

        tensor([1, 2, 3])
        tensor([0, 1, 1, 2, 2, 2])
        tensor([1, 2, 3])
        

        六、按维度去重(dim 参数)

        默认情况下,torch.unique() 会将张量展开成一维后去重。

        但如果你希望在特定维度上去重(如按行或按列),可以使用 dim 参数。

        示例 6:按行去重

        x = torch.tensor([[1, 2],
                          [1, 2],
                          [3, 4]])
        unique_rows = torch.unique(x, dim=0)
        print(unique_rows)

        输出:

        tensor([[1, 2],
                [3, 4]])
        

        表示第 1、2 行重复,只保留一个。

        示例 7:按列去重

        x = torch.tensor([[1, 1, 3],
                          [2, 2, 4]])
        unique_cols = torch.unique(x, dim=1)
        print(unique_cols)

        输出:

        tensor([[1, 3],
                [2, 4]])
        

        七、torch.unique()与 NumPy 对比

        功能PyTorch (torch.unique)NumPy (np.unique)
        默认排序✅ 是✅ 是
        支持 GPU✅ 是❌ 否
        返回 inverse 索引✅ 是✅ 是
        返回 counts✅ 是✅ 是
        按维度去重✅ 是(dim❌ 不直接支持
        性能高(GPU 支持)仅 CPU

        八、实际应用场景http://www.devze.com

        1. 分类问题中统计类别数量

        labels = torch.tensor([0, 1, 0, 2, 2, 1, 3])
        classes = torch.unique(labels)
        print(f"共有 {len(classes)} 个类别: {classes.tolist()}")
        

        输出:

        共有 4 个类别: [0, 1, 2, 3]
        

        2. 计算样本分布(类别频率)

        labels = torch.tensor([0, 1, 0, 2, 2, 1, 3])
        u, counts = torch.unique(labels, return_counts=True)
        for c, cnt in zip(u.tolist(), counts.tolist()):
            print(f"类别 {c}: {cnt} 个样本")
        

        输出:

        类别 0: 2 个样本
        类别 1: 2 个样本
        类别 2: 2 个样本
        类别 3: 1 个样本
        

        3. 在图像分割中统计像素类别

        例如在语义分割任务中,计算 mask 图像中有多少个不同的像素类别:

        mask = torch.randint(0, 5, (256, 256))  # 随机生成类别标签
        num_classes = len(torch.unique(mask))
        print(f"图像中共有 {num_classes} 个类别")
        

        ⚠️ 九、注意事项

        1. torch.unique()** 默认会对结果排序**,如果在意性能,可以设置 sorted=False
        2. 对高维张量使用 dim 去重时,必须保证该维度的所有元素形状一致。
        3. 对大张量使用 return_countsreturn_inverse 时可能会消耗更多显存。

        参考资料

        • PyTorch 官方文档 – torch.unique

        PyTorch torch.unique() 基础与实战应用指南

        NumPy 官方文档 – numpy.unique

        PyTorch torch.unique() 基础与实战应用指南

        到此这篇关于PyTorch torch.unique() 基础与实战应用指南的文章就介绍到这了,更多相关PyTorch torch.unique() 使用内容请搜索编程客栈(www.devze.com)以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程客栈(www.devze.com)!

        0

        精彩评论

        暂无评论...
        验证码 换一张
        取 消

        关注公众号