Cum să utilizați funcția NumPy argmax() în Python
Publicat: 2022-09-14În acest tutorial, veți învăța cum să utilizați funcția NumPy argmax() pentru a găsi indexul elementului maxim din matrice.
NumPy este o bibliotecă puternică pentru calcul științific în Python; oferă tablouri N-dimensionale care sunt mai performante decât listele Python. Una dintre operațiunile obișnuite pe care le veți efectua atunci când lucrați cu matrice NumPy este să găsiți valoarea maximă din matrice. Cu toate acestea, poate doriți uneori să găsiți indexul la care apare valoarea maximă.
Funcția argmax()
vă ajută să găsiți indicele maximului atât în tablourile unidimensionale, cât și multidimensionale. Să continuăm să aflăm cum funcționează.
Cum să găsiți indicele elementului maxim într-o matrice NumPy
Pentru a urma acest tutorial, trebuie să aveți instalate Python și NumPy. Puteți codifica prin pornirea unui REPL Python sau lansând un notebook Jupyter.
Mai întâi, să importăm NumPy sub aliasul obișnuit np
.
import numpy as np
Puteți utiliza funcția NumPy max()
pentru a obține valoarea maximă într-o matrice (opțional de-a lungul unei axe specifice).
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10
În acest caz, np.max(array_1)
returnează 10, ceea ce este corect.
Să presupunem că doriți să găsiți indexul la care apare valoarea maximă în matrice. Puteți adopta următoarea abordare în doi pași:
- Găsiți elementul maxim.
- Găsiți indicele elementului maxim.
În array_1
, valoarea maximă de 10 apare la indexul 4, după indexarea zero. Primul element este la indicele 0; al doilea element este la indicele 1 și așa mai departe.

Pentru a găsi indexul la care apare maximul, puteți utiliza funcția NumPy where(). np.where(condition)
returnează o matrice a tuturor indicilor în care condition
este True
.
Va trebui să accesați matrice și să accesați elementul de la primul index. Pentru a afla unde apare valoarea maximă, setăm condition
la array_1==10
; reamintim că 10 este valoarea maximă în array_1
.
print(int(np.where(array_1==10)[0])) # Output 4
Am folosit np.where()
doar cu condiția, dar aceasta nu este metoda recomandată pentru a utiliza această funcție.
Notă: NumPy unde() Funcția :
np.where(condition,x,y)
returnează:– Elemente din
x
când condiția esteTrue
și
– Elemente diny
când condiția esteFalse
.
Prin urmare, înlănțuind funcțiile np.max()
și np.where()
, putem găsi elementul maxim, urmat de indicele la care apare.
În loc de procesul în doi pași de mai sus, puteți utiliza funcția NumPy argmax() pentru a obține indexul elementului maxim din matrice.
Sintaxa funcției NumPy argmax().
Sintaxa generală pentru a utiliza funcția NumPy argmax() este următoarea:
np.argmax(array,axis,out) # we've imported numpy under the alias np
În sintaxa de mai sus:
- array este orice tablou NumPy valid.
- axa este un parametru opțional. Când lucrați cu tablouri multidimensionale, puteți utiliza parametrul axă pentru a găsi indicele maximului de-a lungul unei axe specifice.
- out este un alt parametru opțional. Puteți seta parametrul
out
la o matrice NumPy pentru a stoca rezultatul funcțieiargmax()
.
Notă : Din versiunea NumPy 1.22.0, există un parametru suplimentar
keepdims
. Când specificăm parametrulaxis
în apelul funcțieiargmax()
, matricea este redusă de-a lungul acelei axe. Dar setarea parametruluikeepdims
laTrue
asigură că rezultatul returnat are aceeași formă ca și matricea de intrare.
Folosind NumPy argmax() pentru a găsi indicele elementului maxim
#1 . Să folosim funcția NumPy argmax() pentru a găsi indexul elementului maxim din array_1
.
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4
Funcția argmax()
returnează 4, ceea ce este corect!
#2 . Dacă redefinim array_1
astfel încât 10 să apară de două ori, funcția argmax()
returnează doar indexul primei apariții.
array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4
Pentru restul exemplelor, vom folosi elementele array_1
pe care le-am definit în exemplul #1.
Utilizarea NumPy argmax() pentru a găsi indicele elementului maxim într-o matrice 2D
Să remodelăm tabloul array_1
într-o matrice bidimensională cu două rânduri și patru coloane.
array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]]
Pentru o matrice bidimensională, axa 0 indică rândurile, iar axa 1 indică coloanele. Matricele NumPy urmează indexarea zero . Deci, indicii rândurilor și coloanelor pentru tabloul array_2
sunt după cum urmează:

