Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions synthpop/synthpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,48 @@ def post_postprocessing(self,syn_df):

return syn_df

def _infer_dtypes(self, df):
"""Automatically infer data types from DataFrame.

Args:
df: pandas DataFrame

Returns:
dict: Mapping of column names to inferred types ('int', 'float', 'datetime', 'category', 'bool')
"""
dtypes = {}
for column in df.columns:
pd_dtype = str(df[column].dtype)

if pd_dtype.startswith('int'):
dtypes[column] = 'int'
elif pd_dtype.startswith('float'):
dtypes[column] = 'float'
elif pd_dtype.startswith('datetime'):
dtypes[column] = 'datetime'
elif pd_dtype.startswith('bool'):
dtypes[column] = 'bool'
else:
# For object or string dtypes, check if it should be categorical
dtypes[column] = 'category'

return dtypes

def fit(self, df, dtypes=None):
# TODO check df and check/EXTRACT dtypes
# - all column names of df are unique
# - all columns data of df are consistent
# - all dtypes of df are correct ('int', 'float', 'datetime', 'category', 'bool'; no object)
# - can map dtypes (if given) correctly to df
# should create map col: dtype (self.df_dtypes)
"""Fit the synthetic data generator.

Args:
df: pandas DataFrame to learn from
dtypes: Optional dict mapping column names to types. If not provided, types will be inferred.
"""
# Infer dtypes if not provided
if dtypes is None:
dtypes = self._infer_dtypes(df)

# Validate DataFrame
if not df.columns.is_unique:
raise ValueError("DataFrame column names must be unique")

df,dtypes = self.pre_preprocess(df,dtypes,-8)

self.df_columns = df.columns.tolist()
Expand Down
34 changes: 29 additions & 5 deletions tests/test_synthpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from datasets.adult import df, dtypes

def test_synthpop_default_parameters():
"""Test Synthpop with default parameters using Adult dataset."""
"""Test Synthpop with default parameters and automatic type inference."""
# Initialize Synthpop
spop = Synthpop()

# Fit the model
spop.fit(df, dtypes)
# Fit the model with automatic type inference
spop.fit(df)

# Generate synthetic data
synth_df = spop.generate(len(df))
Expand All @@ -20,6 +20,11 @@ def test_synthpop_default_parameters():
# Verify the synthetic dataframe has the same columns as original
assert all(synth_df.columns == df.columns)

# Verify inferred dtypes match expected types
assert spop.df_dtypes['age'] == 'int'
assert spop.df_dtypes['workclass'] == 'category'
assert spop.df_dtypes['education'] == 'category'

# Verify the method attribute contains expected default values
assert isinstance(spop.method, pd.Series)
assert 'age' in spop.method.index
Expand All @@ -37,6 +42,25 @@ def test_synthpop_default_parameters():
assert all(spop.predictor_matrix.index == df.columns)
assert all(spop.predictor_matrix.columns == df.columns)

def test_synthpop_with_manual_dtypes():
"""Test Synthpop with manually specified dtypes."""
# Initialize Synthpop
spop = Synthpop()

# Fit the model with explicit dtypes
spop.fit(df, dtypes)

# Verify the dtypes were set correctly
for col, dtype in dtypes.items():
assert spop.df_dtypes[col] == dtype

# Generate synthetic data
synth_df = spop.generate(len(df))

# Verify the synthetic dataframe has the same shape and columns
assert synth_df.shape == df.shape
assert all(synth_df.columns == df.columns)

def test_synthpop_custom_visit_sequence():
"""Test Synthpop with custom visit sequence using Adult dataset."""
# Define custom visit sequence
Expand All @@ -45,8 +69,8 @@ def test_synthpop_custom_visit_sequence():
# Initialize Synthpop with custom visit sequence
spop = Synthpop(visit_sequence=visit_sequence)

# Fit the model
spop.fit(df, dtypes)
# Fit the model with automatic type inference
spop.fit(df)

# Generate synthetic data
synth_df = spop.generate(len(df))
Expand Down
Loading