# Get indices of numpy.argmax elements over an axis

I have N-dimensional matrix which contains the values for a function with N parameters. Each parameter has a discrete number of values. I need to maximize the function over all parameters but one, resulting in a one-dimensional vector of size equal to the number of values of the non-maximized parameter. I also need to save which values are taken by the other parameters.

To do so I wanted to iteratively apply numpy.max over different axes to reduce the dimensionality of the matrix to find what I need. The final vector will then depend on just the parameter I left out.

I'm however having trouble finding the original indices of the final elements (which contain the information about the values taken by the other parameters). I though about using numpy.argmax in the same way as numpy.max but I can't obtain back the original indices.

An example of what I'm trying is:

x = [[[1,2],[0,1]],[[3,4],[6,7]]]
args = np.argmax(x, 0)

This returns

[[1 1]
[1 1]]

Which means that argmax is selecting the elements (2,1,4,7) within the original matrix. But how to get their indices? I tried unravel_index, using the args directly as an index for matrix x, a bunch of functions from numpy to index with no success.

Using numpy.where is not a solution since the input matrix may have equal values inside, so I would not be able to discern from different original values.

x.argmax(0) gives the indexes along the 1st axis for the maximum values. Use np.indices to generate the indices for the other axis.

x = np.array([[[1,2],[0,1]],[[3,4],[6,7]]])
x.argmax(0)
array([[1, 1],
[1, 1]])
a1, a2 = np.indices((2,2))
(x.argmax(0),a1,a2)
(array([[1, 1],
[1, 1]]),
array([[0, 0],
[1, 1]]),
array([[0, 1],
[0, 1]]))

x[x.argmax(0),a1,a2]
array([[3, 4],
[6, 7]])

x[a1,x.argmax(1),a2]
array([[1, 2],
[6, 7]])

x[a1,a2,x.argmax(2)]
array([[2, 1],
[4, 7]])

If x has other dimensions, generate a1, and a2 appropriately.

The official documentation does not say much about how to use argmax, but earlier SO threads have discussed it. I got this general idea from Using numpy.argmax() on multidimensional arrays