如何在 Python 中使用 NumPy argmax() 函数

已发表: 2022-09-14

在本教程中,您将学习如何使用NumPy argmax() 函数来查找数组中最大元素的索引。

NumPy 是一个强大的 Python 科学计算库; 它提供了比 Python 列表更高效的 N 维数组。 使用 NumPy 数组时,您将执行的常见操作之一是查找数组中的最大值。 但是,您有时可能希望找到出现最大值的索引

argmax()函数可帮助您找到一维和多维数组中最大值的索引。 让我们继续了解它是如何工作的。

如何在 NumPy 数组中找到最大元素的索引

要学习本教程,您需要安装 Python 和 NumPy。 您可以通过启动 Python REPL 或启动 Jupyter 笔记本来编写代码。

首先,让我们在通常的别名np下导入 NumPy。

 import numpy as np

您可以使用 NumPy max()函数获取数组中的最大值(可选地沿特定轴)。

 array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10

在这种情况下, np.max(array_1)返回 10,这是正确的。

假设您想找到数组中出现最大值的索引。 您可以采取以下两步方法:

  1. 找到最大元素。
  2. 找到最大元素的索引。

array_1中,最大值 10 出现在索引 4 处,紧随零索引。 第一个元素在索引 0 处; 第二个元素在索引 1 处,依此类推。

numpy-argmax

要查找出现最大值的索引,可以使用 NumPy where() 函数。 np.where(condition)返回conditionTrue的所有索引的数组。

您必须点击数组并访问第一个索引处的项目。 为了找到最大值出现的位置,我们将condition设置为array_1==10 ; 回想一下,10 是array_1中的最大值。

 print(int(np.where(array_1==10)[0])) # Output 4

我们使用了条件的np.where() ,但这不是使用此函数的推荐方法。

注意:NumPy where() 函数
np.where(condition,x,y)返回:

– 条件为True时来自x的元素,并且
– 条件为False时来自y的元素。

因此,链接np.max()np.where()函数,我们可以找到最大元素,然后是它出现的索引。

您可以使用 NumPy argmax() 函数来获取数组中最大元素的索引,而不是上面的两步过程。

NumPy argmax() 函数的语法

使用 NumPy argmax() 函数的一般语法如下:

 np.argmax(array,axis,out) # we've imported numpy under the alias np

在上面的语法中:

  • array是任何有效的 NumPy 数组。
  • 是一个可选参数。 使用多维数组时,您可以使用轴参数来查找沿特定轴的最大值的索引。
  • out是另一个可选参数。 您可以将out参数设置为 NumPy 数组以存储argmax()函数的输出。

注意:从 NumPy 版本 1.22.0 开始,有一个额外的keepdims参数。 当我们在argmax()函数调用中指定axis参数时,数组沿该轴缩小。 但是将keepdims参数设置为True可确保返回的输出与输入数组具有相同的形状。

使用 NumPy argmax() 查找最大元素的索引

#1 。 让我们使用 NumPy argmax() 函数来查找array_1中最大元素的索引。

 array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4

argmax()函数返回 4,这是正确的!

#2 。 如果我们重新定义array_1使得 10 出现两次,那么argmax()函数返回第一次出现的索引。

 array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4

对于其余的示例,我们将使用我们在示例 #1 中定义的array_1的元素。

使用 NumPy argmax() 查找二维数组中最大元素的索引

让我们将 NumPy 数组array_1重塑为一个两行四列的二维数组。

 array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]]

对于二维数组,轴 0 表示行,轴 1 表示列。 NumPy 数组遵循零索引。 所以 NumPy 数组array_2的行和列的索引如下:

numpy-argmax-2darray

现在,让我们在二维数组array_2上调用argmax()函数。

 print(np.argmax(array_2)) # Output 4

即使我们在二维数组上调用argmax() ,它仍然返回 4。这与上一节中一维数组array_1的输出相同。

为什么会这样?

这是因为我们没有为轴参数指定任何值。 如果未设置此轴参数,默认情况下, argmax()函数返回扁平数组中最大元素的索引。

