The "cursed multidimensional index rearrangement" is mostly for doing inner product in the weird way demonstrated. Granted, it proves author's point - you need to be fluent in NumPy to write this and can't take the Go approach of "I refuse to learn anything you better make me up to speed in 5 minutes".
The code could be just:
AiX = np.linalg.solve(A, X[:, np.newaxis, :, np.newaxis]).squeeze()
Z = np.vecdot(Y, AiX)
Or if you don't like NumPy 2.0:
AiX = np.linalg.solve(A, X[:, np.newaxis, :])
Z = np.einsum("jk,ijk->ij", Y, AiX)
The NumPy 1.x code is very intuitive, you have A[i, j, n] and X[i, n] and you want to use the same X[i] for all A[i, j], so just add an axis in the middle. Broadcast, which the author
very strongly refuse to understand, deals with all the messy parts.
Alternatively, if you hate `[:, np.newaxis, :]` syntax, you may do:
I, J, K, K = A.shape
AiX = np.linalg.solve(A, X.reshape(I, 1, K, 1)).squeeze()
Z = np.vecdot(Y, AiX)
I can read this faster than nested for-loops. YMMV.