如何在 Python 中使用 NumPy argmax() 函數

已發表: 2022-09-14

在本教程中,您將學習如何使用NumPy argmax() 函數來查找數組中最大元素的索引。

NumPy 是一個強大的 Python 科學計算庫; 它提供了比 Python 列表更高效的 N 維數組。 使用 NumPy 數組時,您將執行的常見操作之一是查找數組中的最大值。 但是,您有時可能希望找到出現最大值的索引

argmax()函數可幫助您找到一維和多維數組中最大值的索引。 讓我們繼續了解它是如何工作的。

如何在 NumPy 數組中找到最大元素的索引

要學習本教程,您需要安裝 Python 和 NumPy。 您可以通過啟動 Python REPL 或啟動 Jupyter 筆記本來編寫代碼。

首先,讓我們在通常的別名np下導入 NumPy。

 import numpy as np

您可以使用 NumPy max()函數獲取數組中的最大值(可選地沿特定軸)。

 array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10

在這種情況下, np.max(array_1)返回 10,這是正確的。

假設您想找到數組中出現最大值的索引。 您可以採取以下兩步方法:

  1. 找到最大元素。
  2. 找到最大元素的索引。

array_1中,最大值 10 出現在索引 4 處,緊隨零索引。 第一個元素在索引 0 處; 第二個元素在索引 1 處,依此類推。

numpy-argmax

要查找出現最大值的索引,可以使用 NumPy where() 函數。 np.where(condition)返回conditionTrue的所有索引的數組。

您必須點擊數組並訪問第一個索引處的項目。 為了找到最大值出現的位置,我們將condition設置為array_1==10 ; 回想一下,10 是array_1中的最大值。

 print(int(np.where(array_1==10)[0])) # Output 4

我們使用了條件的np.where() ,但這不是使用此函數的推薦方法。

注意:NumPy where() 函數
np.where(condition,x,y)返回:

– 條件為True時來自x的元素,並且
– 條件為False時來自y的元素。

因此,鏈接np.max()np.where()函數,我們可以找到最大元素,然後是它出現的索引。

您可以使用 NumPy argmax() 函數來獲取數組中最大元素的索引,而不是上面的兩步過程。

NumPy argmax() 函數的語法

使用 NumPy argmax() 函數的一般語法如下:

 np.argmax(array,axis,out) # we've imported numpy under the alias np

在上面的語法中:

  • array是任何有效的 NumPy 數組。
  • 是一個可選參數。 使用多維數組時,您可以使用軸參數來查找沿特定軸的最大值的索引。
  • out是另一個可選參數。 您可以將out參數設置為 NumPy 數組以存儲argmax()函數的輸出。

注意:從 NumPy 版本 1.22.0 開始,有一個額外的keepdims參數。 當我們在argmax()函數調用中指定axis參數時,數組沿該軸縮小。 但是將keepdims參數設置為True可確保返回的輸出與輸入數組具有相同的形狀。

使用 NumPy argmax() 查找最大元素的索引

#1 。 讓我們使用 NumPy argmax() 函數來查找array_1中最大元素的索引。

 array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4

argmax()函數返回 4,這是正確的!

#2 。 如果我們重新定義array_1使得 10 出現兩次,那麼argmax()函數返回第一次出現的索引。

 array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4

對於其餘的示例,我們將使用我們在示例 #1 中定義的array_1的元素。

使用 NumPy argmax() 查找二維數組中最大元素的索引

讓我們將 NumPy 數組array_1重塑為一個兩行四列的二維數組。

 array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]]

對於二維數組,軸 0 表示行,軸 1 表示列。 NumPy 數組遵循零索引。 所以 NumPy 數組array_2的行和列的索引如下:

numpy-argmax-2darray

現在,讓我們在二維數組array_2上調用argmax()函數。

 print(np.argmax(array_2)) # Output 4

即使我們在二維數組上調用argmax() ,它仍然返回 4。這與上一節中一維數組array_1的輸出相同。

為什麼會這樣?

這是因為我們沒有為軸參數指定任何值。 如果未設置此軸參數,默認情況下, argmax()函數返回扁平數組中最大元素的索引。

