RoaDs is a novel framework that utilizes the dataset to align priors and employs MTL to resolve the conflict between data-driven and knowledge-driven optimization goals under imperfect structural constraints.
This project implements the RoaDs algorithm, which improves causal discovery accuracy by incorporating imperfect edge constraints information. Key features include:
- Support for linear and nonlinear causal relationship discovery
- Multiple graph structure generation methods (ER, SF, Real)
- Multiple noise type support
- Imperfect generateion
- Constraint information integration
- Multi-task learning framework
- Python 3.7+
- CUDA support (optional, for GPU acceleration)
pip install -r requirements.txtMain dependencies:
gcastle==1.0.4numpy==1.24.3pandas==2.0.3scikit_learn==1.3.0scipy==1.11.1torch==2.5.1tqdm==4.66.5joblib==1.4.2
| Parameter | Type | Default | Description |
|---|---|---|---|
--n_nodes |
int | 20 | Number of nodes in the DAG |
--n_edges |
int | 40 | Number of edges in the DAG |
--n_dataset |
int | 40 | Number of data samples |
--present_rate |
int | 30 | Positive constraints rate |
--forbidden_ratio |
int | 1 | Negative constraint ratio |
--wrong_rate |
int | 30 | Flawed constraint rate |
--dataset_method |
str | 'linear' | Dataset generation method ('linear', 'nonlinear') |
--sem_type |
str | 'exp' | SEM noise type ('gauss', 'exp', 'gumbel', 'uniform', 'mlp', 'gp') |
--graph_type |
str | 'ER' | Graph type ('ER', 'SF', 'Real') |
--eq_variances |
int | 0 | Whether to assume equal variances (0=False, 1=True) |
- Linear causal relationship discovery (default settings):
python main.py- Nonlinear causal relationship discovery:
python main.py --dataset_method nonlinear --sem_type mlp- Using real dataset (Sachs dataset):
python main.py --graph_type Real --n_nodes 11The program will output the following information after running:
- Current Settings: Display all parameter configurations
- Evaluation Metrics: Including but not limited to:
- SHD (Structural Hamming Distance)
- F1 score
- Precision and recall
- Execution time