Module hub.core.chunk_engine.tests.common
Expand source code
import pickle
from typing import List, Tuple, Dict
import numpy as np
import pytest
from hub.core.chunk_engine import read_array, write_array
from hub.core.storage import MemoryProvider, S3Provider
from hub.core.typing import StorageProvider
from hub.tests.common import TENSOR_KEY
from hub.util.array import normalize_and_batchify_shape
from hub.util.keys import get_chunk_key, get_index_map_key, get_meta_key
from hub.util.s3 import has_s3_credentials
def get_min_shape(batch: np.ndarray) -> Tuple:
return tuple(np.minimum.reduce([sample.shape for sample in batch]))
def get_max_shape(batch: np.ndarray) -> Tuple:
return tuple(np.maximum.reduce([sample.shape for sample in batch]))
def get_random_array(shape: Tuple[int], dtype: str) -> np.ndarray:
dtype = dtype.lower()
if "int" in dtype:
low = np.iinfo(dtype).min
high = np.iinfo(dtype).max
return np.random.randint(low=low, high=high, size=shape, dtype=dtype)
if "float" in dtype:
# get float16 because np.random.uniform doesn't support the `dtype` argument.
low = np.finfo("float16").min
high = np.finfo("float16").max
return np.random.uniform(low=low, high=high, size=shape).astype(dtype)
if "bool" in dtype:
a = np.random.uniform(size=shape)
return a > 0.5
raise ValueError("Dtype %s not supported." % dtype)
def assert_meta_is_valid(meta: dict, expected_meta: dict):
for k, v in expected_meta.items():
assert k in meta
assert v == meta[k]
def assert_chunk_sizes(
key: str, index_map: List, chunk_size: int, storage: StorageProvider
):
incomplete_chunk_names = set()
complete_chunk_count = 0
total_chunks = 0
actual_chunk_lengths_dict: Dict[str, int] = {}
for i, entry in enumerate(index_map):
for j, chunk_name in enumerate(entry["chunk_names"]):
chunk_key = get_chunk_key(key, chunk_name)
chunk_length = len(storage[chunk_key])
# exceeding chunk_size is never acceptable
assert (
chunk_length <= chunk_size
), 'Chunk "%s" exceeded chunk_size=%i (got %i) @ [%i, %i].' % (
chunk_name,
chunk_size,
chunk_length,
i,
j,
)
if chunk_name in actual_chunk_lengths_dict:
assert (
chunk_length == actual_chunk_lengths_dict[chunk_name]
), "Chunk size changed from one read to another."
else:
actual_chunk_lengths_dict[chunk_name] = chunk_length
if chunk_length < chunk_size:
incomplete_chunk_names.add(chunk_name)
if chunk_length == chunk_size:
complete_chunk_count += 1
total_chunks += 1
incomplete_chunk_count = len(incomplete_chunk_names)
assert (
incomplete_chunk_count <= 1
), "Incomplete chunk count should never exceed 1. Incomplete count: %i. Complete count: %i. Total: %i.\nIncomplete chunk names: %s" % (
incomplete_chunk_count,
complete_chunk_count,
total_chunks,
str(incomplete_chunk_names),
)
# assert that all chunks (except the last one) are of expected size (`chunk_size`)
actual_chunk_lengths = np.array(list(actual_chunk_lengths_dict.values()))
if len(actual_chunk_lengths) > 1:
assert np.all(
actual_chunk_lengths[:-1] == chunk_size
), "All chunks (except the last one) MUST be == `chunk_size`. chunk_size=%i\n\nactual chunk sizes: %s\n\nactual chunk names: %s" % (
chunk_size,
str(actual_chunk_lengths[:-1]),
str(actual_chunk_lengths_dict.keys()),
)
def run_engine_test(
arrays: List[np.ndarray], storage: StorageProvider, batched: bool, chunk_size: int
):
key = TENSOR_KEY
for i, a_in in enumerate(arrays):
write_array(
a_in,
key,
storage,
chunk_size,
batched=batched,
)
index_map_key = get_index_map_key(key)
index_map = pickle.loads(storage[index_map_key])
assert_chunk_sizes(key, index_map, chunk_size, storage)
# `write_array` implicitly normalizes/batchifies shape
a_in = normalize_and_batchify_shape(a_in, batched=batched)
a_out = read_array(key=key, storage=storage)
meta_key = get_meta_key(key)
assert meta_key in storage, "Meta was not found."
# TODO: use get_meta
meta = pickle.loads(storage[meta_key])
assert_meta_is_valid(
meta,
{
"chunk_size": chunk_size,
"length": a_in.shape[0],
"dtype": a_in.dtype.name,
"min_shape": get_min_shape(a_in),
"max_shape": get_max_shape(a_in),
},
)
assert np.array_equal(a_in, a_out), "Array not equal @ batch_index=%i." % i
clear_if_memory_provider(storage)
def benchmark_write(
key, arrays, chunk_size, storage, batched, clear_memory_after_write=True
):
clear_if_memory_provider(storage)
for a_in in arrays:
write_array(
a_in,
key,
storage,
chunk_size,
batched=batched,
)
if clear_memory_after_write:
clear_if_memory_provider(storage)
def benchmark_read(key: str, storage: StorageProvider):
read_array(key, storage)
def skip_if_no_required_creds(storage: StorageProvider):
"""If `storage` is a StorageProvider that requires creds, and they are not found, skip the current test."""
if type(storage) == S3Provider:
if not has_s3_credentials():
pytest.skip()
def clear_if_memory_provider(storage: StorageProvider):
"""If `storage` is memory-based, clear it."""
if type(storage) == MemoryProvider:
storage.clear()
Functions
def assert_chunk_sizes(key: str, index_map: List, chunk_size: int, storage: StorageProvider)
-
Expand source code
def assert_chunk_sizes( key: str, index_map: List, chunk_size: int, storage: StorageProvider ): incomplete_chunk_names = set() complete_chunk_count = 0 total_chunks = 0 actual_chunk_lengths_dict: Dict[str, int] = {} for i, entry in enumerate(index_map): for j, chunk_name in enumerate(entry["chunk_names"]): chunk_key = get_chunk_key(key, chunk_name) chunk_length = len(storage[chunk_key]) # exceeding chunk_size is never acceptable assert ( chunk_length <= chunk_size ), 'Chunk "%s" exceeded chunk_size=%i (got %i) @ [%i, %i].' % ( chunk_name, chunk_size, chunk_length, i, j, ) if chunk_name in actual_chunk_lengths_dict: assert ( chunk_length == actual_chunk_lengths_dict[chunk_name] ), "Chunk size changed from one read to another." else: actual_chunk_lengths_dict[chunk_name] = chunk_length if chunk_length < chunk_size: incomplete_chunk_names.add(chunk_name) if chunk_length == chunk_size: complete_chunk_count += 1 total_chunks += 1 incomplete_chunk_count = len(incomplete_chunk_names) assert ( incomplete_chunk_count <= 1 ), "Incomplete chunk count should never exceed 1. Incomplete count: %i. Complete count: %i. Total: %i.\nIncomplete chunk names: %s" % ( incomplete_chunk_count, complete_chunk_count, total_chunks, str(incomplete_chunk_names), ) # assert that all chunks (except the last one) are of expected size (`chunk_size`) actual_chunk_lengths = np.array(list(actual_chunk_lengths_dict.values())) if len(actual_chunk_lengths) > 1: assert np.all( actual_chunk_lengths[:-1] == chunk_size ), "All chunks (except the last one) MUST be == `chunk_size`. chunk_size=%i\n\nactual chunk sizes: %s\n\nactual chunk names: %s" % ( chunk_size, str(actual_chunk_lengths[:-1]), str(actual_chunk_lengths_dict.keys()), )
def assert_meta_is_valid(meta: dict, expected_meta: dict)
-
Expand source code
def assert_meta_is_valid(meta: dict, expected_meta: dict): for k, v in expected_meta.items(): assert k in meta assert v == meta[k]
def benchmark_read(key: str, storage: StorageProvider)
-
Expand source code
def benchmark_read(key: str, storage: StorageProvider): read_array(key, storage)
def benchmark_write(key, arrays, chunk_size, storage, batched, clear_memory_after_write=True)
-
Expand source code
def benchmark_write( key, arrays, chunk_size, storage, batched, clear_memory_after_write=True ): clear_if_memory_provider(storage) for a_in in arrays: write_array( a_in, key, storage, chunk_size, batched=batched, ) if clear_memory_after_write: clear_if_memory_provider(storage)
def clear_if_memory_provider(storage: StorageProvider)
-
If
storage
is memory-based, clear it.Expand source code
def clear_if_memory_provider(storage: StorageProvider): """If `storage` is memory-based, clear it.""" if type(storage) == MemoryProvider: storage.clear()
def get_max_shape(batch: numpy.ndarray) ‑> Tuple
-
Expand source code
def get_max_shape(batch: np.ndarray) -> Tuple: return tuple(np.maximum.reduce([sample.shape for sample in batch]))
def get_min_shape(batch: numpy.ndarray) ‑> Tuple
-
Expand source code
def get_min_shape(batch: np.ndarray) -> Tuple: return tuple(np.minimum.reduce([sample.shape for sample in batch]))
def get_random_array(shape: Tuple[int], dtype: str) ‑> numpy.ndarray
-
Expand source code
def get_random_array(shape: Tuple[int], dtype: str) -> np.ndarray: dtype = dtype.lower() if "int" in dtype: low = np.iinfo(dtype).min high = np.iinfo(dtype).max return np.random.randint(low=low, high=high, size=shape, dtype=dtype) if "float" in dtype: # get float16 because np.random.uniform doesn't support the `dtype` argument. low = np.finfo("float16").min high = np.finfo("float16").max return np.random.uniform(low=low, high=high, size=shape).astype(dtype) if "bool" in dtype: a = np.random.uniform(size=shape) return a > 0.5 raise ValueError("Dtype %s not supported." % dtype)
def run_engine_test(arrays: List[numpy.ndarray], storage: StorageProvider, batched: bool, chunk_size: int)
-
Expand source code
def run_engine_test( arrays: List[np.ndarray], storage: StorageProvider, batched: bool, chunk_size: int ): key = TENSOR_KEY for i, a_in in enumerate(arrays): write_array( a_in, key, storage, chunk_size, batched=batched, ) index_map_key = get_index_map_key(key) index_map = pickle.loads(storage[index_map_key]) assert_chunk_sizes(key, index_map, chunk_size, storage) # `write_array` implicitly normalizes/batchifies shape a_in = normalize_and_batchify_shape(a_in, batched=batched) a_out = read_array(key=key, storage=storage) meta_key = get_meta_key(key) assert meta_key in storage, "Meta was not found." # TODO: use get_meta meta = pickle.loads(storage[meta_key]) assert_meta_is_valid( meta, { "chunk_size": chunk_size, "length": a_in.shape[0], "dtype": a_in.dtype.name, "min_shape": get_min_shape(a_in), "max_shape": get_max_shape(a_in), }, ) assert np.array_equal(a_in, a_out), "Array not equal @ batch_index=%i." % i clear_if_memory_provider(storage)
def skip_if_no_required_creds(storage: StorageProvider)
-
If
storage
is a StorageProvider that requires creds, and they are not found, skip the current test.Expand source code
def skip_if_no_required_creds(storage: StorageProvider): """If `storage` is a StorageProvider that requires creds, and they are not found, skip the current test.""" if type(storage) == S3Provider: if not has_s3_credentials(): pytest.skip()