spainn.loss
SPaiNN Loss Functions
- class loss.PhaseLossMSE(*args: Any, **kwargs: Any)[source]
The PhaseLossMSE class is a custom loss function for bulk properties that emerge when two distinct electronic states are coupled. It is intended for non-atomistic properties, such as dipoles, that have a shape of (batch_size*n_dipoles, xyz).
One example are transition dipoles, which are influenced by the dipole operator \(\hat{\mu}\) and read as
\[\mu_{ij}(\mathbf{R}) = \left\langle\Psi_i(\mathbf{R})|\hat{\mu}|\Psi_j(\mathbf{R})\right\rangle\]As the coupled wavefunctions \(\Psi_i\) and \(\Psi_j\) have arbitrary signs, also the resulting property possesses an arbitrary sign. The main feature of the customized loss function is that calculates a phase-independent loss. It implements a lossless Mean Square Error (MSE) cal- culation. The loss of each element multiplied by 1 or -1 is taken, the lowest value gets returned.
During calculation, the reference data and predictions (targets and inputs) are subtracted and added, respectively and all values are squared. The absolute values of these two tensors are computed and summed over the xyz-axis, resulting in two separate tensors: a positive tensor and a negative tensor. The minimum value between the positive and negative tensors is then computed, and the values are summed over all axes and divided by the total number of elements in the target.
The forward() method takes inputs and targets as arguments and re- turns a float, i.e., MSE loss value (L) as the result.
For dipoles of shape (\(N = [1, N_D, 3]\)), the PhaseLossMSE is defined as
\[\begin{split}\mathcal{L} = \frac{1}{3N} \sum_k^{N_D}\min_i\left(\sum_l^3 | D_k^{ref}1_2^{\top} - D_k^{pred}1_2^{\top}\odot \begin{pmatrix} 1 \\ -1 \end{pmatrix}^{\top} |_{l}^2\right)\end{split}\]- Parameters:
natoms – number of atoms
SchNarc Loss Functions
- class loss.PhysPhaseLoss(*args: Any, **kwargs: Any)[source]
The PhaseLossMSE class is a custom loss function for bulk properties that emerge when two distinct electronic states are coupled.
- Parameters:
natoms – number of atoms
mse – if true, calculates loss based on mean sqaured error, else mean absolute error