Como usar a função NumPy argmax() em Python

Publicados: 2022-09-14

Neste tutorial, você aprenderá como usar a função NumPy argmax() para encontrar o índice do elemento máximo em arrays.

NumPy é uma biblioteca poderosa para computação científica em Python; ele fornece arrays N-dimensionais com melhor desempenho do que as listas do Python. Uma das operações comuns que você realizará ao trabalhar com matrizes NumPy é encontrar o valor máximo na matriz. No entanto, às vezes você pode querer localizar o índice no qual ocorre o valor máximo.

A função argmax() ajuda a encontrar o índice do máximo em arrays unidimensionais e multidimensionais. Vamos continuar a aprender como funciona.

Como encontrar o índice do elemento máximo em uma matriz NumPy

Para acompanhar este tutorial, você precisa ter o Python e o NumPy instalados. Você pode codificar iniciando um Python REPL ou iniciando um notebook Jupyter.

Primeiro, vamos importar NumPy sob o alias usual np .

 import numpy as np

Você pode usar a função NumPy max() para obter o valor máximo em uma matriz (opcionalmente ao longo de um eixo específico).

 array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10

Nesse caso, np.max(array_1) retorna 10, o que está correto.

Suponha que você queira encontrar o índice no qual o valor máximo ocorre na matriz. Você pode adotar a seguinte abordagem em duas etapas:

  1. Encontre o elemento máximo.
  2. Encontre o índice do elemento máximo.

Em array_1 , o valor máximo de 10 ocorre no índice 4, seguindo a indexação zero. O primeiro elemento está no índice 0; o segundo elemento está no índice 1 e assim por diante.

numpy-argmax

Para encontrar o índice no qual ocorre o máximo, você pode usar a função NumPy where(). np.where(condition) retorna um array de todos os índices onde a condition é True .

Você terá que tocar no array e acessar o item no primeiro índice. Para descobrir onde ocorre o valor máximo, definimos a condition para array_1==10 ; lembre-se de que 10 é o valor máximo em array_1 .

 print(int(np.where(array_1==10)[0])) # Output 4

Usamos np.where() apenas com a condição, mas este não é o método recomendado para usar esta função.

Nota: NumPy where() Função :
np.where(condition,x,y) retorna:

– Elementos de x quando a condição for True , e
– Elementos de y quando a condição for False .

Portanto, encadeando as np.max() e np.where() , podemos encontrar o elemento máximo, seguido do índice em que ele ocorre.

Em vez do processo de duas etapas acima, você pode usar a função NumPy argmax() para obter o índice do elemento máximo na matriz.

Sintaxe da função NumPy argmax()

A sintaxe geral para usar a função NumPy argmax() é a seguinte:

 np.argmax(array,axis,out) # we've imported numpy under the alias np

Na sintaxe acima:

  • array é qualquer array NumPy válido.
  • axis é um parâmetro opcional. Ao trabalhar com matrizes multidimensionais, você pode usar o parâmetro axis para encontrar o índice de máximo ao longo de um eixo específico.
  • out é outro parâmetro opcional. Você pode definir o parâmetro out para um array NumPy para armazenar a saída da função argmax() .

Nota : A partir da versão 1.22.0 do NumPy, há um parâmetro keepdims adicional. Quando especificamos o parâmetro axis na chamada da função argmax() , a matriz é reduzida ao longo desse eixo. Mas definir o parâmetro keepdims como True garante que a saída retornada tenha a mesma forma que a matriz de entrada.

Usando NumPy argmax() para encontrar o índice do elemento máximo

#1 . Vamos usar a função NumPy argmax() para encontrar o índice do elemento máximo em array_1 .

 array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4

A função argmax() retorna 4, o que está correto!

#2 . Se redefinirmos array_1 de forma que 10 ocorra duas vezes, a função argmax() retornará apenas o índice da primeira ocorrência.

 array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4

Para o restante dos exemplos, usaremos os elementos de array_1 que definimos no exemplo #1.

Usando NumPy argmax() para encontrar o índice do elemento máximo em uma matriz 2D

Vamos remodelar o array NumPy array_1 em um array bidimensional com duas linhas e quatro colunas.

 array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]]

Para uma matriz bidimensional, o eixo 0 denota as linhas e o eixo 1 denota as colunas. As matrizes NumPy seguem a indexação zero . Portanto, os índices das linhas e colunas da matriz NumPy array_2 são os seguintes:

numpy-argmax-2darray

Agora, vamos chamar a função argmax() no array bidimensional, array_2 .

 print(np.argmax(array_2)) # Output 4

Embora tenhamos chamado argmax() no array bidimensional, ele ainda retorna 4. Isso é idêntico à saída para o array unidimensional, array_1 da seção anterior.

Por que isso acontece?

Isso ocorre porque não especificamos nenhum valor para o parâmetro axis. Quando este parâmetro de eixo não está definido, por padrão, a função argmax() retorna o índice do elemento máximo ao longo do array nivelado.

