Как использовать функцию NumPy argmax() в Python
Опубликовано: 2022-09-14В этом руководстве вы узнаете, как использовать функцию NumPy argmax() для поиска индекса максимального элемента в массивах.
NumPy — мощная библиотека для научных вычислений на Python; он предоставляет N-мерные массивы, которые более эффективны, чем списки Python. Одной из распространенных операций, которую вы будете выполнять при работе с массивами NumPy, является поиск максимального значения в массиве. Однако иногда вам может понадобиться найти индекс, при котором происходит максимальное значение.
Функция argmax()
помогает найти индекс максимума как в одномерных, так и в многомерных массивах. Давайте продолжим изучать, как это работает.
Как найти индекс максимального элемента в массиве NumPy
Чтобы следовать этому руководству, вам необходимо установить Python и NumPy. Вы можете кодировать, запустив Python REPL или запустив блокнот Jupyter.
Во-первых, давайте импортируем NumPy под обычным псевдонимом np
.
import numpy as np
Вы можете использовать функцию NumPy max()
, чтобы получить максимальное значение в массиве (необязательно вдоль определенной оси).
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10
В этом случае np.max(array_1)
возвращает 10, что правильно.
Предположим, вы хотите найти индекс, при котором в массиве встречается максимальное значение. Вы можете использовать следующий двухэтапный подход:
- Найдите максимальный элемент.
- Найдите индекс максимального элемента.
В array_1
максимальное значение 10 встречается в индексе 4 после нулевой индексации. Первый элемент имеет индекс 0; второй элемент имеет индекс 1 и так далее.

Чтобы найти индекс, при котором происходит максимум, вы можете использовать функцию NumPy where(). np.where(condition)
возвращает массив всех индексов, где condition
имеет значение True
.
Вам нужно будет подключиться к массиву и получить доступ к элементу по первому индексу. Чтобы найти, где находится максимальное значение, мы устанавливаем condition
в array_1==10
; напомним, что 10 — это максимальное значение в array_1
.
print(int(np.where(array_1==10)[0])) # Output 4
Мы использовали np.where()
только с условием, но это не рекомендуемый метод для использования этой функции.
Примечание. Функция NumPy where() :
np.where(condition,x,y)
возвращает:– Элементы из
x
, когда условиеTrue
, и
– Элементы изy
, когда условие равноFalse
.
Следовательно, объединяя функции np.max()
и np.where()
, мы можем найти максимальный элемент, за которым следует индекс, по которому он встречается.
Вместо описанного выше двухэтапного процесса вы можете использовать функцию NumPy argmax(), чтобы получить индекс максимального элемента в массиве.
Синтаксис функции NumPy argmax()
Общий синтаксис для использования функции NumPy argmax() выглядит следующим образом:
np.argmax(array,axis,out) # we've imported numpy under the alias np
В приведенном выше синтаксисе:
- array — любой допустимый массив NumPy.
- ось является необязательным параметром. При работе с многомерными массивами вы можете использовать параметр оси, чтобы найти индекс максимума по определенной оси.
- out — еще один необязательный параметр. Вы можете установить параметр
out
в массив NumPy для хранения вывода функцииargmax()
.
Примечание . В NumPy версии 1.22.0 появился дополнительный параметр
keepdims
. Когда мы указываем параметрaxis
в вызове функцииargmax()
, массив уменьшается вдоль этой оси. Но установка для параметраkeepdims
True
гарантирует, что возвращаемый результат имеет ту же форму, что и входной массив.
Использование NumPy argmax() для поиска индекса максимального элемента
# 1 . Давайте воспользуемся функцией NumPy argmax(), чтобы найти индекс максимального элемента в array_1
.
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4
Функция argmax()
возвращает 4, и это правильно!
# 2 . Если мы переопределим array_1
так, чтобы 10 встречалось дважды, argmax()
возвращает только индекс первого вхождения.
array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4
В остальных примерах мы будем использовать элементы array_1
, которые мы определили в примере №1.
Использование NumPy argmax() для поиска индекса максимального элемента в 2D-массиве
Давайте изменим массив NumPy array_1
в двумерный массив с двумя строками и четырьмя столбцами.
array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]]
Для двумерного массива ось 0 обозначает строки, а ось 1 — столбцы. Массивы NumPy следуют нулевой индексации . Таким образом, индексы строк и столбцов для массива NumPy array_2
следующие:

