Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions synthpop/processor/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,34 @@ def _preprocess(self, data: pd.DataFrame) -> pd.DataFrame:

for col, dtype in self.metadata.items():
if dtype == "categorical":
# Use Label Encoding for small categories, OneHot for larger
encoder = LabelEncoder() if len(data[col].unique()) < 10 else OneHotEncoder(sparse=False, drop="first")
# Choose encoder based on cardinality
n_unique = len(data[col].unique())
if n_unique < 10:
encoder = LabelEncoder()
elif n_unique < 50:
encoder = OneHotEncoder(sparse=False, drop="first")
else:
# Frequency encoding
value_counts = data[col].value_counts(normalize=True)
encoder = {'type': 'frequency', 'mapping': value_counts.to_dict()}

transformed_data = self._encode_categorical(data[col], encoder)
self.encoders[col] = encoder
data.drop(columns=[col], inplace=True)
data = pd.concat([data, transformed_data], axis=1)

elif dtype == "numerical":
scaler = StandardScaler(with_mean= False, with_std= False)
scaler = StandardScaler(with_mean=False, with_std=False)
data[col] = scaler.fit_transform(data[[col]])
self.scalers[col] = scaler

elif dtype == "boolean":
data[col] = data[col].astype(int) # Convert True/False to 1/0

elif dtype == "datetime":
data[col] = data[col].apply(lambda x: x.timestamp() if pd.notnull(x) else np.nan) # Convert to Unix timestamp
data[col] = data[col].apply(
lambda x: x.timestamp() if pd.notnull(x) else np.nan
) # Convert to Unix timestamp

elif dtype == "timedelta":
data[col] = pd.to_timedelta(data[col]).dt.total_seconds()
Expand Down Expand Up @@ -125,11 +136,23 @@ def validate(self, data: pd.DataFrame):
def _encode_categorical(self, series: pd.Series, encoder):
"""Encode categorical columns."""
if isinstance(encoder, LabelEncoder):
return pd.DataFrame(encoder.fit_transform(series), columns=[series.name])
return pd.DataFrame(
encoder.fit_transform(series),
columns=[series.name]
)
elif isinstance(encoder, OneHotEncoder):
encoded_array = encoder.fit_transform(series.values.reshape(-1, 1))
encoded_df = pd.DataFrame(encoded_array, columns=encoder.get_feature_names_out([series.name]))
encoded_df = pd.DataFrame(
encoded_array,
columns=encoder.get_feature_names_out([series.name])
)
return encoded_df
elif isinstance(encoder, dict) and encoder['type'] == 'frequency':
# Frequency encoding
encoded_values = series.map(encoder['mapping'])
return pd.DataFrame(encoded_values, columns=[series.name])
else:
raise TypeError(f"Unsupported encoder type: {type(encoder)}")

def _decode_categorical(self, encoded: pd.Series or pd.DataFrame, encoder):
"""
Expand Down Expand Up @@ -170,6 +193,23 @@ def _decode_categorical(self, encoded: pd.Series or pd.DataFrame, encoder):
cats = encoder.categories_[0]
return pd.Series(cats[idx], index=getattr(encoded, "index", None))

# FREQUENCY ENCODER CASE
elif isinstance(encoder, dict) and encoder['type'] == 'frequency':
# For frequency encoding, we need to map the encoded values back to categories
# We'll use the inverse mapping (frequency -> category)
inverse_mapping = {v: k for k, v in encoder['mapping'].items()}
# Find the closest frequency for each encoded value
encoded_values = encoded.values.flatten()
decoded = []
for val in encoded_values:
if pd.isna(val):
decoded.append(np.nan)
else:
# Find the category with the closest frequency
closest_freq = min(inverse_mapping.keys(), key=lambda x: abs(x - val))
decoded.append(inverse_mapping[closest_freq])
return pd.Series(decoded, index=getattr(encoded, "index", None))

else:
raise TypeError(f"Unsupported encoder type: {type(encoder)}")

Expand Down