Source code for pyharmx.functions

import torch

_EPSILON = 1e-10


# Nonlinear kernel function
# ====================================
[docs]def get_phi(order: int) -> callable: r""" Coordinate-wise nonlinear function used to define the order of the interpolation. See `here <https://en.wikipedia.org/wiki/Polyharmonic_spline>`_ for the definition. :param order: Interpolation order. :type order: int :return: Coordinate-wise nonlinear kernel :math:`{\phi}`. :rtype: callable """ if (order == 1): return _phi_1 elif (order == 2): return _phi_2 elif (order == 4): return _phi_4 elif (order % 2 == 0): return lambda r: _phi_even(r, order) else: return lambda r: _phi_odd(r, order)
def _phi_1(r: torch.Tensor) -> torch.Tensor: r_eps = torch.clamp(r, min=_EPSILON) return torch.sqrt(r_eps) def _phi_2(r: torch.Tensor) -> torch.Tensor: r_eps = torch.clamp(r, min=_EPSILON) return 0.5 * r * torch.log(r_eps) def _phi_4(r: torch.Tensor) -> torch.Tensor: r_eps = torch.clamp(r, min=_EPSILON) return 0.5 * torch.square(r) * torch.log(r_eps) def _phi_even(r: torch.Tensor, order: int) -> torch.Tensor: r_eps = torch.clamp(r, min=_EPSILON) return 0.5 * torch.pow(r_eps, 0.5 * order) * torch.log(r_eps) def _phi_odd(r: torch.Tensor, order: int) -> torch.Tensor: r_eps = torch.clamp(r, min=_EPSILON) return torch.pow(r_eps, 0.5 * order) # Tensor operation - Distance matrix # ====================================
[docs]def cross_squared_distance_matrix( x: torch.Tensor, y: torch.Tensor ) -> torch.Tensor: """ Pairwise squared distance between two (batch) matrices' rows (2nd dimension). Computes the pairwise distances between rows of `x` and rows of `y`. :param x: 3D tensor with shape `[batch_size, n, d]`. :type x: torch.Tensor :param y: 3D tensor with shape `[batch_size, m, d]`. :type y: torch.Tensor :return: 3D tensor with shape `[batch_size, n, m]`. Each element represents the squared Euclidean distance between vectors `x[b, i, :]` and `y[b, j, :]`. :rtype: torch.Tensor """ # Compute quadratic norm x_sq_norm = torch.sum(torch.square(x), dim=2, keepdim=False) y_sq_norm = torch.sum(torch.square(y), dim=2, keepdim=False) # Increase rank x_sq_norm = torch.unsqueeze(x_sq_norm, dim=2) y_sq_norm = torch.unsqueeze(y_sq_norm, dim=1) # Perform matrix multiplication x_yt = torch.matmul(x, torch.permute(y, dims=(0,2,1))) # Compute squared distance return x_sq_norm - 2 * x_yt + y_sq_norm
[docs]def pairwise_squared_distance_matrix( x: torch.Tensor ) -> torch.Tensor: """ Compute pairwise squared distance among a (batch) matrix's rows (2nd dimension). It is faster than `cross_squared_distance_matrix`. :param x: 3D tensor with shape `[batch_size, n, d]`. :type x: torch.Tensor :return: 3D tensor with shape `[batch_size, n, n]`. Each element represents the squared Euclidean distance between vectors `x[b, i, :]` and `x[b, j, :]`. :rtype: torch.Tensor """ # Compute quadratic values x_xt = torch.matmul(x, torch.permute(x, dims=(0,2,1))) # Extract batch diagonal x_xt_diag = torch.diagonal(x_xt, offset=0, dim1=-2, dim2=-1) # Increase rank x_xt_diag = torch.unsqueeze(x_xt_diag, dim=2) # Compute squared distance return x_xt_diag - 2 * x_xt + torch.permute(x_xt_diag, dims=(0,2,1))