import os.path
import pickle
from abc import ABC, abstractmethod
import numpy as np
import pandas as pd
from cvkit import MAGIC_NUMBER
from cvkit.pose_estimation import Skeleton, Part
[docs]class DataStoreInterface(ABC):
"""
Interface for data reader. This class can be implemented to integrate data files from other toolkits.
:param body_parts: list of column names
:param path: path to data file
:param dimension: data dimension
"""
#: Acts as an ID for the class. File with "CVKit3D" indicates that the file should be opened with :py:class:`CVKitDataStore3D`
FLAVOR = "Abstract"
#: Dimension of the data
DIMENSIONS = 3
#: File Separator
SEP = ','
#: Character to separate multiple behaviours
BEHAVIOUR_SEP = '~'
#: Magic number to represent invalid data
MAGIC_NUMBER = MAGIC_NUMBER
def __init__(self, body_parts, path, dimension=3):
self.body_parts = body_parts
self.data = None
self.path = path
self.base_file_path = os.path.splitext(self.path)[0] if self.path is not None else None
self.DIMENSIONS = dimension
try:
self.stats: DataStoreStats = pickle.load(open(f'{self.base_file_path}_stats.bin', 'rb'))
except:
self.stats = DataStoreStats(body_parts)
[docs] def get_skeleton(self, index) -> Skeleton:
"""
Generates and return skeleton object for the frame defined by index.
:param index: The index number pointing to the data corresponding to the frame.
:return: :py:class:`Skeleton` object
"""
if index in self.data.index:
return self.build_skeleton(self.data.loc[index])
else:
return self.build_empty_skeleton()
[docs] def set_skeleton(self, index, skeleton: Skeleton, force_insert=False) -> None:
"""
Set pose data from :py:class`~cvkit.pose_estimation.skeleton.Skeleton` object at given index.
:param index: The index at which the data will be inserted.
:param skeleton: :py:class:`Skeleton` object containing the pose data.
:param force_insert: By pass index validation
"""
insert = True
if not force_insert and index not in self.data.index:
insert = False
# Insert only if any part has valid data
for part in self.body_parts:
if skeleton[part] > 0:
insert = True
break
if insert or force_insert:
for part in self.body_parts:
self.set_part(index, skeleton[part])
self.set_behaviour(index, skeleton.behaviour)
[docs] def get_numpy(self, index):
"""
Generate nxd Numpy array from given index where n is number of body parts and d is the dimension. The order of data follows :attr:`.DataStoreInterface.body_parts`.
:param index: The index from which data will be retrieved.
:return: nxd Numpy array
"""
s = self.get_skeleton(index)
arr = [np.array(s[part]) for part in self.body_parts]
return np.array(arr)
[docs] def delete_skeleton(self, index):
"""
Deletes data at location pointed by the index.
:param index: The index of the data to be deleted
"""
if index in self.data.index:
for part in self.body_parts:
self.delete_part(index, part, True)
[docs] @abstractmethod
def set_behaviour(self, index, behaviour: list[str]) -> None:
"""
Set behaviour data for current index.
:param index: The index where behaviour data will be inserted.
:param behaviour: List of behaviours
"""
pass
[docs] @abstractmethod
def get_behaviour(self, index) -> list[str]:
"""
Get behaviour at given index.
:param index: The idex of the data to be retrieved.
:return: List of behaviours
"""
pass
[docs] @abstractmethod
def get_part_slice(self, slice_indices: list[int], name: str) -> np.ndarray:
"""
Get slice of data for given part as a Numpy array.
:param slice_indices: List of two integers defining starting and ending point (non-inclusive) of the slice.
:param name: Name of the body part
:return: Numpy array of dimension nxd where n is the size of the slice and d is the dimension of the data.
"""
pass
[docs] @abstractmethod
def set_part_slice(self, slice_indices: list, name: str, data: np.ndarray) -> None:
"""
Set a slice of data for given part.
:param slice_indices: List of two integers defining starting and ending point (non-inclusive) of the slice.
:param name: Name of the body part
:param data: nxd dimensional numpy array where n is the size of the slice and d is the dimension of the data.
"""
pass
[docs] def row_iterator(self):
"""
Generates and iterator which yields index and corresponding :py:class:`Skeleton` sequentially.
Example Usage:
.. highlight:: python
.. code-block:: python
for index, skeleton in data_store.row_iterator():
print(index,skeleton)
"""
for index, row in self.data.iterrows():
yield index, self.build_skeleton(row)
[docs] def part_iterator(self, part):
"""
Generates and iterator which yields index and corresponding :py:class:`Part` sequentially.
Example Usage:
.. highlight:: python
.. code-block:: python
for index, snout in data_store.part_iterator('snout'):
print(index,snout)
:param part: Target body part.
"""
for index, row in self.data[part].items():
yield index, self.build_part(row, part)
[docs] @abstractmethod
def get_part(self, index, name) -> Part:
"""
Get :py:class:`Part` object at given index.
:param index: The index from which the data will be retrieved.
:param name: Name of the target body part.
:return: :py:class:`Part`
"""
pass
[docs] @abstractmethod
def set_part(self, index, part: Part) -> None:
"""
Set :py:class:`Part` object at given index.
:param index: The index at which the data will be inserted.
:param part: :py:class:`Part` to be inserted.
"""
pass
[docs] @abstractmethod
def delete_part(self, index, name, force_remove=False):
"""
Deletes part at given index.
:param index: The index from which the part will be deleted.
:param name: Name of the target part.
:param force_remove: Bypass index validation.
"""
pass
[docs] @abstractmethod
def build_skeleton(self, row) -> Skeleton:
"""
Build skeleton from internal row representation.
:param row: row of a dataframe.
"""
pass
[docs] @abstractmethod
def build_part(self, row, name) -> Part:
"""
Build part from internal row representation
:param row: row of a dataframe
:param name: Name of the part
:return: :py:class:`Part` Object
"""
pass
[docs] def save_file(self, path: str = None) -> None:
"""
Save data to a file.
:param path: Path of the file. If None, the file will overwrite.
"""
if path is None:
path = self.path
self.data.sort_index(inplace=True)
self.data.to_csv(path, sep=self.SEP)
[docs] def set_stats(self, stats):
"""
Set datastore statistics object (:py:class:`DataStoreStats`).
:param stats: :py:class:`DataStoreStats` object
"""
if stats.register(self.compute_data_hash()):
del self.stats
self.stats = stats
pickle.dump(self.stats, open(f'{self.base_file_path}_stats.bin', 'wb'))
[docs] def build_empty_skeleton(self):
"""
Builds empty skeleton object from a pre-defined MAGIC_NUMBER.
:return: Empty :py:class:`Skeleton`
"""
part_map = {}
likelihood_map = {}
for name in self.body_parts:
part_map[name] = [MAGIC_NUMBER] * self.DIMENSIONS
likelihood_map[name] = 0.0
return Skeleton(self.body_parts, part_map=part_map, likelihood_map=likelihood_map, behaviour=[''],
dims=self.DIMENSIONS)
def __len__(self):
return len(self.data)
[docs] def compute_data_hash(self):
"""
Computes a hash value of the dataframe. Used to detect changes.
:return: hash value
"""
return int(pd.util.hash_pandas_object(self.data).sum())
[docs] def verify_stats(self):
"""
Verify whether current datastore statistics are valid.
:return: datastore statistics validity
:rtype: boolean
"""
if not (self.compute_data_hash() == self.stats.data_frame_hash) and (self.stats.body_parts == self.body_parts):
self.stats.registered = False
return False
return True
[docs] @staticmethod
@abstractmethod
def convert_to_list(index, skeleton, threshold=0.8):
"""
Generates a list of parts for :py:class:`csv.writer` module. The structure of the list depends upon output data format.
The data not crossing the threshold will not be included. Can be used to convert one data flavor to another.
:param index: Target Index
:param skeleton: Target skeleton
:param threshold: Threshold for including data.
"""
pass
def allocate(self,n):
self.data = pd.concat([self.data,pd.DataFrame(columns=self.data.columns, index=range(n))])
self.data.sort_index(inplace=True)
class DataStoreStats:
def __init__(self, body_parts):
"""
Datastore-statistics class keeping tracks of clusters of accurate and non-accurate data.
:param body_parts:
"""
self.data_frame_hash = 0
self.body_parts = body_parts
self.na_data_points = {}
self.accurate_data_points = []
self._na_current_cluster = {}
self.occupancy_data = []
for column in body_parts:
self.na_data_points[column] = []
self._na_current_cluster[column] = {'begin': -2, 'end': -2}
self._accurate_cluster = {'begin': -2, 'end': -2}
self.registered = False
def add_occupancy_data(self, fraction):
self.occupancy_data.append(fraction)
def update_cluster_info(self, index, part, accurate=False):
cluster = self._na_current_cluster[part] if not accurate else self._accurate_cluster
data_point = self.na_data_points[part] if not accurate else self.accurate_data_points
if cluster['end'] + 1 == index:
cluster['end'] = index
else:
if cluster['begin'] != -2:
data_point.append(cluster.copy())
cluster['begin'] = cluster['end'] = index
def register(self, data_frame_hash):
if not self.registered:
for col in self._na_current_cluster.keys():
if self._na_current_cluster[col]['begin'] != -2:
self.na_data_points[col].append(self._na_current_cluster[col].copy())
if self._accurate_cluster['begin'] != -2:
self.accurate_data_points.append(self._accurate_cluster.copy())
del self._na_current_cluster, self._accurate_cluster
self.data_frame_hash = data_frame_hash
self.registered = True
return True
return False
def iter_na_clusters(self, part):
for candidate in self.na_data_points[part]:
yield candidate
def iter_accurate_clusters(self):
for accurate in self.accurate_data_points:
yield accurate
def get_accurate_cluster_info(self, bin_width=20, max_bin=100):
histogram = {bucket: 0 for bucket in range(bin_width, max_bin + 1, 20)}
last_bin = list(histogram.keys())[-1]
total = 0
for cluster in self.accurate_data_points:
width = cluster['end'] - cluster['begin']
target_bin = last_bin
total += width
for key in histogram:
if width < key:
target_bin = key
break
histogram[target_bin] += 1
return len(self.accurate_data_points), histogram, total
def get_occupancy_clusters(self, min_occupancy, max_occupancy):
assert 0 <= min_occupancy <= max_occupancy <= 1.0
pose_data = []
cluster = {'begin': -2, 'end': -2}
for index, occupancy in enumerate(self.occupancy_data):
if min_occupancy <= occupancy < max_occupancy:
if cluster['end'] + 1 == index:
cluster['end'] = index
else:
if cluster['begin'] != -2:
pose_data.append(cluster.copy())
cluster['begin'] = cluster['end'] = index
if cluster['begin']!=cluster['end'] and pose_data[-1]['end']!=cluster['end']:
pose_data.append(cluster.copy())
return pose_data
def intersect_accurate_data_points(self, accurate_clusters):
output_accurate_cluster = []
source_index = 0
target_index = 0
while source_index < len(self.accurate_data_points) and target_index < len(accurate_clusters):
if self.accurate_data_points[source_index]['end'] < accurate_clusters[target_index]['begin']:
source_index += 1
continue
if self.accurate_data_points[source_index]['begin'] > accurate_clusters[target_index]['end']:
target_index += 1
continue
output_accurate_cluster.append({'begin': max(self.accurate_data_points[source_index]['begin'],
accurate_clusters[target_index]['begin']),
'end': min(self.accurate_data_points[source_index]['end'],
accurate_clusters[target_index]['end'])})
if source_index + 1 < len(self.accurate_data_points) and self.accurate_data_points[source_index + 1][
'begin'] <= accurate_clusters[target_index]['end']:
source_index += 1
else:
target_index += 1
return output_accurate_cluster