什么是扁平阵列? 如果有一个形状为d1 x d2 x ... x dN的 N 维数组,其中 d1, d2, 直到 dN 是沿 N 维的数组的大小,那么展平的数组是一个大小为长的一维数组d1 * d2 * ... * dN。

要检查array_2的展平数组的外观,可以调用flatten()方法,如下所示:

 array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])

沿行的最大元素的索引(轴 = 0)

让我们继续查找沿行(轴 = 0)的最大元素的索引。

 np.argmax(array_2,axis=0) # Output array([1, 1, 1, 1])

这个输出可能有点难以理解,但我们会理解它是如何工作的。

我们已将axis参数设置为零( axis = 0 ),因为我们希望找到沿行的最大元素的索引。 因此, argmax()函数返回出现最大元素的行号——对于三列中的每一列。

让我们将其可视化以便更好地理解。

numpy-argmax-axis0

从上图和argmax()输出,我们有以下内容:

  • 对于索引 0 处的第一列,最大值10出现在第二行,索引 = 1。
  • 对于索引 1 处的第二列,最大值9出现在第二行,索引 = 1。
  • 对于索引 2 和 3 处的第三和第四列,最大值84都出现在第二行,索引 = 1。

这正是我们有输出array([1, 1, 1, 1])的原因,因为沿行的最大元素出现在第二行(对于所有列)。

沿列的最大元素的索引(轴 = 1)

接下来,让我们使用argmax()函数来查找沿列的最大元素的索引。

运行以下代码片段并观察输出。

 np.argmax(array_2,axis=1)
 array([2, 0])

你能解析输出吗?

我们设置axis = 1来计算沿列的最大元素的索引。

argmax()函数为每一行返回最大值所在的列号。

这是一个视觉解释:

numpy-argmax-axis1

从上图和argmax()输出,我们有以下内容:

  • 对于索引 0 处的第一行,最大值7出现在索引 = 2 处的第三列。
  • 对于索引 1 处的第二行,最大值10出现在索引 = 0 的第一列中。

我希望您现在了解输出array([2, 0])的含义。

在 NumPy argmax() 中使用可选输出参数

您可以使用 NumPy argmax() 函数中的可选参数out将输出存储在 NumPy 数组中。

让我们初始化一个零数组来存储先前的argmax()函数调用的输出 - 沿着列 ( axis= 1 ) 找到最大值的索引。

 out_arr = np.zeros((2,)) print(out_arr) [0. 0.]

现在,让我们重新审视沿列 ( axis = 1 ) 查找最大元素的索引并将out设置为我们上面定义的out_arr的示例。

 np.argmax(array_2,axis=1,out=out_arr)

我们看到 Python 解释器抛出了一个TypeError ,因为out_arr默认初始化为一个浮点数组。

 --------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

因此,在将out参数设置为输出数组时,确保输出数组的形状和数据类型正确非常重要。 由于数组索引总是整数,我们应该在定义输出数组时将dtype参数设置为int

 out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0]

我们现在可以继续使用axisout参数调用argmax()函数,这一次,它运行没有错误。

 np.argmax(array_2,axis=1,out=out_arr)

现在可以在数组out_arr中访问argmax()函数的输出。

 print(out_arr) # Output [2 0]

结论

我希望本教程能帮助您了解如何使用 NumPy argmax() 函数。 您可以在 Jupyter 笔记本中运行代码示例。

让我们回顾一下我们学到了什么。

  • NumPy argmax() 函数返回数组中最大元素的索引。 如果最大元素在数组a中多次出现,则np.argmax(a)返回该元素第一次出现的索引。
  • 使用多维数组时,您可以使用可选的参数来获取沿特定轴的最大元素的索引。 例如,在二维数组中:通过设置axis = 0axis = 1 ,可以分别获取沿行和列的最大元素的索引。
  • 如果您想将返回的值存储在另一个数组中,您可以将可选的out参数设置为输出数组。 但是,输出数组应具有兼容的形状和数据类型。

接下来,查看 Python 集的深入指南。