Теперь давайте вызовем функцию argmax()
для двумерного массива array_2
.
print(np.argmax(array_2)) # Output 4
Несмотря на то, что мы вызвали argmax()
для двумерного массива, он по-прежнему возвращает 4. Это идентично выходным данным для одномерного массива array_1
из предыдущего раздела.
Почему это происходит?
Это потому, что мы не указали никакого значения для параметра оси. Если этот параметр оси не установлен, по умолчанию argmax()
возвращает индекс максимального элемента в сглаженном массиве.
Что такое плоский массив? Если имеется N-мерный массив формы d1 x d2 x … x dN , где d1, d2, to dN — размеры массива по N измерениям, то сплющенный массив представляет собой длинный одномерный массив размера д1*д2*…*дН.
Чтобы проверить, как выглядит сглаженный массив для array_2
, вы можете вызвать метод flatten flatten()
, как показано ниже:
array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])
Индекс максимального элемента вдоль строк (ось = 0)
Перейдем к поиску индекса максимального элемента по строкам (ось = 0).
np.argmax(array_2,axis=0) # Output array([1, 1, 1, 1])
Этот вывод может быть немного сложным для понимания, но мы поймем, как он работает.
Мы установили параметр axis
равным нулю ( axis = 0
), так как мы хотели бы найти индекс максимального элемента в строках. Поэтому argmax()
возвращает номер строки, в которой встречается максимальный элемент, — для каждого из трех столбцов.
Давайте визуализируем это для лучшего понимания.

Из приведенной выше диаграммы и argmax()
мы имеем следующее:
- Для первого столбца с индексом 0 максимальное значение 10 встречается во второй строке с индексом = 1.
- Для второго столбца с индексом 1 максимальное значение 9 встречается во второй строке с индексом = 1.
- Для третьего и четвертого столбцов с индексами 2 и 3 максимальные значения 8 и 4 встречаются во второй строке с индексом = 1.
Именно поэтому у нас есть выходной array([1, 1, 1, 1])
потому что максимальный элемент по строкам приходится на вторую строку (для всех столбцов).
Индекс максимального элемента вдоль столбцов (ось = 1)
Далее воспользуемся argmax()
, чтобы найти индекс максимального элемента по столбцам.
Запустите следующий фрагмент кода и посмотрите на результат.
np.argmax(array_2,axis=1)
array([2, 0])
Можете ли вы разобрать вывод?
Мы установили axis = 1
, чтобы вычислить индекс максимального элемента по столбцам.
Функция argmax()
возвращает для каждой строки номер столбца, в котором встречается максимальное значение.
Вот визуальное объяснение:

Из приведенной выше диаграммы и argmax()
мы имеем следующее:
- Для первой строки с индексом 0 максимальное значение 7 встречается в третьем столбце с индексом = 2.
- Для второй строки с индексом 1 максимальное значение 10 встречается в первом столбце с индексом = 0.
Надеюсь, теперь вы понимаете, что означает вывод array([2, 0])
.
Использование необязательного параметра вывода в NumPy argmax()
Вы можете использовать необязательный параметр out
the в функции NumPy argmax() для сохранения вывода в массиве NumPy.
Давайте инициализируем массив нулей для хранения вывода предыдущего вызова функции argmax()
— для нахождения индекса максимума по столбцам ( axis= 1
).
out_arr = np.zeros((2,)) print(out_arr) [0. 0.]
Теперь давайте вернемся к примеру с поиском индекса максимального элемента по столбцам ( axis = 1
) и установим значение out
out_arr
, которое мы определили выше.
np.argmax(array_2,axis=1,out=out_arr)
Мы видим, что интерпретатор Python выдает TypeError
, так как out_arr
по умолчанию был инициализирован массивом с плавающей запятой.
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Поэтому при установке параметра out
для выходного массива важно убедиться, что выходной массив имеет правильную форму и тип данных. Поскольку индексы массива всегда являются целыми числами, мы должны установить для параметра dtype
значение int
при определении выходного массива.
out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0]
Теперь мы можем продолжить и вызвать argmax()
с параметрами axis
и out
, и на этот раз она выполняется без ошибок.
np.argmax(array_2,axis=1,out=out_arr)
Теперь к выходным данным функции argmax()
можно получить доступ в массиве out_arr
.
print(out_arr) # Output [2 0]
Вывод
Я надеюсь, что это руководство помогло вам понять, как использовать функцию NumPy argmax(). Вы можете запустить примеры кода в блокноте Jupyter.
Давайте повторим, что мы узнали.
- Функция NumPy argmax() возвращает индекс максимального элемента в массиве. Если максимальный элемент встречается в массиве a более одного раза, то np.argmax(a) возвращает индекс первого вхождения элемента.
- При работе с многомерными массивами вы можете использовать необязательный параметр оси , чтобы получить индекс максимального элемента по определенной оси. Например, в двумерном массиве: задав ось = 0 и ось = 1 , можно получить индекс максимального элемента по строкам и столбцам соответственно.
- Если вы хотите сохранить возвращаемое значение в другом массиве, вы можете установить необязательный параметр out для выходного массива. Однако выходной массив должен иметь совместимую форму и тип данных.
Затем ознакомьтесь с подробным руководством по наборам Python.