From 63c9f46e9633b98a093eede5d5c139f6bb47a5b3 Mon Sep 17 00:00:00 2001 From: Kendrick Boyd Date: Wed, 10 Apr 2024 10:50:49 -0600 Subject: [PATCH] RDS-37: Support variable length sequences in dgan Co-authored-by: Piotr Mlocek GitOrigin-RevId: 767d2da3a06f5352b7ca7371fc27e80614e23804 --- src/gretel_synthetics/timeseries_dgan/dgan.py | 441 +++++++++----- .../timeseries_dgan/transformations.py | 307 +++++----- tests/timeseries_dgan/test_dgan.py | 558 ++++++++++++++---- tests/timeseries_dgan/test_transformations.py | 78 ++- 4 files changed, 977 insertions(+), 407 deletions(-) diff --git a/src/gretel_synthetics/timeseries_dgan/dgan.py b/src/gretel_synthetics/timeseries_dgan/dgan.py index 95374e3c..e362356d 100644 --- a/src/gretel_synthetics/timeseries_dgan/dgan.py +++ b/src/gretel_synthetics/timeseries_dgan/dgan.py @@ -50,7 +50,9 @@ import logging import math -from typing import Callable, Dict, List, Optional, Tuple +from collections import Counter +from itertools import cycle +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -65,19 +67,16 @@ from gretel_synthetics.timeseries_dgan.transformations import ( create_additional_attribute_outputs, create_outputs_from_data, - inverse_transform, + inverse_transform_attributes, + inverse_transform_features, Output, - transform, + transform_attributes, + transform_features, ) logger = logging.getLogger(__name__) -logging.basicConfig( - format="%(asctime)s : %(threadName)s : %(levelname)s : %(message)s", - level=logging.INFO, -) - -AttributeFeaturePair = Tuple[Optional[np.ndarray], np.ndarray] +AttributeFeaturePair = Tuple[Optional[np.ndarray], list[np.ndarray]] NumpyArrayTriple = Tuple[np.ndarray, np.ndarray, np.ndarray] NAN_ERROR_MESSAGE = """ @@ -174,7 +173,7 @@ def __init__( def train_numpy( self, - features: np.ndarray, + features: Union[np.ndarray, list[np.ndarray]], feature_types: Optional[List[OutputType]] = None, attributes: Optional[np.ndarray] = None, attribute_types: Optional[List[OutputType]] = None, @@ -183,11 +182,13 @@ def train_numpy( """Train DGAN model on data in numpy arrays. Training data is passed in 2 numpy arrays, one for attributes (2d) and - one for features (3d). This data should be in the original space and is - not transformed. If the data is already transformed into the internal - DGAN representation (continuous variable scaled to [0,1] or [-1,1] and - discrete variables one-hot or binary encoded), use the internal _train() - function instead of train_numpy(). + one for features (3d), features may be a ragged array with variable + length sequences, and then it is a list of numpy arrays. This data + should be in the original space and is not transformed. If the data is + already transformed into the internal DGAN representation (continuous + variable scaled to [0,1] or [-1,1] and discrete variables one-hot or + binary encoded), use the internal _train() function instead of + train_numpy(). In standard usage, attribute_types and feature_types may be provided on the first call to train() to setup the model structure. If not @@ -199,7 +200,9 @@ def train_numpy( Args: features: 3-d numpy array of time series features for the training, size is (# of training examples) X max_sequence_len X (# of - features) + features) OR list of 2-d numpy arrays with one sequence per + numpy array, each numpy array should then have size seq_len X (# + of features) where seq_len <= max_sequence_len feature_types (Optional): Specification of Discrete or Continuous type for each variable of the features. If None, assume continuous variables for floats and integers, and discrete for @@ -215,8 +218,17 @@ def train_numpy( passing *output params at initialization or because train_* was called previously. """ + # To make the rest of code simpler, ensure features is split into a a + # list of 2d numpy arrays, one element per sequence. That representation + # is basically needed for variable length sequences, and even for fixed + # length sequences, we switch to that for easier code. If we're looking + # for efficiency and memory improvements in the future, better handling + # of these objects is a good place to start. + if isinstance(features, np.ndarray): + features = [seq for seq in features] + logging.info( - f"features shape={features.shape}, dtype={features.dtype}", + f"features length={len(features)}, first sequence shape={features[0].shape}, dtype={features[0].dtype}", extra={"user_log": True}, ) if attributes is not None: @@ -226,21 +238,11 @@ def train_numpy( ) if attributes is not None: - if attributes.shape[0] != features.shape[0]: + if attributes.shape[0] != len(features): raise InternalError( "First dimension of attributes and features must be the same length, i.e., the number of training examples." # noqa ) - if features.shape[1] != self.config.max_sequence_len: - raise ParameterError( - "The time dimension of the training data " - f"({features.shape[1]}) does not match " - "config.max_sequence_len " - f"({self.config.max_sequence_len}). This often happens when " - "the chosen config.df_style of wide or long does not match " - "the input data." - ) - if attributes is not None and attribute_types is None: # Automatically determine attribute types attribute_types = [] @@ -255,6 +257,10 @@ def train_numpy( # behavior. And we can look into a better fix in the future, # maybe using # of distinct values, and having an explicit # integer type so we appropriately round the final output. + + # This snippet is only detecting types to construct + # feature_types, not making any changes to elements of + # features. attributes[:, i].astype("float") attribute_types.append(OutputType.CONTINUOUS) except ValueError: @@ -263,11 +269,16 @@ def train_numpy( if feature_types is None: # Automatically determine feature types feature_types = [] - for i in range(features.shape[2]): + for i in range(features[0].shape[1]): try: # Here we treat integer columns as continuous, see above # comment. - features[:, :, i].astype("float") + + # This snippet is only detecting types to construct + # feature_types, not making any changes to elements of + # features. + for seq in features: + seq[:, i].astype("float") feature_types.append(OutputType.CONTINUOUS) except ValueError: feature_types.append(OutputType.DISCRETE) @@ -310,11 +321,12 @@ def train_numpy( # Find valid examples based on minimal number of nans. valid_examples = validation_check( - features[:, :, continuous_features_ind].astype("float") + features, + continuous_features_ind, ) # Only use valid examples for the entire dataset. - features = features[valid_examples] + features = [seq for valid, seq in zip(valid_examples, features) if valid] if attributes is not None: attributes = attributes[valid_examples] @@ -324,37 +336,36 @@ def train_numpy( ) # Apply linear interpolations to replace nans for continuous # features: - features[:, :, continuous_features_ind] = nan_linear_interpolation( - features[:, :, continuous_features_ind].astype("float") - ) + nan_linear_interpolation(features, continuous_features_ind) logger.info("Creating encoded array of features", extra={"user_log": True}) - if self.additional_attribute_outputs: - ( - internal_features, - internal_additional_attributes, - ) = transform(features, self.feature_outputs, variable_dim_index=2) + ( + internal_features, + internal_additional_attributes, + ) = transform_features( + features, self.feature_outputs, self.config.max_sequence_len + ) + if internal_additional_attributes is not None: if np.any(np.isnan(internal_additional_attributes)): raise InternalError( f"NaN found in internal additional attributes. {NAN_ERROR_MESSAGE}" ) - else: - internal_features = transform( - features, self.feature_outputs, variable_dim_index=2 - ) - internal_additional_attributes = torch.Tensor( - np.full((internal_features.shape[0], 1), np.nan) + internal_additional_attributes = np.full( + (internal_features.shape[0], 1), np.nan, dtype=np.float32 ) logger.info("Creating encoded array of attributes", extra={"user_log": True}) - internal_attributes = transform( - attributes, - self.attribute_outputs, - variable_dim_index=1, - num_examples=internal_features.shape[0], - ) + if attributes is not None and self.attribute_outputs is not None: + internal_attributes = transform_attributes( + attributes, + self.attribute_outputs, + ) + else: + internal_attributes = np.full( + (internal_features.shape[0], 1), np.nan, dtype=np.float32 + ) logger.info( f"internal_features shape={internal_features.shape}, dtype={internal_features.dtype}", @@ -498,6 +509,7 @@ def train_dataframe( self.data_frame_converter = _LongDataFrameConverter.create( df, + max_sequence_len=self.config.max_sequence_len, attribute_columns=attribute_columns, feature_columns=feature_columns, example_id_column=example_id_column, @@ -596,16 +608,25 @@ def generate_numpy( internal_features, ) = internal_data - attributes = inverse_transform( - internal_attributes, self.attribute_outputs, variable_dim_index=1 - ) + attributes = None + if internal_attributes is not None and self.attribute_outputs is not None: + attributes = inverse_transform_attributes( + internal_attributes, + self.attribute_outputs, + ) + + if internal_features is None: + raise InternalError( + "Received None instead of internal features numpy array" + ) - features = inverse_transform( + features = inverse_transform_features( internal_features, self.feature_outputs, - variable_dim_index=2, additional_attributes=internal_additional_attributes, ) + # Convert to list of numpy arrays to match primary input to train_numpy + features = [seq for seq in features] if n is not None: if attributes is None: @@ -1188,13 +1209,15 @@ def convert(self, df: pd.DataFrame) -> AttributeFeaturePair: @abc.abstractmethod def invert( - self, attributes: Optional[np.ndarray], features: np.ndarray + self, + attributes: Optional[np.ndarray], + features: list[np.ndarray], ) -> pd.DataFrame: """Invert from DGAN input format back to DataFrame. Args: attributes: 2d numpy array of attributes - features: 3d numpy array of features + features: list of 2d numpy arrays Returns: DataFrame representing attributes and features in original format. @@ -1327,10 +1350,10 @@ def convert(self, df: pd.DataFrame) -> AttributeFeaturePair: features = np.expand_dims(df[self._feature_columns].to_numpy(), axis=-1) - return attributes, features + return attributes, [seq for seq in features] def invert( - self, attributes: Optional[np.ndarray], features: np.ndarray + self, attributes: Optional[np.ndarray], features: list[np.ndarray] ) -> pd.DataFrame: if self._attribute_columns: if attributes is None: @@ -1338,11 +1361,11 @@ def invert( "Data converter with attribute columns expects attributes array, received None" ) data = np.concatenate( - (attributes, features.reshape(features.shape[0], features.shape[1])), + (attributes, np.vstack([seq.reshape((1, -1)) for seq in features])), axis=1, ) else: - data = features.reshape(features.shape[0], features.shape[1]) + data = np.vstack([seq.reshape((1, -1)) for seq in features]) df = pd.DataFrame(data, columns=self._attribute_columns + self._feature_columns) @@ -1363,6 +1386,32 @@ def _state_dict(self) -> Dict: } +def _add_generation_flag( + sequence: np.ndarray, generation_flag_index: int +) -> np.ndarray: + """Adds column indicating continuing and end time points in sequence. + + Args: + sequence: 2-d numpy array of a single sequence + generation_flag_index: index of column to insert + + Returns: + New array including the generation flag column + """ + # Generation flag is all True + flag_column = np.full((sequence.shape[0], 1), True) + # except last value is False to indicate the end of the sequence + flag_column[-1, 0] = False + + return np.hstack( + ( + sequence[:, :generation_flag_index], + flag_column, + sequence[:, generation_flag_index:], + ) + ) + + class _LongDataFrameConverter(_DataFrameConverter): """Convert "long" format DataFrames. @@ -1381,6 +1430,7 @@ def __init__( attribute_types: List[OutputType], feature_types: List[OutputType], time_column_values: Optional[List[str]], + generation_flag_index: Optional[int] = None, ): super().__init__() self._attribute_columns = attribute_columns @@ -1392,11 +1442,13 @@ def __init__( self._attribute_types = attribute_types self._feature_types = feature_types self._time_column_values = time_column_values + self._generation_flag_index = generation_flag_index @classmethod def create( cls, df: pd.DataFrame, + max_sequence_len: int, attribute_columns: Optional[List[str]] = None, feature_columns: Optional[List[str]] = None, example_id_column: Optional[str] = None, @@ -1480,6 +1532,27 @@ def create( else: time_column_values = None + # generation_flag_index is the index in feature_types (and thus + # features) of the boolean variable indicating the end of sequence. + # generation_flag_index=None means there are no variable length + # sequences, so the indicator variable is not needed and no boolean + # feature is added. + generation_flag_index = None + if example_id_column: + id_counter = Counter(df[example_id_column]) + has_variable_length_sequences = False + for item in id_counter.most_common(): + if item[1] > max_sequence_len: + raise DataError( + f"Found sequence with length {item[1]}, longer than max_sequence_len={max_sequence_len}" + ) + elif item[1] < max_sequence_len: + has_variable_length_sequences = True + + if has_variable_length_sequences: + generation_flag_index = len(feature_types) + feature_types.append(OutputType.DISCRETE) + return cls( attribute_columns=attribute_columns, feature_columns=feature_columns, @@ -1490,6 +1563,7 @@ def create( attribute_types=attribute_types, feature_types=feature_types, time_column_values=time_column_values, + generation_flag_index=generation_flag_index, ) @property @@ -1515,13 +1589,10 @@ def convert(self, df: pd.DataFrame) -> AttributeFeaturePair: # Use example_id_column to split into separate time series df_features = sorted_df[self._feature_columns] - features = np.stack( - list( - df_features.groupby(sorted_df[self._example_id_column]).apply( - pd.DataFrame.to_numpy - ) - ), - axis=0, + features = list( + df_features.groupby(sorted_df[self._example_id_column]).apply( + pd.DataFrame.to_numpy + ) ) if self._attribute_columns: @@ -1570,9 +1641,7 @@ def custom_max(a): else: # No example_id column provided to create multiple examples, so we # create one example from all time points. - features = np.expand_dims( - sorted_df[self._feature_columns].to_numpy(), axis=0 - ) + features = [sorted_df[self._feature_columns].to_numpy()] # Check that attributes are the same for all rows (since they are # all implicitly in the same example) @@ -1589,60 +1658,102 @@ def custom_max(a): else: attributes = None + if self._generation_flag_index is not None: + features = [ + _add_generation_flag(seq, self._generation_flag_index) + for seq in features + ] return attributes, features def invert( - self, attributes: Optional[np.ndarray], features: np.ndarray + self, + attributes: Optional[np.ndarray], + features: list[np.ndarray], ) -> pd.DataFrame: - num_examples = features.shape[0] - num_time_points = features.shape[1] - num_features = features.shape[2] - - if num_features != len(self._feature_columns): - raise InternalError( - "Unable to invert features back to data frame, " - + f"converter expected {len(self._feature_columns)} features, " - + f"received numpy array with {features.shape[2]}" - ) + sequences = [] + for seq_index, seq in enumerate(features): + if self._generation_flag_index is not None: + # Remove generation flag and truncate sequences based on the values. + # The first value of False in the generation flag indicates the last + # time point. + try: + first_false = np.min( + np.argwhere(seq[:, self._generation_flag_index] == False) + ) + # Include the time point with the first False in generation + # flag + seq = seq[: (first_false + 1), :] + except ValueError: + # No False found in generation flag column, use all time + # points + pass - # Reshape so each time point is its own row in a 2d array - long_features = features.reshape(-1, num_features) + # Remove the generation flag column + seq = np.delete(seq, self._generation_flag_index, axis=1) - if self._attribute_columns: - if attributes is None: + if seq.shape[1] != len(self._feature_columns): raise InternalError( - "Data converter with attribute columns expects attributes array, received None" + "Unable to invert features back to data frame, " + + f"converter expected {len(self._feature_columns)} features, " + + f"received numpy array with {seq.shape[1]}" ) - # Repeat attribute rows for every time point in each example - long_attributes = np.repeat(attributes, num_time_points, axis=0) - df = pd.DataFrame( - np.hstack((long_attributes, long_features)), - columns=self._attribute_columns + self._feature_columns, - ) - else: - df = pd.DataFrame( - long_features, - columns=self._feature_columns, - ) + seq_column_parts = [seq] + if self._attribute_columns: + if attributes is None: + raise InternalError( + "Data converter with attribute columns expects attributes array, received None" + ) + seq_attributes = np.repeat( + attributes[seq_index : (seq_index + 1), :], seq.shape[0], axis=0 + ) + seq_column_parts.append(seq_attributes) - # Convert discrete columns to int where possible. - df = _discrete_cols_to_int(df, self._discrete_columns) + if self._example_id_column: + # TODO: match example_id style of original data somehow + seq_column_parts.append(np.full((seq.shape[0], 1), seq_index)) - if self._example_id_column: - # Use [0,1,2,...] for example_id - # This may not match the style of the originally converted data - df[self._example_id_column] = np.repeat( - range(num_examples), num_time_points - ) + if self._time_column: + if self._time_column_values is None: + raise InternalError( + "time_column is present, but not time_column_values" + ) + # TODO: do something better if time_column_values isn't long + # enough, for now we just repeat the time column values + values = [ + v + for _, v in zip( + range(seq.shape[0]), cycle(self._time_column_values) + ) + ] + seq_column_parts.append(np.array(values).reshape((-1, 1))) + + sequences.append(np.hstack(seq_column_parts)) + column_names = self._feature_columns + self._attribute_columns + + if self._example_id_column: + column_names.append(self._example_id_column) if self._time_column: - if self._time_column_values is None: - raise InternalError( - "time_column is present, but not time_column_values" - ) + column_names.append(self._time_column) + + df = pd.DataFrame(np.vstack(sequences), columns=column_names) - df[self._time_column] = np.tile(self._time_column_values, num_examples) + for c in df.columns: + try: + df[c] = df[c].astype("float64") + except ValueError: + continue + except TypeError: + continue + + # Convert discrete columns to int where possible. + df = _discrete_cols_to_int( + df, + (self._discrete_columns), + ) + if self._example_id_column: + df = _discrete_cols_to_int(df, [self._example_id_column]) return df[self._df_column_order] @@ -1657,6 +1768,7 @@ def _state_dict(self) -> Dict: "attribute_types": self._attribute_types, "feature_types": self._feature_types, "time_column_values": self._time_column_values, + "generation_flag_index": self._generation_flag_index, } @@ -1666,7 +1778,7 @@ def _state_dict(self) -> Dict: } -def find_max_consecutive_nans(array: np.array) -> int: +def find_max_consecutive_nans(array: np.ndarray) -> int: """ Returns the maximum number of consecutive NaNs in an array. @@ -1685,12 +1797,13 @@ def find_max_consecutive_nans(array: np.array) -> int: def validation_check( - array: np.ndarray, + features: list[np.ndarray], + continuous_features_ind: list[int], invalid_examples_ratio_cutoff: float = 0.5, nans_ratio_cutoff: float = 0.1, consecutive_nans_max: int = 5, consecutive_nans_ratio_cutoff: float = 0.05, -) -> np.array: +) -> np.ndarray: """Checks if continuous features of examples are valid. Returns a 1-d numpy array of booleans with shape (#examples) indicating @@ -1704,39 +1817,63 @@ def validation_check( these are omitted from training. If there are too many, later, we error out. Args: - array: 3-d numpy array of continuous features with - shape (#examples,max_sequence_length, #continuous features). - invalid_examples_ratio_cutoff: Error out if the invalid examples ratio in the dataset - is higher than this value. - nans_ratio_cutoff: If the percentage of nans for any continuous feature in an example - is greater than this value, the example is invalid. - consecutive_nans_max: If the maximum number of consecutive nans in a continuous - feature is greater than this number, then that example is invalid. - consecutive_nans_ratio_cutoff: If the maximum number of consecutive nans in a - continuous feature is greater than this ratio times the length of the example - (number samples), then the example is invalid. + features: list of 2-d numpy arrays, each element is a sequence of + possibly varying length + continuous_features_ind: list of indices of continuous features to + analyze, indexes the 2nd dimension of the sequence arrays in + features + invalid_examples_ratio_cutoff: Error out if the invalid examples ratio + in the dataset is higher than this value. + nans_ratio_cutoff: If the percentage of nans for any continuous feature + in an example is greater than this value, the example is invalid. + consecutive_nans_max: If the maximum number of consecutive nans in a + continuous feature is greater than this number, then that example is + invalid. + consecutive_nans_ratio_cutoff: If the maximum number of consecutive nans + in a continuous feature is greater than this ratio times the length of + the example (number samples), then the example is invalid. Returns: - valid_examples : 1-d numpy array of booleans indicating valid examples with + valid_examples: 1-d numpy array of booleans indicating valid examples with shape (#examples). """ # Check for the nans ratio per examples and feature. # nan_ratio_feature is a 2-d numpy array of size (#examples,#features) + nan_ratio_feature = np.array( + [ + [ + np.mean(np.isnan(seq[:, ind].astype("float"))) + for ind in continuous_features_ind + ] + for seq in features + ] + ) - nan_ratio_feature = np.mean(np.isnan(array), axis=1) nan_ratio = nan_ratio_feature < nans_ratio_cutoff # Check for max number of consecutive NaN values per example and feature. # cons_nans_feature is a 2-d numpy array of size (#examples,#features) - cons_nans_feature = np.apply_along_axis(find_max_consecutive_nans, 1, array) - cons_nans = cons_nans_feature < min( - consecutive_nans_max, - max(2, int(consecutive_nans_ratio_cutoff * array.shape[1])), + cons_nans_feature = np.array( + [ + [ + find_max_consecutive_nans(seq[:, ind].astype("float")) + for ind in continuous_features_ind + ] + for seq in features + ] ) + # With examples of variable sequence length, the threshold for allowable + # consecutive nans may be different for each example. + cons_nans_threshold = np.clip( + [consecutive_nans_ratio_cutoff * seq.shape[0] for seq in features], + a_min=2, + a_max=consecutive_nans_max, + ).reshape((-1, 1)) + cons_nans = cons_nans_feature < cons_nans_threshold # The two above checks should pass for a valid example for all features, otherwise - # the example is invalid. + # the example is invalid. valid_examples_per_feature = np.logical_and(nan_ratio, cons_nans) valid_examples = np.all(valid_examples_per_feature, axis=1) @@ -1754,27 +1891,25 @@ def validation_check( return valid_examples -def nan_linear_interpolation(arrays: np.ndarray) -> np.ndarray: +def nan_linear_interpolation( + features: list[np.ndarray], continuous_features_ind: list[int] +): """Replaces all NaNs via linear interpolation. - Args: - arrays: 3-d numpy array of continuous features, with shape - (#examples, max_sequence_length, #continuous features) - - Returns: - arrays: 3-d numpy array where NaNs are replaced via - linear interpolation. + Changes numpy arrays in features in place. + Args: + features: list of 2-d numpy arrays, each element is a sequence of shape + (sequence_len, #features) + continuous_features_ind: features to apply nan interpolation to, indexes + the 2nd dimension of the sequence arrays of features """ - examples = arrays.shape[0] - features = arrays.shape[2] - - for exp in range(examples): - for f in range(features): - array = arrays[exp, :, f] - if np.isnan(array).any(): - nans = np.isnan(array) + for seq in features: + for ind in continuous_features_ind: + continuous_feature = seq[:, ind].astype("float") + is_nan = np.isnan(continuous_feature) + if is_nan.any(): ind_func = lambda z: z.nonzero()[0] # noqa - array[nans] = np.interp(ind_func(nans), ind_func(~nans), array[~nans]) - - return arrays + seq[is_nan, ind] = np.interp( + ind_func(is_nan), ind_func(~is_nan), continuous_feature[~is_nan] + ) diff --git a/src/gretel_synthetics/timeseries_dgan/transformations.py b/src/gretel_synthetics/timeseries_dgan/transformations.py index 41d3c42b..b1488a05 100644 --- a/src/gretel_synthetics/timeseries_dgan/transformations.py +++ b/src/gretel_synthetics/timeseries_dgan/transformations.py @@ -10,6 +10,7 @@ from category_encoders import BinaryEncoder, OneHotEncoder from scipy.stats import mode +from gretel_synthetics.errors import ParameterError from gretel_synthetics.timeseries_dgan.config import Normalization, OutputType @@ -332,16 +333,19 @@ def _transform(self, column: np.ndarray) -> np.ndarray: """Apply continuous variable encoding/scaling. Args: - column: numpy array + column: 1-d numpy array Returns: - numpy array of rescaled data + 2-d numpy array of rescaled data """ column = column.astype("float") + if self.apply_feature_scaling: - return rescale(column, self.normalization, self.global_min, self.global_max) + return rescale( + column, self.normalization, self.global_min, self.global_max + ).reshape((-1, 1)) else: - return column + return column.reshape((-1, 1)) def _inverse_transform(self, columns: np.ndarray) -> np.ndarray: """Invert continus variable encoding/scaling. @@ -355,14 +359,14 @@ def _inverse_transform(self, columns: np.ndarray) -> np.ndarray: if self.apply_feature_scaling: return rescale_inverse( columns, self.normalization, self.global_min, self.global_max - ) + ).flatten() else: - return columns + return columns.flatten() def create_outputs_from_data( attributes: Optional[np.ndarray], - features: np.ndarray, + features: list[np.ndarray], attribute_types: Optional[List[OutputType]], feature_types: Optional[List[OutputType]], normalization: Normalization, @@ -374,7 +378,7 @@ def create_outputs_from_data( Args: attributes: 2d numpy array of attributes - features: 3d numpy array of features + features: list of 2d numpy arrays, each element is one sequence attribute_types: variable type for each attribute, assumes continuous if None feature_types: variable type for each feature, assumes continuous if None normalization: internal representation for continuous variables, scale @@ -412,8 +416,8 @@ def create_outputs_from_data( ] if feature_types is None: - feature_types = [OutputType.CONTINUOUS] * features.shape[2] - elif len(feature_types) != features.shape[2]: + feature_types = [OutputType.CONTINUOUS] * features[0].shape[1] + elif len(feature_types) != features[0].shape[1]: raise RuntimeError( "feature_types must be the same length as the 3rd (last) dimemnsion of features" ) @@ -423,7 +427,7 @@ def create_outputs_from_data( create_output( index, t, - features[:, :, index], + np.hstack([seq[:, index] for seq in features]), normalization=normalization, apply_feature_scaling=apply_feature_scaling, apply_example_scaling=apply_example_scaling, @@ -449,7 +453,7 @@ def create_output( Args: index: index of variable within attributes or features t: type of output - data: numpy array of data just for this variable + data: 1-d numpy array of data just for this variable normalization: see documentation in create_outputs_from_data apply_feature_scaling: see documentation in create_outputs_from_data apply_example_scaling: see documentation in create_outputs_from_data @@ -458,6 +462,7 @@ def create_output( Returns: Output metadata instance """ + if t == OutputType.CONTINUOUS: output = ContinuousOutput( name="a" + str(index), @@ -542,13 +547,74 @@ def rescale_inverse( return ((transformed + 1.0) / 2.0) * range + global_min -def transform( - original_data: Optional[np.ndarray], +def transform_attributes( + original_data: np.ndarray, + outputs: List[Output], +) -> np.ndarray: + """Transform attributes to internal representation expected by DGAN. + + See transform_features pydoc for more details on how the original_data is + changed. + + Args: + original_data: data to transform as a 2d numpy array + outputs: Output metadata for each attribute + + Returns: + 2d numpy array of the internal representation of data. + """ + parts = [] + for index, output in enumerate(outputs): + parts.append(output.transform(original_data[:, index])) + + return np.concatenate(parts, axis=1, dtype=np.float32) + + +def _grouped_min_and_max( + example_ids: np.ndarray, values: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """Compute min and max for each example. + + Sorts by example_ids, then values, and then indexes into the sorted values + to efficiently obtain min/max. Compute both min and max in one function to + reuse the sorted objects. + + Args: + example_ids: 1d numpy array of example ids, mapping each element + in values to an example/sequence + values: 1d numpy array + + Returns: + Tuple of min and max values for each example/sequence, each is a 1d + numpy array of size # of unique example_ids. The min and max values are + for the sorted example_ids, so the first element is the min/max of the + smallest example_id value, and so on. + """ + # lexsort primary key is last element, so sorts by example_ids first, then + # values + order = np.lexsort((values, example_ids)) + g = example_ids[order] + d = values[order] + # Construct index marking lower borders between examples to capture the min + # values + min_index = np.empty(len(g), dtype="bool") + min_index[0] = True + min_index[1:] = g[1:] != g[:-1] + # Construct index marking upper borders between groups to capture the max + # values + max_index = np.empty(len(g), dtype="bool") + max_index[-1] = True + max_index[:-1] = g[1:] != g[:-1] + + return d[min_index], d[max_index] + + +def transform_features( + original_data: list[np.ndarray], outputs: List[Output], - variable_dim_index: int, - num_examples: Optional[int] = None, -) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: - """Transform data to internal representation expected by DoppelGANger. + max_sequence_len: int, +) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """Transform features to internal representation expected by DGAN. Specifically, performs the following changes: @@ -559,30 +625,26 @@ def transform( apply_example_scaling is True Args: - original_data: data to transform, 2d or 3d numpy array, or None + original_data: data to transform as a list of 2d numpy + arrays, each element is a sequence outputs: Output metadata for each variable - variable_dim_index: dimension of numpy array that contains the - variables, for 2d numpy arrays this should be 1, for 3d should be 2 - num_examples: dimension of feature output array. If the original - data is None, we want the empty/none torch array to match the first - dimension of the feature output array. This makes sure that the - TensorDataset module works smoothly. If the first dimensions are different, - torch will give an error. + max_sequence_len: pad all sequences to max_sequence_len Returns: - Internal representation of data. A single numpy array if the input was a - 2d array or if no outputs have apply_example_scaling=True. A tuple of - features, additional_attributes is returned when transforming features - (a 3d numpy array) and example scaling is used. If the input data is - None, then a single numpy array filled with nan's that has the first - dimension shape of the number examples of the feature vector is - returned. + Internal representation of data. A tuple of 3d numpy array of features + and optional 2d numpy array of additional_attributes. """ - additional_attribute_parts = [] - parts = [] - if original_data is None: - return np.full((num_examples, 1), np.nan, dtype=np.float32) + sequence_lengths = [seq.shape[0] for seq in original_data] + if max(sequence_lengths) > max_sequence_len: + raise ParameterError( + f"Found sequence with length {max(sequence_lengths)}, longer than max_sequence_len={max_sequence_len}" + ) + example_ids = np.repeat(range(len(original_data)), sequence_lengths) + + long_data = np.vstack(original_data) + parts = [] + additional_attribute_parts = [] for index, output in enumerate(outputs): # NOTE: isinstance(output, DiscreteOutput) does not work consistently # with all import styles in jupyter notebooks, using string @@ -590,129 +652,118 @@ def transform( if "OneHotEncodedOutput" in str( output.__class__ ) or "BinaryEncodedOutput" in str(output.__class__): - - if variable_dim_index == 1: - original_column = original_data[:, index] - target_shape = (original_data.shape[0], -1) - elif variable_dim_index == 2: - original_column = original_data[:, :, index] - target_shape = (original_data.shape[0], original_data.shape[1], -1) - else: - raise RuntimeError( - f"Unsupported variable_dim_index={variable_dim_index}" - ) - - transformed_data = output.transform(original_column.flatten()) - - parts.append(transformed_data.reshape(target_shape)) - + transformed_data = output.transform(long_data[:, index]) + parts.append(transformed_data) elif "ContinuousOutput" in str(output.__class__): output = cast(ContinuousOutput, output) - if variable_dim_index == 1: - raw = original_data[:, index] - elif variable_dim_index == 2: - raw = original_data[:, :, index] - else: - raise RuntimeError( - f"Unsupported variable_dim_index={variable_dim_index}" - ) + raw = long_data[:, index] - feature_scaled = output.transform(raw.flatten()).reshape(raw.shape) + feature_scaled = output.transform(raw) if output.apply_example_scaling: - if variable_dim_index != 2: - raise RuntimeError( - "apply_example_scaling only applies to features, that is when the data has 3 dimensions" - ) - - mins = np.min(feature_scaled, axis=1) - maxes = np.max(feature_scaled, axis=1) + # Group-wise mins and maxes, dimension of each is (# examples,) + group_mins, group_maxes = _grouped_min_and_max( + example_ids, feature_scaled.flatten() + ) + # Project back to size of long data + mins = np.repeat(group_mins, sequence_lengths).reshape((-1, 1)) + maxes = np.repeat(group_maxes, sequence_lengths).reshape((-1, 1)) additional_attribute_parts.append( - ((mins + maxes) / 2).reshape(mins.shape[0], 1) + ((group_mins + group_maxes) / 2).reshape((-1, 1)) ) additional_attribute_parts.append( - ((maxes - mins) / 2).reshape(mins.shape[0], 1) - ) - - mins = np.broadcast_to( - mins.reshape(mins.shape[0], 1), - (mins.shape[0], feature_scaled.shape[1]), - ) - maxes = np.broadcast_to( - maxes.reshape(maxes.shape[0], 1), - (mins.shape[0], feature_scaled.shape[1]), + ((group_maxes - group_mins) / 2).reshape((-1, 1)) ) scaled = rescale(feature_scaled, output.normalization, mins, maxes) else: scaled = feature_scaled - if variable_dim_index == 1: - scaled = scaled.reshape((original_data.shape[0], 1)) - elif variable_dim_index == 2: - scaled = scaled.reshape( - (original_data.shape[0], original_data.shape[1], 1) - ) - parts.append(scaled) + parts.append(scaled.reshape(-1, 1)) else: raise RuntimeError(f"Unsupported output type, class={type(output)}'") + long_transformed = np.concatenate(parts, axis=1, dtype=np.float32) + + # Fit possibly jagged sequences into 3d numpy array. Pads shorter sequences + # with all 0s in the internal representation. + features_transformed = np.zeros( + (len(original_data), max_sequence_len, long_transformed.shape[1]), + dtype=np.float32, + ) + i = 0 + for example_index, length in enumerate(sequence_lengths): + features_transformed[example_index, 0:length, :] = long_transformed[ + i : (i + length), : + ] + i += length + + additional_attributes = None if additional_attribute_parts: - return ( - np.concatenate(parts, axis=variable_dim_index, dtype=np.float32), - np.concatenate(additional_attribute_parts, axis=1, dtype=np.float32), + additional_attributes = np.concatenate( + additional_attribute_parts, axis=1, dtype=np.float32 ) - else: - return np.concatenate(parts, axis=variable_dim_index, dtype=np.float32) + + return features_transformed, additional_attributes -def inverse_transform( +def inverse_transform_attributes( + transformed_data: np.ndarray, + outputs: list[Output], +) -> Optional[np.ndarray]: + """Inverse of transform_attributes to map back to original space. + + Args: + transformed_data: 2d numpy array of internal representation + outputs: Output metadata for each variable + """ + # TODO: we should not use nans as an indicator and just not call this + # method, or use a zero sized numpy array, to indicate no attributes. + if np.isnan(transformed_data).any(): + return None + parts = [] + transformed_index = 0 + for output in outputs: + original = output.inverse_transform( + transformed_data[:, transformed_index : (transformed_index + output.dim)] + ) + parts.append(original.reshape((-1, 1))) + transformed_index += output.dim + + return np.hstack(parts) + + +def inverse_transform_features( transformed_data: np.ndarray, outputs: List[Output], - variable_dim_index: int, additional_attributes: Optional[np.ndarray] = None, -) -> Optional[np.ndarray]: - """Invert transform to map back to original space. +) -> np.ndarray: + """Inverse of transform_features to map back to original space. Args: - transformed_data: internal representation data + transformed_data: 3d numpy array of internal representation data outputs: Output metadata for each variable - variable_dim_index: dimension of numpy array that contains the - variables, for 2d numpy arrays this should be 1, for 3d should be 2 additional_attributes: midpoint and half-ranges for outputs with apply_example_scaling=True Returns: - If the input data provided is a numpy array with no-nans, then a numpy array of - data in original space is returned. If the input data is nan-filled, the function - returns None. + List of numpy arrays, each element corresponds to one sequence with 2d + array of (time x variables). """ - parts = [] transformed_index = 0 additional_attribute_index = 0 - if np.isnan(transformed_data).any(): - return None + parts = [] for output in outputs: if "OneHotEncodedOutput" in str( output.__class__ ) or "BinaryEncodedOutput" in str(output.__class__): - if variable_dim_index == 1: - v = transformed_data[ - :, transformed_index : (transformed_index + output.dim) - ] - target_shape = (transformed_data.shape[0], 1) - elif variable_dim_index == 2: - v = transformed_data[ - :, :, transformed_index : (transformed_index + output.dim) - ] - target_shape = (transformed_data.shape[0], transformed_data.shape[1], 1) - else: - raise RuntimeError( - f"Unsupported variable_dim_index={variable_dim_index}" - ) + v = transformed_data[ + :, :, transformed_index : (transformed_index + output.dim) + ] + target_shape = (transformed_data.shape[0], transformed_data.shape[1], 1) original = output.inverse_transform(v.reshape((-1, v.shape[-1]))) @@ -721,23 +772,11 @@ def inverse_transform( elif "ContinuousOutput" in str(output.__class__): output = cast(ContinuousOutput, output) - if variable_dim_index == 1: - transformed = transformed_data[:, transformed_index] - elif variable_dim_index == 2: - transformed = transformed_data[:, :, transformed_index] - else: - raise RuntimeError( - f"Unsupported variable_dim_index={variable_dim_index}" - ) + transformed = transformed_data[:, :, transformed_index] if output.apply_example_scaling: - if variable_dim_index != 2: - raise RuntimeError( - "apply_example_scaling only applies to features where the data has 3 dimensions" - ) - if additional_attributes is None: - raise RuntimeError( + raise ValueError( "Must provide additional_attributes if apply_example_scaling=True" ) @@ -770,7 +809,7 @@ def inverse_transform( else: raise RuntimeError(f"Unsupported output type, class={type(output)}'") - return np.concatenate(parts, axis=variable_dim_index) + return np.concatenate(parts, axis=2) def create_additional_attribute_outputs(feature_outputs: List[Output]) -> List[Output]: diff --git a/tests/timeseries_dgan/test_dgan.py b/tests/timeseries_dgan/test_dgan.py index d9a41e8b..6474feb7 100644 --- a/tests/timeseries_dgan/test_dgan.py +++ b/tests/timeseries_dgan/test_dgan.py @@ -1,6 +1,9 @@ import itertools import os.path +from collections import Counter +from typing import Any, Optional, Sequence, Union + import numpy as np import pandas as pd import pytest @@ -68,6 +71,37 @@ def config() -> DGANConfig: ) +def assert_attributes_features_shape( + attributes: Optional[np.ndarray], + features: list[np.ndarray], + attributes_shape: Optional[tuple[int, int]], + features_shape: tuple[int, int, int], +): + + if attributes_shape: + assert attributes is not None + assert attributes.shape == attributes_shape + + assert len(features) == features_shape[0] + assert all(seq.shape == features_shape[1:] for seq in features) + + +def assert_attributes_features( + attributes: Optional[np.ndarray], + features: list[np.ndarray], + expected_attributes: Optional[Union[np.ndarray, Sequence[Sequence[Any]]]], + expected_features: Union[ + np.ndarray, list[np.ndarray], Sequence[Sequence[Sequence[Any]]] + ], +): + if expected_attributes is not None: + assert attributes is not None + np.testing.assert_allclose(attributes, expected_attributes) + + for f, ef in zip(features, expected_features): + np.testing.assert_allclose(f, ef) + + def test_discrete_cols_to_int(): df = pd.DataFrame( data=zip(["1", "2", "3", "4"], ["one", "two", "three", "four"]), @@ -118,23 +152,27 @@ def test_generate(): # Check requesting various number of examples attributes, features = dg.generate_numpy(8) - assert attributes.shape == (8, 3) - assert features.shape == (8, 20, 2) + assert_attributes_features_shape( + attributes, features, attributes_shape=(8, 3), features_shape=(8, 20, 2) + ) attributes, features = dg.generate_numpy(64) - assert attributes.shape == (64, 3) - assert features.shape == (64, 20, 2) + assert_attributes_features_shape( + attributes, features, attributes_shape=(64, 3), features_shape=(64, 20, 2) + ) attributes, features = dg.generate_numpy(200) - assert attributes.shape == (200, 3) - assert features.shape == (200, 20, 2) + assert_attributes_features_shape( + attributes, features, attributes_shape=(200, 3), features_shape=(200, 20, 2) + ) attributes, features = dg.generate_numpy(1) - assert attributes.shape == (1, 3) - assert features.shape == (1, 20, 2) + assert_attributes_features_shape( + attributes, features, attributes_shape=(1, 3), features_shape=(1, 20, 2) + ) # Check passing noise vectors @@ -142,9 +180,9 @@ def test_generate(): attribute_noise=dg.attribute_noise_func(20), feature_noise=dg.feature_noise_func(20), ) - - assert attributes.shape == (20, 3) - assert features.shape == (20, 20, 2) + assert_attributes_features_shape( + attributes, features, attributes_shape=(20, 3), features_shape=(20, 20, 2) + ) def test_generate_example_normalized(): @@ -181,13 +219,15 @@ def test_generate_example_normalized(): ) attributes, features = dg.generate_numpy(8) - assert attributes.shape == (8, 3) - assert features.shape == (8, 20, 2) + assert_attributes_features_shape( + attributes, features, attributes_shape=(8, 3), features_shape=(8, 20, 2) + ) attributes, features = dg.generate_numpy(64) - assert attributes.shape == (64, 3) - assert features.shape == (64, 20, 2) + assert_attributes_features_shape( + attributes, features, attributes_shape=(64, 3), features_shape=(64, 20, 2) + ) @pytest.mark.parametrize( @@ -219,8 +259,9 @@ def test_train_numpy( attributes, features = dg.generate_numpy(18) - assert attributes.shape == (18, 2) - assert features.shape == (18, 20, 2) + assert_attributes_features_shape( + attributes, features, attributes_shape=(18, 2), features_shape=(18, 20, 2) + ) @pytest.mark.parametrize( @@ -248,8 +289,9 @@ def test_train_numpy_no_attributes_1( attributes, features = dg.generate_numpy(18) - assert attributes is None - assert features.shape == (18, 20, 2) + assert_attributes_features_shape( + attributes, features, attributes_shape=None, features_shape=(18, 20, 2) + ) def test_train_numpy_no_attributes_2(config: DGANConfig): @@ -263,8 +305,12 @@ def test_train_numpy_no_attributes_2(config: DGANConfig): ) assert type(model_attributes_blank) == DGAN - assert synthetic_attributes is None - assert synthetic_features.shape == (n_samples, features.shape[1], features.shape[2]) + assert_attributes_features_shape( + synthetic_attributes, + synthetic_features, + attributes_shape=None, + features_shape=(n_samples, features.shape[1], features.shape[2]), + ) model_attributes_none = DGAN(config) model_attributes_none.train_numpy(attributes=None, features=features) @@ -273,8 +319,12 @@ def test_train_numpy_no_attributes_2(config: DGANConfig): ) assert type(model_attributes_none) == DGAN - assert synthetic_attributes is None - assert synthetic_features.shape == (n_samples, features.shape[1], features.shape[2]) + assert_attributes_features_shape( + synthetic_attributes, + synthetic_features, + attributes_shape=None, + features_shape=(n_samples, features.shape[1], features.shape[2]), + ) def test_train_numpy_batch_size_of_1(config: DGANConfig): @@ -295,9 +345,13 @@ def test_train_numpy_batch_size_of_1(config: DGANConfig): ) synthetic_attributes, synthetic_features = model.generate_numpy(11) - assert synthetic_attributes is not None - assert synthetic_attributes.shape == (11, 1) - assert synthetic_features.shape == (11, 20, 2) + + assert_attributes_features_shape( + synthetic_attributes, + synthetic_features, + attributes_shape=(11, 1), + features_shape=(11, 20, 2), + ) def test_train_dataframe_wide(config: DGANConfig): @@ -704,24 +758,52 @@ def test_nan_linear_interpolation(): # Inserting nans in different length and locations of a 3-d array. # np interpolation uses padding for values in the begining and the end of an array. - features = np.array( - [ - [[0.0, 1.0, 2.0], [np.nan, 7, 5.0], [np.nan, 4, 8.0], [8.0, 10.0, np.nan]], + features = [ + np.array( + [[0.0, 1.0, 2.0], [np.nan, 7, 5.0], [np.nan, 4, 8.0], [8.0, 10.0, np.nan]] + ), + np.array( [ [np.nan, 13.0, 14.0], [np.nan, 16.0, 17.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0], - ], - ] - ) + ] + ), + np.array([[5.0, np.nan, 85.0], [np.nan, 10.0, 80.0], [5.0, 10.0, np.nan]]), + ] + # Note, the interpolation is linear if there is a value before and after the + # section of nans. For nans at the beginning or end of a sequence, the + # interpolation assumes a diff of 0 and uses the first/last non-nan value as + # a constant to replace the nans. + expected_features = [ + np.array( + [ + [0.0, 1.0, 2.0], + [8.0 / 3.0, 7.0, 5.0], + [16.0 / 3.0, 4.0, 8.0], + [8.0, 10.0, 8.0], + ] + ), + np.array( + [ + [18.0, 13.0, 14.0], + [18.0, 16.0, 17.0], + [18.0, 19.0, 20.0], + [21.0, 22.0, 23.0], + ] + ), + np.array( + [[5.0, 10.0, 85.0], [5.0, 10.0, 80.0], [5.0, 10.0, 80.0]], + ), + ] - features = nan_linear_interpolation(features) + nan_linear_interpolation(list(features), continuous_features_ind=[0, 1, 2]) - assert (features[0, 1:3, 0] == np.array([8 / 3, (8 / 3 + 8) / 2])).all() - assert (np.diff(features[0, 2:, 2]) == 0).all() - assert (np.diff(features[1, 0:3, 0]) == 0).all() - assert np.isnan(features).sum() == 0 + assert all(np.isnan(seq).sum() == 0 for seq in features) + + for f, ef in zip(features, expected_features): + np.testing.assert_allclose(f, ef) def test_validation_check(): @@ -739,7 +821,7 @@ def test_validation_check(): invalid_examples = np.random.rand(n, 20, 3) invalid_examples[0:26, 2:4, 2] = np.nan with pytest.raises(DataError, match="NaN"): - validation_check(invalid_examples) + validation_check(list(invalid_examples), continuous_features_ind=[0, 1, 2]) # Set nans for various features. Features 1 and 2 have fixable invalid examples, # while feature 0 has 10 invalid examples which should be dropped (high consecutive nans) @@ -748,9 +830,14 @@ def test_validation_check(): invalid_examples_dropped[20:30, 10:20, 0] = np.nan invalid_examples_dropped[30:40, 15, 1] = np.nan - test_boolean = np.array([True] * n) - test_boolean[20:30] = False - assert (validation_check(invalid_examples_dropped) == test_boolean).all() + expected = np.array([True] * n) + expected[20:30] = False + np.testing.assert_equal( + validation_check( + list(invalid_examples_dropped), continuous_features_ind=[0, 1, 2] + ), + expected, + ) # inserting small number of nans for each feature, non should be dropped during # the check. @@ -758,7 +845,9 @@ def test_validation_check(): valid_examples[5:7, 2, 2] = np.nan valid_examples[15:20, 15, 0] = np.nan valid_examples[-5:, 8, 1] = np.nan - assert validation_check(valid_examples).all() + assert validation_check( + list(valid_examples), continuous_features_ind=[0, 1, 2] + ).all() def test_train_numpy_nans(config: DGANConfig): @@ -992,14 +1081,17 @@ def test_train_numpy_with_strings(config: DGANConfig): synthetic_attributes, synthetic_features = dg.generate_numpy(5) - assert synthetic_attributes is None - assert synthetic_features.shape == (5, 5, 3) + assert_attributes_features_shape( + synthetic_attributes, + synthetic_features, + attributes_shape=None, + features_shape=(5, 5, 3), + ) expected_categories = set(["aa", "bb", "cc"]) - assert np.all( - [x in expected_categories for x in synthetic_features[:, :, 0].flatten()] - ) + for seq in synthetic_features: + assert all([x in expected_categories for x in seq[:, 0]]) def test_train_numpy_max_sequence_len_error(config: DGANConfig): @@ -1052,11 +1144,13 @@ def test_wide_data_frame_converter1(df_wide): ) attributes, features = converter.convert(df_wide) - assert attributes.shape == (6, 2) - assert features.shape == (6, 3, 1) + assert_attributes_features_shape( + attributes, features, attributes_shape=(6, 2), features_shape=(6, 3, 1) + ) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, features, expected_attributes, expected_features + ) # Check invert produces original df df_out = converter.invert(attributes, features) @@ -1090,11 +1184,16 @@ def test_wide_data_frame_converter2(df_wide): ) attributes, features = converter.convert(df_wide) - assert attributes.shape == (6, 1) - assert features.shape == (6, 3, 1) + assert_attributes_features_shape( + attributes, features, attributes_shape=(6, 1), features_shape=(6, 3, 1) + ) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) # Check invert produces original df df_out = converter.invert(attributes, features) @@ -1119,10 +1218,16 @@ def test_wide_data_frame_converter_no_attributes(df_wide): ) attributes, features = converter.convert(df_wide) - assert attributes is None - assert features.shape == (6, 3, 1) + assert_attributes_features_shape( + attributes, features, attributes_shape=None, features_shape=(6, 3, 1) + ) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes=None, + expected_features=expected_features, + ) # Check invert produces original df df_out = converter.invert(attributes, features) @@ -1145,10 +1250,16 @@ def test_wide_data_frame_converter_no_attributes_no_column_name(df_wide): converter = _WideDataFrameConverter.create(df_wide) attributes, features = converter.convert(df_wide) - assert attributes is None - assert features.shape == (6, 3, 1) + assert_attributes_features_shape( + attributes, features, attributes_shape=None, features_shape=(6, 3, 1) + ) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes=None, + expected_features=expected_features, + ) # Check invert produces original df df_out = converter.invert(attributes, features) @@ -1173,8 +1284,12 @@ def test_wide_data_frame_converter_save_and_load(df_wide): attributes, features = loaded_converter.convert(df_wide) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) df = loaded_converter.invert(attributes, features) @@ -1220,6 +1335,7 @@ def test_long_data_frame_converter1(df_long): converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=3, attribute_columns=["a1", "a2"], feature_columns=["f1", "f2", "f3"], example_id_column="example_id", @@ -1228,26 +1344,44 @@ def test_long_data_frame_converter1(df_long): ) attributes, features = converter.convert(df_long) - assert attributes.shape == (2, 2) - assert features.shape == (2, 3, 3) + assert_attributes_features_shape( + attributes, + features, + attributes_shape=(2, 2), + features_shape=(2, 3, 3), + ) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) # Check works the same if feature column param is omitted converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=3, attribute_columns=["a1", "a2"], example_id_column="example_id", time_column="time", discrete_columns=["a1"], ) attributes, features = converter.convert(df_long) - assert attributes.shape == (2, 2) - assert features.shape == (2, 3, 3) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features_shape( + attributes, + features, + attributes_shape=(2, 2), + features_shape=(2, 3, 3), + ) + + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) # Check the inverse returns the original df df_out = converter.invert(attributes, features) @@ -1279,6 +1413,7 @@ def test_long_data_frame_converter2(df_long): converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=3, attribute_columns=["a1", "a2"], feature_columns=["f1", "f2", "f3"], example_id_column="example_id", @@ -1286,26 +1421,43 @@ def test_long_data_frame_converter2(df_long): ) attributes, features = converter.convert(df_long) - assert attributes.shape == (2, 2) - assert features.shape == (2, 3, 3) + assert_attributes_features_shape( + attributes, + features, + attributes_shape=(2, 2), + features_shape=(2, 3, 3), + ) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) # Check works the same if feature column param is omitted converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=3, attribute_columns=["a1", "a2"], example_id_column="example_id", discrete_columns=["a1"], ) attributes, features = converter.convert(df_long) - assert attributes.shape == (2, 2) - assert features.shape == (2, 3, 3) + assert_attributes_features_shape( + attributes, + features, + attributes_shape=(2, 2), + features_shape=(2, 3, 3), + ) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) # Check the inverse returns the original df df_out = converter.invert(attributes, features) @@ -1337,6 +1489,7 @@ def test_long_data_frame_converter3(df_long): converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=6, attribute_columns=["a1", "a2"], feature_columns=["f1", "f2", "f3"], time_column="time", @@ -1344,26 +1497,42 @@ def test_long_data_frame_converter3(df_long): ) attributes, features = converter.convert(df_long) - assert attributes.shape == (1, 2) - assert features.shape == (1, 6, 3) + assert_attributes_features_shape( + attributes, + features, + attributes_shape=(1, 2), + features_shape=(1, 6, 3), + ) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) # Check works the same if feature column param is omitted converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=6, attribute_columns=["a1", "a2"], time_column="time", discrete_columns=["a1"], ) attributes, features = converter.convert(df_long) + assert_attributes_features_shape( + attributes, + features, + attributes_shape=(1, 2), + features_shape=(1, 6, 3), + ) - assert attributes.shape == (1, 2) - assert features.shape == (1, 6, 3) - - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) # Check the inverse returns the original df df_out = converter.invert(attributes, features) @@ -1398,31 +1567,49 @@ def test_long_data_frame_converter4(config, df_long): converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=6, attribute_columns=["a1", "a2"], feature_columns=["f1", "f2", "f3"], discrete_columns=["a1"], ) attributes, features = converter.convert(df_long) - assert attributes.shape == (1, 2) - assert features.shape == (1, 6, 3) + assert_attributes_features_shape( + attributes, + features, + attributes_shape=(1, 2), + features_shape=(1, 6, 3), + ) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) # Check works the same if feature column param is omitted converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=6, attribute_columns=["a1", "a2"], discrete_columns=["a1"], ) attributes, features = converter.convert(df_long) - assert attributes.shape == (1, 2) - assert features.shape == (1, 6, 3) + assert_attributes_features_shape( + attributes, + features, + attributes_shape=(1, 2), + features_shape=(1, 6, 3), + ) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) # Check the inverse returns the original df df_out = converter.invert(attributes, features) @@ -1447,6 +1634,7 @@ def test_long_data_frame_converter_extra_cols(df_long): converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=3, attribute_columns=["a1", "a2"], feature_columns=["f1", "f2", "f3"], example_id_column="example_id", @@ -1454,11 +1642,19 @@ def test_long_data_frame_converter_extra_cols(df_long): ) attributes, features = converter.convert(df_long) - assert attributes.shape == (2, 2) - assert features.shape == (2, 3, 3) + assert_attributes_features_shape( + attributes, + features, + attributes_shape=(2, 2), + features_shape=(2, 3, 3), + ) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) # Check the inverse returns the original df df_out = converter.invert(attributes, features) @@ -1478,6 +1674,7 @@ def test_long_data_frame_converter_attribute_errors(df_long): converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=3, attribute_columns=["a1", "a2"], feature_columns=["f1", "f2", "f3"], example_id_column="example_id", @@ -1489,6 +1686,7 @@ def test_long_data_frame_converter_attribute_errors(df_long): # Same if we don't use example_id where attributes should be constant converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=6, attribute_columns=["a1", "a2"], feature_columns=["f1", "f2", "f3"], discrete_columns=["a1"], @@ -1509,13 +1707,14 @@ def test_long_data_frame_converter_mixed_feature_types(df_long): converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=3, attribute_columns=["a1", "a2"], feature_columns=["f1", "f2", "f3"], example_id_column="example_id", ) _, features = converter.convert(df_long) - assert features.dtype == "float64" + assert all(seq.dtype == "float64" for seq in features) def test_long_data_frame_converter_example_id_object(df_long): @@ -1530,6 +1729,7 @@ def test_long_data_frame_converter_example_id_object(df_long): converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=3, attribute_columns=["a1", "a2"], feature_columns=["f1", "f2", "f3"], example_id_column="example_id", @@ -1538,7 +1738,7 @@ def test_long_data_frame_converter_example_id_object(df_long): attributes, features = converter.convert(df_long) assert attributes is not None assert attributes.dtype == "float64" - assert features.dtype == "float64" + assert all(seq.dtype == "float64" for seq in features) def test_long_data_frame_converter_example_id_float(): @@ -1560,19 +1760,101 @@ def test_long_data_frame_converter_example_id_float(): converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=2, example_id_column="example_id", time_column="time", ) attributes, features = converter.convert(df_long) - assert attributes is None - assert features.dtype == "float64" - assert features.shape == (2, 2, 1) + assert_attributes_features_shape( + attributes, + features, + attributes_shape=None, + features_shape=(2, 2, 1), + ) + + assert all(seq.dtype == "float64" for seq in features) + + +def test_long_data_frame_converter_variable_length(): + df_long = pd.DataFrame( + { + "example_id": ["a", "b", "b", "c", "c", "c"], + "f": [2.0, 2.5, 3.0, 1.0, 1.5, 4.0], + } + ) + + expected_features = [ + np.array([[2.0, 0.0]]), + np.array([[2.5, 1.0], [3.0, 0.0]]), + np.array([[1.0, 1.0], [1.5, 1.0], [4.0, 0.0]]), + ] + converter = _LongDataFrameConverter.create( + df_long, + max_sequence_len=3, + example_id_column="example_id", + ) + + attributes, features = converter.convert(df_long) + assert converter._feature_types == [OutputType.CONTINUOUS, OutputType.DISCRETE] + + assert_attributes_features( + attributes, + features, + expected_attributes=None, + expected_features=expected_features, + ) + + +def test_long_data_frame_converter_variable_length_error(df_long): + with pytest.raises(DataError): + _LongDataFrameConverter.create( + df_long, + max_sequence_len=1, + example_id_column="example_id", + ) def test_long_data_frame_converter_save_and_load(df_long): converter = _LongDataFrameConverter.create( df_long, + max_sequence_len=3, + attribute_columns=["a1", "a2"], + feature_columns=["f1", "f2", "f3"], + example_id_column="example_id", + time_column="time", + discrete_columns=["a1"], + ) + + expected_attributes, expected_features = converter.convert(df_long) + + expected_df = converter.invert(expected_attributes, expected_features) + + state = converter.state_dict() + + loaded_converter = _DataFrameConverter.load_from_state_dict(state) + + attributes, features = loaded_converter.convert(df_long) + + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) + + df = loaded_converter.invert(attributes, features) + + assert_frame_equal(df, expected_df) + + +def test_long_data_frame_converter_save_and_load_variable_length(df_long): + # Remove first row so the first example has 2 time points and the second + # example has 3 time points + df_long = df_long[1:] + converter = _LongDataFrameConverter.create( + df_long, + max_sequence_len=5, attribute_columns=["a1", "a2"], feature_columns=["f1", "f2", "f3"], example_id_column="example_id", @@ -1590,8 +1872,12 @@ def test_long_data_frame_converter_save_and_load(df_long): attributes, features = loaded_converter.convert(df_long) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) df = loaded_converter.invert(attributes, features) @@ -1648,8 +1934,12 @@ def test_save_and_load( attribute_noise=attribute_noise, feature_noise=feature_noise ) - np.testing.assert_allclose(attributes, expected_attributes) - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes, + expected_features, + ) @pytest.mark.parametrize( @@ -1698,10 +1988,12 @@ def test_save_and_load_no_attributes( attribute_noise=attribute_noise, feature_noise=feature_noise ) - assert attributes is None - assert expected_attributes is None - assert features.shape == expected_features.shape - np.testing.assert_allclose(features, expected_features) + assert_attributes_features( + attributes, + features, + expected_attributes=expected_attributes, + expected_features=expected_features, + ) def test_save_and_load_dataframe_with_attributes(config: DGANConfig, tmp_path): @@ -1888,3 +2180,43 @@ def test_dataframe_long_one_and_partial_example(config: DGANConfig): df_style=DfStyle.LONG, discrete_columns=["a"], ) + + +def test_dataframe_variable_sequences(config: DGANConfig): + # Variable length sequences that dgan should automatically pad to + # max_sequence_len + + # Build dataframe of variable length sequences + rows = [] + for id, seq_length in enumerate([3, 6, 5, 1, 1, 8, 8, 3]): + a1 = np.random.choice(["x", "y", "z"]) + for _ in range(seq_length): + rows.append( + ( + id, + a1, + np.random.random(), + np.random.choice(["foo", "bar"]), + ) + ) + df = pd.DataFrame(rows, columns=["example_id", "a1", "f1", "f2"]) + + config.max_sequence_len = 8 + config.sample_len = 1 + config.epochs = 1 + + dg = DGAN(config=config) + + dg.train_dataframe( + df=df, + df_style=DfStyle.LONG, + example_id_column="example_id", + attribute_columns=["a1"], + ) + + df_synth = dg.generate_dataframe(3) + assert df.shape[1] == df_synth.shape[1] + assert all(str(x) == str(y) for x, y in zip(df_synth.columns, df.columns)) + + for count in Counter(df_synth["example_id"]).most_common(): + assert 1 <= count[1] <= 8 diff --git a/tests/timeseries_dgan/test_transformations.py b/tests/timeseries_dgan/test_transformations.py index 9566dc95..89a2e9be 100644 --- a/tests/timeseries_dgan/test_transformations.py +++ b/tests/timeseries_dgan/test_transformations.py @@ -9,15 +9,17 @@ from gretel_synthetics.timeseries_dgan.transformations import ( BinaryEncodedOutput, ContinuousOutput, - inverse_transform, + inverse_transform_attributes, + inverse_transform_features, OneHotEncodedOutput, rescale, rescale_inverse, - transform, + transform_attributes, + transform_features, ) -def assert_array_equal(a: np.array, b: np.array): +def assert_array_equal(a: np.ndarray, b: np.ndarray): """Custom assert to handle float and object numpy arrays with nans. Treats nan as a "normal" number where nan == nan is True. @@ -290,6 +292,64 @@ def test_rescale_and_inverse_by_example(): np.testing.assert_allclose(inversed, original) +def test_transform_features_variable_length_example_scaling(): + # Var 0 has no overall scaling (since global min=0.0, max=1.0), but per + # example scaling. Var 1 has overall scaling and per example scaling. + outputs = [ + ContinuousOutput( + "a", + Normalization.ZERO_ONE, + apply_feature_scaling=True, + apply_example_scaling=True, + global_min=0.0, + global_max=1.0, + ), + ContinuousOutput( + "b", Normalization.ZERO_ONE, True, True, global_min=3.0, global_max=7.0 + ), + ] + input = [ + # Parenthesis for variable 1 is the global scaled version of the + # min/max. + # Example 0: + # Var 0: min=0.3, max=0.5 + # Var 1: min=5.0 (0.5), max=7.0 (1.0) + np.array([[0.5, 5.0], [0.35, 6.0], [0.3, 7.0]]), + # Example 1: + # Var 0: min=0.5, max=0.5 + # Var 1: min=3.0 (0.0), max=3.0 (0.0) + np.array([[0.5, 3.0]]), + # Example 2: + # Var 0: min=0.6, max=0.7 + # Var 1: min=4.0 (0.25), max=6.5 (0.875) + np.array([[0.7, 4.0], [0.7, 6.0], [0.7, 5.5], [0.6, 6.5]]), + ] + expected_transformed = np.array( + [ + [[1.0, 0.0], [0.25, 0.5], [0.0, 1.0], [0.0, 0.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[1.0, 0.0], [1.0, 0.8], [1.0, 0.6], [0.0, 1.0]], + ] + ) + + # Columns: Var 0 midpoint, Var 0 halfrange, Var 1 midpoint, Var 1 halfrange + expected_additional_attributes = np.array( + [ + [0.4, 0.1, 0.75, 0.25], + [0.5, 0.0, 0.0, 0.0], + [0.65, 0.05, 0.5625, 0.3125], + ] + ) + + transformed, additional_attributes = transform_features( + input, outputs, max_sequence_len=4 + ) + + np.testing.assert_allclose(transformed, expected_transformed) + assert additional_attributes is not None + np.testing.assert_allclose(additional_attributes, expected_additional_attributes) + + @pytest.mark.parametrize( "normalization", [Normalization.ZERO_ONE, Normalization.MINUSONE_ONE] ) @@ -320,12 +380,13 @@ def test_transform_and_inverse_attributes(normalization): ), ] - transformed = transform(attributes, outputs, 1) + transformed = transform_attributes(attributes, outputs) # 3 continuous + 2 for one hot encoded + 4 for binary encoded # (No idea why category encoders needs 4 bits to encode 5 unique values) assert transformed.shape == (n, 9) - inversed = inverse_transform(transformed, outputs, 1) + inversed = inverse_transform_attributes(transformed, outputs) + assert inversed is not None np.testing.assert_allclose(inversed, attributes, rtol=1e-04) @@ -352,11 +413,14 @@ def test_transform_and_inverse_features(normalization): for index, output in enumerate(outputs): output.fit(features[:, :, index].flatten()) - transformed, additional_attributes = transform(features, outputs, 2) + transformed, additional_attributes = transform_features( + list(features), outputs, max_sequence_len=10 + ) assert transformed.shape == (100, 10, 5) + assert additional_attributes is not None assert additional_attributes.shape == (100, 4) - inversed = inverse_transform(transformed, outputs, 2, additional_attributes) + inversed = inverse_transform_features(transformed, outputs, additional_attributes) # TODO: 1e-04 seems too lax of a tolerance for float32, but values very # close to 0.0 are failing the check at 1e-05, so going with this for now to # reduce flakiness. Could be something we can do in the calculations to have