Module hub.core.chunk_engine.tests.common

Expand source code
import pickle
from typing import List, Tuple, Dict

import numpy as np
import pytest

from hub.core.chunk_engine import read_array, write_array
from hub.core.storage import MemoryProvider, S3Provider
from hub.core.typing import StorageProvider
from hub.tests.common import TENSOR_KEY
from hub.util.array import normalize_and_batchify_shape
from hub.util.keys import get_chunk_key, get_index_map_key, get_meta_key
from hub.util.s3 import has_s3_credentials


def get_min_shape(batch: np.ndarray) -> Tuple:
    return tuple(np.minimum.reduce([sample.shape for sample in batch]))


def get_max_shape(batch: np.ndarray) -> Tuple:
    return tuple(np.maximum.reduce([sample.shape for sample in batch]))


def get_random_array(shape: Tuple[int], dtype: str) -> np.ndarray:
    dtype = dtype.lower()

    if "int" in dtype:
        low = np.iinfo(dtype).min
        high = np.iinfo(dtype).max
        return np.random.randint(low=low, high=high, size=shape, dtype=dtype)

    if "float" in dtype:
        # get float16 because np.random.uniform doesn't support the `dtype` argument.
        low = np.finfo("float16").min
        high = np.finfo("float16").max
        return np.random.uniform(low=low, high=high, size=shape).astype(dtype)

    if "bool" in dtype:
        a = np.random.uniform(size=shape)
        return a > 0.5

    raise ValueError("Dtype %s not supported." % dtype)


def assert_meta_is_valid(meta: dict, expected_meta: dict):
    for k, v in expected_meta.items():
        assert k in meta
        assert v == meta[k]


def assert_chunk_sizes(
    key: str, index_map: List, chunk_size: int, storage: StorageProvider
):
    incomplete_chunk_names = set()
    complete_chunk_count = 0
    total_chunks = 0
    actual_chunk_lengths_dict: Dict[str, int] = {}
    for i, entry in enumerate(index_map):
        for j, chunk_name in enumerate(entry["chunk_names"]):
            chunk_key = get_chunk_key(key, chunk_name)
            chunk_length = len(storage[chunk_key])

            # exceeding chunk_size is never acceptable
            assert (
                chunk_length <= chunk_size
            ), 'Chunk "%s" exceeded chunk_size=%i (got %i) @ [%i, %i].' % (
                chunk_name,
                chunk_size,
                chunk_length,
                i,
                j,
            )

            if chunk_name in actual_chunk_lengths_dict:
                assert (
                    chunk_length == actual_chunk_lengths_dict[chunk_name]
                ), "Chunk size changed from one read to another."
            else:
                actual_chunk_lengths_dict[chunk_name] = chunk_length

            if chunk_length < chunk_size:
                incomplete_chunk_names.add(chunk_name)
            if chunk_length == chunk_size:
                complete_chunk_count += 1

            total_chunks += 1

    incomplete_chunk_count = len(incomplete_chunk_names)
    assert (
        incomplete_chunk_count <= 1
    ), "Incomplete chunk count should never exceed 1. Incomplete count: %i. Complete count: %i. Total: %i.\nIncomplete chunk names: %s" % (
        incomplete_chunk_count,
        complete_chunk_count,
        total_chunks,
        str(incomplete_chunk_names),
    )

    # assert that all chunks (except the last one) are of expected size (`chunk_size`)
    actual_chunk_lengths = np.array(list(actual_chunk_lengths_dict.values()))
    if len(actual_chunk_lengths) > 1:
        assert np.all(
            actual_chunk_lengths[:-1] == chunk_size
        ), "All chunks (except the last one) MUST be == `chunk_size`. chunk_size=%i\n\nactual chunk sizes: %s\n\nactual chunk names: %s" % (
            chunk_size,
            str(actual_chunk_lengths[:-1]),
            str(actual_chunk_lengths_dict.keys()),
        )


