如何在 Python 中使用 NumPy argmax() 函数
已发表: 2022-09-14在本教程中,您将学习如何使用NumPy argmax() 函数来查找数组中最大元素的索引。
NumPy 是一个强大的 Python 科学计算库; 它提供了比 Python 列表更高效的 N 维数组。 使用 NumPy 数组时,您将执行的常见操作之一是查找数组中的最大值。 但是,您有时可能希望找到出现最大值的索引。
argmax()
函数可帮助您找到一维和多维数组中最大值的索引。 让我们继续了解它是如何工作的。
如何在 NumPy 数组中找到最大元素的索引
要学习本教程,您需要安装 Python 和 NumPy。 您可以通过启动 Python REPL 或启动 Jupyter 笔记本来编写代码。
首先,让我们在通常的别名np
下导入 NumPy。
import numpy as np
您可以使用 NumPy max()
函数获取数组中的最大值(可选地沿特定轴)。
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10
在这种情况下, np.max(array_1)
返回 10,这是正确的。
假设您想找到数组中出现最大值的索引。 您可以采取以下两步方法:
- 找到最大元素。
- 找到最大元素的索引。
在array_1
中,最大值 10 出现在索引 4 处,紧随零索引。 第一个元素在索引 0 处; 第二个元素在索引 1 处,依此类推。

要查找出现最大值的索引,可以使用 NumPy where() 函数。 np.where(condition)
返回condition
为True
的所有索引的数组。
您必须点击数组并访问第一个索引处的项目。 为了找到最大值出现的位置,我们将condition
设置为array_1==10
; 回想一下,10 是array_1
中的最大值。
print(int(np.where(array_1==10)[0])) # Output 4
我们只使用了条件的np.where()
,但这不是使用此函数的推荐方法。
注意:NumPy where() 函数:
np.where(condition,x,y)
返回:– 条件为
True
时来自x
的元素,并且
– 条件为False
时来自y
的元素。
因此,链接np.max()
和np.where()
函数,我们可以找到最大元素,然后是它出现的索引。
您可以使用 NumPy argmax() 函数来获取数组中最大元素的索引,而不是上面的两步过程。
NumPy argmax() 函数的语法
使用 NumPy argmax() 函数的一般语法如下:
np.argmax(array,axis,out) # we've imported numpy under the alias np
在上面的语法中:
- array是任何有效的 NumPy 数组。
- 轴是一个可选参数。 使用多维数组时,您可以使用轴参数来查找沿特定轴的最大值的索引。
- out是另一个可选参数。 您可以将
out
参数设置为 NumPy 数组以存储argmax()
函数的输出。
注意:从 NumPy 版本 1.22.0 开始,有一个额外的
keepdims
参数。 当我们在argmax()
函数调用中指定axis
参数时,数组沿该轴缩小。 但是将keepdims
参数设置为True
可确保返回的输出与输入数组具有相同的形状。
使用 NumPy argmax() 查找最大元素的索引
#1 。 让我们使用 NumPy argmax() 函数来查找array_1
中最大元素的索引。
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4
argmax()
函数返回 4,这是正确的!
#2 。 如果我们重新定义array_1
使得 10 出现两次,那么argmax()
函数只返回第一次出现的索引。
array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4
对于其余的示例,我们将使用我们在示例 #1 中定义的array_1
的元素。
使用 NumPy argmax() 查找二维数组中最大元素的索引
让我们将 NumPy 数组array_1
重塑为一个两行四列的二维数组。
array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]]
对于二维数组,轴 0 表示行,轴 1 表示列。 NumPy 数组遵循零索引。 所以 NumPy 数组array_2
的行和列的索引如下:

