This repository contains code and datasets for the paper: CTRL Your Shift: Clustered Transfer Residual Learning for Many Small Datasets
- Main Pipeline Scripts:
runscript.shorchestrates the complete pipeline - Baseline Scripts:
baselines.Rprovides results for JTT and RWG baselines - Functions/: Reusable R functions for estimation, prediction, and data handling
- Data/: 4 datasets with different characteristics and domains
- RWA Calculation:
get_RWA.Rfor computing the RWA score used in the paper for each method
The repository includes 4 datasets, each requiring specific configuration parameters. We provide adapted versions of the datasets to specifically study source level heterogeneity, but the full original dataset links are also provided:
| Dataset | Dataset Name | Group Variable | Outcome Variable | Data Source |
|---|---|---|---|---|
| Synthetic | synthetic |
group |
outcome |
Generated data from generate_synthetic_data.py |
| Education | education |
STATE |
Education |
Educational Outcomes Dataset |
| Health | health |
group |
has_chronic_condition |
Dissecting Bias Health Dataset |
| UK Refugee Asylum | uk_refugee_asylum |
nationality |
outcome_binary |
UK Refugee Asylum Dataset |
-
Install required R packages:
install.packages(c("dplyr", "data.table", "xgboost", "glmnet", "readr", "tidyr", "tidyverse", "parallel"))
-
Install Julia dependencies (optional, for advanced optimization):
using Pkg Pkg.add(["JuMP", "Gurobi", "DataFrames", "CSV"])
-
Clone the repository:
git clone https://github.com/Gjain234/CTRLYourShift.git cd CTRLYourShift
First, make the script executable and then run it:
chmod +x runscript.sh
./runscript.shThe main pipeline requires you to configure the following parameters in runscript.sh:
# Configuration section in runscript.sh
export data_type="synthetic" # Dataset name (see table above)
export grouping_var="group" # Group variable name
export outcome_var="outcome" # Outcome variable name
export num_iter=250 # Number of iterations
export max_top_k=10 # Maximum top-k for rankingExample configurations for each dataset:
Synthetic Dataset:
export data_type="synthetic"
export grouping_var="group"
export outcome_var="outcome"Education Dataset:
export data_type="education"
export grouping_var="STATE"
export outcome_var="Education"Health Dataset:
export data_type="health"
export grouping_var="group"
export outcome_var="has_chronic_condition"UK Refugee Asylum Dataset:
export data_type="uk_refugee_asylum"
export grouping_var="nationality"
export outcome_var="outcome_binary"The runscript.sh executes the following stages:
- Generate Weights (R): Creates prediction weights for each iteration
- Julia Weight Optimization: Advanced optimization-based weight generation
- Generate Ranks: Computes rankings based on generated weights
- Model Comparison: Evaluates performance across different approaches
- Assignment Computation: Parallel assignment computation for different methods
- Reward Computation: Final reward computation and plotting
Note: The reward generation in Stage 6 computes a debiased reward estimate, which is not included in the paper itself but represents a commonly accepted notion of finding rewards in the absence of counterfactuals. For the RWA estimate in the paper, see get_RWA.R.
Run the baseline script to get results for JTT and RWG baselines:
Rscript baselines.RNote: You may need to modify the dataset configuration in baselines.R:
data_type = "synthetic" # Change to your dataset
grouping_var = 'group' # Change to your group variable
outcome_var = 'outcome' # Change to your outcome variable- Plots: Reward and MSE plots are saved in the
plots/folder - Results: CSV files with detailed results are saved in the
results/folder - Logs: Execution logs are saved in the
logs/folder
baselines.R: Computes JTT and RWG baseline resultsget_RWA.R: After running the pipeline, you can run this script to compute the RWA reward for each method.
You can generate custom synthetic datasets using generate_synthetic_data.py. While we provide a pre-generated synthetic dataset in data/synthetic_dataset.csv, you can create your own datasets with different parameters:
python generate_synthetic_data.pyThe script allows you to customize:
total_size: Total number of samplesmin_group_size: Minimum group sizen_groups: Number of groupsfeature_dim: Number of featuresglobal_weight_scale: Balance between global and local effects
- The pipeline can be run in parallel depending on your machine's capabilities
- All scripts are designed to work with the standardized dataset format:
data/{dataset_name}_dataset.csv