def run_engine_test(
    arrays: List[np.ndarray], storage: StorageProvider, batched: bool, chunk_size: int
):
    key = TENSOR_KEY

    for i, a_in in enumerate(arrays):
        write_array(
            a_in,
            key,
            storage,
            chunk_size,
            batched=batched,
        )

        index_map_key = get_index_map_key(key)
        index_map = pickle.loads(storage[index_map_key])

        assert_chunk_sizes(key, index_map, chunk_size, storage)

        # `write_array` implicitly normalizes/batchifies shape
        a_in = normalize_and_batchify_shape(a_in, batched=batched)

        a_out = read_array(key=key, storage=storage)

        meta_key = get_meta_key(key)
        assert meta_key in storage, "Meta was not found."
        # TODO: use get_meta
        meta = pickle.loads(storage[meta_key])

        assert_meta_is_valid(
            meta,
            {
                "chunk_size": chunk_size,
                "length": a_in.shape[0],
                "dtype": a_in.dtype.name,
                "min_shape": get_min_shape(a_in),
                "max_shape": get_max_shape(a_in),
            },
        )

        assert np.array_equal(a_in, a_out), "Array not equal @ batch_index=%i." % i

    clear_if_memory_provider(storage)


def benchmark_write(
    key, arrays, chunk_size, storage, batched, clear_memory_after_write=True
):
    clear_if_memory_provider(storage)

    for a_in in arrays:
        write_array(
            a_in,
            key,
            storage,
            chunk_size,
            batched=batched,
        )

    if clear_memory_after_write:
        clear_if_memory_provider(storage)


def benchmark_read(key: str, storage: StorageProvider):
    read_array(key, storage)


def skip_if_no_required_creds(storage: StorageProvider):
    """If `storage` is a StorageProvider that requires creds, and they are not found, skip the current test."""

    if type(storage) == S3Provider:
        if not has_s3_credentials():
            pytest.skip()


def clear_if_memory_provider(storage: StorageProvider):
    """If `storage` is memory-based, clear it."""

    if type(storage) == MemoryProvider:
        storage.clear()

Functions

def assert_chunk_sizes(key: str, index_map: List, chunk_size: int, storage: StorageProvider)
Expand source code
def assert_chunk_sizes(
    key: str, index_map: List, chunk_size: int, storage: StorageProvider
):
    incomplete_chunk_names = set()
    complete_chunk_count = 0
    total_chunks = 0
    actual_chunk_lengths_dict: Dict[str, int] = {}
    for i, entry in enumerate(index_map):
        for j, chunk_name in enumerate(entry["chunk_names"]):
            chunk_key = get_chunk_key(key, chunk_name)
            chunk_length = len(storage[chunk_key])

            # exceeding chunk_size is never acceptable
            assert (
                chunk_length <= chunk_size
            ), 'Chunk "%s" exceeded chunk_size=%i (got %i) @ [%i, %i].' % (
                chunk_name,
                chunk_size,
                chunk_length,
                i,
                j,
            )

            if chunk_name in actual_chunk_lengths_dict:
                assert (
                    chunk_length == actual_chunk_lengths_dict[chunk_name]
                ), "Chunk size changed from one read to another."
            else:
                actual_chunk_lengths_dict[chunk_name] = chunk_length

            if chunk_length < chunk_size:
                incomplete_chunk_names.add(chunk_name)
            if chunk_length == chunk_size:
                complete_chunk_count += 1

            total_chunks += 1

    incomplete_chunk_count = len(incomplete_chunk_names)
    assert (
        incomplete_chunk_count <= 1
    ), "Incomplete chunk count should never exceed 1. Incomplete count: %i. Complete count: %i. Total: %i.\nIncomplete chunk names: %s" % (
        incomplete_chunk_count,
        complete_chunk_count,
        total_chunks,
        str(incomplete_chunk_names),
    )

    # assert that all chunks (except the last one) are of expected size (`chunk_size`)
    actual_chunk_lengths = np.array(list(actual_chunk_lengths_dict.values()))
    if len(actual_chunk_lengths) > 1:
        assert np.all(
            actual_chunk_lengths[:-1] == chunk_size
        ), "All chunks (except the last one) MUST be == `chunk_size`. chunk_size=%i\n\nactual chunk sizes: %s\n\nactual chunk names: %s" % (
            chunk_size,
            str(actual_chunk_lengths[:-1]),
            str(actual_chunk_lengths_dict.keys()),
        )
