diff --git a/synthpop/synthpop.py b/synthpop/synthpop.py index d70fcf2..2d89784 100644 --- a/synthpop/synthpop.py +++ b/synthpop/synthpop.py @@ -77,13 +77,48 @@ def post_postprocessing(self,syn_df): return syn_df + def _infer_dtypes(self, df): + """Automatically infer data types from DataFrame. + + Args: + df: pandas DataFrame + + Returns: + dict: Mapping of column names to inferred types ('int', 'float', 'datetime', 'category', 'bool') + """ + dtypes = {} + for column in df.columns: + pd_dtype = str(df[column].dtype) + + if pd_dtype.startswith('int'): + dtypes[column] = 'int' + elif pd_dtype.startswith('float'): + dtypes[column] = 'float' + elif pd_dtype.startswith('datetime'): + dtypes[column] = 'datetime' + elif pd_dtype.startswith('bool'): + dtypes[column] = 'bool' + else: + # For object or string dtypes, check if it should be categorical + dtypes[column] = 'category' + + return dtypes + def fit(self, df, dtypes=None): - # TODO check df and check/EXTRACT dtypes - # - all column names of df are unique - # - all columns data of df are consistent - # - all dtypes of df are correct ('int', 'float', 'datetime', 'category', 'bool'; no object) - # - can map dtypes (if given) correctly to df - # should create map col: dtype (self.df_dtypes) + """Fit the synthetic data generator. + + Args: + df: pandas DataFrame to learn from + dtypes: Optional dict mapping column names to types. If not provided, types will be inferred. + """ + # Infer dtypes if not provided + if dtypes is None: + dtypes = self._infer_dtypes(df) + + # Validate DataFrame + if not df.columns.is_unique: + raise ValueError("DataFrame column names must be unique") + df,dtypes = self.pre_preprocess(df,dtypes,-8) self.df_columns = df.columns.tolist() diff --git a/tests/test_synthpop.py b/tests/test_synthpop.py index 91f57f5..4604040 100644 --- a/tests/test_synthpop.py +++ b/tests/test_synthpop.py @@ -4,12 +4,12 @@ from datasets.adult import df, dtypes def test_synthpop_default_parameters(): - """Test Synthpop with default parameters using Adult dataset.""" + """Test Synthpop with default parameters and automatic type inference.""" # Initialize Synthpop spop = Synthpop() - # Fit the model - spop.fit(df, dtypes) + # Fit the model with automatic type inference + spop.fit(df) # Generate synthetic data synth_df = spop.generate(len(df)) @@ -20,6 +20,11 @@ def test_synthpop_default_parameters(): # Verify the synthetic dataframe has the same columns as original assert all(synth_df.columns == df.columns) + # Verify inferred dtypes match expected types + assert spop.df_dtypes['age'] == 'int' + assert spop.df_dtypes['workclass'] == 'category' + assert spop.df_dtypes['education'] == 'category' + # Verify the method attribute contains expected default values assert isinstance(spop.method, pd.Series) assert 'age' in spop.method.index @@ -37,6 +42,25 @@ def test_synthpop_default_parameters(): assert all(spop.predictor_matrix.index == df.columns) assert all(spop.predictor_matrix.columns == df.columns) +def test_synthpop_with_manual_dtypes(): + """Test Synthpop with manually specified dtypes.""" + # Initialize Synthpop + spop = Synthpop() + + # Fit the model with explicit dtypes + spop.fit(df, dtypes) + + # Verify the dtypes were set correctly + for col, dtype in dtypes.items(): + assert spop.df_dtypes[col] == dtype + + # Generate synthetic data + synth_df = spop.generate(len(df)) + + # Verify the synthetic dataframe has the same shape and columns + assert synth_df.shape == df.shape + assert all(synth_df.columns == df.columns) + def test_synthpop_custom_visit_sequence(): """Test Synthpop with custom visit sequence using Adult dataset.""" # Define custom visit sequence @@ -45,8 +69,8 @@ def test_synthpop_custom_visit_sequence(): # Initialize Synthpop with custom visit sequence spop = Synthpop(visit_sequence=visit_sequence) - # Fit the model - spop.fit(df, dtypes) + # Fit the model with automatic type inference + spop.fit(df) # Generate synthetic data synth_df = spop.generate(len(df))