from . import base as base_impl
from functools import partial
from xbarray.backends.jax import JaxComputeBackend as BindingBackend

__all__ = [
    "pixel_coordinate_and_depth_to_world",
    "depth_image_to_world",
    "world_to_pixel_coordinate_and_depth",
    "world_to_depth",
    "farthest_point_sampling",
    "random_point_sampling",
]

pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
farthest_point_sampling = partial(base_impl.farthest_point_sampling, BindingBackend)
random_point_sampling = partial(base_impl.random_point_sampling, BindingBackend)