def assert_meta_is_valid(meta: dict, expected_meta: dict)
Expand source code
def assert_meta_is_valid(meta: dict, expected_meta: dict):
    for k, v in expected_meta.items():
        assert k in meta
        assert v == meta[k]
def benchmark_read(key: str, storage: StorageProvider)
Expand source code
def benchmark_read(key: str, storage: StorageProvider):
    read_array(key, storage)
def benchmark_write(key, arrays, chunk_size, storage, batched, clear_memory_after_write=True)
Expand source code
def benchmark_write(
    key, arrays, chunk_size, storage, batched, clear_memory_after_write=True
):
    clear_if_memory_provider(storage)

    for a_in in arrays:
        write_array(
            a_in,
            key,
            storage,
            chunk_size,
            batched=batched,
        )

    if clear_memory_after_write:
        clear_if_memory_provider(storage)
def clear_if_memory_provider(storage: StorageProvider)

If storage is memory-based, clear it.

Expand source code
def clear_if_memory_provider(storage: StorageProvider):
    """If `storage` is memory-based, clear it."""

    if type(storage) == MemoryProvider:
        storage.clear()
def get_max_shape(batch: numpy.ndarray) ‑> Tuple
Expand source code
def get_max_shape(batch: np.ndarray) -> Tuple:
    return tuple(np.maximum.reduce([sample.shape for sample in batch]))
def get_min_shape(batch: numpy.ndarray) ‑> Tuple
Expand source code
def get_min_shape(batch: np.ndarray) -> Tuple:
    return tuple(np.minimum.reduce([sample.shape for sample in batch]))
def get_random_array(shape: Tuple[int], dtype: str) ‑> numpy.ndarray
Expand source code
def get_random_array(shape: Tuple[int], dtype: str) -> np.ndarray:
    dtype = dtype.lower()

    if "int" in dtype:
        low = np.iinfo(dtype).min
        high = np.iinfo(dtype).max
        return np.random.randint(low=low, high=high, size=shape, dtype=dtype)

    if "float" in dtype:
        # get float16 because np.random.uniform doesn't support the `dtype` argument.
        low = np.finfo("float16").min
        high = np.finfo("float16").max
        return np.random.uniform(low=low, high=high, size=shape).astype(dtype)

    if "bool" in dtype:
        a = np.random.uniform(size=shape)
        return a > 0.5

    raise ValueError("Dtype %s not supported." % dtype)
def run_engine_test(arrays: List[numpy.ndarray], storage: StorageProvider, batched: bool, chunk_size: int)
Expand source code
def run_engine_test(
    arrays: List[np.ndarray], storage: StorageProvider, batched: bool, chunk_size: int
):
    key = TENSOR_KEY

    for i, a_in in enumerate(arrays):
        write_array(
            a_in,
            key,
            storage,
            chunk_size,
            batched=batched,
        )

        index_map_key = get_index_map_key(key)
        index_map = pickle.loads(storage[index_map_key])

        assert_chunk_sizes(key, index_map, chunk_size, storage)

        # `write_array` implicitly normalizes/batchifies shape
        a_in = normalize_and_batchify_shape(a_in, batched=batched)

        a_out = read_array(key=key, storage=storage)

        meta_key = get_meta_key(key)
        assert meta_key in storage, "Meta was not found."
        # TODO: use get_meta
        meta = pickle.loads(storage[meta_key])

        assert_meta_is_valid(
            meta,
            {
                "chunk_size": chunk_size,
                "length": a_in.shape[0],
                "dtype": a_in.dtype.name,
                "min_shape": get_min_shape(a_in),
                "max_shape": get_max_shape(a_in),
            },
        )

        assert np.array_equal(a_in, a_out), "Array not equal @ batch_index=%i." % i

    clear_if_memory_provider(storage)
def skip_if_no_required_creds(storage: StorageProvider)

If storage is a StorageProvider that requires creds, and they are not found, skip the current test.

Expand source code
def skip_if_no_required_creds(storage: StorageProvider):
    """If `storage` is a StorageProvider that requires creds, and they are not found, skip the current test."""

    if type(storage) == S3Provider:
        if not has_s3_credentials():
            pytest.skip()