Module shoji.io.filter_io
Expand source code
from typing import Tuple, Dict, List, Union
import shoji
import shoji.io
import fdb
import numpy as np
from .enums import Compartment
@fdb.transactional
def const_compare(tr, wsm: "shoji.WorkspaceManager", name: str, operator: str, const: Tuple[int, str, float]) -> np.ndarray:
"""
Compare a tensor to a constant value, and return all indices that match
"""
# Code for range, equality and inequality filters
tensor = shoji.io.get_tensor(tr, wsm, name)
const = tensor.python_dtype()(const) # Cast the const to string, float or int
index = wsm._subdir[Compartment.TensorIndex][name]
eq_range = index[const].range()
all_range = index.range()
start, stop = all_range.start, all_range.stop
if operator == "!=":
stop = tr.get_key(fdb.KeySelector.last_less_than(eq_range.start))
a = np.array([index.unpack(k)[1] for k, _ in tr[start:stop]], dtype="int64")
start = tr.get_key(fdb.KeySelector.first_greater_than(eq_range.stop))
stop = all_range.stop
b = np.array([index.unpack(k)[1] for k, _ in tr[start:stop]], dtype="int64")
return np.concatenate([a, b])
if operator == "==":
start, stop = eq_range.start, eq_range.stop
elif operator == ">=":
start = eq_range.start
elif operator == ">":
start = tr.get_key(fdb.KeySelector.first_greater_than(eq_range.stop))
elif operator == "<=":
stop = eq_range.stop
elif operator == "<":
stop = tr.get_key(fdb.KeySelector.last_less_than(eq_range.start))
return np.array([index.unpack(k)[1] for k, _ in tr[start:stop]], dtype="int64")
def const_compare_non_transactional(wsm: "shoji.WorkspaceManager", name: str, operator: str, const: Tuple[int, str, float]) -> np.ndarray:
"""
Compare a tensor to a constant value, and return all indices that match
"""
# Code for range, equality and inequality filters
tensor = wsm._get_tensor(name)
const = tensor.python_dtype()(const) # Cast the const to string, float or int
index = wsm._subdir[Compartment.TensorIndex][name]
eq_range = index[const].range()
all_range = index.range()
start, stop = all_range.start, all_range.stop
tr = wsm._db.transaction # This will typically (except inside a Transaction scope) be set to the database, so that each time it's used it will create a separeate transaction
if operator == "!=":
stop = tr.get_key(fdb.KeySelector.last_less_than(eq_range.start))
a = np.array([index.unpack(k)[1] for k, _ in tr[start:stop]], dtype="int64")
start = tr.get_key(fdb.KeySelector.first_greater_than(eq_range.stop))
stop = all_range.stop
b = np.array([index.unpack(k)[1] for k, _ in tr[start:stop]], dtype="int64")
return np.concatenate([a, b])
if operator == "==":
start, stop = eq_range.start, eq_range.stop
elif operator == ">=":
start = eq_range.start
elif operator == ">":
start = tr.get_key(fdb.KeySelector.first_greater_than(eq_range.stop))
elif operator == "<=":
stop = eq_range.stop
elif operator == "<":
stop = tr.get_key(fdb.KeySelector.last_less_than(eq_range.start))
tr = wsm._db.create_transaction()
n = 100_000
result = []
next_start = b''
while start < stop:
try:
temp = []
for k, _ in tr.get_range(start, stop, limit=n):
temp.append(index.unpack(k)[1])
next_start = tr.get_key(fdb.KeySelector.first_greater_than(k)).value
result += temp
except fdb.impl.FDBError as e:
if e.code in (1004, 1007, 1031, 2101) and n > 1: # Too many bytes or too long time, so try again with less
n = max(1, n // 2)
tr = wsm._db.create_transaction()
continue
else:
raise e
start = next_start
return np.array(result, dtype="int64")
def get_filtered_indices(wsm: "shoji.WorkspaceManager", tensor: "shoji.Tensor", filters: List["shoji.Filter"], axis: int, n_rows: int) -> np.ndarray:
indices = None
for f in filters:
if isinstance(tensor.dims[axis], str) and tensor.dims[axis] == f.dim:
indices = np.sort(f.get_rows(wsm))
elif isinstance(f, (shoji.TensorBoolFilter, shoji.TensorIndicesFilter, shoji.TensorSliceFilter)) and f.tensor.name == tensor.name and f.axis == axis:
indices = np.sort(f.get_rows(wsm, n_rows))
if indices is None:
indices = np.arange(n_rows)
return indices
def read_filtered(wsm: "shoji.WorkspaceManager", name: str, filters: List["shoji.Filter"]) -> Union[np.ndarray, List[np.ndarray]]:
tensor = wsm._get_tensor(name)
subspace = wsm._subdir
if tensor.jagged:
rows = get_filtered_indices(wsm, tensor, filters, 0, tensor.shape[0])
result = []
for row in rows:
row_shape = fdb.tuple.unpack(wsm._db.transaction[subspace.pack((Compartment.TensorRowShapes, name, int(row)))])
indices = [np.array([row])]
for axis in range(1, tensor.rank):
indices.append(get_filtered_indices(wsm, tensor, filters, axis, row_shape[axis - 1]))
result.append(shoji.io.read_at_indices(wsm, name, indices, tensor.chunks, False))
return result
else:
indices = [get_filtered_indices(wsm, tensor, filters, i, tensor.shape[i]) for i in range(tensor.rank)]
return shoji.io.read_at_indices(wsm, name, indices, tensor.chunks, False)
@fdb.transactional
def write_filtered(tr: fdb.impl.Transaction, wsm: "shoji.WorkspaceManager", name: str, vals: Union[np.ndarray, List[np.ndarray]], filters: List["shoji.Filter"]) -> None:
tensor: shoji.Tensor = wsm._get_tensor(name)
subspace = wsm._subdir
assert isinstance(vals, (np.ndarray, list, tuple)), f"Value assigned to '{name}' is not a numpy array or a list or tuple of numpy arrays"
if tensor.jagged:
rows = get_filtered_indices(wsm, tensor, filters, 0, tensor.shape[0])
for row in rows:
row_shape = fdb.tuple.unpack(wsm._db.transaction[subspace.pack((Compartment.TensorRowShapes, name, int(row)))])
indices = [np.array([row])]
for axis in range(1, tensor.rank):
indices.append(get_filtered_indices(wsm, tensor, filters, axis, row_shape[axis - 1]))
shoji.io.write_at_indices(tr, wsm, (Compartment.TensorValues, name), indices, tensor.chunks, vals[row])
else:
indices = [get_filtered_indices(wsm, tensor, filters, i, tensor.shape[i]) for i in range(tensor.rank)]
shoji.io.write_at_indices(tr, wsm, (Compartment.TensorValues, name), indices, tensor.chunks, vals)
Functions
def const_compare(tr, wsm: shoji.WorkspaceManager, name: str, operator: str, const: Tuple[int, str, float]) ‑> numpy.ndarray
-
Compare a tensor to a constant value, and return all indices that match
Expand source code
@fdb.transactional def const_compare(tr, wsm: "shoji.WorkspaceManager", name: str, operator: str, const: Tuple[int, str, float]) -> np.ndarray: """ Compare a tensor to a constant value, and return all indices that match """ # Code for range, equality and inequality filters tensor = shoji.io.get_tensor(tr, wsm, name) const = tensor.python_dtype()(const) # Cast the const to string, float or int index = wsm._subdir[Compartment.TensorIndex][name] eq_range = index[const].range() all_range = index.range() start, stop = all_range.start, all_range.stop if operator == "!=": stop = tr.get_key(fdb.KeySelector.last_less_than(eq_range.start)) a = np.array([index.unpack(k)[1] for k, _ in tr[start:stop]], dtype="int64") start = tr.get_key(fdb.KeySelector.first_greater_than(eq_range.stop)) stop = all_range.stop b = np.array([index.unpack(k)[1] for k, _ in tr[start:stop]], dtype="int64") return np.concatenate([a, b]) if operator == "==": start, stop = eq_range.start, eq_range.stop elif operator == ">=": start = eq_range.start elif operator == ">": start = tr.get_key(fdb.KeySelector.first_greater_than(eq_range.stop)) elif operator == "<=": stop = eq_range.stop elif operator == "<": stop = tr.get_key(fdb.KeySelector.last_less_than(eq_range.start)) return np.array([index.unpack(k)[1] for k, _ in tr[start:stop]], dtype="int64")
def const_compare_non_transactional(wsm: shoji.WorkspaceManager, name: str, operator: str, const: Tuple[int, str, float]) ‑> numpy.ndarray
-
Compare a tensor to a constant value, and return all indices that match
Expand source code
def const_compare_non_transactional(wsm: "shoji.WorkspaceManager", name: str, operator: str, const: Tuple[int, str, float]) -> np.ndarray: """ Compare a tensor to a constant value, and return all indices that match """ # Code for range, equality and inequality filters tensor = wsm._get_tensor(name) const = tensor.python_dtype()(const) # Cast the const to string, float or int index = wsm._subdir[Compartment.TensorIndex][name] eq_range = index[const].range() all_range = index.range() start, stop = all_range.start, all_range.stop tr = wsm._db.transaction # This will typically (except inside a Transaction scope) be set to the database, so that each time it's used it will create a separeate transaction if operator == "!=": stop = tr.get_key(fdb.KeySelector.last_less_than(eq_range.start)) a = np.array([index.unpack(k)[1] for k, _ in tr[start:stop]], dtype="int64") start = tr.get_key(fdb.KeySelector.first_greater_than(eq_range.stop)) stop = all_range.stop b = np.array([index.unpack(k)[1] for k, _ in tr[start:stop]], dtype="int64") return np.concatenate([a, b]) if operator == "==": start, stop = eq_range.start, eq_range.stop elif operator == ">=": start = eq_range.start elif operator == ">": start = tr.get_key(fdb.KeySelector.first_greater_than(eq_range.stop)) elif operator == "<=": stop = eq_range.stop elif operator == "<": stop = tr.get_key(fdb.KeySelector.last_less_than(eq_range.start)) tr = wsm._db.create_transaction() n = 100_000 result = [] next_start = b'' while start < stop: try: temp = [] for k, _ in tr.get_range(start, stop, limit=n): temp.append(index.unpack(k)[1]) next_start = tr.get_key(fdb.KeySelector.first_greater_than(k)).value result += temp except fdb.impl.FDBError as e: if e.code in (1004, 1007, 1031, 2101) and n > 1: # Too many bytes or too long time, so try again with less n = max(1, n // 2) tr = wsm._db.create_transaction() continue else: raise e start = next_start return np.array(result, dtype="int64")
def get_filtered_indices(wsm: shoji.WorkspaceManager, tensor: shoji.Tensor, filters: List[ForwardRef('shoji.Filter')], axis: int, n_rows: int) ‑> numpy.ndarray
-
Expand source code
def get_filtered_indices(wsm: "shoji.WorkspaceManager", tensor: "shoji.Tensor", filters: List["shoji.Filter"], axis: int, n_rows: int) -> np.ndarray: indices = None for f in filters: if isinstance(tensor.dims[axis], str) and tensor.dims[axis] == f.dim: indices = np.sort(f.get_rows(wsm)) elif isinstance(f, (shoji.TensorBoolFilter, shoji.TensorIndicesFilter, shoji.TensorSliceFilter)) and f.tensor.name == tensor.name and f.axis == axis: indices = np.sort(f.get_rows(wsm, n_rows)) if indices is None: indices = np.arange(n_rows) return indices
def read_filtered(wsm: shoji.WorkspaceManager, name: str, filters: List[ForwardRef('shoji.Filter')]) ‑> Union[numpy.ndarray, List[numpy.ndarray]]
-
Expand source code
def read_filtered(wsm: "shoji.WorkspaceManager", name: str, filters: List["shoji.Filter"]) -> Union[np.ndarray, List[np.ndarray]]: tensor = wsm._get_tensor(name) subspace = wsm._subdir if tensor.jagged: rows = get_filtered_indices(wsm, tensor, filters, 0, tensor.shape[0]) result = [] for row in rows: row_shape = fdb.tuple.unpack(wsm._db.transaction[subspace.pack((Compartment.TensorRowShapes, name, int(row)))]) indices = [np.array([row])] for axis in range(1, tensor.rank): indices.append(get_filtered_indices(wsm, tensor, filters, axis, row_shape[axis - 1])) result.append(shoji.io.read_at_indices(wsm, name, indices, tensor.chunks, False)) return result else: indices = [get_filtered_indices(wsm, tensor, filters, i, tensor.shape[i]) for i in range(tensor.rank)] return shoji.io.read_at_indices(wsm, name, indices, tensor.chunks, False)
def write_filtered(tr: fdb.impl.Transaction, wsm: shoji.WorkspaceManager, name: str, vals: Union[numpy.ndarray, List[numpy.ndarray]], filters: List[ForwardRef('shoji.Filter')]) ‑> NoneType
-
Expand source code
@fdb.transactional def write_filtered(tr: fdb.impl.Transaction, wsm: "shoji.WorkspaceManager", name: str, vals: Union[np.ndarray, List[np.ndarray]], filters: List["shoji.Filter"]) -> None: tensor: shoji.Tensor = wsm._get_tensor(name) subspace = wsm._subdir assert isinstance(vals, (np.ndarray, list, tuple)), f"Value assigned to '{name}' is not a numpy array or a list or tuple of numpy arrays" if tensor.jagged: rows = get_filtered_indices(wsm, tensor, filters, 0, tensor.shape[0]) for row in rows: row_shape = fdb.tuple.unpack(wsm._db.transaction[subspace.pack((Compartment.TensorRowShapes, name, int(row)))]) indices = [np.array([row])] for axis in range(1, tensor.rank): indices.append(get_filtered_indices(wsm, tensor, filters, axis, row_shape[axis - 1])) shoji.io.write_at_indices(tr, wsm, (Compartment.TensorValues, name), indices, tensor.chunks, vals[row]) else: indices = [get_filtered_indices(wsm, tensor, filters, i, tensor.shape[i]) for i in range(tensor.rank)] shoji.io.write_at_indices(tr, wsm, (Compartment.TensorValues, name), indices, tensor.chunks, vals)