Como usar a função NumPy argmax() em Python
Publicados: 2022-09-14Neste tutorial, você aprenderá como usar a função NumPy argmax() para encontrar o índice do elemento máximo em arrays.
NumPy é uma biblioteca poderosa para computação científica em Python; ele fornece arrays N-dimensionais com melhor desempenho do que as listas do Python. Uma das operações comuns que você realizará ao trabalhar com matrizes NumPy é encontrar o valor máximo na matriz. No entanto, às vezes você pode querer localizar o índice no qual ocorre o valor máximo.
A função argmax()
ajuda a encontrar o índice do máximo em arrays unidimensionais e multidimensionais. Vamos continuar a aprender como funciona.
Como encontrar o índice do elemento máximo em uma matriz NumPy
Para acompanhar este tutorial, você precisa ter o Python e o NumPy instalados. Você pode codificar iniciando um Python REPL ou iniciando um notebook Jupyter.
Primeiro, vamos importar NumPy sob o alias usual np
.
import numpy as np
Você pode usar a função NumPy max()
para obter o valor máximo em uma matriz (opcionalmente ao longo de um eixo específico).
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10
Nesse caso, np.max(array_1)
retorna 10, o que está correto.
Suponha que você queira encontrar o índice no qual o valor máximo ocorre na matriz. Você pode adotar a seguinte abordagem em duas etapas:
- Encontre o elemento máximo.
- Encontre o índice do elemento máximo.
Em array_1
, o valor máximo de 10 ocorre no índice 4, seguindo a indexação zero. O primeiro elemento está no índice 0; o segundo elemento está no índice 1 e assim por diante.

Para encontrar o índice no qual ocorre o máximo, você pode usar a função NumPy where(). np.where(condition)
retorna um array de todos os índices onde a condition
é True
.
Você terá que tocar no array e acessar o item no primeiro índice. Para descobrir onde ocorre o valor máximo, definimos a condition
para array_1==10
; lembre-se de que 10 é o valor máximo em array_1
.
print(int(np.where(array_1==10)[0])) # Output 4
Usamos np.where()
apenas com a condição, mas este não é o método recomendado para usar esta função.
Nota: NumPy where() Função :
np.where(condition,x,y)
retorna:– Elementos de
x
quando a condição forTrue
, e
– Elementos dey
quando a condição forFalse
.
Portanto, encadeando as np.max()
e np.where()
, podemos encontrar o elemento máximo, seguido do índice em que ele ocorre.
Em vez do processo de duas etapas acima, você pode usar a função NumPy argmax() para obter o índice do elemento máximo na matriz.
Sintaxe da função NumPy argmax()
A sintaxe geral para usar a função NumPy argmax() é a seguinte:
np.argmax(array,axis,out) # we've imported numpy under the alias np
Na sintaxe acima:
- array é qualquer array NumPy válido.
- axis é um parâmetro opcional. Ao trabalhar com matrizes multidimensionais, você pode usar o parâmetro axis para encontrar o índice de máximo ao longo de um eixo específico.
- out é outro parâmetro opcional. Você pode definir o parâmetro
out
para um array NumPy para armazenar a saída da funçãoargmax()
.
Nota : A partir da versão 1.22.0 do NumPy, há um parâmetro
keepdims
adicional. Quando especificamos o parâmetroaxis
na chamada da funçãoargmax()
, a matriz é reduzida ao longo desse eixo. Mas definir o parâmetrokeepdims
comoTrue
garante que a saída retornada tenha a mesma forma que a matriz de entrada.
Usando NumPy argmax() para encontrar o índice do elemento máximo
#1 . Vamos usar a função NumPy argmax() para encontrar o índice do elemento máximo em array_1
.
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4
A função argmax()
retorna 4, o que está correto!
#2 . Se redefinirmos array_1
de forma que 10 ocorra duas vezes, a função argmax()
retornará apenas o índice da primeira ocorrência.
array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4
Para o restante dos exemplos, usaremos os elementos de array_1
que definimos no exemplo #1.
Usando NumPy argmax() para encontrar o índice do elemento máximo em uma matriz 2D
Vamos remodelar o array NumPy array_1
em um array bidimensional com duas linhas e quatro colunas.
array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]]
Para uma matriz bidimensional, o eixo 0 denota as linhas e o eixo 1 denota as colunas. As matrizes NumPy seguem a indexação zero . Portanto, os índices das linhas e colunas da matriz NumPy array_2
são os seguintes:

