Source code for mprod._base

import numpy as np
from typing import Callable, Tuple, Dict, List

import scipy.fft
from scipy.fft import dct, idct, rfft, irfft
from scipy.stats import ortho_group

NumpynDArray = np.ndarray
MatrixTensorProduct = Callable[[NumpynDArray], NumpynDArray]


def _default_transform(tube_size: int) -> Tuple[MatrixTensorProduct, MatrixTensorProduct]:
    def fun_m(x):
        return dct(x, type=2, n=tube_size, axis=-1, norm='ortho')

    def inv_m(x):
        return idct(x, type=2, n=tube_size, axis=-1, norm='ortho')

    return fun_m, inv_m


def generate_dct(tube_size: int, dct_type: int = 2) -> Tuple[MatrixTensorProduct, MatrixTensorProduct]:
    """Generates a DCT based tensor-matrix operation (forward and inverse)

    Parameters
    ----------

    tube_size: int
        the fiber-tube size of the tensors of interest

    dct_type: int, default = 2
        The choice of dct type, see scipy.fft.dct.__doc__ for details

    Returns
    -------

    fun_m: MatrixTensorProduct
        A tensor transform

    inv_m: MatrixTensorProduct
        A tensor transform (the inverse of `fun_m`)

    """

    def fun_m(x):
        return dct(x, type=dct_type, n=tube_size, axis=-1, norm='ortho')

    def inv_m(x):
        return idct(x, type=dct_type, n=tube_size, axis=-1, norm='ortho')

    return fun_m, inv_m


# noinspection PyPep8Naming
def _mod3prod(A: NumpynDArray, funM: MatrixTensorProduct) -> NumpynDArray:
    """Maps a tensor `A` to the tensor domain transform defined by the operation of a mapping `funM` on
    the tube fibers of `A`

    Parameters
    ----------

    A: NumpynDArray
        Tensor with `A.shape[2] == n`

    funM: MatrixTensorProduct
        Picklable mapping that operates on (n dimensional) tube fibers of a tensor

    Returns
    -------

    hatA: MatrixTensorProduct
        Returns domain transform of `A` defined by the operation of `funM`
    """
    m, p, n = A.shape
    return funM(A.transpose((2, 1, 0)).reshape(n, m * p)).reshape((n, p, m)).transpose((2, 1, 0))


[docs]def x_m3(M: NumpynDArray) -> MatrixTensorProduct: """ Creates a picklable tensor transformation forming the mod3 tensor-matrix multiplication required in the M product definition. Parameters ---------- M: np.ndarray A matrix of shape `(n,n)` Returns ------- fun: Callable[[NumpynDArray], NumpynDArray] Picklable mapping that operates on (n dimensional) tube fibers of a tensor """ def fun(A: NumpynDArray) -> NumpynDArray: try: m, p, n = A.shape return (M @ A.transpose((2, 1, 0)).reshape(n, m * p)).reshape((n, p, m)).transpose((2, 1, 0)) except ValueError as ve: return M @ A return fun
def generate_haar(tube_size: int, random_state = None) -> Tuple[MatrixTensorProduct, MatrixTensorProduct]: """Generates a tensor-matrix transformation based on random sampling of unitary matrix (according to the Haar distribution on O_n See scipy.stats.) Parameters ---------- tube_size: int the fiber-tube size of the tensors of interest Returns ------- fun_m: MatrixTensorProduct A tensor transform inv_m: MatrixTensorProduct A tensor transform (the inverse of `fun_m`) """ M = ortho_group.rvs(tube_size, random_state=random_state) fun_m = x_m3(M) inv_m = x_m3(M.T) return fun_m, inv_m
[docs]def m_prod(tens_a: NumpynDArray, tens_b: NumpynDArray, fun_m: MatrixTensorProduct, inv_m: MatrixTensorProduct) -> NumpynDArray: """ Returns the :math:`\\star_{\\mathbf{M}}` product of tensors `A` and `B` where ``A.shape == (m,p,n)`` and ``B.shape == (p,r,n)``. Parameters ---------- tens_a: array-like 3'rd order tensor with shape `m x p x n` tens_b: array-like 3'rd order tensor with shape `p x r x n` fun_m: MatrixTensorProduct, Callable[[NumpynDArray], NumpynDArray] Invertible linear mapping from `R^n` to `R^n` inv_m: MatrixTensorProduct, Callable[[NumpynDArray], NumpynDArray] Invertible linear mapping from R^n to R^n ( `fun_m(inv_m(x)) = inv_m(fun_m(x)) = x` ) Returns ------- tensor: array-like 3'rd order tensor of shape `m x r x n` that is the star :math:`\\star_{\\mathbf{M}}` product of `A` and `B` """ assert tens_a.shape[1] == tens_b.shape[0] assert tens_a.shape[-1] == tens_b.shape[-1] a_hat = fun_m(tens_a) b_hat = fun_m(tens_b) c_hat = np.einsum('mpi,pli->mli', a_hat, b_hat) return inv_m(c_hat)
# copied version from transformers.py # def m_prod(A: NumpynDArray, B: NumpynDArray, funM: MatrixTensorProduct, invM: MatrixTensorProduct) -> NumpynDArray: # # assert A.shape[1] == B.shape[0] # # assert A.shape[-1] == B.shape[-1] # A_hat = funM(A) # B_hat = funM(B) # # calE_hat = np.einsum('mpi,pli->mli', A_hat, B_hat) # return invM(calE_hat)
[docs]def tensor_mtranspose(tensor, mfun, minv): tensor_hat = mfun(tensor) tensor_hat_t = tensor_hat.transpose((1, 0, 2)) tensor_t = minv(tensor_hat_t) return tensor_t
def _t_pinv_fdiag(F, Mfun, Minv) -> NumpynDArray: m, p, n = F.shape hat_f = Mfun(F) pinv_hat_f = np.zeros_like(hat_f) for i in range(n): fi_diag = np.diagonal(hat_f[:, :, i]).copy() fi_diag[(fi_diag ** 2) > 1e-6] = 1 / fi_diag[(fi_diag ** 2) > 1e-6] pinv_hat_f[:fi_diag.size, :fi_diag.size, i] = np.diag(fi_diag) pinv_f = Minv(pinv_hat_f) return tensor_mtranspose(pinv_f, Mfun, Minv) # # TODO: Is TensorArray needed ? # # noinspection PyPep8Naming # class TensorArray(np.ndarray): # def __new__(cls, input_array): # # Input array is an already formed ndarray instance # # We first cast to be our class type # obj = np.asarray(input_array).view(cls) # # add the new attribute to the created instance # # Finally, we must return the newly created object: # return obj # # @property # def TT(self): # return self.transpose((1, 0, 2)) # # def __array_finalize__(self, obj): # # see InfoArray.__array_finalize__ for comments # if obj is None: return # self.info = getattr(obj, 'info', None)