O que é uma matriz achatada? Se houver uma matriz N-dimensional de forma d1 x d2 x … x dN , onde d1, d2, até dN são os tamanhos da matriz ao longo das N dimensões, então a matriz achatada é uma longa matriz unidimensional de tamanho d1 * d2 * … * dN.

Para verificar a aparência do array achatado para array_2 , você pode chamar o método flatten() , conforme mostrado abaixo:

 array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])

Índice do elemento máximo ao longo das linhas (eixo = 0)

Vamos prosseguir para encontrar o índice do elemento máximo ao longo das linhas (eixo = 0).

 np.argmax(array_2,axis=0) # Output array([1, 1, 1, 1])

Essa saída pode ser um pouco difícil de compreender, mas vamos entender como ela funciona.

Definimos o parâmetro axis como zero ( axis = 0 ), pois gostaríamos de encontrar o índice do elemento máximo ao longo das linhas. Portanto, a função argmax() retorna o número da linha em que o elemento máximo ocorre — para cada uma das três colunas.

Vamos visualizar isso para melhor compreensão.

numpy-argmax-axis0

Do diagrama acima e da saída argmax() , temos o seguinte:

  • Para a primeira coluna no índice 0, o valor máximo 10 ocorre na segunda linha, no índice = 1.
  • Para a segunda coluna no índice 1, o valor máximo 9 ocorre na segunda linha, no índice = 1.
  • Para a terceira e quarta colunas no índice 2 e 3, os valores máximos 8 e 4 ocorrem na segunda linha, no índice = 1.

É exatamente por isso que temos o array([1, 1, 1, 1]) porque o elemento máximo ao longo das linhas ocorre na segunda linha (para todas as colunas).

Índice do Elemento Máximo ao Longo das Colunas (eixo = 1)

Em seguida, vamos usar a função argmax() para encontrar o índice do elemento máximo ao longo das colunas.

Execute o trecho de código a seguir e observe a saída.

 np.argmax(array_2,axis=1)
 array([2, 0])

Você pode analisar a saída?

Definimos axis = 1 para calcular o índice do elemento máximo ao longo das colunas.

A função argmax() retorna, para cada linha, o número da coluna em que ocorre o valor máximo.

Aqui está uma explicação visual:

numpy-argmax-axis1

Do diagrama acima e da saída argmax() , temos o seguinte:

  • Para a primeira linha no índice 0, o valor máximo 7 ocorre na terceira coluna, no índice = 2.
  • Para a segunda linha no índice 1, o valor máximo 10 ocorre na primeira coluna, no índice = 0.

Espero que agora você entenda o que a saída, array([2, 0]) significa.

Usando o parâmetro opcional de saída no NumPy argmax ()

Você pode usar o parâmetro opcional out the na função NumPy argmax() para armazenar a saída em uma matriz NumPy.

Vamos inicializar um array de zeros para armazenar a saída da chamada de função argmax() anterior – para encontrar o índice do máximo ao longo das colunas ( axis= 1 ).

 out_arr = np.zeros((2,)) print(out_arr) [0. 0.]

Agora, vamos revisitar o exemplo de encontrar o índice do elemento máximo ao longo das colunas ( axis = 1 ) e definir out como out_arr que definimos acima.

 np.argmax(array_2,axis=1,out=out_arr)

Vemos que o interpretador Python lança um TypeError , pois o out_arr foi inicializado para um array de floats por padrão.

 --------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

Portanto, ao definir o parâmetro out para a matriz de saída, é importante garantir que a matriz de saída tenha a forma e o tipo de dados corretos. Como os índices de array são sempre inteiros, devemos definir o parâmetro dtype como int ao definir o array de saída.

 out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0]

Agora podemos ir em frente e chamar a função argmax() com os parâmetros axis e out e, desta vez, ela é executada sem erros.

 np.argmax(array_2,axis=1,out=out_arr)

A saída da função argmax() agora pode ser acessada no array out_arr .

 print(out_arr) # Output [2 0]

Conclusão

Espero que este tutorial tenha ajudado você a entender como usar a função NumPy argmax(). Você pode executar os exemplos de código em um notebook Jupyter.

Vamos rever o que aprendemos.

  • A função NumPy argmax() retorna o índice do elemento máximo em uma matriz. Se o elemento máximo ocorrer mais de uma vez em uma matriz a , então np.argmax(a) retornará o índice da primeira ocorrência do elemento.
  • Ao trabalhar com matrizes multidimensionais, você pode usar o parâmetro opcional axis para obter o índice do elemento máximo ao longo de um eixo específico. Por exemplo, em uma matriz bidimensional: definindo axis = 0 e axis = 1 , você pode obter o índice do elemento máximo ao longo das linhas e colunas, respectivamente.
  • Se você quiser armazenar o valor retornado em outro array, você pode definir o parâmetro opcional out para o array de saída. No entanto, a matriz de saída deve ter formato e tipo de dados compatíveis.

Em seguida, confira o guia detalhado sobre conjuntos de Python.