Source code for sleepless.data.dataset

# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import annotations

import csv
import importlib.abc
import json
import logging
import os
import pathlib
import typing

from collections.abc import Mapping

from .sample import DelayedSample

logger = logging.getLogger(__name__)


[docs] class JSONDataset: """Generic multi-protocol/subset filelist dataset that yields samples. To create a new dataset, you need to provide one or more JSON formatted filelists (one per protocol) with the following contents: .. code-block:: json { "subset1": [ [ "value1", "value2", "value3" ], [ "value4", "value5", "value6" ] ], "subset2": [ ] } Your dataset many contain any number of subsets, but all sample entries must contain the same number of fields. :param protocols: Paths to one or more JSON formatted files containing the various protocols to be recognized by this dataset, or a dictionary, mapping protocol names to paths (or opened file objects) of CSV files. Internally, we save a dictionary where keys default to the basename of paths (list input). :param fieldnames: An iterable over the field names (strings) to assign to each entry in the JSON file. It should have as many items as fields in each entry of the JSON file. :param loader: A function that receives as input, a context dictionary (with at least a "protocol" and "subset" keys indicating which protocol and subset are being served), and a dictionary with ``{fieldname: value}`` entries, and returns an object with at least 2 attributes: * ``key``: which must be a unique string for every sample across subsets in a protocol, and * ``data``: which contains the data associated witht this sample """ _protocols: dict[ str, tuple[ str | pathlib.Path | importlib.abc.Traversable, Mapping, ], ] def __init__( self, protocols: ( typing.Iterable[ tuple[ str | pathlib.Path | importlib.abc.Traversable, Mapping, ] ] | dict[ str, tuple[ str | pathlib.Path | importlib.abc.Traversable, Mapping, ], ] ), fieldnames: typing.Iterable[str], loader: typing.Callable, ) -> None: if isinstance(protocols, dict): self._protocols = protocols else: self._protocols = { os.path.splitext(os.path.basename(str(k)))[0]: k for k in protocols } self.fieldnames = fieldnames self._loader = loader
[docs] def check(self, limit: int = 0) -> int: """For each protocol, check if all data can be correctly accessed. This function assumes each sample has a ``data`` and a ``key`` attribute. The ``key`` attribute should be a string, or representable as such. :param limit: Maximum number of samples to check (in each protocol/subset combination) in this dataset. If set to zero, then check everything. :return: Number of errors during check """ logger.info("Checking dataset...") errors = 0 for proto in self._protocols: logger.info(f"Checking protocol '{proto}'...") for name, samples in self.subsets(proto).items(): logger.info(f"Checking subset '{name}'...") if limit: logger.info(f"Checking at most first '{limit}' samples...") samples = samples[:limit] for pos, sample in enumerate(samples): try: sample.data # may trigger data loading logger.info(f"{sample.key}: OK") except Exception as e: logger.error( f"Found error loading entry {pos} in subset {name} " f"of protocol {proto} from file " f"'{self._protocols[proto]}': {e}" ) errors += 1 return errors
[docs] def subsets(self, protocol: str) -> dict[str, list[DelayedSample]]: """Returns all subsets in a protocol. This method will load JSON information for a given protocol and return all subsets of the given protocol after converting each entry through the loader function. Parameters: :param protocol: Name of the protocol data to load :return: A dictionary mapping subset names to lists of objects (respecting the ``key``, ``data`` interface). """ fileobj, preproc_params = self._protocols[protocol] if isinstance( fileobj, (str, bytes, pathlib.Path, importlib.abc.Traversable) ): with open(str(fileobj)) as f: data = json.load(f) else: data = json.load(fileobj) fileobj.seek(0) csv_subset = protocol.split("-")[0] protoc_loader = self._loader(preproc_params, csv_subset, protocol) retval = {} for subset, samples in data.items(): retval[subset] = [ protoc_loader._loader( dict(protocol=protocol, subset=subset, order=n), dict(zip(self.fieldnames, k)), ) for n, k in enumerate(samples) ] return retval
[docs] class CSVDataset: """Generic multi-subset filelist dataset that yields samples. To create a new dataset, you only need to provide a CSV formatted filelist using any separator (e.g. comma, space, semi-colon) with the following information: .. code-block:: text value1,value2,value3 value4,value5,value6 ... Notice that all rows must have the same number of entries. :param subsets: Paths to one or more CSV formatted files containing the various subsets to be recognized by this dataset, or a dictionary, mapping subset names to paths (or opened file objects) of CSV files. Internally, we save a dictionary where keys default to the basename of paths (list input). :param fieldnames: An iterable over the field names (strings) to assign to each column in the CSV file. It should have as many items as fields in each row of the CSV file(s). :param loader: A function that receives as input, a context dictionary (with, at least, a "subset" key indicating which subset is being served), and a dictionary with ``{key: path}`` entries, and returns a dictionary with the loaded data. """ _subsets: dict[str, str] def __init__( self, subsets: typing.Iterable[str] | dict[str, str], fieldnames: typing.Iterable[str], loader: typing.Callable, ) -> None: if isinstance(subsets, dict): self._subsets = subsets else: self._subsets = { os.path.splitext(os.path.basename(k))[0]: k for k in subsets } self.fieldnames = fieldnames self._loader = loader
[docs] def check(self, limit: int = 0) -> float: """For each subset, check if all data can be correctly accessed. This function assumes each sample has a ``data`` and a ``key`` attribute. The ``key`` attribute should be a string, or representable as such. :param limit: Maximum number of samples to check (in each protocol/subset combination) in this dataset. If set to zero, then check everything. :return: Number of errors during check """ logger.info("Checking dataset...") errors = 0 for name in self._subsets.keys(): logger.info(f"Checking subset '{name}'...") samples = self.samples(name) if limit: logger.info(f"Checking at most first '{limit}' samples...") samples = samples[:limit] for pos, sample in enumerate(samples): try: sample.data # may trigger data loading logger.info(f"{sample.key}: OK") except Exception as e: logger.error( f"Found error loading entry {pos} in subset {name} " f"from file '{self._subsets[name]}': {e}" ) errors += 1 return errors
[docs] def subsets(self) -> dict[str, list[DelayedSample]]: """Returns all available subsets at once. :return: A dictionary mapping subset names to lists of objects (respecting the ``key``, ``data`` interface). """ return {k: self.samples(k) for k in self._subsets.keys()}
[docs] def samples(self, subset: str) -> list[DelayedSample]: """Returns all samples in a subset. This method will load CSV information for a given subset and return all samples of the given subset after passing each entry through the loading function. :param subset: Name of the subset data to load :return: A lists of objects (respecting the ``key``, ``data`` interface). """ fileobj = self._subsets[subset] if isinstance(fileobj, (str, bytes, pathlib.Path)): with open(self._subsets[subset], newline="") as f: cf = csv.reader(f) samples = [k for k in cf] else: cf = csv.reader(fileobj) samples = [k for k in cf] fileobj.seek(0) return [ self._loader( dict(subset=subset, order=n), dict(zip(self.fieldnames, k)) ) for n, k in enumerate(samples) ]