|
详解Python_Numpy库函数take_along_axis()【由索引矩阵生成新的矩阵】
提问:由已有矩阵的索引生成新的矩阵为什么要用函数take_along_axis(),我用Numpy库ndarray对象的切片操作不行么?
答案是:Numpy库ndarray对象的切片操作不是万能的,比如下面的两种情况它就不能解决,而下面两种情况可以用函数take_along_axis()解决。
情况一:
我由argsort()函数得到了矩阵元素按从小到大排序的索引,接下来我想由个这个排序索引得到一个新的矩阵,这个新矩阵的元素就是按从小到大排列的。这种情况下光靠切片操作就很难实现这个功能了。不信的话诸君可以试一试,反正昊虹君是试了的,很麻烦。但是此时用函数take_along_axis()就很方便,示例如下:
- import numpy as np
- A = np.array([[10, 30, 20], [60, 40, 50]])
- B = np.sort(A, axis=1)
- index1 = np.argsort(A, axis=1)
- C = np.take_along_axis(A, index1, axis=1)
复制代码
运行结果如下:
从这个示例可以看出,函数take_along_axis()很方便的帮我们通过索引值矩阵index1按序取出了A中的元素形成了数组C。
情况二:
现有三维矩阵A如下:
- A = np.arange(2*3*4).reshape([2, 3, 4])
复制代码
现在要实现下面这个目标:
选取A的第0页的第1行和A的第1页的第2行构成一个新的三维矩阵B,B矩阵的形状为(2, 1, 4)。
这个目标用切片操作是无法实现的,昊虹君也尝试过直接用切片实现这个目标,但无奈没有成功。
不过这个目标用函数take_along_axis()就很容易实现了,实现的代码如下:
- # -*- coding: utf-8 -*-
- # 出处:昊虹AI笔记网(hhai.cc)
- # 用心记录计算机视觉和AI技术
- # 博主微信/QQ 2487872782
- # QQ群 271891601
- # 欢迎技术交流与咨询
- # OpenCV的版本为4.4.0
- import numpy as np
- A = np.arange(2*3*4).reshape([2, 3, 4])
- index1 = np.zeros([2, 1, 1]).astype('int')
- index1[0, 0, :] = 1
- index1[1, 0, :] = 2
- B = np.take_along_axis(A, index1, axis=1)
复制代码
运行结果如下:
具体是怎么实现的,参考博文https://blog.csdn.net/baidu_37157624/article/details/123124561,
并仔细思考后得到其实现原理的精炼理解如下:
①显然,B矩阵的形状为为(2, 1, 4),又加上我们是以行为单位进行数据选取,即最小选取单位为一行,此时元素的列索引无意义,所以索引矩阵的形状为index1(2,1,1)。
②当axis=1时,有:
索引矩阵每个元素自身索引值的行索引值代表新矩阵B中的行索引值;
索引矩阵每个元素自身索引值的页索引值代表原矩阵A和新矩阵B的页索引值;
索引矩阵每个元素的值代表原矩阵A中的行索引值;
由于最小选取单位为一行,所以这里列索引值不用考虑。
基于以上认识,所以有:
index1[0, 0, :] = 1 代表取A的第0页的第1行形成B的第0页第0行;
index1[1, 0, :] = 2 代表取A的第1页的第2行形成B的第0页第0行。
补充说明:
①使用函数take_along_axis()时要注意,索引矩阵index1的维度数应该和原矩阵A的维度数相同。
②二维以下时实现上面的功能是完全可以用ndarray的切片或方法take()实现的。
关于ndarray的切片操作的详细介绍见博文 https://www.hhai.cc/thread-117-1-1.html
关于方法take()的详细介绍见博文 https://www.hhai.cc/thread-121-1-1.html |
|