现在,让我们在二维数组array_2
上调用argmax()
函数。
print(np.argmax(array_2)) # Output 4
即使我们在二维数组上调用argmax()
,它仍然返回 4。这与上一节中一维数组array_1
的输出相同。
为什么会这样?
这是因为我们没有为轴参数指定任何值。 如果未设置此轴参数,默认情况下, argmax()
函数返回扁平数组中最大元素的索引。
什么是扁平阵列? 如果有一个形状为d1 x d2 x ... x dN的 N 维数组,其中 d1, d2, 直到 dN 是沿 N 维的数组的大小,那么展平的数组是一个大小为长的一维数组d1 * d2 * ... * dN。
要检查array_2
的展平数组的外观,可以调用flatten()
方法,如下所示:
array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])
沿行的最大元素的索引(轴 = 0)
让我们继续查找沿行(轴 = 0)的最大元素的索引。
np.argmax(array_2,axis=0) # Output array([1, 1, 1, 1])
这个输出可能有点难以理解,但我们会理解它是如何工作的。
我们已将axis
参数设置为零( axis = 0
),因为我们希望找到沿行的最大元素的索引。 因此, argmax()
函数返回出现最大元素的行号——对于三列中的每一列。
让我们将其可视化以便更好地理解。

从上图和argmax()
输出,我们有以下内容:
- 对于索引 0 处的第一列,最大值10出现在第二行,索引 = 1。
- 对于索引 1 处的第二列,最大值9出现在第二行,索引 = 1。
- 对于索引 2 和 3 处的第三和第四列,最大值8和4都出现在第二行,索引 = 1。
这正是我们有输出array([1, 1, 1, 1])
的原因,因为沿行的最大元素出现在第二行(对于所有列)。
沿列的最大元素的索引(轴 = 1)
接下来,让我们使用argmax()
函数来查找沿列的最大元素的索引。
运行以下代码片段并观察输出。
np.argmax(array_2,axis=1)
array([2, 0])
你能解析输出吗?
我们设置axis = 1
来计算沿列的最大元素的索引。
argmax()
函数为每一行返回最大值所在的列号。
这是一个视觉解释:

从上图和argmax()
输出,我们有以下内容:
- 对于索引 0 处的第一行,最大值7出现在索引 = 2 处的第三列。
- 对于索引 1 处的第二行,最大值10出现在索引 = 0 的第一列中。
我希望您现在了解输出array([2, 0])
的含义。
在 NumPy argmax() 中使用可选输出参数
您可以使用 NumPy argmax() 函数中的可选参数out
将输出存储在 NumPy 数组中。
让我们初始化一个零数组来存储先前的argmax()
函数调用的输出 - 沿着列 ( axis= 1
) 找到最大值的索引。
out_arr = np.zeros((2,)) print(out_arr) [0. 0.]
现在,让我们重新审视沿列 ( axis = 1
) 查找最大元素的索引并将out
设置为我们上面定义的out_arr
的示例。
np.argmax(array_2,axis=1,out=out_arr)
我们看到 Python 解释器抛出了一个TypeError
,因为out_arr
默认初始化为一个浮点数组。
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
因此,在将out
参数设置为输出数组时,确保输出数组的形状和数据类型正确非常重要。 由于数组索引总是整数,我们应该在定义输出数组时将dtype
参数设置为int
。
out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0]
我们现在可以继续使用axis
和out
参数调用argmax()
函数,这一次,它运行没有错误。
np.argmax(array_2,axis=1,out=out_arr)
现在可以在数组out_arr
中访问argmax()
函数的输出。
print(out_arr) # Output [2 0]
结论
我希望本教程能帮助您了解如何使用 NumPy argmax() 函数。 您可以在 Jupyter 笔记本中运行代码示例。
让我们回顾一下我们学到了什么。
- NumPy argmax() 函数返回数组中最大元素的索引。 如果最大元素在数组a中多次出现,则np.argmax(a)返回该元素第一次出现的索引。
- 使用多维数组时,您可以使用可选的轴参数来获取沿特定轴的最大元素的索引。 例如,在二维数组中:通过设置axis = 0和axis = 1 ,可以分别获取沿行和列的最大元素的索引。
- 如果您想将返回的值存储在另一个数组中,您可以将可选的out参数设置为输出数组。 但是,输出数组应具有兼容的形状和数据类型。
接下来,查看 Python 集的深入指南。