Acum, să apelăm funcția argmax()
pe tabloul bidimensional, array_2
.
print(np.argmax(array_2)) # Output 4
Chiar dacă am apelat argmax()
pe tabloul bidimensional, totuși returnează 4. Acesta este identic cu rezultatul pentru tabloul unidimensional, array_1
din secțiunea anterioară.
De ce se întâmplă asta?
Acest lucru se datorează faptului că nu am specificat nicio valoare pentru parametrul axei. Când acest parametru de axă nu este setat, în mod implicit, funcția argmax()
returnează indexul elementului maxim de-a lungul matricei aplatizate.
Ce este o matrice aplatizată? Dacă există o matrice N-dimensională de formă d1 x d2 x … x dN , unde d1, d2, până la dN sunt dimensiunile matricei de-a lungul N dimensiuni, atunci matricea aplatizată este o matrice lungă unidimensională de dimensiune d1 * d2 * … * dN.
Pentru a verifica cum arată matricea aplatizată pentru array_2
, puteți apela metoda flatten()
, după cum se arată mai jos:
array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])
Indicele elementului maxim de-a lungul rândurilor (axa = 0)
Să continuăm să găsim indicele elementului maxim de-a lungul rândurilor (axa = 0).
np.argmax(array_2,axis=0) # Output array([1, 1, 1, 1])
Această ieșire poate fi puțin dificil de înțeles, dar vom înțelege cum funcționează.
Am setat parametrul axis
la zero ( axis = 0
), deoarece am dori să găsim indicele elementului maxim de-a lungul rândurilor. Prin urmare, funcția argmax()
returnează numărul rândului în care apare elementul maxim — pentru fiecare dintre cele trei coloane.
Să vizualizăm acest lucru pentru o mai bună înțelegere.

Din diagrama de mai sus și rezultatul argmax()
avem următoarele:
- Pentru prima coloană la indexul 0, valoarea maximă 10 apare în al doilea rând, la index = 1.
- Pentru a doua coloană la indexul 1, valoarea maximă 9 apare în al doilea rând, la index = 1.
- Pentru a treia și a patra coloană de la indexul 2 și 3, valorile maxime 8 și 4 apar ambele în al doilea rând, la index = 1.
Tocmai de aceea avem array([1, 1, 1, 1])
deoarece elementul maxim de-a lungul rândurilor apare în al doilea rând (pentru toate coloanele).
Indicele elementului maxim de-a lungul coloanelor (axa = 1)
În continuare, să folosim funcția argmax()
pentru a găsi indicele elementului maxim de-a lungul coloanelor.
Rulați următorul fragment de cod și observați rezultatul.
np.argmax(array_2,axis=1)
array([2, 0])
Puteți analiza rezultatul?
Am stabilit axis = 1
pentru a calcula indicele elementului maxim de-a lungul coloanelor.
Funcția argmax()
returnează, pentru fiecare rând, numărul coloanei în care apare valoarea maximă.
Iată o explicație vizuală:

Din diagrama de mai sus și rezultatul argmax()
avem următoarele:
- Pentru primul rând la indicele 0, valoarea maximă 7 apare în a treia coloană, la index = 2.
- Pentru al doilea rând la indexul 1, valoarea maximă 10 apare în prima coloană, la index = 0.
Sper că acum înțelegeți ce înseamnă rezultatul, array([2, 0])
.
Utilizarea parametrului opțional out în NumPy argmax()
Puteți utiliza parametrul opțional out
funcția NumPy argmax() pentru a stoca rezultatul într-o matrice NumPy.
Să inițializam o matrice de zerouri pentru a stoca rezultatul apelului anterior al funcției argmax()
– pentru a găsi indicele maximului de-a lungul coloanelor ( axis= 1
).
out_arr = np.zeros((2,)) print(out_arr) [0. 0.]
Acum, să revedem exemplul de găsire a indexului elementului maxim de-a lungul coloanelor ( axis = 1
) și să out_arr
out
care l-am definit mai sus.
np.argmax(array_2,axis=1,out=out_arr)
Vedem că interpretul Python aruncă o TypeError
, deoarece out_arr
a fost inițializat la o serie de floats în mod implicit.
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Prin urmare, atunci când setați parametrul de out
la matricea de ieșire, este important să vă asigurați că matricea de ieșire are forma și tipul de date corecte. Deoarece indicii de matrice sunt întotdeauna numere întregi, ar trebui să setăm parametrul dtype
la int
atunci când definim tabloul de ieșire.
out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0]
Acum putem merge mai departe și apelăm funcția argmax()
atât cu parametrii axis
, cât și cu parametrii out
, iar de această dată, rulează fără eroare.
np.argmax(array_2,axis=1,out=out_arr)
Ieșirea funcției argmax()
poate fi acum accesată în tabloul out_arr
.
print(out_arr) # Output [2 0]
Concluzie
Sper că acest tutorial v-a ajutat să înțelegeți cum să utilizați funcția NumPy argmax(). Puteți rula exemplele de cod într-un blocnotes Jupyter.
Să revizuim ceea ce am învățat.
- Funcția NumPy argmax() returnează indexul elementului maxim dintr-o matrice. Dacă elementul maxim apare de mai multe ori într-o matrice a , atunci np.argmax(a) returnează indexul primei apariții a elementului.
- Când lucrați cu matrice multidimensionale, puteți utiliza parametrul opțional de axă pentru a obține indicele elementului maxim de-a lungul unei anumite axe. De exemplu, într-o matrice bidimensională: setând axa = 0 și axa = 1 , puteți obține indicele elementului maxim de-a lungul rândurilor și, respectiv, coloanelor.
- Dacă doriți să stocați valoarea returnată într-o altă matrice, puteți seta parametrul opțional out la matricea de ieșire. Cu toate acestea, matricea de ieșire ar trebui să fie de formă și tip de date compatibile.
Apoi, consultați ghidul aprofundat despre seturile Python.