from __future__ import annotations
import builtins as __builtins__
import cupy as cp
import cupy._core.raw
import typing
__all__ = ['cp', 'reconstruct_pattern_kernel', 'reconstruct_pattern_kernel_code', 'reconstruct_shared_kernel', 'reconstruct_shared_kernel_code', 'reconstruct_vectorized_kernel', 'reconstruct_vectorized_kernel_code', 'subsample_image_back_shared', 'subsample_image_back_vectorized']
def subsample_image_back_shared(subsampled_images: cp.ndarray | list[cp.ndarray], dim: int) -> cp.ndarray:
    """
    Shared memory optimized version
    """
def subsample_image_back_vectorized(subsampled_images: cp.ndarray | list[cp.ndarray], dim: int) -> cp.ndarray:
    """
    Vectorized version optimized for RGB images
    """
__test__: dict = {}
reconstruct_pattern_kernel: cupy._core.raw.RawKernel  # value = <cupy._core.raw.RawKernel object>
reconstruct_pattern_kernel_code: str = '\nextern "C" __global__ void reconstruct_pattern_kernel(\n    const float* __restrict__ input,\n    float* __restrict__ output,\n    const int* __restrict__ pattern,\n    const int pattern_size,\n    const int total_pixels,\n    const int channels\n) {\n    const int pixel_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (pixel_idx >= total_pixels) return;\n    \n    const int pattern_idx = pixel_idx % pattern_size;\n    const int pattern_offset = pixel_idx / pattern_size;\n    \n    const int input_offset = pattern[pattern_idx] + pattern_offset * pattern_size;\n    const int output_offset = pixel_idx * channels;\n    \n    // Unrolled copy for common channel counts\n    if (channels == 3) {\n        output[output_offset] = input[input_offset];\n        output[output_offset + 1] = input[input_offset + total_pixels];\n        output[output_offset + 2] = input[input_offset + 2 * total_pixels];\n    } else {\n        for (int c = 0; c < channels; c++) {\n            output[output_offset + c] = input[input_offset + c * total_pixels];\n        }\n    }\n}\n'
reconstruct_shared_kernel: cupy._core.raw.RawKernel  # value = <cupy._core.raw.RawKernel object>
reconstruct_shared_kernel_code: str = '\n#define TILE_DIM 32\n#define BLOCK_ROWS 8\n\nextern "C" __global__ void reconstruct_shared_kernel(\n    const float* __restrict__ input,  // NCHW format\n    float* __restrict__ output,       // HWC format\n    const int batch_size,\n    const int channels,\n    const int input_height,\n    const int input_width,\n    const int output_height,\n    const int output_width,\n    const int dim\n) {\n    __shared__ float tile[TILE_DIM][TILE_DIM+1];  // +1 to avoid bank conflicts\n    \n    int x = blockIdx.x * TILE_DIM + threadIdx.x;\n    int y = blockIdx.y * TILE_DIM + threadIdx.y;\n    \n    // Process multiple elements per thread\n    for (int k = 0; k < TILE_DIM; k += BLOCK_ROWS) {\n        int yIndex = y + k;\n        \n        if (x < output_width && yIndex < output_height) {\n            // Calculate subsample indices\n            int dy = yIndex % dim;\n            int dx = x % dim;\n            int subsample_idx = dy * dim + dx;\n            int y_in = yIndex / dim;\n            int x_in = x / dim;\n            \n            // Process all channels\n            for (int c = 0; c < channels; c++) {\n                // Read from NCHW input\n                int idx_in = subsample_idx * channels * input_height * input_width +\n                            c * input_height * input_width +\n                            y_in * input_width + x_in;\n                \n                // Write to HWC output\n                int idx_out = (yIndex * output_width + x) * channels + c;\n                \n                output[idx_out] = input[idx_in];\n            }\n        }\n    }\n}\n'
reconstruct_vectorized_kernel: cupy._core.raw.RawKernel  # value = <cupy._core.raw.RawKernel object>
reconstruct_vectorized_kernel_code: str = '\nextern "C" __global__ void reconstruct_vectorized_kernel(\n    const float* __restrict__ input,  // NCHW format\n    float* __restrict__ output,       // HWC format\n    const int input_height,\n    const int input_width,\n    const int output_height,\n    const int output_width,\n    const int dim\n) {\n    const int x = blockIdx.x * blockDim.x + threadIdx.x;\n    const int y = blockIdx.y * blockDim.y + threadIdx.y;\n    \n    if (x >= output_width || y >= output_height) return;\n    \n    // Calculate subsample indices\n    const int dy = y % dim;\n    const int dx = x % dim;\n    const int subsample_idx = dy * dim + dx;\n    const int y_in = y / dim;\n    const int x_in = x / dim;\n    \n    // Calculate base indices\n    const int input_pixel_base = subsample_idx * 3 * input_height * input_width +\n                                y_in * input_width + x_in;\n    const int output_pixel_base = (y * output_width + x) * 3;\n    \n    // Vectorized load and store for RGB\n    float3 pixel;\n    pixel.x = input[input_pixel_base];\n    pixel.y = input[input_pixel_base + input_height * input_width];\n    pixel.z = input[input_pixel_base + 2 * input_height * input_width];\n    \n    output[output_pixel_base] = pixel.x;\n    output[output_pixel_base + 1] = pixel.y;\n    output[output_pixel_base + 2] = pixel.z;\n}\n'
