Como usar o método “torch.argmax()” no PyTorch?

Como Usar O Metodo Torch Argmax No Pytorch



No PyTorch, o “ tocha.argmax() ”O método é uma função integrada que retorna índices de valores máximos de um tensor específico em uma determinada dimensão. Os usuários usam esta função quando trabalham com tensores e desejam encontrar o índice do valor máximo ao longo de uma determinada dimensão de um tensor. Além disso, este método também pode ser útil para classificação onde os usuários desejam saber qual classe tem a maior probabilidade.

Este blog irá exemplificar o método para usar o método “torch.argmax()” no PyTorch.

Como usar o método “torch.argmax()” no PyTorch?

O método “torch.argmax()” pega qualquer tensor 1D ou 2D como entrada e retorna um tensor que contém os índices/índices dos valores máximos ao longo da dimensão dada.







A sintaxe do método “torch.argmax()” é fornecida abaixo:



tocha. argmax ( < tensor_de_entrada > )

Para usar este método no PyTorch, siga os exemplos a seguir para um melhor entendimento:



Exemplo 1: Use o método “torch.argmax()” com tensor 1D

No primeiro exemplo, criaremos um tensor 1D e usaremos o método “torch.argmax()” com ele. Vamos seguir o procedimento passo a passo abaixo:





Etapa 1: importar biblioteca PyTorch

Primeiro, importe o “ tocha ”biblioteca para usar o método “torch.argmax()”:

importar tocha

Etapa 2: criar tensor 1D

Em seguida, crie um tensor 1D e imprima seus elementos. Aqui, estamos criando o seguinte “ Dezenas1 ”tensor de uma lista usando o“ tocha.tensor() ”função:



Dezenas1 = tocha. tensor ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )

imprimir ( Dezenas1 )

Isso criou um tensor 1D conforme visto abaixo:

Etapa 3: Encontre índices de valor máximo

Agora, utilize o “ tocha.argmax() ”Função para encontrar o índice/índices do valor máximo no“ Dezenas1 ” tensor:

T1_ind = tocha. argmax ( Dezenas1 )

Etapa 4: Imprimir índice de valor máximo

Por último, exiba o índice do valor máximo no tensor de entrada:

imprimir ( 'Índices:' , T1_ind )

A saída abaixo mostra o índice do valor máximo no “ Dezenas1 ”tensor, ou seja, 4. Significa que o valor mais alto do tensor está no 4º índice que é“ 9 ”:

Exemplo 2: Use o método “torch.argmax()” com tensor 2D

No segundo exemplo, criaremos um tensor 2D e usaremos o método “torch.argmax()” com ele. Vamos seguir as etapas fornecidas:

Etapa 1: importar biblioteca PyTorch

Primeiro, importe o “ tocha ”biblioteca para usar o método “torch.argmax()”:

importar tocha

Etapa 2: criar tensor 2D

Em seguida, use o “ tocha.tensor() ”Função para criar um tensor 2D e imprimir seus elementos. Aqui, estamos criando o seguinte “ Dezenas2 “Tensor 2D:

Dezenas2 = tocha. tensor ( [ [ 4 , 1 , - 7 ] , [ quinze , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )

imprimir ( Dezenas2 )

Isso criou um tensor 2D conforme visto abaixo:

Etapa 3: Encontre índices de valor máximo

Agora, encontre o índice do valor máximo no “ Dezenas2 ”tensor utilizando o“ tocha.argmax() ”função:

T2_ind = tocha. argmax ( Dezenas2 )

Etapa 4: Imprimir índice de valor máximo

Finalmente, exiba o índice do valor máximo no tensor de entrada:

imprimir ( 'Índices:' , T2_ind )

De acordo com a saída abaixo, o índice do valor máximo no “ Dezenas2 ”tensor é “3”. Isso significa que o valor mais alto do tensor está no 3º índice que é “ quinze ”:

Etapa 5: Encontre índices de valor máximo ao longo das colunas

Além disso, os usuários também podem encontrar os índices/índices dos valores máximos ao longo de cada coluna de um tensor. Por exemplo, podemos usar o “ escuro = 0 ”Argumento com a função “torch.argmax()”. Ele encontra os índices dos valores máximos ao longo das colunas no “ Dezenas2 ”tensor e então imprime esses índices:

col_index = tocha. argmax ( Dezenas2 , escurecer = 0 )

imprimir ( 'Índices em colunas:' , col_index )

A saída abaixo mostra os índices dos valores máximos ao longo de cada coluna do tensor:

Etapa 6: Encontre índices de valor máximo ao longo das linhas

Da mesma forma, os usuários também podem encontrar os índices/índices dos valores máximos ao longo de cada linha de um tensor. Por exemplo, use o “ escuro = 1 ”Argumento com a função “torch.argmax()” para encontrar os índices dos valores máximos ao longo das linhas no tensor “Tens2” e então imprimir esses índices:

índice_linha = tocha. argmax ( Dezenas2 , escurecer = 1 )

imprimir ( 'Índices em linhas:' , índice_linha )

Os índices dos valores máximos ao longo de cada linha de um tensor “Tens2” podem ser vistos abaixo:

Explicamos com eficiência o método para usar o método “torch.argmax()” no PyTorch.

Observação : Você pode acessar nosso Google Colab Notebook neste link .

Conclusão

Para usar o método “torch.argmax()” no PyTorch, primeiro importe o “ tocha ' biblioteca. Em seguida, crie o tensor 1D ou 2D desejado e visualize seus elementos. Em seguida, use o “ tocha.argmax() ”Método para encontrar/calcular os índices/índices dos valores máximos no tensor. Além disso, os usuários também podem encontrar os índices do valor máximo ao longo de cada linha ou coluna do tensor usando o “ escurecer ”Argumento. Por fim, exiba o índice do valor máximo no tensor de entrada. Este blog exemplificou o método para usar o método “torch.argmax()” no PyTorch.