什麼是扁平陣列? 如果有一個形狀為d1 x d2 x ... x dN的 N 維數組,其中 d1, d2, 直到 dN 是沿 N 維的數組的大小,那麼展平的數組是一個大小為長的一維數組d1 * d2 * ... * dN。

要檢查array_2的展平數組的外觀,可以調用flatten()方法,如下所示:

 array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])

沿行的最大元素的索引(軸 = 0)

讓我們繼續查找沿行(軸 = 0)的最大元素的索引。

 np.argmax(array_2,axis=0) # Output array([1, 1, 1, 1])

這個輸出可能有點難以理解,但我們會理解它是如何工作的。

我們已將axis參數設置為零( axis = 0 ),因為我們希望找到沿行的最大元素的索引。 因此, argmax()函數返回出現最大元素的行號——對於三列中的每一列。

讓我們將其可視化以便更好地理解。

numpy-argmax-axis0

從上圖和argmax()輸出,我們有以下內容:

  • 對於索引 0 處的第一列,最大值10出現在第二行,索引 = 1。
  • 對於索引 1 處的第二列,最大值9出現在第二行,索引 = 1。
  • 對於索引 2 和 3 處的第三和第四列,最大值84都出現在第二行,索引 = 1。

這正是我們有輸出array([1, 1, 1, 1])的原因,因為沿行的最大元素出現在第二行(對於所有列)。

沿列的最大元素的索引(軸 = 1)

接下來,讓我們使用argmax()函數來查找沿列的最大元素的索引。

運行以下代碼片段並觀察輸出。

 np.argmax(array_2,axis=1)
 array([2, 0])

你能解析輸出嗎?

我們設置axis = 1來計算沿列的最大元素的索引。

argmax()函數為每一行返回最大值所在的列號。

這是一個視覺解釋:

numpy-argmax-axis1

從上圖和argmax()輸出,我們有以下內容:

  • 對於索引 0 處的第一行,最大值7出現在索引 = 2 處的第三列。
  • 對於索引 1 處的第二行,最大值10出現在索引 = 0 的第一列中。

我希望您現在了解輸出array([2, 0])的含義。

在 NumPy argmax() 中使用可選輸出參數

您可以使用 NumPy argmax() 函數中的可選參數out將輸出存儲在 NumPy 數組中。

讓我們初始化一個零數組來存儲先前的argmax()函數調用的輸出 - 沿著列 ( axis= 1 ) 找到最大值的索引。

 out_arr = np.zeros((2,)) print(out_arr) [0. 0.]

現在,讓我們重新審視沿列 ( axis = 1 ) 查找最大元素的索引並將out設置為我們上面定義的out_arr的示例。

 np.argmax(array_2,axis=1,out=out_arr)

我們看到 Python 解釋器拋出了一個TypeError ,因為out_arr默認初始化為一個浮點數組。

 --------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

因此,在將out參數設置為輸出數組時,確保輸出數組的形狀和數據類型正確非常重要。 由於數組索引總是整數,我們應該在定義輸出數組時將dtype參數設置為int

 out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0]

我們現在可以繼續使用axisout參數調用argmax()函數,這一次,它運行沒有錯誤。

 np.argmax(array_2,axis=1,out=out_arr)

現在可以在數組out_arr中訪問argmax()函數的輸出。

 print(out_arr) # Output [2 0]

結論

我希望本教程能幫助您了解如何使用 NumPy argmax() 函數。 您可以在 Jupyter 筆記本中運行代碼示例。

讓我們回顧一下我們學到了什麼。

  • NumPy argmax() 函數返回數組中最大元素的索引。 如果最大元素在數組a中多次出現,則np.argmax(a)返回該元素第一次出現的索引。
  • 使用多維數組時,您可以使用可選的參數來獲取沿特定軸的最大元素的索引。 例如,在二維數組中:通過設置axis = 0axis = 1 ,可以分別獲取沿行和列的最大元素的索引。
  • 如果您想將返回的值存儲在另一個數組中,您可以將可選的out參數設置為輸出數組。 但是,輸出數組應具有兼容的形狀和數據類型。

接下來,查看 Python 集的深入指南。