Module hub.util.array
Expand source code
import os
import numpy as np
from typing import Tuple
def normalize_and_batchify_shape(array: np.ndarray, batched: bool) -> np.ndarray:
"""Remove all `array.shape` axes with a value of 1 & add a batch dimension if needed.
Example 1:
input_array.shape = (10, 1, 5)
batched = False
output_array.shape = (1, 10, 5) # batch axis is added
Example 2:
input_array.shape = (1, 100, 1, 1, 3)
batched = True
output_array.shape = (1, 100, 3) # batch axis is preserved
Args:
array (np.ndarray): Array that will have it's shape normalized/batchified.
batched (bool): If True, `array.shape[0]` is assumed to be the batch axis. If False,
an axis will be added such that `array.shape[0] == 1`.
Returns:
np.ndarray: Array with a guarenteed batch dimension. `out_array.shape[1:]` will always be > 1.
`out_array.shape[0]` may be >= 1.
"""
if batched:
# Don't squeeze the primary axis, even if it's 1
squeeze_axes = tuple(i + 1 for i, s in enumerate(array.shape[1:]) if s == 1)
array = array.squeeze(squeeze_axes)
else:
array = array.squeeze()
array = np.expand_dims(array, axis=0)
# If we squeezed everything except the primary axis, append one dimension of length 1
if len(array.shape) == 1:
array = np.expand_dims(array, axis=1)
return array
Functions
def normalize_and_batchify_shape(array: numpy.ndarray, batched: bool) ‑> numpy.ndarray
-
Remove all
array.shape
axes with a value of 1 & add a batch dimension if needed.Example 1: input_array.shape = (10, 1, 5) batched = False output_array.shape = (1, 10, 5) # batch axis is added
Example 2: input_array.shape = (1, 100, 1, 1, 3) batched = True output_array.shape = (1, 100, 3) # batch axis is preserved
Args
array
:np.ndarray
- Array that will have it's shape normalized/batchified.
batched
:bool
- If True,
array.shape[0]
is assumed to be the batch axis. If False, an axis will be added such thatarray.shape[0] == 1
.
Returns
np.ndarray
- Array with a guarenteed batch dimension.
out_array.shape[1:]
will always be > 1.out_array.shape[0]
may be >= 1.
Expand source code
def normalize_and_batchify_shape(array: np.ndarray, batched: bool) -> np.ndarray: """Remove all `array.shape` axes with a value of 1 & add a batch dimension if needed. Example 1: input_array.shape = (10, 1, 5) batched = False output_array.shape = (1, 10, 5) # batch axis is added Example 2: input_array.shape = (1, 100, 1, 1, 3) batched = True output_array.shape = (1, 100, 3) # batch axis is preserved Args: array (np.ndarray): Array that will have it's shape normalized/batchified. batched (bool): If True, `array.shape[0]` is assumed to be the batch axis. If False, an axis will be added such that `array.shape[0] == 1`. Returns: np.ndarray: Array with a guarenteed batch dimension. `out_array.shape[1:]` will always be > 1. `out_array.shape[0]` may be >= 1. """ if batched: # Don't squeeze the primary axis, even if it's 1 squeeze_axes = tuple(i + 1 for i, s in enumerate(array.shape[1:]) if s == 1) array = array.squeeze(squeeze_axes) else: array = array.squeeze() array = np.expand_dims(array, axis=0) # If we squeezed everything except the primary axis, append one dimension of length 1 if len(array.shape) == 1: array = np.expand_dims(array, axis=1) return array