Agora, vamos chamar a função argmax()
no array bidimensional, array_2
.
print(np.argmax(array_2)) # Output 4
Embora tenhamos chamado argmax()
no array bidimensional, ele ainda retorna 4. Isso é idêntico à saída para o array unidimensional, array_1
da seção anterior.
Por que isso acontece?
Isso ocorre porque não especificamos nenhum valor para o parâmetro axis. Quando este parâmetro de eixo não está definido, por padrão, a função argmax()
retorna o índice do elemento máximo ao longo do array nivelado.
O que é uma matriz achatada? Se houver uma matriz N-dimensional de forma d1 x d2 x … x dN , onde d1, d2, até dN são os tamanhos da matriz ao longo das N dimensões, então a matriz achatada é uma longa matriz unidimensional de tamanho d1 * d2 * … * dN.
Para verificar a aparência do array achatado para array_2
, você pode chamar o método flatten()
, conforme mostrado abaixo:
array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])
Índice do elemento máximo ao longo das linhas (eixo = 0)
Vamos prosseguir para encontrar o índice do elemento máximo ao longo das linhas (eixo = 0).
np.argmax(array_2,axis=0) # Output array([1, 1, 1, 1])
Essa saída pode ser um pouco difícil de compreender, mas vamos entender como ela funciona.
Definimos o parâmetro axis
como zero ( axis = 0
), pois gostaríamos de encontrar o índice do elemento máximo ao longo das linhas. Portanto, a função argmax()
retorna o número da linha em que o elemento máximo ocorre — para cada uma das três colunas.
Vamos visualizar isso para melhor compreensão.

Do diagrama acima e da saída argmax()
, temos o seguinte:
- Para a primeira coluna no índice 0, o valor máximo 10 ocorre na segunda linha, no índice = 1.
- Para a segunda coluna no índice 1, o valor máximo 9 ocorre na segunda linha, no índice = 1.
- Para a terceira e quarta colunas no índice 2 e 3, os valores máximos 8 e 4 ocorrem na segunda linha, no índice = 1.
É exatamente por isso que temos o array([1, 1, 1, 1])
porque o elemento máximo ao longo das linhas ocorre na segunda linha (para todas as colunas).
Índice do Elemento Máximo ao Longo das Colunas (eixo = 1)
Em seguida, vamos usar a função argmax()
para encontrar o índice do elemento máximo ao longo das colunas.
Execute o trecho de código a seguir e observe a saída.
np.argmax(array_2,axis=1)
array([2, 0])
Você pode analisar a saída?
Definimos axis = 1
para calcular o índice do elemento máximo ao longo das colunas.
A função argmax()
retorna, para cada linha, o número da coluna em que ocorre o valor máximo.
Aqui está uma explicação visual:

Do diagrama acima e da saída argmax()
, temos o seguinte:
- Para a primeira linha no índice 0, o valor máximo 7 ocorre na terceira coluna, no índice = 2.
- Para a segunda linha no índice 1, o valor máximo 10 ocorre na primeira coluna, no índice = 0.
Espero que agora você entenda o que a saída, array([2, 0])
significa.
Usando o parâmetro opcional de saída no NumPy argmax ()
Você pode usar o parâmetro opcional out
the na função NumPy argmax() para armazenar a saída em uma matriz NumPy.
Vamos inicializar um array de zeros para armazenar a saída da chamada de função argmax()
anterior – para encontrar o índice do máximo ao longo das colunas ( axis= 1
).
out_arr = np.zeros((2,)) print(out_arr) [0. 0.]
Agora, vamos revisitar o exemplo de encontrar o índice do elemento máximo ao longo das colunas ( axis = 1
) e definir out
como out_arr
que definimos acima.
np.argmax(array_2,axis=1,out=out_arr)
Vemos que o interpretador Python lança um TypeError
, pois o out_arr
foi inicializado para um array de floats por padrão.
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Portanto, ao definir o parâmetro out
para a matriz de saída, é importante garantir que a matriz de saída tenha a forma e o tipo de dados corretos. Como os índices de array são sempre inteiros, devemos definir o parâmetro dtype
como int
ao definir o array de saída.
out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0]
Agora podemos ir em frente e chamar a função argmax()
com os parâmetros axis
e out
e, desta vez, ela é executada sem erros.
np.argmax(array_2,axis=1,out=out_arr)
A saída da função argmax()
agora pode ser acessada no array out_arr
.
print(out_arr) # Output [2 0]
Conclusão
Espero que este tutorial tenha ajudado você a entender como usar a função NumPy argmax(). Você pode executar os exemplos de código em um notebook Jupyter.
Vamos rever o que aprendemos.
- A função NumPy argmax() retorna o índice do elemento máximo em uma matriz. Se o elemento máximo ocorrer mais de uma vez em uma matriz a , então np.argmax(a) retornará o índice da primeira ocorrência do elemento.
- Ao trabalhar com matrizes multidimensionais, você pode usar o parâmetro opcional axis para obter o índice do elemento máximo ao longo de um eixo específico. Por exemplo, em uma matriz bidimensional: definindo axis = 0 e axis = 1 , você pode obter o índice do elemento máximo ao longo das linhas e colunas, respectivamente.
- Se você quiser armazenar o valor retornado em outro array, você pode definir o parâmetro opcional out para o array de saída. No entanto, a matriz de saída deve ter formato e tipo de dados compatíveis.
Em seguida, confira o guia detalhado sobre conjuntos de Python.