diff --git a/synthpop/processor/data_processor.py b/synthpop/processor/data_processor.py index 2116cb6..1d86b7d 100644 --- a/synthpop/processor/data_processor.py +++ b/synthpop/processor/data_processor.py @@ -55,15 +55,24 @@ def _preprocess(self, data: pd.DataFrame) -> pd.DataFrame: for col, dtype in self.metadata.items(): if dtype == "categorical": - # Use Label Encoding for small categories, OneHot for larger - encoder = LabelEncoder() if len(data[col].unique()) < 10 else OneHotEncoder(sparse=False, drop="first") + # Choose encoder based on cardinality + n_unique = len(data[col].unique()) + if n_unique < 10: + encoder = LabelEncoder() + elif n_unique < 50: + encoder = OneHotEncoder(sparse=False, drop="first") + else: + # Frequency encoding + value_counts = data[col].value_counts(normalize=True) + encoder = {'type': 'frequency', 'mapping': value_counts.to_dict()} + transformed_data = self._encode_categorical(data[col], encoder) self.encoders[col] = encoder data.drop(columns=[col], inplace=True) data = pd.concat([data, transformed_data], axis=1) elif dtype == "numerical": - scaler = StandardScaler(with_mean= False, with_std= False) + scaler = StandardScaler(with_mean=False, with_std=False) data[col] = scaler.fit_transform(data[[col]]) self.scalers[col] = scaler @@ -71,7 +80,9 @@ def _preprocess(self, data: pd.DataFrame) -> pd.DataFrame: data[col] = data[col].astype(int) # Convert True/False to 1/0 elif dtype == "datetime": - data[col] = data[col].apply(lambda x: x.timestamp() if pd.notnull(x) else np.nan) # Convert to Unix timestamp + data[col] = data[col].apply( + lambda x: x.timestamp() if pd.notnull(x) else np.nan + ) # Convert to Unix timestamp elif dtype == "timedelta": data[col] = pd.to_timedelta(data[col]).dt.total_seconds() @@ -125,11 +136,23 @@ def validate(self, data: pd.DataFrame): def _encode_categorical(self, series: pd.Series, encoder): """Encode categorical columns.""" if isinstance(encoder, LabelEncoder): - return pd.DataFrame(encoder.fit_transform(series), columns=[series.name]) + return pd.DataFrame( + encoder.fit_transform(series), + columns=[series.name] + ) elif isinstance(encoder, OneHotEncoder): encoded_array = encoder.fit_transform(series.values.reshape(-1, 1)) - encoded_df = pd.DataFrame(encoded_array, columns=encoder.get_feature_names_out([series.name])) + encoded_df = pd.DataFrame( + encoded_array, + columns=encoder.get_feature_names_out([series.name]) + ) return encoded_df + elif isinstance(encoder, dict) and encoder['type'] == 'frequency': + # Frequency encoding + encoded_values = series.map(encoder['mapping']) + return pd.DataFrame(encoded_values, columns=[series.name]) + else: + raise TypeError(f"Unsupported encoder type: {type(encoder)}") def _decode_categorical(self, encoded: pd.Series or pd.DataFrame, encoder): """ @@ -170,6 +193,23 @@ def _decode_categorical(self, encoded: pd.Series or pd.DataFrame, encoder): cats = encoder.categories_[0] return pd.Series(cats[idx], index=getattr(encoded, "index", None)) + # FREQUENCY ENCODER CASE + elif isinstance(encoder, dict) and encoder['type'] == 'frequency': + # For frequency encoding, we need to map the encoded values back to categories + # We'll use the inverse mapping (frequency -> category) + inverse_mapping = {v: k for k, v in encoder['mapping'].items()} + # Find the closest frequency for each encoded value + encoded_values = encoded.values.flatten() + decoded = [] + for val in encoded_values: + if pd.isna(val): + decoded.append(np.nan) + else: + # Find the category with the closest frequency + closest_freq = min(inverse_mapping.keys(), key=lambda x: abs(x - val)) + decoded.append(inverse_mapping[closest_freq]) + return pd.Series(decoded, index=getattr(encoded, "index", None)) + else: raise TypeError(f"Unsupported encoder type: {type(encoder)}")