diff --git a/causallearn/search/HiddenCausal/RLCD/Chi2RankTest.py b/causallearn/search/HiddenCausal/RLCD/Chi2RankTest.py new file mode 100644 index 00000000..d00c68a5 --- /dev/null +++ b/causallearn/search/HiddenCausal/RLCD/Chi2RankTest.py @@ -0,0 +1,193 @@ +import numpy as np +from pdb import set_trace +from statsmodels.multivariate.cancorr import CanCorr +from math import log, pow +from scipy.stats import chi2 +from scipy.linalg import eigh + +from .logger import LOGGER + + +class Chi2RankTest(object): + + def __init__(self, data, N_scaling=1): + + self.data = data + self.data = self.data - self.data.mean(axis=0) + self.data = self.data/self.data.std(axis=0) + + self.N = data.shape[0] + self.N_scaling = N_scaling + self.unnormalized_crosscovs = self.data.T@self.data # data are zero mean + self.cca_cache_dict = {} + + def get_cachekey(self, pcols_, qcols_): + + pcols = pcols_.copy() + qcols = qcols_.copy() + pcols.sort() + qcols.sort() + + key = '' + for i in range(self.data.shape[1]): + if i in pcols and i in qcols: + key = key+str(3) + elif i in pcols and not i in qcols: + key = key+str(2) + elif not i in pcols and i in qcols: + key = key+str(1) + else: + key = key+str(0) + + return key + + def test(self, pcols, qcols, r, alpha): + ''' + Parameters + ---------- + pcols, qcols : column indices of data + r: null hypo that rank <= r + alpha: significance level + + Returns + ------- + if_fail_to_reject: 0 means reject and 1 means fail to reject + p : the p-value of the test + ''' + + + + cachekey = self.get_cachekey(pcols, qcols) + + if cachekey in self.cca_cache_dict: + cancorr = self.cca_cache_dict[cachekey] + else: + + X = self.data[:, pcols] + Y = self.data[:, qcols] + + unnormalized_crosscovs = [self.unnormalized_crosscovs[pcols,:][:,pcols], self.unnormalized_crosscovs[pcols,:][:,qcols], \ + self.unnormalized_crosscovs[qcols,:][:,pcols], self.unnormalized_crosscovs[qcols,:][:,qcols]] + + try: + comps = kcca_modified([X, Y], reg=0., + numCC=None, kernelcca=False, ktype='linear', + gausigma=1.0, degree=2, crosscovs = unnormalized_crosscovs) + + cancorr, _, _ = recon([X,Y], comps, kernelcca=False) + cancorr = cancorr[:,0,1] + except: + LOGGER.debug("calculating cancorr error %s %s, using slower implementation instead", pcols, qcols) + X_fallback = np.atleast_2d(X).reshape(X.shape[0], -1) + Y_fallback = np.atleast_2d(Y).reshape(Y.shape[0], -1) + cancorr = CanCorr(X_fallback, Y_fallback, tolerance=1e-8).cancorr + + self.cca_cache_dict[cachekey] = cancorr + + testStat = 0 + p = len(pcols) + q = len(qcols) + + l = cancorr[r:] + for li in l: + li = min(li, 1-1e-15) + testStat += log(1)-log(1-li*li) + ratio = 0 + for i in range(r): + li = cancorr[i] + ratio += 1/(li*li)-1 + + ratio += self.N*self.N_scaling - r - 0.5*(p+q+1) + testStat = testStat * ratio + + dfreedom = (p-r) * (q-r) + criticalValue = chi2.ppf(1-alpha, dfreedom) + p = 1 - chi2.cdf(testStat, dfreedom) + if_fail_to_reject = testStat<=criticalValue + + # due to numerical errors comparing criticalValue with testStat is more accurate + + return if_fail_to_reject + + +def kcca_modified( + data, reg=0.0, numCC=None, kernelcca=False, ktype="linear", gausigma=1.0, degree=2, crosscovs=None +): + """Set up and solve the kernel CCA eigenproblem""" + if kernelcca: + raise NotImplementedError + #kernel = [ + # _make_kernel(d, ktype=ktype, gausigma=gausigma, degree=degree) for d in data + #] + else: + kernel = [d.T for d in data] + + nDs = len(kernel) + nFs = [k.shape[0] for k in kernel] + numCC = min([k.shape[0] for k in kernel]) if numCC is None else numCC + + # Get the auto- and cross-covariance matrices + if crosscovs is None: + crosscovs = [np.dot(ki, kj.T) for ki in kernel for kj in kernel] + + # Allocate left-hand side (LH) and right-hand side (RH): + n = sum(nFs) + LH = np.zeros((n, n)) + RH = np.zeros((n, n)) + + # Fill the left and right sides of the eigenvalue problem + for i in range(nDs): + RH[ + sum(nFs[:i]): sum(nFs[: i + 1]), sum(nFs[:i]): sum(nFs[: i + 1]) + ] = crosscovs[i * (nDs + 1)] + reg * np.eye(nFs[i]) + + for j in range(nDs): + if i != j: + LH[ + sum(nFs[:j]): sum(nFs[: j + 1]), sum(nFs[:i]): sum(nFs[: i + 1]) + ] = crosscovs[nDs * j + i] + + LH = (LH + LH.T) / 2.0 + RH = (RH + RH.T) / 2.0 + + maxCC = LH.shape[0] + try: + r, Vs = eigh(LH, RH, subset_by_index=[maxCC - numCC, maxCC - 1]) + except TypeError: + r, Vs = eigh(LH, RH, eigvals=(maxCC - numCC, maxCC - 1)) + r[np.isnan(r)] = 0 + rindex = np.argsort(r)[::-1] + comp = [] + Vs = Vs[:, rindex] + for i in range(nDs): + comp.append(Vs[sum(nFs[:i]): sum(nFs[: i + 1]), :numCC]) + return comp + +def _listdot(d1, d2): + return [np.dot(x[0].T, x[1]) for x in zip(d1, d2)] + +def _listcorr(a): + """Returns pairwise row correlations for all items in array as a list of matrices""" + corrs = np.zeros((a[0].shape[1], len(a), len(a))) + for i in range(len(a)): + for j in range(len(a)): + if j > i: + corrs[:, i, j] = [ + np.nan_to_num(np.corrcoef(ai, aj)[0, 1]) + for (ai, aj) in zip(a[i].T, a[j].T) + ] + return corrs + +def recon(data, comp, corronly=False, kernelcca=False): + # Get canonical variates and CCs + if kernelcca: + ws = _listdot(data, comp) + else: + ws = comp + ccomp = _listdot([d.T for d in data], ws) + corrs = _listcorr(ccomp) + if corronly: + return corrs + else: + return corrs, ws, ccomp + diff --git a/causallearn/search/HiddenCausal/RLCD/Cover.py b/causallearn/search/HiddenCausal/RLCD/Cover.py new file mode 100644 index 00000000..edb4207f --- /dev/null +++ b/causallearn/search/HiddenCausal/RLCD/Cover.py @@ -0,0 +1,186 @@ +from __future__ import annotations +from itertools import combinations + + +class Cover: + """ + Class to represent a Cover of latent variables. A Cover can either be + atomic or non-atomic. + + An atomic Cover is one where the variables cannot be split into + disjoint set of atomic Covers. E.g. if {L1, L2} forms a Cover, it is + not atomic if L1 and L2 individually are atomic Covers. + Note: Only atomic Covers can be children nodes. + + A Cover can also be temporary or not. A temporary Cover is one which was + introduced not because a rank deficiency was found but because we had to + introduce a temporary root variable to connect the remaining variables + when no more rank deficient sets may be found. + """ + + def __init__(self, varnames, atomic=True, temp=False, is_observed=True, is_leaf=None): + if isinstance(varnames, str): + self.vars = set([varnames]) + + elif isinstance(varnames, list): + self.vars = set(varnames) + + elif isinstance(varnames, set): + self.vars = varnames + else: + raise ValueError(f"{varnames} is neither str, list, set.") + + # if len(self.vars) > 0: + # v = next(iter(self.vars)) + # self.type = v[:1] + + self.atomic = atomic + self.temp = temp + self.is_observed=is_observed + self.is_leaf=is_leaf + + def __eq__(self, other): + if not isinstance(other, Cover): + return NotImplemented + return self.vars == other.vars + + # The set of variables in any minimalGroup should be unique + def __hash__(self): + s = "".join(sorted(list(self.vars))) + return hash(s) + + # Union with another Cover + def union(self, L): + self.vars = self.vars.union(L.vars) + + def __len__(self): + return len(self.vars) + + @property + def isAtomic(self): + return self.atomic + + @property + def isTemp(self): + return self.temp + + def takeOne(self): + return next(iter(self.vars)) + + def __str__(self): + if len(self.vars) == 1: + return next(iter(self.vars)) + else: + vars = ",".join(list(self.vars)) + return "{" + vars + "}" + + def __repr__(self): + return str(self) + + def isSubset(self, Bs: set[Cover] | Cover, strict=False): + if isinstance(Bs, set): + Bvars = getVars(Bs) + elif isinstance(Bs, Cover): + Bvars = Bs.vars + else: + raise ValueError("Argument must be set of Covers or Cover.") + if strict: + return self.vars < Bvars + return self.vars <= Bvars + + def intersection(self, B): + return self.vars.intersection(B.vars) + + +################################## +# Methods associated with Covers # +################################## + + +def setLength(Vs: set[Cover]): + """ + Determine ||Vs|| + """ + assert not isinstance(Vs, str), "Cannot be string." + return len(getVars(Vs)) + + +def setDifference(As: set[Cover], Bs: set[Cover]): + diff = As - Bs # first remove any common elements + newset = set() + while len(diff) > 0: + A = diff.pop() + newset.add(A) + for B in Bs: + if len(A.intersection(B)) > 0: + newset.remove(A) + break + return newset + + +def setOverlap(As: set[Cover], Bs: set[Cover]): + if len(As.intersection(Bs)) > 0: + return True + return len(setIntersection(As, Bs)) > 0 + + +def setIntersection(As: set[Cover], Bs: set[Cover]): + Avars = getVars(As) + Bvars = getVars(Bs) + return Avars.intersection(Bvars) + + +def getVars(As: set[Cover]): + """ + Get all variables from a set of Covers. + """ + vars = set() + for A in As: + assert isinstance(A, Cover), "Argument should be a set of Covers." + vars.update(A.vars) + return vars + +def getOrderedVarsString(As: set[Cover]|Cover): + """ + Get all variables from a set of Covers. + """ + + if isinstance(As, Cover): + As = {As} + + vars = set() + for A in As: + assert isinstance(A, Cover), "Argument should be a set of Covers." + vars.update(A.vars) + + vars = [x for x in vars] + vars.sort() + + return " ".join(vars) + +def deduplicate(Vs: set[Cover]): + """ + Deduplicate cases where Vs includes {L1, {L1, L3}} into just {{L1, L3}} + """ + newVs = set() + for Vi in Vs: + for Vj in Vs: + isDuplicate = False + if Vi.vars < Vj.vars: + isDuplicate = True + break + if not isDuplicate: + newVs.add(Vi) + return newVs + + +def pairwiseOverlap(Vs: set[Cover]): + """ + For each pair A, B in Vs, check if there are any variables overlapping. + Return True if so. + """ + for pair in combinations(Vs, 2): + A, B = list(pair) + if len(A.vars.intersection(B.vars)) > 0: + return True + return False diff --git a/causallearn/search/HiddenCausal/RLCD/DSU.py b/causallearn/search/HiddenCausal/RLCD/DSU.py new file mode 100644 index 00000000..c29d9ad1 --- /dev/null +++ b/causallearn/search/HiddenCausal/RLCD/DSU.py @@ -0,0 +1,19 @@ +class DSU: + def __init__(self, N): + self.root = [i for i in range(N)] + + def find(self, k): + if self.root[k] == k: + return k + + fa = self.find(self.root[k]) + self.root[k] = fa + + return fa + + def union(self, a, b): + x = self.find(a) + y = self.find(b) + if x != y: + self.root[y] = x + return \ No newline at end of file diff --git a/causallearn/search/HiddenCausal/RLCD/FCI_CovRank.py b/causallearn/search/HiddenCausal/RLCD/FCI_CovRank.py new file mode 100644 index 00000000..8b73821a --- /dev/null +++ b/causallearn/search/HiddenCausal/RLCD/FCI_CovRank.py @@ -0,0 +1,892 @@ +from __future__ import annotations + +import warnings +from queue import Queue +from typing import List, Set, Tuple, Dict +from numpy import ndarray + +from causallearn.graph.Edge import Edge +from causallearn.graph.Endpoint import Endpoint +from causallearn.graph.Graph import Graph +from causallearn.graph.GraphNode import GraphNode +from causallearn.graph.Node import Node +from causallearn.utils.ChoiceGenerator import ChoiceGenerator +from causallearn.utils.DepthChoiceGenerator import DepthChoiceGenerator +from causallearn.utils.cit import * +from causallearn.utils.FAS import fas +from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge + +def traverseSemiDirected(node: Node, edge: Edge) -> Node | None: + if node == edge.get_node1(): + if edge.get_endpoint1() == Endpoint.TAIL or edge.get_endpoint1() == Endpoint.CIRCLE: + return edge.get_node2() + elif node == edge.get_node2(): + if edge.get_endpoint2() == Endpoint.TAIL or edge.get_endpoint2() == Endpoint.CIRCLE: + return edge.get_node1() + return None + + +def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: + Q = Queue() + V = set() + + for node_u in G.get_adjacent_nodes(node_from): + edge = G.get_edge(node_from, node_u) + node_c = traverseSemiDirected(node_from, edge) + + if node_c is None: + continue + + if not V.__contains__(node_c): + V.add(node_c) + Q.put(node_c) + + while not Q.empty(): + node_t = Q.get_nowait() + if node_t == node_to: + return True + + for node_u in G.get_adjacent_nodes(node_t): + edge = G.get_edge(node_t, node_u) + node_c = traverseSemiDirected(node_t, edge) + + if node_c is None: + continue + + if not V.__contains__(node_c): + V.add(node_c) + Q.put(node_c) + + return False + + +def existOnePathWithPossibleParents(previous, node_w: Node, node_x: Node, node_b: Node, graph: Graph) -> bool: + if node_w == node_x: + return True + + p = previous.get(node_w) + if p is None: + return False + + for node_r in p: + if node_r == node_b or node_r == node_x: + continue + + if existsSemiDirectedPath(node_r, node_x, graph) or existsSemiDirectedPath(node_r, node_b, graph): + return True + + return False + + +def getPossibleDsep(node_x: Node, node_y: Node, graph: Graph, maxPathLength: int) -> List[Node]: + dsep = set() + + Q = Queue() + V = set() + + previous = {node_x: None} + + e = None + distance = 0 + + adjacentNodes = set(graph.get_adjacent_nodes(node_x)) + + for node_b in adjacentNodes: + if node_b == node_y: + continue + edge = (node_x, node_b) + if e is None: + e = edge + Q.put(edge) + V.add(edge) + + # addToSet + node_list = previous.get(node_x) + if node_list is None: + previous[node_x] = set() + node_list = previous.get(node_x) + node_list.add(node_b) + previous[node_x] = node_list + + dsep.add(node_b) + + while not Q.empty(): + t = Q.get_nowait() + if e == t: + e = None + distance += 1 + if distance > 0 and distance > (1000 if maxPathLength == -1 else maxPathLength): + break + node_a, node_b = t + + if existOnePathWithPossibleParents(previous, node_b, node_x, node_b, graph): + dsep.add(node_b) + + for node_c in graph.get_adjacent_nodes(node_b): + if node_c == node_a: + continue + if node_c == node_x: + continue + if node_c == node_y: + continue + + # addToSet + node_list = previous.get(node_c) + if node_list is None: + previous[node_c] = set() + node_list = previous.get(node_c) + node_list.add(node_b) + previous[node_c] = node_list + + if graph.is_def_collider(node_a, node_b, node_c) or graph.is_adjacent_to(node_a, node_c): + u = (node_a, node_c) + if V.__contains__(u): + continue + + V.add(u) + Q.put(u) + + if e is None: + e = u + + if dsep.__contains__(node_x): + dsep.remove(node_x) + if dsep.__contains__(node_y): + dsep.remove(node_y) + + _dsep = list(dsep) + _dsep.sort(reverse=True) + return _dsep + + +def fci_orient_bk(bk: BackgroundKnowledge | None, graph: Graph): + if bk is None: + return + print("Starting BK Orientation.") + edges = graph.get_graph_edges() + for edge in edges: + if bk.is_forbidden(edge.get_node1(), edge.get_node2()): + graph.remove_edge(edge) + graph.add_directed_edge(edge.get_node2(), edge.get_node1()) + print("Orienting edge (Knowledge): " + str(graph.get_edge(edge.get_node2(), edge.get_node1()))) + elif bk.is_forbidden(edge.get_node2(), edge.get_node1()): + graph.remove_edge(edge) + graph.add_directed_edge(edge.get_node1(), edge.get_node2()) + print("Orienting edge (Knowledge): " + str(graph.get_edge(edge.get_node2(), edge.get_node1()))) + elif bk.is_required(edge.get_node1(), edge.get_node2()): + graph.remove_edge(edge) + graph.add_directed_edge(edge.get_node1(), edge.get_node2()) + print("Orienting edge (Knowledge): " + str(graph.get_edge(edge.get_node2(), edge.get_node1()))) + elif bk.is_required(edge.get_node2(), edge.get_node1()): + graph.remove_edge(edge) + graph.add_directed_edge(edge.get_node2(), edge.get_node1()) + print("Orienting edge (Knowledge): " + str(graph.get_edge(edge.get_node2(), edge.get_node1()))) + print("Finishing BK Orientation.") + + +def is_arrow_point_allowed(node_x: Node, node_y: Node, graph: Graph, knowledge: BackgroundKnowledge | None) -> bool: + if graph.get_endpoint(node_x, node_y) == Endpoint.ARROW: + return True + if graph.get_endpoint(node_x, node_y) == Endpoint.TAIL: + return False + if graph.get_endpoint(node_y, node_x) == Endpoint.ARROW: + if knowledge is not None and knowledge.is_forbidden(node_x, node_y): + return False + if graph.get_endpoint(node_y, node_x) == Endpoint.TAIL: + if knowledge is not None and knowledge.is_forbidden(node_x, node_y): + return False + return graph.get_endpoint(node_x, node_y) == Endpoint.CIRCLE + + +def rule0(graph: Graph, nodes: List[Node], sep_sets: Dict[Tuple[int, int], Set[int]], + knowledge: BackgroundKnowledge | None, + verbose: bool): + reorientAllWith(graph, Endpoint.CIRCLE) + fci_orient_bk(knowledge, graph) + for node_b in nodes: + adjacent_nodes = graph.get_adjacent_nodes(node_b) + if len(adjacent_nodes) < 2: + continue + + cg = ChoiceGenerator(len(adjacent_nodes), 2) + combination = cg.next() + while combination is not None: + node_a = adjacent_nodes[combination[0]] + node_c = adjacent_nodes[combination[1]] + combination = cg.next() + + if graph.is_adjacent_to(node_a, node_c): + continue + if graph.is_def_collider(node_a, node_b, node_c): + continue + # check if is collider + sep_set = sep_sets.get((graph.get_node_map()[node_a], graph.get_node_map()[node_c])) + if sep_set is not None and not sep_set.__contains__(graph.get_node_map()[node_b]): + if not is_arrow_point_allowed(node_a, node_b, graph, knowledge): + continue + if not is_arrow_point_allowed(node_c, node_b, graph, knowledge): + continue + + edge1 = graph.get_edge(node_a, node_b) + graph.remove_edge(edge1) + graph.add_edge(Edge(node_a, node_b, edge1.get_proximal_endpoint(node_a), Endpoint.ARROW)) + + edge2 = graph.get_edge(node_c, node_b) + graph.remove_edge(edge2) + graph.add_edge(Edge(node_c, node_b, edge2.get_proximal_endpoint(node_c), Endpoint.ARROW)) + + if verbose: + print( + "Orienting collider: " + node_a.get_name() + " *-> " + node_b.get_name() + " <-* " + node_c.get_name()) + + +def reorientAllWith(graph: Graph, endpoint: Endpoint): + # reorient all edges with CIRCLE Endpoint + ori_edges = graph.get_graph_edges() + for ori_edge in ori_edges: + graph.remove_edge(ori_edge) + ori_edge.set_endpoint1(endpoint) + ori_edge.set_endpoint2(endpoint) + graph.add_edge(ori_edge) + + +def ruleR1(node_a: Node, node_b: Node, node_c: Node, graph: Graph, bk: BackgroundKnowledge | None, changeFlag: bool, + verbose: bool = False) -> bool: + if graph.is_adjacent_to(node_a, node_c): + return changeFlag + + if graph.get_endpoint(node_a, node_b) == Endpoint.ARROW and graph.get_endpoint(node_c, node_b) == Endpoint.CIRCLE: + if not is_arrow_point_allowed(node_b, node_c, graph, bk): + return changeFlag + + edge1 = graph.get_edge(node_c, node_b) + graph.remove_edge(edge1) + graph.add_edge(Edge(node_c, node_b, Endpoint.ARROW, Endpoint.TAIL)) + + changeFlag = True + + if verbose: + print("Orienting edge (Away from collider):" + graph.get_edge(node_b, node_c).__str__()) + + return changeFlag + + +def ruleR2(node_a: Node, node_b: Node, node_c: Node, graph: Graph, bk: BackgroundKnowledge | None, changeFlag: bool, + verbose=False) -> bool: + if graph.is_adjacent_to(node_a, node_c) and graph.get_endpoint(node_a, node_c) == Endpoint.CIRCLE: + if graph.get_endpoint(node_a, node_b) == Endpoint.ARROW and \ + graph.get_endpoint(node_b, node_c) == Endpoint.ARROW and \ + (graph.get_endpoint(node_b, node_a) == Endpoint.TAIL or + graph.get_endpoint(node_c, node_b) == Endpoint.TAIL): + if not is_arrow_point_allowed(node_a, node_c, graph, bk): + return changeFlag + + edge1 = graph.get_edge(node_a, node_c) + graph.remove_edge(edge1) + graph.add_edge(Edge(node_a, node_c, edge1.get_proximal_endpoint(node_a), Endpoint.ARROW)) + + if verbose: + print("Orienting edge (Away from ancestor): " + graph.get_edge(node_a, node_c).__str__()) + + changeFlag = True + + return changeFlag + + +def rulesR1R2cycle(graph: Graph, bk: BackgroundKnowledge | None, changeFlag: bool, verbose: bool = False) -> bool: + nodes = graph.get_nodes() + for node_B in nodes: + adj = graph.get_adjacent_nodes(node_B) + + if len(adj) < 2: + continue + + cg = ChoiceGenerator(len(adj), 2) + combination = cg.next() + + while combination is not None: + node_A = adj[combination[0]] + node_C = adj[combination[1]] + combination = cg.next() + + changeFlag = ruleR1(node_A, node_B, node_C, graph, bk, changeFlag, verbose) + changeFlag = ruleR1(node_C, node_B, node_A, graph, bk, changeFlag, verbose) + changeFlag = ruleR2(node_A, node_B, node_C, graph, bk, changeFlag, verbose) + changeFlag = ruleR2(node_C, node_B, node_A, graph, bk, changeFlag, verbose) + + return changeFlag + + +def isNoncollider(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], node_i: Node, node_j: Node, + node_k: Node) -> bool: + sep_set = sep_sets[(graph.get_node_map()[node_i], graph.get_node_map()[node_k])] + return sep_set is not None and sep_set.__contains__(graph.get_node_map()[node_j]) + + +def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: BackgroundKnowledge | None, changeFlag: bool, + verbose: bool = False) -> bool: + nodes = graph.get_nodes() + for node_B in nodes: + intoBArrows = graph.get_nodes_into(node_B, Endpoint.ARROW) + intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE) + + for node_D in intoBCircles: + if len(intoBArrows) < 2: + continue + gen = ChoiceGenerator(len(intoBArrows), 2) + choice = gen.next() + + while choice is not None: + node_A = intoBArrows[choice[0]] + node_C = intoBArrows[choice[1]] + choice = gen.next() + + if graph.is_adjacent_to(node_A, node_C): + continue + + if (not graph.is_adjacent_to(node_A, node_D)) or (not graph.is_adjacent_to(node_C, node_D)): + continue + + if not isNoncollider(graph, sep_sets, node_A, node_D, node_C): + continue + + if graph.get_endpoint(node_A, node_D) != Endpoint.CIRCLE: + continue + + if graph.get_endpoint(node_C, node_D) != Endpoint.CIRCLE: + continue + + if not is_arrow_point_allowed(node_D, node_B, graph, bk): + continue + + edge1 = graph.get_edge(node_D, node_B) + graph.remove_edge(edge1) + graph.add_edge(Edge(node_D, node_B, edge1.get_proximal_endpoint(node_D), Endpoint.ARROW)) + + if verbose: + print("Orienting edge (Double triangle): " + graph.get_edge(node_D, node_B).__str__()) + + changeFlag = True + return changeFlag + + +def getPath(node_c: Node, previous) -> List[Node]: + l = [] + node_p = previous[node_c] + if node_p is not None: + l.append(node_p) + while node_p is not None: + node_p = previous.get(node_p) + if node_p is not None: + l.append(node_p) + return l + + +def doDdpOrientation(node_d: Node, node_a: Node, node_b: Node, node_c: Node, previous, graph: Graph, data, + independence_test_method, alpha: float, sep_sets: Dict[Tuple[int, int], Set[int]], + change_flag: bool, bk, verbose: bool = False): + """ + Orients the edges inside the definite discriminating path triangle. Takes + the left endpoint, and a,b,c as arguments. + """ + if graph.is_adjacent_to(node_d, node_c): + raise Exception("illegal argument!") + path = getPath(node_d, previous) + + X, Y = graph.get_node_map()[node_d], graph.get_node_map()[node_c] + condSet = tuple([graph.get_node_map()[nn] for nn in path]) + p_value = independence_test_method(X, Y, condSet) + ind = p_value > alpha + + path2 = list(path) + path2.remove(node_b) + + X, Y = graph.get_node_map()[node_d], graph.get_node_map()[node_c] + condSet = tuple([graph.get_node_map()[nn2] for nn2 in path2]) + p_value2 = independence_test_method(X, Y, condSet) + ind2 = p_value2 > alpha + + if not ind and not ind2: + sep_set = sep_sets.get((graph.get_node_map()[node_d], graph.get_node_map()[node_c])) + if verbose: + message = "Sepset for d = " + node_d.get_name() + " and c = " + node_c.get_name() + " = [ " + if sep_set is not None: + for ss in sep_set: + message += graph.get_nodes()[ss].get_name() + " " + message += "]" + print(message) + + if sep_set is None: + if verbose: + print( + "Must be a sepset: " + node_d.get_name() + " and " + node_c.get_name() + "; they're non-adjacent.") + return False, change_flag + + ind = sep_set.__contains__(graph.get_node_map()[node_b]) + + if ind: + edge = graph.get_edge(node_c, node_b) + graph.remove_edge(edge) + graph.add_edge(Edge(node_c, node_b, edge.get_proximal_endpoint(node_c), Endpoint.TAIL)) + + if verbose: + print( + "Orienting edge (Definite discriminating path d = " + node_d.get_name() + "): " + graph.get_edge(node_b, + node_c).__str__()) + + change_flag = True + return True, change_flag + else: + if not is_arrow_point_allowed(node_a, node_b, graph, bk): + return False, change_flag + + if not is_arrow_point_allowed(node_c, node_b, graph, bk): + return False, change_flag + + edge1 = graph.get_edge(node_a, node_b) + graph.remove_edge(edge1) + graph.add_edge(Edge(node_a, node_b, edge1.get_proximal_endpoint(node_a), Endpoint.ARROW)) + + edge2 = graph.get_edge(node_c, node_b) + graph.remove_edge(edge2) + graph.add_edge(Edge(node_c, node_b, edge2.get_proximal_endpoint(node_c), Endpoint.ARROW)) + + if verbose: + print( + "Orienting collider (Definite discriminating path.. d = " + node_d.get_name() + "): " + node_a.get_name() + " *-> " + node_b.get_name() + " <-* " + node_c.get_name()) + + change_flag = True + return True, change_flag + + +def ddpOrient(node_a: Node, node_b: Node, node_c: Node, graph: Graph, maxPathLength: int, data: ndarray, + independence_test_method, alpha: float, sep_sets: Dict[Tuple[int, int], Set[int]], change_flag: bool, + bk: BackgroundKnowledge | None, verbose: bool = False) -> bool: + """ + a method to search "back from a" to find a DDP. It is called with a + reachability list (first consisting only of a). This is breadth-first, + utilizing "reachability" concept from Geiger, Verma, and Pearl 1990. + The body of a DDP consists of colliders that are parents of c. + """ + Q = Queue() + V = set() + e = None + distance = 0 + previous = {} + + cParents = graph.get_parents(node_c) + + Q.put(node_a) + V.add(node_a) + V.add(node_b) + previous[node_a] = node_b + + while not Q.empty(): + node_t = Q.get_nowait() + + if e is None or e == node_t: + e = node_t + distance += 1 + if distance > 0 and distance > (1000 if maxPathLength == -1 else maxPathLength): + return change_flag + + nodesInTo = graph.get_nodes_into(node_t, Endpoint.ARROW) + + for node_d in nodesInTo: + if V.__contains__(node_d): + continue + + previous[node_d] = node_t + node_p = previous[node_t] + + if not graph.is_def_collider(node_d, node_t, node_p): + continue + + previous[node_d] = node_t + + if not graph.is_adjacent_to(node_d, node_c) and node_d != node_c: + res, change_flag = \ + doDdpOrientation(node_d, node_a, node_b, node_c, previous, graph, data, + independence_test_method, alpha, sep_sets, change_flag, bk, verbose) + + if res: + return change_flag + + if cParents.__contains__(node_d): + Q.put(node_d) + V.add(node_d) + return change_flag + + +def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_method, alpha: float, + sep_sets: Dict[Tuple[int, int], Set[int]], + change_flag: bool, bk: BackgroundKnowledge | None, + verbose: bool = False) -> bool: + nodes = graph.get_nodes() + + for node_b in nodes: + possA = graph.get_nodes_out_of(node_b, Endpoint.ARROW) + possC = graph.get_nodes_into(node_b, Endpoint.CIRCLE) + + for node_a in possA: + for node_c in possC: + if not graph.is_parent_of(node_a, node_c): + continue + + if graph.get_endpoint(node_b, node_c) != Endpoint.ARROW: + continue + + change_flag = ddpOrient(node_a, node_b, node_c, graph, maxPathLength, data, independence_test_method, + alpha, sep_sets, change_flag, bk, verbose) + return change_flag + + +def visibleEdgeHelperVisit(graph: Graph, node_c: Node, node_a: Node, node_b: Node, path: List[Node]) -> bool: + if path.__contains__(node_a): + return False + + path.append(node_a) + + if node_a == node_b: + return True + + for node_D in graph.get_nodes_into(node_a, Endpoint.ARROW): + if graph.is_parent_of(node_D, node_c): + return True + + if not graph.is_def_collider(node_D, node_c, node_a): + continue + elif not graph.is_parent_of(node_c, node_b): + continue + + if visibleEdgeHelperVisit(graph, node_D, node_c, node_b, path): + return True + + path.pop() + return False + + +def visibleEdgeHelper(node_A: Node, node_B: Node, graph: Graph) -> bool: + path = [node_A] + + for node_C in graph.get_nodes_into(node_A, Endpoint.ARROW): + if graph.is_parent_of(node_C, node_A): + return True + + if visibleEdgeHelperVisit(graph, node_C, node_A, node_B, path): + return True + + return False + + +def defVisible(edge: Edge, graph: Graph) -> bool: + if graph.contains_edge(edge): + if edge.get_endpoint1() == Endpoint.TAIL: + node_A = edge.get_node1() + node_B = edge.get_node2() + else: + node_A = edge.get_node2() + node_B = edge.get_node1() + + for node_C in graph.get_adjacent_nodes(node_A): + if node_C != node_B and not graph.is_adjacent_to(node_C, node_B): + e = graph.get_edge(node_C, node_A) + + if e.get_proximal_endpoint(node_A) == Endpoint.ARROW: + return True + + return visibleEdgeHelper(node_A, node_B, graph) + else: + raise Exception("Given edge is not in the graph.") + + +def get_color_edges(graph: Graph) -> List[Edge]: + edges = graph.get_graph_edges() + for edge in edges: + if (edge.get_endpoint1() == Endpoint.TAIL and edge.get_endpoint2() == Endpoint.ARROW) or \ + (edge.get_endpoint1() == Endpoint.ARROW and edge.get_endpoint2() == Endpoint.TAIL): + if edge.get_endpoint1() == Endpoint.TAIL: + node_x = edge.get_node1() + node_y = edge.get_node2() + else: + node_x = edge.get_node2() + node_y = edge.get_node1() + + graph.remove_edge(edge) + + if not existsSemiDirectedPath(node_x, node_y, graph): + edge.properties.append(Edge.Property.dd) # green + else: + edge.properties.append(Edge.Property.pd) + + graph.add_edge(edge) + + if defVisible(edge, graph): + edge.properties.append(Edge.Property.nl) # bold + print(edge) + else: + edge.properties.append(Edge.Property.pl) + return edges + + +def removeByPossibleDsep(graph: Graph, independence_test_method: CIT, alpha: float, + sep_sets: Dict[Tuple[int, int], Set[int]]): + def _contains_all(set_a: Set[Node], set_b: Set[Node]): + for node_b in set_b: + if not set_a.__contains__(node_b): + return False + return True + + edges = graph.get_graph_edges() + for edge in edges: + node_a = edge.get_node1() + node_b = edge.get_node2() + + possibleDsep = getPossibleDsep(node_a, node_b, graph, -1) + gen = DepthChoiceGenerator(len(possibleDsep), len(possibleDsep)) + + choice = gen.next() + while choice is not None: + origin_choice = choice + choice = gen.next() + if len(origin_choice) < 2: + continue + sepset = tuple([possibleDsep[index] for index in origin_choice]) + if _contains_all(set(graph.get_adjacent_nodes(node_a)), set(sepset)): + continue + if _contains_all(set(graph.get_adjacent_nodes(node_b)), set(sepset)): + continue + X, Y = graph.get_node_map()[node_a], graph.get_node_map()[node_b] + condSet_index = tuple([graph.get_node_map()[possibleDsep[index]] for index in origin_choice]) + p_value = independence_test_method(X, Y, condSet_index) + independent = p_value > alpha + if independent: + graph.remove_edge(edge) + sep_sets[(X, Y)] = set(condSet_index) + break + + if graph.contains_edge(edge): + possibleDsep = getPossibleDsep(node_b, node_a, graph, -1) + gen = DepthChoiceGenerator(len(possibleDsep), len(possibleDsep)) + + choice = gen.next() + while choice is not None: + origin_choice = choice + choice = gen.next() + if len(origin_choice) < 2: + continue + sepset = tuple([possibleDsep[index] for index in origin_choice]) + if _contains_all(set(graph.get_adjacent_nodes(node_a)), set(sepset)): + continue + if _contains_all(set(graph.get_adjacent_nodes(node_b)), set(sepset)): + continue + X, Y = graph.get_node_map()[node_a], graph.get_node_map()[node_b] + condSet_index = tuple([graph.get_node_map()[possibleDsep[index]] for index in origin_choice]) + p_value = independence_test_method(X, Y, condSet_index) + independent = p_value > alpha + if independent: + graph.remove_edge(edge) + sep_sets[(X, Y)] = set(condSet_index) + break + +def fci_true_cov_rank(fake_data: ndarray, independence_test_method, alpha: float = 0.05, depth: int = -1, + max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, show_progress: bool = True, + **kwargs) -> Tuple[Graph, List[Edge]]: + """ + Perform Fast Causal Inference (FCI) algorithm for causal discovery + + Parameters + ---------- + dataset: data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of + samples and n_features is the number of features. + independence_test_method: str, name of the function of the independence test being used + [fisherz, chisq, gsq, kci] + - fisherz: Fisher's Z conditional independence test + - chisq: Chi-squared conditional independence test + - gsq: G-squared conditional independence test + - kci: Kernel-based conditional independence test + alpha: float, desired significance level of independence tests (p_value) in (0,1) + depth: The depth for the fast adjacency search, or -1 if unlimited + max_path_length: the maximum length of any discriminating path, or -1 if unlimited. + verbose: True is verbose output should be printed or logged + background_knowledge: background knowledge + + Returns + ------- + graph : a GeneralGraph object, where graph.graph[j,i]=1 and graph.graph[i,j]=-1 indicates i --> j , + graph.graph[i,j] = graph.graph[j,i] = -1 indicates i --- j, + graph.graph[i,j] = graph.graph[j,i] = 1 indicates i <-> j, + graph.graph[j,i]=1 and graph.graph[i,j]=2 indicates i o-> j. + edges : list + Contains graph's edges properties. + If edge.properties have the Property 'nl', then there is no latent confounder. Otherwise, + there are possibly latent confounders. + If edge.properties have the Property 'dd', then it is definitely direct. Otherwise, + it is possibly direct. + If edge.properties have the Property 'pl', then there are possibly latent confounders. Otherwise, + there is no latent confounder. + If edge.properties have the Property 'pd', then it is possibly direct. Otherwise, + it is definitely direct. + """ + + #if dataset.shape[0] < dataset.shape[1]: + # warnings.warn("The number of features is much larger than the sample size!") + + ## ------- check parameters ------------ + if (depth is None) or type(depth) != int: + raise TypeError("'depth' must be 'int' type!") + if (background_knowledge is not None) and type(background_knowledge) != BackgroundKnowledge: + raise TypeError("'background_knowledge' must be 'BackgroundKnowledge' type!") + if type(max_path_length) != int: + raise TypeError("'max_path_length' must be 'int' type!") + ## ------- end check parameters ------------ + + + nodes = [] + for i in range(fake_data.shape[1]): + node = GraphNode(f"X{i + 1}") + node.add_attribute("id", i) + nodes.append(node) + + # FAS (“Fast Adjacency Search”) is the adjacency search of the PC algorithm, used as a first step for the FCI algorithm. + graph, sep_sets, test_results = fas(fake_data, nodes, independence_test_method=independence_test_method, alpha=alpha, + knowledge=background_knowledge, depth=depth, verbose=verbose, show_progress=show_progress) + + reorientAllWith(graph, Endpoint.CIRCLE) + + rule0(graph, nodes, sep_sets, background_knowledge, verbose) + + removeByPossibleDsep(graph, independence_test_method, alpha, sep_sets) + + reorientAllWith(graph, Endpoint.CIRCLE) + rule0(graph, nodes, sep_sets, background_knowledge, verbose) + + change_flag = True + first_time = True + + while change_flag: + change_flag = False + change_flag = rulesR1R2cycle(graph, background_knowledge, change_flag, verbose) + change_flag = ruleR3(graph, sep_sets, background_knowledge, change_flag, verbose) + + if change_flag or (first_time and background_knowledge is not None and + len(background_knowledge.forbidden_rules_specs) > 0 and + len(background_knowledge.required_rules_specs) > 0 and + len(background_knowledge.tier_map.keys()) > 0): + change_flag = ruleR4B(graph, max_path_length, fake_data, independence_test_method, alpha, sep_sets, + change_flag, + background_knowledge, verbose) + + first_time = False + + if verbose: + print("Epoch") + + graph.set_pag(True) + + edges = get_color_edges(graph) + + return graph, edges + + +def fci_cov_rank(data: ndarray, alpha: float = 0.05, rescale_rank_test=1, depth: int = -1, + max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, show_progress: bool = True, + **kwargs) -> Tuple[Graph, List[Edge]]: + """ + Perform Fast Causal Inference (FCI) algorithm for causal discovery + + Parameters + ---------- + dataset: data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of + samples and n_features is the number of features. + independence_test_method: str, name of the function of the independence test being used + [fisherz, chisq, gsq, kci] + - fisherz: Fisher's Z conditional independence test + - chisq: Chi-squared conditional independence test + - gsq: G-squared conditional independence test + - kci: Kernel-based conditional independence test + alpha: float, desired significance level of independence tests (p_value) in (0,1) + depth: The depth for the fast adjacency search, or -1 if unlimited + max_path_length: the maximum length of any discriminating path, or -1 if unlimited. + verbose: True is verbose output should be printed or logged + background_knowledge: background knowledge + + Returns + ------- + graph : a GeneralGraph object, where graph.graph[j,i]=1 and graph.graph[i,j]=-1 indicates i --> j , + graph.graph[i,j] = graph.graph[j,i] = -1 indicates i --- j, + graph.graph[i,j] = graph.graph[j,i] = 1 indicates i <-> j, + graph.graph[j,i]=1 and graph.graph[i,j]=2 indicates i o-> j. + edges : list + Contains graph's edges properties. + If edge.properties have the Property 'nl', then there is no latent confounder. Otherwise, + there are possibly latent confounders. + If edge.properties have the Property 'dd', then it is definitely direct. Otherwise, + it is possibly direct. + If edge.properties have the Property 'pl', then there are possibly latent confounders. Otherwise, + there is no latent confounder. + If edge.properties have the Property 'pd', then it is possibly direct. Otherwise, + it is definitely direct. + """ + + #if dataset.shape[0] < dataset.shape[1]: + # warnings.warn("The number of features is much larger than the sample size!") + + from .PC_CovRank import CovRank + independence_test_method = CovRank(data, alpha, rescale_rank_test, **kwargs) + + ## ------- check parameters ------------ + if (depth is None) or type(depth) != int: + raise TypeError("'depth' must be 'int' type!") + if (background_knowledge is not None) and type(background_knowledge) != BackgroundKnowledge: + raise TypeError("'background_knowledge' must be 'BackgroundKnowledge' type!") + if type(max_path_length) != int: + raise TypeError("'max_path_length' must be 'int' type!") + ## ------- end check parameters ------------ + + + nodes = [] + for i in range(data.shape[1]): + node = GraphNode(f"X{i + 1}") + node.add_attribute("id", i) + nodes.append(node) + + # FAS (“Fast Adjacency Search”) is the adjacency search of the PC algorithm, used as a first step for the FCI algorithm. + graph, sep_sets, test_results = fas(data, nodes, independence_test_method=independence_test_method, alpha=alpha, + knowledge=background_knowledge, depth=depth, verbose=verbose, show_progress=show_progress) + + reorientAllWith(graph, Endpoint.CIRCLE) + + rule0(graph, nodes, sep_sets, background_knowledge, verbose) + + removeByPossibleDsep(graph, independence_test_method, alpha, sep_sets) + + reorientAllWith(graph, Endpoint.CIRCLE) + rule0(graph, nodes, sep_sets, background_knowledge, verbose) + + change_flag = True + first_time = True + + while change_flag: + change_flag = False + change_flag = rulesR1R2cycle(graph, background_knowledge, change_flag, verbose) + change_flag = ruleR3(graph, sep_sets, background_knowledge, change_flag, verbose) + + if change_flag or (first_time and background_knowledge is not None and + len(background_knowledge.forbidden_rules_specs) > 0 and + len(background_knowledge.required_rules_specs) > 0 and + len(background_knowledge.tier_map.keys()) > 0): + change_flag = ruleR4B(graph, max_path_length, data, independence_test_method, alpha, sep_sets, + change_flag, + background_knowledge, verbose) + + first_time = False + + if verbose: + print("Epoch") + + graph.set_pag(True) + + edges = get_color_edges(graph) + + return graph, edges \ No newline at end of file diff --git a/causallearn/search/HiddenCausal/RLCD/GraphDrawer.py b/causallearn/search/HiddenCausal/RLCD/GraphDrawer.py new file mode 100644 index 00000000..b45eb343 --- /dev/null +++ b/causallearn/search/HiddenCausal/RLCD/GraphDrawer.py @@ -0,0 +1,181 @@ +import pydot +import json + + +PARAMS = { + "default_node_colour": "black", + "refined_node_colour": "red", + "default_edge_colour": "black", + "directed_edge_colour": "black", +} +#with open("params.json") as f: +# PARAMS.update(json.load(f)) + + +class DotGraph: + """ + A class used to construct a directed/undirected graph and parses the graph + into a .dot file for pydot plotting. + """ + + def __init__( + self, + default_node_colour: str = PARAMS["default_node_colour"], + refined_node_colour: str = PARAMS["refined_node_colour"], + default_edge_colour: str = PARAMS["default_edge_colour"], + directed_edge_colour: str = PARAMS["directed_edge_colour"], + ): + self.default_node_colour = default_node_colour + self.refined_node_colour = refined_node_colour + self.default_edge_colour = default_edge_colour + self.directed_edge_colour = directed_edge_colour + self.nodes = set() + self.dirEdges = set() + self.undirEdges = set() + self.nodecolor = {} + + def addNode(self, V, refined=False): + self.nodes.add(V) + if refined: + self.nodecolor[V] = self.refined_node_colour + else: + self.nodecolor[V] = self.default_node_colour + + def addNodeByColor(self, V, Color): + self.nodes.add(V) + self.nodecolor[V] = Color + + # 0 for undirected, 1 for directed + def addEdge(self, u, v, type=0): + if type == 0: + self.undirEdges.add(frozenset([u, v])) + + elif type == 1: + self.dirEdges.add((u, v)) + + def edges(self, V, type=0): + edgelist = [] + if type == 0: + for edge in self.undirEdges: + if V in edge: + edgelist.append(edge) + + if type == 1: + for edge in self.dirEdges: + if V in edge: + edgelist.append(edge) + return edgelist + + def removeUndirEdgesFromNode(self, V): + edgesToRemove = set() + for edgeSet in self.undirEdges: + if V in edgeSet: + edgesToRemove.add(edgeSet) + self.undirEdges = self.undirEdges - edgesToRemove + + # def toDot(self, outpath: str): + # # TODO: Improve plotting for phase III with undirected edges + # text = "digraph {\n" + + # # Add nodes + # for node in self.nodes: + # text += f"{node} [color = {self.nodecolor[node]}]; " + # text += "\n" + + # # Add undirected edges + # text += "subgraph Undirected {\n" + # text += f"edge [dir=none, color={self.default_edge_colour}]\n" + # for edgeSet in self.undirEdges: + # edgeSet = list(edgeSet) + # text += f"{edgeSet[0]} -> {edgeSet[1]}\n" + + # text += "}\n\n" + + # # Add directed Edges + # text += "subgraph Directed {\n" + # text += f"edge [color={self.directed_edge_colour}]\n" + # for edgeSet in self.dirEdges: + # edgeSet = list(edgeSet) + # text += f"{edgeSet[0]} -> {edgeSet[1]}\n" + + # text += "}\n\n" + # text += "}\n" + # with open(outpath, "w") as f: + # f.write(text) + + def toDot(self, outpath: str): + # UTF-8 support + text = "digraph {\n" + text += 'charset="UTF-8";\n' + + # Add nodes - ensure node names are quoted to handle special characters + for node in self.nodes: + text += f'"{node}" [color = {self.nodecolor[node]}];\n' + text += "\n" + + # Add undirected edges + text += "subgraph Undirected {\n" + text += f"edge [dir=none, color={self.default_edge_colour}]\n" + for edgeSet in self.undirEdges: + edgeSet = list(edgeSet) + text += f'"{edgeSet[0]}" -> "{edgeSet[1]}"\n' + + text += "}\n\n" + + # Add directed Edges + text += "subgraph Directed {\n" + text += f"edge [color={self.directed_edge_colour}]\n" + for edgeSet in self.dirEdges: + edgeSet = list(edgeSet) + text += f'"{edgeSet[0]}" -> "{edgeSet[1]}"\n' + + text += "}\n\n" + text += "}\n" + + # 指定 UTF-8 编码写入文件 + with open(outpath, "w", encoding="utf-8") as f: + f.write(text) + + +def printGraph(O: object, outpath="plots/test.png", layout="dot", res=100): + """ + Function to plot a graph object using pydot from various types of graph + objects. + """ + if isinstance(O, DotGraph): + dotGraph = O + else: + try: + dotGraph = O.getDotGraph() + except: + raise TypeError(f"{O} is of type {type(O)}, not supported.") + + dot_path = outpath.replace(".png", ".dot") + + dotGraph.toDot(dot_path) + graphs = pydot.graph_from_dot_file(dot_path) + graphs[0].set_size(f'"{res},{res}!"') + graphs[0].set_layout(layout) + graphs[0].write_png(outpath) + + +def AdjToGraph(Adj, varnames_for_Adj): + + dotGraph = DotGraph() + for x in varnames_for_Adj: + if x.startswith("L"): + dotGraph.addNodeByColor(x, 'red') + else: + dotGraph.addNodeByColor(x, 'blue') + + for i in range(len(Adj)): + for j in range(i+1, len(Adj)): + + if (Adj[i,j]==1 and Adj[j,i]==1) or Adj[i,j]==-1 and Adj[j,i]==-1: + dotGraph.addEdge(varnames_for_Adj[i], varnames_for_Adj[j]) + elif Adj[i,j]==-1 and Adj[j,i]==1: + dotGraph.addEdge(varnames_for_Adj[i], varnames_for_Adj[j], type=1) + elif Adj[j,i]==-1 and Adj[i,j]==1: + dotGraph.addEdge(varnames_for_Adj[j], varnames_for_Adj[i], type=1) + + return dotGraph \ No newline at end of file diff --git a/causallearn/search/HiddenCausal/RLCD/LatentGroups.py b/causallearn/search/HiddenCausal/RLCD/LatentGroups.py new file mode 100644 index 00000000..3d0e4dca --- /dev/null +++ b/causallearn/search/HiddenCausal/RLCD/LatentGroups.py @@ -0,0 +1,1414 @@ +from __future__ import annotations + +import pickle +import random +from collections import deque +from copy import deepcopy +from pdb import set_trace +import networkx as nx +import math +from . import misc as M +from .GraphDrawer import DotGraph +from .Cover import Cover, setLength, getVars, setDifference, setIntersection, deduplicate +from .logger import LOGGER +import numpy as np + +# Class to store discovered latent groups +class LatentGroups: + def __init__(self, X, Xns, all_nb_set, nb_set_dict, local_Adj):# X_ns is a list of potential observed non sink variables + self.i = 1 + self.X = set([Cover(x, is_observed=True) for x in X]) + self.activeSet = set([Cover(x, is_observed=True) for x in X]) + self.ChildrenOfNonAtomicsSet=set() + self.latentDict = {} + self.rankDefSets = {} + self.clusters = {} + self.nonAtomics = [] + self.Xns = set([x for x in Xns]) + self.activeNonSinkSet = set([x for x in Xns]) + self.X_dict = {x.takeOne():x for x in self.X} + + self.X_names = X + + self.nb_set_dict = nb_set_dict + self.all_nb_set = all_nb_set + + self.local_Adj = local_Adj + self.x_list_for_local_Adj = X + + def update_X_dict(self):# need to update whenever X changes + self.X_dict = {x.takeOne():x for x in self.X} + + def get_observed_cover_by_str(self, str): + if str in self.X_dict: + return self.X_dict[str] + else: + return None + + def addRankDefSet(self, Vs, k=1, used_nonsinks=[]): + """ + Save a rankDefSet of variables Vs, for merging into clusters later. + """ + if not k in self.rankDefSets: + self.rankDefSets[k] = [] + self.rankDefSets[k].append([frozenset(Vs), set(used_nonsinks)]) + + def determineClusters(self): + """ + From the saved rankDefSets, merge any pair of sets with a common + element to derive the clusters. + """ + k = min(list(self.rankDefSets)) + clusters = self.rankDefSets.pop(k) + clusters_um = clusters.copy() + + n = len(clusters) + + for i in range(len(clusters)): + clusters[i].append([clusters[i][0]]) + + while True: + i = 0 + j = 1 + while j < len(clusters): + set1 = clusters[i][0] + set1_nonsinks = clusters[i][1] + + set2 = clusters[j][0] + set2_nonsinks = clusters[j][1] + + # Merge overlapping sets + if len(setIntersection(set1, set2)) >= min(len(set1), len(set2))-1 and set1_nonsinks==set2_nonsinks: + Vs = set1 | set2 + clusters[i][0] = Vs + #clusters[i][1] = set1_nonsinks | set2_nonsinks + clusters[i][1] = set1_nonsinks + + clusters[i][2] = clusters[i][2] + clusters[j][2] + + clusters.pop(j) + + if j >= len(clusters) - 1: + i += 1 + j = i + 1 + else: + j += 1 + + if n == len(clusters): + break + else: + n = len(clusters) + + self.clusters[k] = clusters + + # Other rankDefSets of higher cardinality are discarded + self.rankDefSets = {} + + def confirmClusters(self): + """ + Given the clusters that we have determined, add each cluster of Vs + as children of new latent Covers. + Returns: + success: boolean - whether any new cluster was successfully added + """ + k = min(list(self.clusters)) + success = False + clusters = self.clusters.pop(k) + for Vs, used_nonsinks, fullVs in clusters: + current_success = self.addCluster(Vs, fullVs, k, list(used_nonsinks)) + + if current_success: + + temp_X_ls = list(self.X) + for i, x in enumerate(temp_X_ls): + if len(x.vars.intersection(set(used_nonsinks)))>0: + if x.is_leaf is None: + temp_X_ls[i].is_leaf = False + self.Xns = self.Xns - x.vars + + for V in Vs: + if len(V.vars)==1: + for i, x in enumerate(temp_X_ls): + if len(x.vars.intersection(V.vars))>0: + if x.is_leaf is None: + temp_X_ls[i].is_leaf = True + self.Xns = self.Xns - x.vars + + self.X = set(temp_X_ls) + self.update_X_dict() + + LOGGER.info(f"Current Xns {self.Xns}") + #LOGGER.info(f"Current X set {self.X}") + + success = success or current_success + + self.clusters = {} + return success + + def splitfullVs(self, Vs, fullVs): + + local_Adj = self.local_Adj + x_list_for_local_Adj = self.x_list_for_local_Adj + local_Adj = np.abs(local_Adj+np.identity(local_Adj.shape[0])) + + Vmeasures = self.pickAllMeasures(Vs) + + Vs1 = set() + Vs2 = set() + + Vs_idx_list = [] + for V in Vmeasures: + Vs_idx_list.append(x_list_for_local_Adj.index(list(V.vars)[0])) + + for V in Vmeasures: + idx = x_list_for_local_Adj.index(list(V.vars)[0]) + if (local_Adj[idx].T[Vs_idx_list].T==1).all(): + Vs1.add(V) + else: + Vs2.add(V) + + LOGGER.debug("split Vs result: Vs=%s, Vs1=%s, Vs2=%s", Vs, Vs1, Vs2) + + return Vs1, Vs2 + + def addCluster(self, Vs, fullVs, k, used_nonsinks=[]): + """ + For a discovered cluster Vs, create a new latent Cover over it. + Returns: a boolean indicating whether the addition of a new latent + relationship was successful or not (i.e. contradiction) + """ + + Vs1, Vs2 = self.splitfullVs(Vs, fullVs) + + #if setLength(Vs1) < (k-len(used_nonsinks)+1): + # return False + + # check cycle + for used_nonsink in used_nonsinks: + set1 = set() + set2 = set() + for x in self.findParents(self.get_observed_cover_by_str(used_nonsink)): + set1=set1|x.vars + for x in Vs: + set2=set2|x.vars + + if len(set1.intersection(set2))!=0: + LOGGER.info( + f"Rejecting {Vs} as a cluster because {used_nonsink} is a child of {Vs}") + return False + + parents = self.findParents(Vs) + parentsSize = setLength(parents) + gap = k - parentsSize + LOGGER.info(f"Trying to add to Dict {Vs} with k={k}") + + # If gap < 0, it means that there is a contradiction, i.e. the + # current parents of Vs has higher cardinality than the actual rank of + # testing these Vs together. + # However, we can just ignore this set of Vs in this case, and + # hopefully refineClusters will correct the error later on. + if gap < 0: + LOGGER.info( + f"Rejecting {Vs} as a cluster because it is rank {k}" + f" but has parents of cardinality {parentsSize}" + ) + return False + + # decide elements in the new Cover -> newCover_ls + parents_str_set = set() + for parent in parents: + for V in parent.vars: + parents_str_set.add(V) + additional_used_nonsinks = list(set(used_nonsinks) - parents_str_set) + newCover_ls = list(parents_str_set) + if gap {{L1, L3}} + Vs = deduplicate(Vs) + + success = self.addOrUpdateCover(L=newCover, children=Vs, fake_children=Vs2) + return success + + def findParents(self, Vs: set[Cover] | Cover, atomic=False, non_atomic=False): + """ + Find parents of Vs. Returns empty set if no parents found. + + Args: + atomic: Whether to take only atomic parents. + non_atomic: Whether to take only non-atomic parents. + """ + assert not ( + atomic and non_atomic + ), "Can only specify atomic or non_atomic, not both." + parents = set() + if isinstance(Vs, Cover): + Vs = set([Vs]) + + for parent, values in self.latentDict.items(): + if atomic and not parent.isAtomic: + continue + if non_atomic and parent.isAtomic: + continue + for V in Vs: + if V in values["children"]: + parents.add(parent) + parents = deduplicate(parents) + + if non_atomic and len(Vs) == 1: + assert ( + len(parents) <= 1 + ), f"{next(iter(Vs))} should not have more than one non-atomic parent." + return parents + + def findAtomicParent(self, L): + """ + Find the atomic parent of L. + + Raises an error if more than one atomic parent is found, which should + not be the case. + Returns None is L is root. + """ + Ps = self.findParents(L, atomic=True) + assert len(Ps) <= 1, "Nodes cannot have more than 1 atomic parent." + P = next(iter(Ps)) if len(Ps) == 1 else None + return P + + # Check if an AtomicGroup L has observed children + def hasObservedChildren(self, L): + for child in self.latentDict[L]["children"]: + if not child.isLatent: + return True + return False + + def updateactiveNonSinkSet(self): + self.activeNonSinkSet=self.Xns.copy() + #for cover in self.activeSet: + # if len(cover.vars)==1 and cover.takeOne() in self.Xns: + # self.activeNonSinkSet.add(cover.takeOne()) + + def updateActiveSet(self, if_for_finish=False): + """ + Refresh the activeSet after new Covers are added. + """ + self.activeSet = set() + self.ChildrenOfNonAtomicsSet = set() + # Add all measures to the activeSet + for X in self.X: + self.activeSet.add(X) + + # Add all atomic Covers to the activeSet + # Non-atomic covers are never added, since the atomic covers within + # would already be in. + for P in self.latentDict.keys(): + if if_for_finish: + self.activeSet.add(P) + else:# normal mode + if P.isAtomic: + self.activeSet.add(P) + + # Remove variables that are children of Covers from activeSet + for P, val in self.latentDict.items(): + if P.isAtomic: + self.activeSet = setDifference(self.activeSet, val["children"]) + else: + self.activeSet = setDifference(self.activeSet, val["children"]) + self.ChildrenOfNonAtomicsSet |= val["children"] + + self.activeSet = deduplicate(self.activeSet) + LOGGER.info(f"Active Set (if_for_finish:{if_for_finish}): {self.activeSet}") + return + + + def removeCover(self, L: Cover): + """ + Remove an atomic Cover from the latentDict and activeSet. + activeSet will be updated at the end to include Children of the Cover. + """ + #assert not L.is_observed, "Can only remove latent Cover." + #assert L.isAtomic, "Can only remove atomic Cover." + + # Get the atomicSuperCover of L and remove it + # e.g. if L=L1 and {L1, L2} is also atomic, we must remove {L1, L2}. + L = self.findAtomicSuperCover(L) + + # Remove all subsets of L which are also AtomicGroups + # e.g. {L1, L2} is atomic, and L1 is atomic, so remove both {L1, L2} + # and L1 from latentDict + subsets = self.subsets(L) + for subset in subsets: + self.latentDict.pop(subset) + + for k in self.latentDict.keys(): + self.latentDict[k]["subcovers"] -= subsets + self.latentDict[k]["children"] -= subsets + + # L may be a subset of a non-atomic Cover. + # For this case, we only need to remove the non-atomic Cover, but the + # other atomic Covers within can remain. + # E.g. if L=L1 and L2 is also atomic, such that {L1, L2} is non-atomic. + # then we only remove {L1, L2} as a Cover but allow L2 to remain. + nonAtomics, latentDict = self.findNonAtomics(L) + self.latentDict = latentDict + + def findNonAtomics(self, L): + """ + Find all nonAtomics associated with L. + """ + latentDict = {} + nonAtomics = {} + for Lp, value in reversed(tuple(self.latentDict.items())): + #for Lp, value in reversed(self.latentDict.items()): + if not Lp.isAtomic: + if L.vars < Lp.vars: + nonAtomics[Lp] = value + continue + latentDict[Lp] = value + return nonAtomics, latentDict + + def dissolveNode(self, L): + """ + Dissolve a latent cover L by: + 1. Making it root + 2. Remove it and L's parent, and add their respective children (in the + graph where L is root) into the activeSet + """ + assert isinstance(L, Cover), f"{L} must be Cover" + assert ( + len(self.activeSet) == 1 + ), f"activeSet is {self.activeSet} but should only have root variable." + + P = self.findAtomicParent(L) + if P is None: + # If L root, make another refined node root before continuing + for V in self.latentDict: + if V.isAtomic and self.isRefined(V): + LOGGER.info(f"{L} is root, making {V} root instead..") + self.makeRoot(V) + #printGraph(self) + P = self.findAtomicParent(L) + break + assert ( + P is not None + ), f"Trying to refine root {L} but no other variable available to set as root." + + # If L is an atomic cover which is a subcover of another atomicCover + # We should dissolve the larger one instead. + L = self.findAtomicSuperCover(L) + + LOGGER.info(f"dissolveNode {L}...") + + # Remove L and parent + #printGraph(self) + LOGGER.info(f"Finding non-atomics for {L}..") + self.nonAtomics.extend(self.logNonAtomics(L)) + LOGGER.info(f"Finding non-atomics for {P}..") + self.nonAtomics.extend(self.logNonAtomics(P)) + self.makeRoot(L) + self.removeCover(L) + self.removeCover(P) + self.updateActiveSet() + #M.display(self) + #printGraph(self) + return True + + # Get all AtomicGroups in a non-AtomicGroup + def getAtomicsFromGroup(self, Ls): + assert not Ls.isMinimal(), "Ls must not be minimal" + groups = set() + for subcover in self.latentDict[Ls]["subcovers"]: + if subcover.vars <= Ls.vars: + groups.add(subcover) + return groups + + # Make a new connection between parent and child + def connectNodes(self, parents, children): + for parent in parents: + #assert not parent.isLatent, "Parent must be latent" + self.addOrUpdateCover(parent, children) + + # Disconnect all linkages between parent and children + def disconnectNodes(self, parents, children, bidirectional=False): + for parent in parents: + self.latentDict[parent]["children"] -= children + if bidirectional: + # Remove edges in the other direction as well + for child in children: + self.latentDict[child]["children"] -= parents + + # Check if a latent has already been refined + def isRefined(self, L): + return self.latentDict[L].get("refined", False) + + # Reduce a list of variable sets by merging them into + # the minimal set of non-overlapping variable sets + def mergeList(self, Vlist): + out = [] + mergeSuccess = False + + while len(Vlist) > 0: + first = Vlist.pop() + newVlist = [] + for i, Vs in enumerate(Vlist): + commonVs = Vs.intersection(first) + commonVs = [V for V in commonVs if not self.inLatentDict(V)] + if len(commonVs) > 0: + mergeSuccess = True + first |= Vs + else: + newVlist.append(Vs) + Vlist = newVlist + out.append(first) + + if not mergeSuccess: + return out + else: + return self.mergeList(out) + + def containsCluster(self, Vs, nonsinks: list[str]): + """ + Test whether the set of Covers Vs contains any subset such that the + subset contains > k elements from an existing k-cluster. + """ + for L, values in self.latentDict.items(): + #if L.isAtomic: + k = len(L) + children = self.findChildren(L) + + if len(setIntersection(Vs, children)) + len(set(nonsinks).intersection(L.vars))> k: + return True + return False + + def containsonlyaCluster(self, Vs, nonsinks: list[str]): + """ + Test whether the set of Covers Vs contains any subset such that the + subset contains > k elements from an existing k-cluster. + """ + for L, values in self.latentDict.items(): + #if L.isAtomic: + k = len(L) + children = self.findChildren(L) + + if len(setIntersection(Vs, children))+len(setIntersection(Vs, {L}))==setLength(Vs): + # all Vs are in children or L + #if len(setIntersection(Vs, children))+len(setIntersection(Vs, {L}))==setLength(Vs) and len(set(nonsinks).intersection(L.vars))==len(nonsinks): + # all Vs are in children or L and all nonsinks are in L + return True + return False + + def checkNonSinksAreAsChildren(self, Vs, nonsinks: list[str]): + """ + Test whether the set of Covers Vs contains any subset such that the + subset contains > k elements from an existing k-cluster. + """ + for L, values in self.latentDict.items(): + if len(setIntersection(Vs, {L}))>0: + children = self.findChildren(L) + children_str_set = set() + for ch_cover in children: + children_str_set|=ch_cover.vars + + if len(set(nonsinks).intersection(children_str_set))>0: + return True + + return False + + def checkAsAllAdjacent(self, As): + + local_Adj = self.local_Adj + x_list_for_local_Adj = self.x_list_for_local_Adj + + local_Adj = np.abs(local_Adj+np.identity(local_Adj.shape[0])) + + Ameasures = self.pickAllMeasures(As) + idxlist = [] + for A in Ameasures: + for x in A.vars: + idxlist.append(x_list_for_local_Adj.index(x)) + + if ((local_Adj[idxlist].T[idxlist].T)==1).all(): + return True + else: + return False + + def MeassuredHasNonSinks(self, As, nonsinks): + for temp1 in self.pickAllMeasures(As): + temp2 = temp1.vars + if len(temp2.intersection(set(nonsinks)))>0: + return True + return False + + def overlapPaCh(self, Vs: set[Cover]): + + # if any V1 in Vs has intersection with any parent of V2 in Vs then return True. + + for V in Vs: + PaV = self.findParents(V) + temp = Vs.copy() + temp.discard(V) + if len(setIntersection(temp, PaV)) > 0: + return True + + return False + + def parentCardinality(self, Vs): + """ + To compute the cardinality of Vs after we replace any cluster within Vs + by their latent parents. + Requires a recursive call as Vs may contain nested clusters. + """ + Vs = deepcopy(Vs) + k1 = setLength(Vs) + for L, _ in self.latentDict.items(): + if L.isAtomic: + k = len(L) + children = self.findChildren(L) + if len(setIntersection(Vs, children)) + len(getVars(Vs).intersection(L.vars))> k: + Vs -= children + Vs -= {L} + for str in getVars(Vs).intersection(L.vars): + if str in self.X_dict: + Vs.discard(self.X_dict[str]) + Vs.add(L) + k2 = setLength(Vs) + if k2 < k1: + return self.parentCardinality(Vs) + else: + return k1 + + # Check if a variable V already belongs to an AtomicGroup + def inLatentDict(self, V): + for _, values in self.latentDict.items(): + if V in values["children"]: + return True + return False + + # Given child and parent, reverse their parentage direction + # i.e. make child the parent instead + def reverseParentage(self, child, parent): + #assert not parent.is_observed, "Parent is not latent" + #assert not child.is_observed, "Child is not latent" + # print(f"Reversing parentage! Parent:{parent} Child:{child}") + + # Remove child as a child of parent + self.latentDict[parent]["children"] -= set([child]) + + # Add parent as a child of child + self.latentDict[child]["children"].add(parent) + + # Recursive function for use in makeRoot + def makeRootRecursive(self, Ls, G=None): + + # Make a copy of self, to modify + if G is None: + G = deepcopy(self) + + # Parents of L + # Note: We are finding parents of L based on `self`, not the modified + # graph G that is passed around. This is so that we don't end up + # in an infinite loop making L -> P then P -> L forever. + parents = set() + for L in Ls: + parents.update(self.findParents(L, atomic=True)) + + # If no parents, L is root. Do nothing. + if len(parents) == 0: + return G + + # Reverse Direction to parents + for parent in parents: + for L in Ls: + G.reverseParentage(L, parent) + G = self.makeRootRecursive(parents, G) + return G + + def makeRoot(self, L: Cover): + """ + Re-orient latentDict such that L becomes the root node of the graph. + + Note that this procedure does not affect non-atomic Covers. + """ + assert isinstance(L, Cover), f"{L} must be a Cover." + G = self.makeRootRecursive(set([L])) + self.latentDict = G.latentDict + self.activeSet = set([L]) + + # Find all Groups that are a superset of L + # Including L itself + def supersets(self, L): + groups = set() + for group in self.latentDict: + if group.vars >= L.vars: + groups.add(group) + return groups + + # Find the largest superset for L + def supersetLargest(self, L): + k = len(L) + largest = None + for group in self.latentDict: + if group.vars >= L.vars and len(group) >= k: + largest = group + k = len(group) + return largest + + # Find all Groups that are a subset of L + # Including L itself + def subsets(self, L): + groups = set() + for group in self.latentDict: + if group.vars <= L.vars: + groups.add(group) + return groups + + def findAncestor(self, L: Cover): + ancestor_set = set() + parents_set = self.findParents(L) + for pa in parents_set: + ancestor_set.add(pa) + ancestor_set |= self.findAncestor(pa) + + return ancestor_set + + def findChildrenOfAllSubSets(self, Ls: set): + """ + Recursive search for all immediate children of an atomic Cover + """ + + children = set() + + LsVars = set() + for x in Ls: + LsVars |= x.vars + + for key in self.latentDict.keys(): + if key.vars.issubset(LsVars): + children = children | self.latentDict[key]["children"] + #for subcover in self.latentDict[key]["subcovers"]: + # children = children | self.findChildrenOfAllSubSets(subcover) + + return children + + + def findChildren(self, L: Cover, rigorous=True): + """ + Recursive search for all immediate children of an atomic Cover + """ + # everything is recorded in latentDict + assert L is not None, "Should not look for None." + #assert L.isAtomic, f"{L} should be an atomic Cover." + #assert not L.isLeaf, f"{L} should be NonLeaf." + children = set() + + if rigorous==True: + if L in self.latentDict: + children = children | self.latentDict[L]["children"] + for subcover in self.latentDict[L]["subcovers"]: + children = children | self.findChildren(subcover, rigorous) + + return children + + else: + #if L in self.latentDict: + # for subcover in self.latentDict[L]["subcovers"]: + # children = children | self.findChildren(subcover, rigorous) + + for key in self.latentDict.keys(): + if len(set.intersection(L.vars, key.vars))!=0: + children = children | self.latentDict[key]["children"] + + return children + + def findDescendants(self, L: Cover, rigorous=True, visited=None): + if visited is None: + visited = set() + if L in visited: + return set() + visited.add(L) + + descendants = set() + children = self.findChildren(L, rigorous=rigorous) + descendants |= children + + for ch in children: + descendants |= self.findDescendants(ch, rigorous=rigorous, visited=visited) + + return descendants + + def findMeassuredSubset(self, L: Cover): + + assert L is not None, "Should not look for None." + MeassuredSubset=set() + + for str in L.vars: + if str in self.X_dict: + MeassuredSubset.add(self.X_dict[str]) + return MeassuredSubset + + def findNonAtomicChildren(self, L: Cover): + """ + Find all children of non-Atomic Covers of which L is a subcover. + """ + children = set() + for cover in self.latentDict: + if cover.isAtomic: + continue + if L in self.latentDict[cover]["subcovers"]: + children.update(self.latentDict[cover]["children"]) + return children + + def isRoot(self, L: Cover): + """ + Check if L is root in that it has no parents. + """ + parents = self.findParents(L) + return len(parents) == 0 + + def findRandomLatentChild(self, L: Cover): + children = self.findChildren(L) + latents = [child for child in children if child.isLatent] + child = random.sample(latents, k=1)[0] + return child + + # For a given atomic cover L, find the largest atomicCover of which L is a + # subset. If L is not a subset of any atomicCover, returns L itself. + def findAtomicSuperCover(self, L): + assert isinstance(L, Cover), "L must be a Cover." + largestCover = L + superCoverFound = False + for Lp in self.latentDict: + if (L.vars < Lp.vars) and Lp.isAtomic: + largestCover = Lp + superCoverFound = True + break + if superCoverFound: + return self.findAtomicSuperCover(largestCover) + else: + return L + + def bypassSingleChild(self, L): + """ + Given a latent variable L, this function removes the only child of L + from the graph and connects L to its grandchildren + """ + children = self.findChildren(L) + assert len(children) == 1, "Can only perform if single child." + C = next(iter(children)) + grandchildren = self.findChildren(C) + self.removeCover(C) + self.connectNodes(set([L]), grandchildren) # Connect to grandchild + + def pickAllMeasures(self, Ls): + visitedA, visitedNA, measures = self._pickAllMeasures(Ls) + return measures + + def _pickAllMeasures(self, Ls, visitedA=set(), visitedNA=set(), measures=set()): + """ + Given a set of latent Covers, get all the measured descendants. + This includes descendants of non-atomic covers. + """ + visitedA = visitedA.copy() + visitedNA = visitedNA.copy() + measures = measures.copy() + Q = deque() # FIFO queue for BFS + for L in Ls: + if L.is_observed: + measures.add(L) + else: + for l_str in L.vars: + temp = self.get_observed_cover_by_str(l_str) + if temp is not None: + measures.add(temp) + Q.append(L) + + # BFS amongst atomic descendants of Ls + while len(Q) > 0: + L = Q.popleft() + visitedA.add(L) + + for C in self.findChildren(L): + if C.is_observed: + measures.add(C) + else: + for C_str in C.vars: + temp = self.get_observed_cover_by_str(C_str) + if temp is not None: + measures.add(temp) + Q.append(C) + + # Now check if the visited nodes contain non-atomic covers + # If yes, do DFS on the children of each non-atomic cover + for cover in self.latentDict: + if cover.isAtomic or (cover in visitedNA): + continue + if cover.isSubset(visitedA): + visitedNA.add(cover) + Cs = set() + for C in self.latentDict[cover]["children"]: + if C.is_observed: + measures.add(C) + else: + Cs.add(C) + visitedA, visitedNA, measures = self._pickAllMeasures( + Cs, visitedA, visitedNA, measures + ) + + return visitedA, visitedNA, measures + + def saveLatentGroup(self, path): + with open(path, "wb") as f: + pickle.dump(self, f) + + def addOrUpdateCover(self, L: Cover, children: set[Cover] = set(), fake_children: set[Cover] = set()): + """ + Add a new cover to latentDict with the specified children. + If cover exists, add the specified children to it. + + Handles the logic of adding subcovers, nonAtomic cover etc. + """ + + subcovers = self.findSubcovers(L) + L.atomic = not self.isNonAtomic(L) + if L in self.latentDict: + + if children==self.latentDict[L]["children"] and subcovers==self.latentDict[L]["subcovers"] \ + and fake_children==self.latentDict[L]["fake_children"]: + return False + else: + self.latentDict[L]["children"].update(children) + self.latentDict[L]["subcovers"].update(subcovers) + self.latentDict[L]["fake_children"].update(fake_children) + return True + else: + self.latentDict[L] = { + "children": children, + "subcovers": subcovers, + 'fake_children': fake_children, + "refined": False, + } + return True + + # If L is a rediscovered non-atomic, we might need to override edge(s) + # e.g. if we re-discover P1, P2 -> C, but C -> P1, we remove the latter + if L.atomic: + return + for C, v in self.latentDict.items(): + for P in subcovers: + if (P in v["children"]) and (C in children): + LOGGER.info(f"Removing {P} as a child of {C}..") + self.latentDict[C]["children"].remove(P) + + def introduceTempRoot(self): + """ + Add a temporary root over all active variables such that no rank + deficiency is introduced. + + For n := len(activeSet), we can get rank deficiency of rank k only if + n >= 2k + 2. So the minimal k to have no rank deficiency is n < 2k + 2, + i.e. k > n/2 - 1. + """ + assert len(self.activeSet) > 1, "activeSet should have > 1 Cover." + LOGGER.info(f"Introducing a temporary root over {self.activeSet}..") + n = setLength(self.activeSet) + k = math.ceil(n / 2 - 1 + 0.1) + self.addCluster(Vs=self.activeSet, k=k) + self.updateActiveSet() + assert len(self.activeSet) == 1, "Root should be a single Cover." + tempRoot = next(iter(self.activeSet)) + tempRoot.temp = True + self.addOrUpdateCover(tempRoot) + + def connectVariablesChain(self, Vs: set[Cover]): + """ + Connect variables Vs in a chain structure. + E.g. Vs={L1, L2, L3, L4}, then: + - {L1} -> {L2} + - {L1, L2} -> {L3} + - {L1, L2, L3} -> {L4} + """ + Ls = [V for V in Vs if V.isLatent] + Xs = [X for X in Vs if not X.isLatent] + j = 1 + while j < len(Ls): + Ps = set(Ls[0:j]) # Set all but first cover as Parents + C = Ls[j] + newCover = Cover(getVars(Ps)) + LOGGER.info(f"Setting {newCover} -----> {C}") + self.addOrUpdateCover(newCover, children=set([C])) + j += 1 + + # Set any measured variables as children of all latents + if len(Xs) > 0: + Ps = set(Ls) + Cs = set(Xs) + newCover = Cover(getVars(Ps)) + LOGGER.info(f"Setting {newCover} -----> {Cs}") + self.addOrUpdateCover(newCover, children=Cs) + + def logNonAtomics(self, L): + """ + Find nonAtomics associated with L and store their information, namely + the set of measures that define each atomic cover within the nonAtomic. + This info will be used to re-identify the nonAtomic variables later. + """ + nonAtomics, _ = self.findNonAtomics(L) + infos = [] + for Lp in nonAtomics: + spouses = [] + LOGGER.info(f"Storing info for non-atomic cover {Lp}..") + for C in self.latentDict[Lp]["subcovers"]: + LOGGER.info(f"Storing {C}..") + + # If C is root, we need to set a random child to be root + # before we record C's measures. Otherwise, C will have all + # measures recorded which carries no information. + if self.isRoot(C): + child = self.findRandomLatentChild(C) + Gp = deepcopy(self) + Gp.makeRoot(child) + LOGGER.info(f"{C} is root, setting {child} to be root.") + measures = Gp.pickAllMeasures(set([C])) + else: + measures = self.pickAllMeasures(set([C])) + k = len(C) + spouses.append((k, measures)) + LOGGER.info(f" For {C}: {measures} measures..") + infos.append( + { + "spouses": spouses, + "children": self.latentDict[Lp]["children"], + } + ) + return infos + + def reconnectNonAtomics(self): + """ + Rediscover the nonAtomic Cover(s) found in self.nonAtomics. + """ + + def _rediscover(nonAtomics=[], retryList=[]): + """ + Rediscover the nonAtomic Cover corresponding to the item with + smallest cardinality of measures, and add it back to latentDict. + """ + # Terminate when queue is empty + if len(nonAtomics) == 0: + return retryList + + d, nonAtomics = _findSmallestNonAtomic(nonAtomics) + + # Find the set of covers corresponding to each atomic + spouses, children = d["spouses"], d["children"] + failed = False + discoveredCovers = set() + for (k, measures) in spouses: + U = set([next(iter(x.vars)) for x in measures]) + covers = UG.findMinimalSepSet(U, k) + LOGGER.info(f"Finding Min Sep Set for {U}...") + + # This attempt fails if we fail to find covers for any spouse + if len(covers) == 0: + failed = True + retryList.append(d) + break + + discoveredCovers.update(covers) + LOGGER.info(f" Found {covers} as covers over {U}...") + + # Create the nonAtomic Cover and add it to latentDict + if not failed: + coverVars = set() + for cover in discoveredCovers: + coverVars.update(cover.vars) + nonAtomicCover = Cover(coverVars, atomic=False) + self.addOrUpdateCover(nonAtomicCover, children) + LOGGER.info( + f"Rediscovered {nonAtomicCover} as a non-atomic" + f" parent of {children}.." + ) + + # Continue search with nonAtomics + nonAtomics = _rediscover(nonAtomics, retryList) + return retryList + + def _findSmallestNonAtomic(nonAtomics): + """ + Find the nonAtomicCover with the smallest cardinality of measures. + We should identify the nonAtomicCover for him first because it is + easiest to find. + """ + newlist = [] + lowest = 1e9 + index = None + + # Find smallest cardinality + for i, d in enumerate(nonAtomics): + spouses, children = d["spouses"], d["children"] + card = 0 + for (k, measures) in spouses: + card += len(measures) + if card < lowest: + index = i + + # Pop the smallest guy + for i, v in enumerate(nonAtomics): + if i != index: + newlist.append(v) + return nonAtomics[index], newlist + + if len(self.nonAtomics) == 0: + return + + LOGGER.info("Finding nonAtomic Covers...") + + # 1. Create a copy of the graph + # 2. Fully connect the remaining variables in activeSet + # 3. Use the UndirectedGraph for rediscovering the nonAtomics + Gp = deepcopy(self) + Gp.introduceTempRoot() + Gp.updateActiveSet() + UG = UndirectedGraph(Gp) + self.nonAtomics = _rediscover(self.nonAtomics) + + def findAdjacentNodes(self, L: Cover): + """ + Find adjacent atomicCovers to L. + """ + Ns = set() + Gp = deepcopy(self) + Gp.makeRoot(L) + for C in Gp.findChildren(L): + if C.isTemp: + Ns.update(Gp.findChildren(C)) + else: + Ns.add(C) + return set([V for V in Ns if V.isAtomic]) + + def findSubcovers(self, L: Cover, only_atomic=False): + """ + Find all subcovers of L in latentDict. + """ + subcovers = set() + temp = set(self.latentDict.keys())|self.X + for cover in temp: + if only_atomic and not cover.isAtomic: + continue + if cover.isSubset(L, strict=True): + subcovers.add(cover) + return subcovers + + def isNonAtomic(self, L: Cover): + """ + Determine if L is nonAtomic, i.e. it can be subdivided into a disjoint + set of atomic Covers. + + Assumption: no pair of atomic Covers has overlapping variables. + """ + subcovers = self.findSubcovers(L, only_atomic=True) + subcoverVars = getVars(subcovers) + return subcoverVars == L.vars + + def disconnectForNonAtomicParents(self, G: LatentGroups, P: Cover): + """ + When testing for independence at a child of nonAtomic parent P, we need + to represent each cover within P with its own disjoint set of variables. + Hence we need to disconnect the graph at suitable points to achieve this. + + We use BFS to visit descendants of each subcover of P. If we visit the + same node again, we disconnect all edges to that node. + + Returns: + a modified LatentGroups graph. + """ + assert not P.isAtomic, f"{P} should be non atomic." + Gp = deepcopy(G) + subcovers = Gp.latentDict[P]["subcovers"] + visited = set() + commonNodes = set() + + # First pass to find commonNodes + for subcover in Gp.latentDict[P]["subcovers"]: + Gp.makeRoot(subcover) + Q = deque() + Q.append(subcover) + while len(Q) > 0: + L = Q.pop() + children = Gp.findChildren(L) + for child in children: + if child in subcovers | visited: + commonNodes.add(child) + else: + if child.isLatent: + Q.append(child) + visited.add(child) + + # Second pass to disconnect edges + for subcover in Gp.latentDict[P]["subcovers"]: + Gp.makeRoot(subcover) + Q = deque() + Q.append(subcover) + while len(Q) > 0: + L = Q.pop() + children = Gp.findChildren(L) + for child in children: + if child in subcovers | commonNodes: + Gp.disconnectNodes(set([L]), set([child])) + else: + if child.isLatent: + Q.append(child) + return Gp + + def getDotGraph(self): + """ + Parse a LatentGroups object into a DotGraph. + """ + + def addParentToGraph(dotGraph, parent, childrenSet): + + # Add edges from children to new parents + for P in parent.vars: + for childGroup in childrenSet: + for child in childGroup.vars: + dotGraph.addEdge(P, child, type=1) + + G = deepcopy(self) + Xvars = G.X + + # Add X variables + dotGraph = DotGraph() + for X in Xvars: + X = next(iter(X.vars)) + dotGraph.addNode(X) + + # Add nonAtomics first + for cover in G.latentDict: + if cover.isAtomic: + continue + for L in cover.vars: + dotGraph.addNode(L, refined=False) + + # Add atomics second so that refined gets reflected correctly + for cover in G.latentDict: + if not cover.isAtomic: + continue + refined = G.latentDict[cover].get("refined", False) + for L in cover.vars: + dotGraph.addNode(L, refined=refined) + + # Work iteratively through the Graph Dictionary + while len(G.latentDict) > 0: + parent, values = G.latentDict.popitem() + addParentToGraph(dotGraph, parent, values["children"]) + + return dotGraph + + def pruneControlSet(self, G: LatentGroups, As: set[Cover], Bs: set[Cover]): + """ + When doing tests for independence, there may exist backdoor connections + from variables in As to variables in Bs, hence we need to remove those + variables with backdoor from Bs to prevent messing up the test. + + Returns: + Bs: A pruned control set. + """ + Gp = deepcopy(G) + toPrune = set() + for A in As: + visited = set() + Q = deque() + Q.append(A) + while len(Q) > 0: + V = Q.pop() + visited.add(V) + if V in Bs: + toPrune.add(V) + if V.isLatent: + for C in Gp.findChildren(V) | Gp.findNonAtomicChildren(V): + if not C in visited: + Q.append(C) + return Bs - toPrune + + def disconnectAllEdgestoCover(self, G: LatentGroups, L: Cover): + """ + Disconnect all edges to L. + + This differs from removeCover in that if L is part of a non-atomic + cover {L, L2}->C, we want to retain edge from L2->C but remove edge + from L->C. + + Returns: + Gp: A modified LatentGroups object + """ + assert L.isAtomic, f"{L} is not atomic." + Gp = deepcopy(G) + for cover, v in G.latentDict.items(): + + # Remove L as an atomic parent + if cover == L: + Gp.latentDict.pop(L) + + # Remove L as a child of any atomic/non-atomic parent + if L in v["children"]: + Gp.latentDict[cover]["children"].remove(L) + + # Remove L as a co-parent, but retain remaining parents + if not cover.isAtomic: + if L in v["subcovers"]: + subcovers = G.latentDict[cover]["subcovers"] - set([L]) + newCover = Cover(getVars(subcovers)) + Gp.addOrUpdateCover(newCover, v["children"]) + Gp.latentDict.pop(cover) + + return Gp + + def toNetworkX(self): + """ + Convert graph into a networkx undirected graph. + """ + NG = nx.Graph() + for L, v in self.latentDict.items(): + if L.isAtomic: + for C in v["children"]: + NG.add_edge(L, C) + else: + for S in v["subcovers"]: + for C in v["children"]: + NG.add_edge(S, C) + return NG + + +def pruneGraph(G: LatentGroups, Vs: set[Cover]): + """ + Prune away all nodes in the graph G that are descendants of Vs. + + Note that this will result in Vs becoming leaf nodes in the pruned graph. + + Returns: + Gp: A pruned graph. + """ + Gp = deepcopy(G) + for V in Vs: + assert V.isAtomic, f"{V} is not atomic." + + nodesToDrop = set() + # BFS to add nodes + Q = deque() + for V in Vs: + Q.append(V) + + while len(Q) > 0: + A = Q.popleft() + if A not in Vs: + nodesToDrop.add(A) + if A.isLatent: + for C in Gp.findChildren(A) | Gp.findNonAtomicChildren(A): + if not C in nodesToDrop | Vs: + Q.append(C) + + for node in nodesToDrop: + if node.isLatent: + Gp.removeCover(node) + else: + for P in Gp.findParents(node): + if P in Gp.latentDict: + Gp.latentDict[P]["children"] -= set([node]) + + Gp.X = Gp.X.intersection(Vs) + Gp.updateActiveSet() + return Gp + + + +def getLfromLatentGroups(G: LatentGroups, xvars: list): + vars_set = set() + for cover in G.latentDict.keys(): + vars_set |= cover.vars + + lvars_set = vars_set - set(xvars) + + vars_ls = xvars + list(lvars_set) + L = np.zeros((len(vars_ls), len(vars_ls))) + + for cover, val in G.latentDict.items(): + + if cover.atomic: + fa_list = [] + for var in cover.vars: + index_fa = vars_ls.index(var) + fa_list.append(index_fa) + + for fa_index1 in fa_list: + for fa_index2 in fa_list: + if fa_index1!=fa_index2: + L[fa_index1][fa_index2]=-2 + + for var in cover.vars: + index_fa = vars_ls.index(var) + for chcover in val['children']: + for chvar in chcover.vars: + index_ch = vars_ls.index(chvar) + + if cover.atomic: + #if chcover in val['fake_children']: + # L[index_fa][index_ch] = 1 + # L[index_ch][index_fa] = -1 + #else: + # L[index_fa][index_ch] = 1 + # L[index_ch][index_fa] = 1 + L[index_fa][index_ch] = -1 + L[index_ch][index_fa] = 1 + #L[index_fa][index_ch] = 1 + #L[index_ch][index_fa] = 1 + else: + L[index_fa][index_ch] = -1 + L[index_ch][index_fa] = 1 + + + return L \ No newline at end of file diff --git a/causallearn/search/HiddenCausal/RLCD/PC_CovRank.py b/causallearn/search/HiddenCausal/RLCD/PC_CovRank.py new file mode 100644 index 00000000..7ada9021 --- /dev/null +++ b/causallearn/search/HiddenCausal/RLCD/PC_CovRank.py @@ -0,0 +1,329 @@ +from __future__ import annotations +import os, json, codecs, time, hashlib +import numpy as np +from math import log, sqrt +from collections.abc import Iterable +from scipy.stats import chi2, norm + +from causallearn.utils.cit import CIT_Base + +CONST_BINCOUNT_UNIQUE_THRESHOLD = 1e5 +NO_SPECIFIED_PARAMETERS_MSG = "NO SPECIFIED PARAMETERS" + +import time +import warnings +from itertools import combinations, permutations +from typing import Dict, List, Tuple + +import networkx as nx +import numpy as np +from numpy import ndarray + +from causallearn.graph.GraphClass import CausalGraph +from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge +from causallearn.utils.cit import * +from causallearn.utils.PCUtils import Helper, Meek, SkeletonDiscovery, UCSepset +from causallearn.utils.PCUtils.BackgroundKnowledgeOrientUtils import \ + orient_by_background_knowledge + + +class CovRank(CIT_Base): + def __init__(self, data, alpha, rescale_rank_test, **kwargs): + super().__init__(data, **kwargs) + self.check_cache_method_consistent('CovRank', NO_SPECIFIED_PARAMETERS_MSG) + self.assert_input_data_is_valid() + + self.data = (data-data.mean())/data.std() + self.alpha = alpha + + try: + from .CCARankTester import CCARankTester + except ImportError as exc: + raise ImportError("CCARankTester is required for pc_cov_rank but is not included in scm-identify's StructureLearning/RLCD package.") from exc + alpha_dict = {i:self.alpha for i in range(data.shape[1])} + self.RankTester = CCARankTester(self.data, alpha_dict=alpha_dict, rescale_rank_test=rescale_rank_test) + + def __call__(self, X, Y, condition_set=None): + ''' + Perform an independence test using Fisher-Z's test. + + Parameters + ---------- + X, Y and condition_set : column indices of data + + Returns + ------- + p : the p-value of the test + ''' + Xs, Ys, condition_set, cache_key = self.get_formatted_XYZ_and_cachekey(X, Y, condition_set) + if cache_key in self.pvalue_cache: return self.pvalue_cache[cache_key] + + _, p = self.RankTester.test(Xs+condition_set, Ys+condition_set, r=len(condition_set)) + self.pvalue_cache[cache_key] = p + + return p + + +def pc_true_cov_rank( + fake_data: ndarray, + independence_test_method, + alpha=0.05, + stable: bool = True, + uc_rule: int = 0, + uc_priority: int = 2, + mvpc: bool = False, + correction_name: str = 'MV_Crtn_Fisher_Z', + background_knowledge: BackgroundKnowledge | None = None, + verbose: bool = False, + show_progress: bool = True, + node_names: List[str] | None = None, + **kwargs +): + + return pc_alg_true_cov_rank(fake_data=fake_data, independence_test_method=independence_test_method, node_names=node_names, alpha=alpha, stable=stable, uc_rule=uc_rule, + uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose, + show_progress=show_progress, **kwargs) + + + +def pc_alg_true_cov_rank( + fake_data: ndarray, + independence_test_method, + node_names: List[str] | None, + alpha: float, + stable: bool, + uc_rule: int, + uc_priority: int, + background_knowledge: BackgroundKnowledge | None = None, + verbose: bool = False, + show_progress: bool = True, + **kwargs +) -> CausalGraph: + """ + Perform Peter-Clark (PC) algorithm for causal discovery + + Parameters + ---------- + data : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of samples and n_features is the number of features. + node_names: Shape [n_features]. The name for each feature (each feature is represented as a Node in the graph, so it's also the node name) + alpha : float, desired significance level of independence tests (p_value) in (0, 1) + indep_test : str, the name of the independence test being used + ["fisherz", "chisq", "gsq", "kci"] + - "fisherz": Fisher's Z conditional independence test + - "chisq": Chi-squared conditional independence test + - "gsq": G-squared conditional independence test + - "kci": Kernel-based conditional independence test + stable : run stabilized skeleton discovery if True (default = True) + uc_rule : how unshielded colliders are oriented + 0: run uc_sepset + 1: run maxP + 2: run definiteMaxP + uc_priority : rule of resolving conflicts between unshielded colliders + -1: whatever is default in uc_rule + 0: overwrite + 1: orient bi-directed + 2. prioritize existing colliders + 3. prioritize stronger colliders + 4. prioritize stronger* colliers + background_knowledge : background knowledge + verbose : True iff verbose output should be printed. + show_progress : True iff the algorithm progress should be show in console. + + Returns + ------- + cg : a CausalGraph object, where cg.G.graph[j,i]=1 and cg.G.graph[i,j]=-1 indicates i --> j , + cg.G.graph[i,j] = cg.G.graph[j,i] = -1 indicates i --- j, + cg.G.graph[i,j] = cg.G.graph[j,i] = 1 indicates i <-> j. + + """ + + start = time.time() + + indep_test = independence_test_method + + cg_1 = SkeletonDiscovery.skeleton_discovery(fake_data, alpha, indep_test, stable, + background_knowledge=background_knowledge, verbose=verbose, + show_progress=show_progress, node_names=node_names) + + if background_knowledge is not None: + orient_by_background_knowledge(cg_1, background_knowledge) + + if uc_rule == 0: + if uc_priority != -1: + cg_2 = UCSepset.uc_sepset(cg_1, uc_priority, background_knowledge=background_knowledge) + else: + cg_2 = UCSepset.uc_sepset(cg_1, background_knowledge=background_knowledge) + cg = Meek.meek(cg_2, background_knowledge=background_knowledge) + + elif uc_rule == 1: + if uc_priority != -1: + cg_2 = UCSepset.maxp(cg_1, uc_priority, background_knowledge=background_knowledge) + else: + cg_2 = UCSepset.maxp(cg_1, background_knowledge=background_knowledge) + cg = Meek.meek(cg_2, background_knowledge=background_knowledge) + + elif uc_rule == 2: + if uc_priority != -1: + cg_2 = UCSepset.definite_maxp(cg_1, alpha, uc_priority, background_knowledge=background_knowledge) + else: + cg_2 = UCSepset.definite_maxp(cg_1, alpha, background_knowledge=background_knowledge) + cg_before = Meek.definite_meek(cg_2, background_knowledge=background_knowledge) + cg = Meek.meek(cg_before, background_knowledge=background_knowledge) + else: + raise ValueError("uc_rule should be in [0, 1, 2]") + end = time.time() + + cg.PC_elapsed = end - start + + return cg + + +def pc_cov_rank( + data: ndarray, + alpha=0.05, + rescale_rank_test=1, + stable: bool = True, + uc_rule: int = 0, + uc_priority: int = 2, + mvpc: bool = False, + correction_name: str = 'MV_Crtn_Fisher_Z', + background_knowledge: BackgroundKnowledge | None = None, + verbose: bool = False, + show_progress: bool = True, + node_names: List[str] | None = None, + **kwargs +): + if data.shape[0] < data.shape[1]: + warnings.warn("The number of features is much larger than the sample size!") + + + return pc_alg_cov_rank(data=data, node_names=node_names, alpha=alpha, rescale_rank_test=rescale_rank_test, stable=stable, uc_rule=uc_rule, + uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose, + show_progress=show_progress, **kwargs) + + + +def pc_alg_cov_rank( + data: ndarray, + node_names: List[str] | None, + alpha: float, + rescale_rank_test: float, + stable: bool, + uc_rule: int, + uc_priority: int, + background_knowledge: BackgroundKnowledge | None = None, + verbose: bool = False, + show_progress: bool = True, + **kwargs +) -> CausalGraph: + """ + Perform Peter-Clark (PC) algorithm for causal discovery + + Parameters + ---------- + data : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of samples and n_features is the number of features. + node_names: Shape [n_features]. The name for each feature (each feature is represented as a Node in the graph, so it's also the node name) + alpha : float, desired significance level of independence tests (p_value) in (0, 1) + indep_test : str, the name of the independence test being used + ["fisherz", "chisq", "gsq", "kci"] + - "fisherz": Fisher's Z conditional independence test + - "chisq": Chi-squared conditional independence test + - "gsq": G-squared conditional independence test + - "kci": Kernel-based conditional independence test + stable : run stabilized skeleton discovery if True (default = True) + uc_rule : how unshielded colliders are oriented + 0: run uc_sepset + 1: run maxP + 2: run definiteMaxP + uc_priority : rule of resolving conflicts between unshielded colliders + -1: whatever is default in uc_rule + 0: overwrite + 1: orient bi-directed + 2. prioritize existing colliders + 3. prioritize stronger colliders + 4. prioritize stronger* colliers + background_knowledge : background knowledge + verbose : True iff verbose output should be printed. + show_progress : True iff the algorithm progress should be show in console. + + Returns + ------- + cg : a CausalGraph object, where cg.G.graph[j,i]=1 and cg.G.graph[i,j]=-1 indicates i --> j , + cg.G.graph[i,j] = cg.G.graph[j,i] = -1 indicates i --- j, + cg.G.graph[i,j] = cg.G.graph[j,i] = 1 indicates i <-> j. + + """ + + start = time.time() + indep_test = CovRank(data, alpha, rescale_rank_test, **kwargs) + cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable, + background_knowledge=background_knowledge, verbose=verbose, + show_progress=show_progress, node_names=node_names) + + if background_knowledge is not None: + orient_by_background_knowledge(cg_1, background_knowledge) + + if uc_rule == 0: + if uc_priority != -1: + cg_2 = UCSepset.uc_sepset(cg_1, uc_priority, background_knowledge=background_knowledge) + else: + cg_2 = UCSepset.uc_sepset(cg_1, background_knowledge=background_knowledge) + cg = Meek.meek(cg_2, background_knowledge=background_knowledge) + + elif uc_rule == 1: + if uc_priority != -1: + cg_2 = UCSepset.maxp(cg_1, uc_priority, background_knowledge=background_knowledge) + else: + cg_2 = UCSepset.maxp(cg_1, background_knowledge=background_knowledge) + cg = Meek.meek(cg_2, background_knowledge=background_knowledge) + + elif uc_rule == 2: + if uc_priority != -1: + cg_2 = UCSepset.definite_maxp(cg_1, alpha, uc_priority, background_knowledge=background_knowledge) + else: + cg_2 = UCSepset.definite_maxp(cg_1, alpha, background_knowledge=background_knowledge) + cg_before = Meek.definite_meek(cg_2, background_knowledge=background_knowledge) + cg = Meek.meek(cg_before, background_knowledge=background_knowledge) + else: + raise ValueError("uc_rule should be in [0, 1, 2]") + end = time.time() + + cg.PC_elapsed = end - start + + return cg + + +class FisherZ(CIT_Base): + def __init__(self, data, **kwargs): + super().__init__(data, **kwargs) + self.check_cache_method_consistent('fisherz', NO_SPECIFIED_PARAMETERS_MSG) + self.assert_input_data_is_valid() + self.correlation_matrix = np.corrcoef(data.T) + + def __call__(self, X, Y, condition_set=None): + ''' + Perform an independence test using Fisher-Z's test. + + Parameters + ---------- + X, Y and condition_set : column indices of data + + Returns + ------- + p : the p-value of the test + ''' + Xs, Ys, condition_set, cache_key = self.get_formatted_XYZ_and_cachekey(X, Y, condition_set) + if cache_key in self.pvalue_cache: return self.pvalue_cache[cache_key] + var = Xs + Ys + condition_set + sub_corr_matrix = self.correlation_matrix[np.ix_(var, var)] + try: + inv = np.linalg.inv(sub_corr_matrix) + except np.linalg.LinAlgError: + raise ValueError('Data correlation matrix is singular. Cannot run fisherz test. Please check your data.') + r = -inv[0, 1] / sqrt(abs(inv[0, 0] * inv[1, 1])) + if abs(r) >= 1: r = (1. - np.finfo(float).eps) * np.sign(r) # may happen when samplesize is very small or relation is deterministic + Z = 0.5 * log((1 + r) / (1 - r)) + X = sqrt(self.sample_size - len(condition_set) - 3) * abs(Z) + p = 2 * (1 - norm.cdf(abs(X))) + self.pvalue_cache[cache_key] = p + return p \ No newline at end of file diff --git a/causallearn/search/HiddenCausal/RLCD/RLCD_alg.py b/causallearn/search/HiddenCausal/RLCD/RLCD_alg.py new file mode 100644 index 00000000..94e479e1 --- /dev/null +++ b/causallearn/search/HiddenCausal/RLCD/RLCD_alg.py @@ -0,0 +1,886 @@ +import pandas as pd +import importlib +from .logger import LOGGER +from .DSU import DSU +from communities.algorithms import bron_kerbosch +import numpy as np +import copy +import os +from causallearn.graph.Edge import Edge +from causallearn.graph.Endpoint import Endpoint +from causallearn.graph.GeneralGraph import GeneralGraph +from causallearn.graph.GraphClass import CausalGraph +from causallearn.graph.GraphNode import GraphNode +from causallearn.graph.NodeType import NodeType +from .LatentGroups import LatentGroups, getLfromLatentGroups +from .Chi2RankTest import Chi2RankTest +from . import misc as M +from .misc import Independences, Edges, powerset +from .Cover import getOrderedVarsString, setLength, setDifference, Cover, pairwiseOverlap, getVars +from itertools import combinations +from .GraphDrawer import DotGraph +from joblib import delayed, Parallel + + +def _adjacency_to_causal_graph(adjacency, var_names): + """Convert RLCD's adjacency matrix into causal-learn's CausalGraph wrapper.""" + nodes = [] + for name in var_names: + node = GraphNode(name) + if name.startswith("L"): + node.set_node_type(NodeType.LATENT) + nodes.append(node) + + graph = GeneralGraph(nodes) + for i in range(len(adjacency)): + for j in range(i + 1, len(adjacency)): + if adjacency[i, j] == 0 and adjacency[j, i] == 0: + continue + if adjacency[i, j] == -1 and adjacency[j, i] == 1: + graph.add_directed_edge(nodes[i], nodes[j]) + elif adjacency[i, j] == 1 and adjacency[j, i] == -1: + graph.add_directed_edge(nodes[j], nodes[i]) + else: + graph.add_edge(Edge(nodes[i], nodes[j], Endpoint.TAIL, Endpoint.TAIL)) + + cg = CausalGraph(len(var_names), var_names) + cg.G = graph + return cg + + +def _rlcd_impl( + sample, + xvars: list = None, + df: pd.DataFrame = None, + input_parameters: dict = None, +): + + parameters = { + "xvars": xvars, + "alpha_dict": {0: 0.01, 1: 0.01, 2: 0.01, 3:0.01}, + "maxk": 3, + "allow_nonleafx": True, + "unfold_covers": True, + "check_v": True, + "stages": 2, + "stage1_method": "fges", + "stage1_ges_sparsity": 2, + "stage1_CI_alpha": 0.01, + "stage1_partition_thres": 3, + "ranktest_method": None, + "citest_method": None, + } + + if input_parameters is not None: + parameters.update(input_parameters) + parameters['sample'] = sample + xvars = parameters["xvars"] + if xvars is None: + if df is not None: + xvars = list(df.columns) + elif parameters["ranktest_method"] is not None and hasattr(parameters["ranktest_method"], "data"): + xvars = [f"X{i + 1}" for i in range(parameters["ranktest_method"].data.shape[1])] + else: + raise ValueError("xvars must be provided when neither df nor ranktest_method.data is available.") + parameters["xvars"] = xvars + + if parameters['stages']>=1: + if not parameters['sample']: + if parameters['stage1_method']=='all': + Adj_stage1 = np.ones((len(xvars),len(xvars))) + partition = [xvars] + elif parameters['stage1_method']=='fci': + from .FCI_CovRank import fci_true_cov_rank + G, edges = fci_true_cov_rank(np.zeros((1, len(xvars))), parameters['citest_method']) + Adj_stage1 = process_fci_result(G.graph) + partition = getPartition(xvars, abs(Adj_stage1), parameters['stage1_partition_thres']) + else: + if parameters['stage1_method']=='all': + Adj_stage1 = np.ones((len(xvars),len(xvars))) + partition = [xvars] + elif parameters['stage1_method']=='fges': + jpype = importlib.import_module("jpype") + importlib.import_module("jpype.imports") + try: + jpype.startJVM(classpath=[f"./pytetrad/tetrad-current_old.jar"]) + #current_dirname = os.path.dirname(__file__) + #jpype.startJVM(classpath=[f"{current_dirname}/../../utils/pytetrad/tetrad-current.jar"]) + LOGGER.info("JVM started") + except OSError: + LOGGER.info("JVM already started") + LOGGER.info('running fges') + TetradSearch = importlib.import_module("pytetrad.TetradSearch_old").TetradSearch + pytetrad_search = TetradSearch(df) + pytetrad_search.set_verbose(False) + pytetrad_search.use_sem_bic(penalty_discount=parameters['stage1_ges_sparsity']) + pytetrad_search.run_fges() + cg = pytetrad_search.get_causal_learn() + Adj_stage1 = cg.graph + #Adj_stage1_dag = pdag2dag(cg).graph + + partition = getPartition(xvars, abs(Adj_stage1), parameters['stage1_partition_thres']) + + elif parameters['stage1_method']=='ges': + LOGGER.info('running ges') + from causallearn.search.ScoreBased.GES import ges + Record = ges(df.to_numpy(), parameters={'lambda':parameters['stage1_ges_sparsity']}) + Adj_stage1 = Record['G'].graph + #Adj_dag = pdag2dag(Record['G']).graph + + partition = getPartition(xvars, abs(Adj_stage1), parameters['stage1_partition_thres']) + else: + raise NotImplementedError + + LOGGER.info("Partition of Cliques") + for group in partition: + LOGGER.info(group) + + Adj = Adj_stage1 + + if parameters['stages']>=2: + + for current_xvars in partition: + + current_xvars_idx = [xvars.index(x) for x in current_xvars] + + def get_neighbour_set(all_xvars, current_xvars, Adj): + nb_set = set() + + for xvar1 in current_xvars: + xvar1_idx = all_xvars.index(xvar1) + + for xvar2 in all_xvars: + if xvar2 not in current_xvars: + xvar2_idx = all_xvars.index(xvar2) + + if Adj[xvar1_idx, xvar2_idx]!=0: #adjacent + nb_set.add(xvar2) + + return nb_set + + neighbour_set = get_neighbour_set(xvars, current_xvars, Adj) + + #local_Adj = Adj[current_xvars_idx].T[current_xvars_idx].T + local_Adj = Adj[np.ix_(current_xvars_idx, current_xvars_idx)] + + current_G = LatentGroups(X=current_xvars, Xns=current_xvars, all_nb_set=neighbour_set, nb_set_dict={x_var:set() for x_var in current_xvars}, \ + local_Adj=local_Adj) + + current_G = rlcd_find_latent(current_G, parameters) + current_output_Adj = getLfromLatentGroups(current_G, current_xvars) + current_output_Adj = getReducedAdj(current_output_Adj, [i for i in range(len(current_xvars))]) + + num_new_latent = current_output_Adj.shape[0] - len(current_xvars) + + if num_new_latent>0: + + # pad Adj by num_new_latent + temp = np.zeros((Adj.shape[0]+num_new_latent, Adj.shape[0]+num_new_latent)) + temp[:Adj.shape[0],:Adj.shape[0]] = Adj + Adj = temp + + def copy_by_idx(A_, B, indexes1_in_A, indexes2_in_A): + A=A_.copy() + assert(len(indexes1_in_A)==B.shape[0]) + assert(len(indexes2_in_A)==B.shape[1]) + for idx1_B, idx1_A in enumerate(indexes1_in_A): + for idx2_B, idx2_A in enumerate(indexes2_in_A): + A[idx1_A,idx2_A] = B[idx1_B,idx2_B] + return A + + # update x*x + Adj = copy_by_idx(Adj, current_output_Adj[:len(current_xvars), :len(current_xvars)], current_xvars_idx, current_xvars_idx) + # update x*l, l*x, and l*l + current_lvars_idx = [x for x in range(Adj.shape[0]-num_new_latent, Adj.shape[0])] + Adj = copy_by_idx(Adj, current_output_Adj[:len(current_xvars), len(current_xvars):], current_xvars_idx, current_lvars_idx) + Adj = copy_by_idx(Adj, current_output_Adj[len(current_xvars):, :len(current_xvars)], current_lvars_idx, current_xvars_idx) + Adj = copy_by_idx(Adj, current_output_Adj[len(current_xvars):, len(current_xvars):], current_lvars_idx, current_lvars_idx) + + # ending + all_vars = [x for x in xvars] + for i in range(Adj.shape[0]-len(xvars)): + all_vars.append(f"L{i+1}") + + atomic_mask = np.zeros_like(Adj) + for i in range(len(Adj)): + for j in range(len(Adj)): + if Adj[i,j]==-2: + Adj[i,j]=1 + atomic_mask[i,j]=1 + + Adj_combined = Adj.copy() + for i in range(len(Adj_combined)): + for j in range(len(Adj_combined)): + if atomic_mask[i,j]==1: + Adj_combined[i,j]=0 + + cg = _adjacency_to_causal_graph(Adj_combined, all_vars) + stage1_cg = _adjacency_to_causal_graph(Adj_stage1, xvars) + + return cg, stage1_cg, Adj_combined, all_vars + + + +def RLCD( + data, + ranktest_method=None, + stage1_method="ges", + alpha_dict=None, + maxk=3, + node_names=None, + stage1_ges_sparsity=2, + stage1_partition_thres=3, + allow_nonleafx=True, + **kwargs, +): + """Run Rank-based Latent Causal Discovery. + + Parameters + ---------- + data : numpy.ndarray + Data matrix with shape (n_samples, n_features). + ranktest_method : object, optional + Rank test object with a ``test(pcols, qcols, r, alpha)`` method. If + omitted, ``Chi2RankTest(data)`` is used. + stage1_method : str, default="ges" + Stage-1 method used to partition observed variables. Supported values + are inherited from RLCD's structure-learning implementation. + alpha_dict : dict, optional + Significance levels for rank tests by rank. + maxk : int, default=3 + Maximum rank-search cardinality. + node_names : list, optional + Names for observed variables in the returned graph. If omitted, + variables are named X1, X2, ... + + Returns + ------- + cg : CausalGraph + Learned graph over observed and latent variables, where + cg.G.graph[j, i] = 1 and cg.G.graph[i, j] = -1 indicate i --> j. + Additional RLCD outputs are attached as ``stage1_cg``, ``adjacency``, + and ``all_vars``. + """ + data = np.asarray(data) + if data.ndim != 2: + raise ValueError("data must be a 2-dimensional array.") + + if node_names is None: + node_names = [f"X{i + 1}" for i in range(data.shape[1])] + if len(node_names) != data.shape[1]: + raise ValueError("node_names must have the same length as the number of columns in data.") + + if alpha_dict is None: + alpha_dict = {0: 0.01, 1: 0.01, 2: 0.01, 3: 0.01} + if ranktest_method is None: + ranktest_method = Chi2RankTest(data) + + input_parameters = { + "ranktest_method": ranktest_method, + "stage1_method": stage1_method, + "alpha_dict": alpha_dict, + "maxk": maxk, + "allow_nonleafx": allow_nonleafx, + "stage1_ges_sparsity": stage1_ges_sparsity, + "stage1_partition_thres": stage1_partition_thres, + } + input_parameters.update(kwargs) + + df = pd.DataFrame(data, columns=node_names) + cg, stage1_cg, adjacency, all_vars = _rlcd_impl( + sample=True, + xvars=node_names, + df=df, + input_parameters=input_parameters, + ) + cg.stage1_cg = stage1_cg + cg.adjacency = adjacency + cg.all_vars = all_vars + return cg + +def getReducedAdj(Adj, cid_ls): + + num_vars = len(Adj) + + def dfs(start_id, current_id, travel_record): + + result_id_set = set() + + if current_id in cid_ls and current_id!=start_id: + result_id_set |= set([i for i, x in enumerate(travel_record) if x==1]) + return result_id_set + + for j in range(num_vars): + if Adj[current_id, j]!=0 and travel_record[j]==0: + travel_record_new = travel_record.copy() + travel_record_new[j]=1 + + result_id_set = result_id_set | dfs(start_id, j, travel_record_new) + + return result_id_set + + + result_id_set = set(cid_ls) + + for start_id in cid_ls: + travel_record = [0 for i in range(num_vars)] + travel_record[start_id] = 1 + result_id_set |= dfs(start_id, start_id, travel_record) + + result_id_list = list(result_id_set) + result_id_list.sort() + + return Adj[result_id_list,:][:,result_id_list] + + + +def getPartition(xvars, Adj, clique_size_thres, direct_mode=False): + + def checkRelationBetweenCliques(clique1, clique2): + + common = set.intersection(clique1, clique2) + if len(common)<2: + return False + else: + return True + + # Adj is pc's result + partition = [] + communities = bron_kerbosch(abs(Adj), pivot=True) #if c1 subset c2 then it does not output c1 + communities = [x for x in communities if len(x)>=clique_size_thres] + + if direct_mode: + for clique in communities: + if len(clique)>=clique_size_thres+1: + temp = {xvars[i] for i in clique} + partition.append(temp) + LOGGER.info(f"Put {temp} with length {len(temp)} into queue") + return partition + + else: + dsu = DSU(len(communities)) + for i, clique1 in enumerate(communities): + for j, clique2 in enumerate(communities): + if i!=j and checkRelationBetweenCliques(clique1, clique2): + dsu.union(i, j) + + fa_set = set() + for i in range(len(communities)): + fa_set.add(dsu.find(i)) + + partition = [] + for fa in fa_set: + cliques = [] + for i in range(len(communities)): + if dsu.find(i)==fa: + cliques.append(communities[i]) + temp = {xvars[i] for i in set.union(*cliques)} + if len(temp)>=4: + partition.append(list(temp)) + LOGGER.info(f"Put {temp} with length {len(temp)} into queue") + + return partition + +def process_fci_result(adj_L): + result_L = np.zeros_like(adj_L) + for i in range(adj_L.shape[0]): + for j in range(adj_L.shape[1]): + if j= 2: + G, _ = findClusters(G, parameters) + + #if parameters['stages'] >= 3: + # G = refineClusters(G) + + return G + + +def generateLatentPowersetFromActiveSet(G): + """ + Generate an iterator over powerset of active latents, + Including the combination where all latents are included. + """ + + Ls = set([V for V in G.activeSet if V.is_leaf==False]) + Ls_ordered = list(Ls) + Ls_ordered.sort(key=lambda x: getOrderedVarsString(x)) + Lsubsets = reversed(list(M.powerset(Ls_ordered))) + + return [x for x in Lsubsets] + +def getVarNames(As): + measuredVars = [] + for A in As: + if not A.is_observed: + assert False, "A is not a measured var set" + for temp in A.vars: + measuredVars.append(temp) + return measuredVars + +def structuralRankTest(xvars, ranktest_method, alpha_dict, G: LatentGroups, As, Bs, k, nonLeafs): + """ + Test if As forms a cluster by seeing if rank(subcov[A,B]) <= k. + + Returns tuple of whether rank is deficient and lowest rank tested. + """ + + Ameasures = G.pickAllMeasures(As) + Ameasures = getVarNames(Ameasures) + Bmeasures = G.pickAllMeasures(Bs) + Bmeasures = getVarNames(Bmeasures) + + Ameasures += nonLeafs + Bmeasures += nonLeafs + + Ameasures = list(set(Ameasures)) + Bmeasures = list(set(Bmeasures)) + + pcols = [xvars.index(a) for a in Ameasures] + qcols = [xvars.index(b) for b in Bmeasures] + + fail_to_reject = ranktest_method.test(pcols, qcols, k, alpha_dict[k]) + + if fail_to_reject==False: + return (fail_to_reject, None) + else: + min_rank = k + for h0_k in range(k-1, -1, -1): + fail_to_reject_h0_k = ranktest_method.test(pcols, qcols, h0_k, alpha_dict[h0_k]) + if fail_to_reject_h0_k: + min_rank = h0_k + else: + break + return fail_to_reject, min_rank + +def findClusters_at_k_by_nonsinks(G: LatentGroups, k, nonsinks, parameters): + """ + Internal method for searchClusters. + """ + terminate = False # Whether we ran out of variables to test + found = False # Whether we found any clusters + res_for_add = [] + + num_nonsinks = len(nonsinks) + + current_activeSet = G.activeSet.copy() + current_ChildrenOfNonAtomicsSet = G.ChildrenOfNonAtomicsSet.copy() + + for temp in nonsinks: + current_activeSet.discard(G.X_dict[temp]) + current_ChildrenOfNonAtomicsSet.discard(G.X_dict[temp]) + + # Terminate if not enough active variables + # To test, we need n >= 2k+2 + # So terminate if n < 2k+2 + # i.e. k > n/2 - 1 + if k-num_nonsinks > setLength(current_activeSet) / 2 - 1: + terminate = True + return (found, terminate, res_for_add) + + if k!=len(nonsinks): # could induce latent then do not consider those neighbours in active set + for temp in G.all_nb_set: + if temp in G.X_dict and G.X_dict[temp] in current_activeSet: + current_activeSet.discard(G.X_dict[temp]) + + allSubsets = [x for x in M.generateSubsetMinimal(current_activeSet, k-num_nonsinks)] + + + for v in current_activeSet: + if len(v)>=k-num_nonsinks+1 and k-num_nonsinks!=0: # more than or eq + tempset = current_activeSet.copy() + tempset.remove(v) + additionalls = [x for x in M.generateSubsetMinimal(tempset, 0)] + for x in additionalls: + x.add(v) + allSubsets=allSubsets+additionalls + + + # If no subsets can be generated, terminate + if allSubsets == [set()]: + terminate = True + return (found, terminate, res_for_add) + + #for As in allSubsets: + for As in reversed(allSubsets): + #As = set(As) # test set + + effective_ChildrenOfNonAtomicsSet = current_ChildrenOfNonAtomicsSet.copy() + temp_set = As | {G.X_dict[t] for t in nonsinks} + + for cover in temp_set: + #effective_ChildrenOfNonAtomicsSet = effective_ChildrenOfNonAtomicsSet - G.findDescendants(cover, rigorous=False) + effective_ChildrenOfNonAtomicsSet = effective_ChildrenOfNonAtomicsSet - G.findDescendants(cover, rigorous=False) + + Bs = setDifference(current_activeSet | effective_ChildrenOfNonAtomicsSet, As) # control set + #Bs = setDifference(current_activeSet, As) # control set + #BBs = setDifference(current_activeSet, As) # control set + observed_vars_in_As = {x.__str__() for x in As if x.is_observed} + observed_vars_in_As_and_nonsinks = observed_vars_in_As.union(set(nonsinks)) + observed_vars_in_As_and_nonsinks = list(observed_vars_in_As_and_nonsinks) + observed_vars_in_As_and_nonsinks_idx_in_local_adj = [G.x_list_for_local_Adj.index(x) for x in observed_vars_in_As_and_nonsinks] + + temp_local_adj = G.local_Adj[observed_vars_in_As_and_nonsinks_idx_in_local_adj,:][:,observed_vars_in_As_and_nonsinks_idx_in_local_adj] + + def check_dsu(adj): + num_var = len(adj) + dsu = DSU(num_var) + for i in range(num_var): + for j in range(num_var): + if i!=j and (adj[i,j]!=0 or adj[j,i]!=0): + dsu.union(i, j) + fa_set = set() + for i in range(num_var): + fa_set.add(dsu.find(i)) + + if len(fa_set)==1: + return True + else: + return False + + if not check_dsu(temp_local_adj): + continue + + if setLength(Bs)<=k-len(nonsinks): + continue + #if len(As) > setLength(Bs): + # continue + + # As must not contain more than k elements from + # any atomic Cover with cardinality <= k-1 + if G.containsCluster(As, nonsinks): # toask + continue + + #if G.containsonlyaCluster(Bs, nonsinks): # toask + # continue + + if G.overlapPaCh(As): + continue + + if G.MeassuredHasNonSinks(As, nonsinks): + continue + + if G.checkNonSinksAreAsChildren(As, nonsinks): + continue + + # Bs parentCardinality cannot be < k+1, since otherwise + # we get rank <= k regardless of what As is + if G.parentCardinality(Bs) <= k - num_nonsinks: # dxs seems important to LLHCase2 + continue + #if len(unfolded)==1 and Bs.issubset(G.findChildren(unfolded[0])): + # print("allow parentCardinality(Bs) <= k - num_nonsinks") + #else: + # continue + + fail_to_reject, rk = structuralRankTest(parameters['xvars'], parameters['ranktest_method'], parameters['alpha_dict'], G, As, Bs, k, list(nonsinks)) + + + if fail_to_reject: + LOGGER.info(f" {As} is rank deficient! given {nonsinks}, Bs:{Bs}") + + v_structure_found = False + + if parameters['check_v']: + # check v structure + for num_colider in range(1,k-num_nonsinks+1): #|As|=k-num_nonsinks+1 + num_subAs = k-num_nonsinks+1-num_colider + for subAs in M.generateSubsetMinimal(As, num_subAs-1): + #test_subAs, rk_subAs = self.structuralRankTest(G, subAs, Bs - current_ChildrenOfNonAtomicsSet, num_nonsinks+num_subAs-1, list(nonsinks)) + test_subAs, rk_subAs = \ + structuralRankTest(parameters['xvars'], parameters['ranktest_method'], parameters['alpha_dict'], G, subAs, Bs, num_nonsinks+num_subAs-1, list(nonsinks)) + if test_subAs: + LOGGER.info(f" {As} has v structure! subAs:{subAs} given {nonsinks}, Bs:{Bs}") + v_structure_found = True + + if v_structure_found == False: + res_for_add.append((As, rk, nonsinks)) + #G.addRankDefSet(As, rk, used_nonsinks=nonsinks) + found = True + + return (found, terminate, res_for_add) + + +def findClusters_at_k_mp(G: LatentGroups, k, parameters, n_jobs=-1): + """ + Run one round of search for clusters of size k. + """ + LOGGER.info(f"Starting searchClusters k={k}...") + global_terminate=True + global_found=False + found_deficiency = False + + input_list = [] + + for num_nonsinks in range(k, -1, -1): # [k,k-1,...,0] + + temp_activeNonSink_ls = sorted(list(G.activeNonSinkSet), reverse=True) + #temp_activeNonSink_ls = list(G.activeNonSinkSet) + nonsinks_ls = list(combinations(temp_activeNonSink_ls, num_nonsinks)) + + for nonsinks in nonsinks_ls: + input_list.append(list(nonsinks).copy()) + + output_list = Parallel(n_jobs=n_jobs, backend='loky')( + delayed(findClusters_at_k_by_nonsinks)(G, k, nonsinks, parameters) + for nonsinks in input_list + ) + + for output in output_list: + current_found_deficiency, current_terminate, res_for_add = output + found_deficiency = found_deficiency or current_found_deficiency + global_terminate = global_terminate and current_terminate + + for i in range(len(res_for_add)): + G.addRankDefSet(res_for_add[i][0], res_for_add[i][1], used_nonsinks=res_for_add[i][2]) + + if found_deficiency: + G.determineClusters() # all the input deficient set are based on the same nonLeafs + found = G.confirmClusters() + global_found = global_found or found + + if global_found: + G.updateActiveSet() + G.updateactiveNonSinkSet() + M.display(G) + #printGraph(G) + return G, (global_found, global_terminate) + + return G, (global_found, global_terminate) + +def findClusters_at_k(G: LatentGroups, k, parameters): + """ + Run one round of search for clusters of size k. + """ + LOGGER.info(f"Starting searchClusters k={k}...") + global_terminate=True + global_found=False + found_deficiency = False + + for num_nonsinks in range(k, -1, -1): # [k,k-1,...,0] + + temp_activeNonSink_ls = sorted(list(G.activeNonSinkSet), reverse=True) + #temp_activeNonSink_ls = list(G.activeNonSinkSet) + nonsinks_ls = list(combinations(temp_activeNonSink_ls, num_nonsinks)) + + for nonsinks in nonsinks_ls: + + current_found_deficiency, current_terminate, res_for_add = findClusters_at_k_by_nonsinks(G, k, list(nonsinks), parameters) + found_deficiency = found_deficiency or current_found_deficiency + global_terminate = global_terminate and current_terminate + + for i in range(len(res_for_add)): + G.addRankDefSet(res_for_add[i][0], res_for_add[i][1], used_nonsinks=res_for_add[i][2]) + + + if found_deficiency: + G.determineClusters() # all the input deficient set are based on the same nonLeafs + found = G.confirmClusters() + global_found = global_found or found + + if global_found: + G.updateActiveSet() + G.updateactiveNonSinkSet() + M.display(G) + #printGraph(G) + return G, (global_found, global_terminate) + + + return G, (global_found, global_terminate) + + +def findClusters(G: LatentGroups, parameters): + + prevCovers = set(G.latentDict.keys()) # Record current latent Covers + + k = 1 + while True: + LOGGER.info(f"{'-'*15} Test Cardinality now k={k} {'-'*15}") + + if parameters['unfold_covers']: + LPowerSet = generateLatentPowersetFromActiveSet(G) + else: + LPowerSet = [()] + activeSetCopy = copy.deepcopy(G.activeSet) + activeNonSinkSetCopy = copy.deepcopy(G.activeNonSinkSet) + + # Select a combination of latents, and replace their place in + # the activeSet with their children for the search + for i, Ls in enumerate(LPowerSet): + + if i==0: + all_unfolded=True + else: + all_unfolded=False + + Vprime = copy.deepcopy(activeSetCopy) + Tprime = copy.deepcopy(activeNonSinkSetCopy) + + for L in Ls: + # L is a Cover + #children = G.findChildren(L, rigorous=False) # + children = G.findChildren(L, rigorous=True) # + meassured_subset_L = G.findMeassuredSubset(L) + for cover in meassured_subset_L: + Tprime |= cover.vars + + # If children of L is just one or zero variable, do not replace + if len(children) + len(meassured_subset_L) > 1: + #if len(children)>= 1: + Vprime = Vprime - set([L]) + Vprime |= children + Vprime |= meassured_subset_L + + children = G.findChildrenOfAllSubSets(Ls) # + Vprime |= children + + for cover in G.activeSet: + if len(cover.vars)==1 and cover.takeOne() in G.X_dict: + if G.X_dict[cover.takeOne()].is_leaf!=True: + Tprime |= cover.vars + + + G.activeSet = Vprime + + if parameters['allow_nonleafx']: + G.activeNonSinkSet = Tprime + else: + G.activeNonSinkSet = set() + + LOGGER.info(f"Unfolding {Ls}") + LOGGER.info(f"G.activeSet {G.activeSet}") + LOGGER.info(f"G.activeNonSinkSet {G.activeNonSinkSet}") + G, (found, terminate) = findClusters_at_k_mp(G, k, parameters, n_jobs=-1) + #G, (found, terminate) = findClusters_at_k(G, k, parameters) + + if found: + #G.reconnectNonAtomics() + G.updateActiveSet() # toask + G.updateactiveNonSinkSet() + + # CASE 2 + if terminate: + break + + # CASE 1 + if found: + k = 1 + break + + # CASE 3 + if not found: + LOGGER.info("Nothing found!") + k += 1 + + # CASE 2 toask + #if (k > self.maxk) or terminate: + if (k > parameters['maxk']): + LOGGER.info(f"Procedure ending...") + break + + + G = findClusters_finish(G, parameters) + G.updateActiveSet(if_for_finish=True) + # 1 + + newCovers = set(G.latentDict.keys()) - prevCovers + return G, newCovers + + +def findClusters_finish(G: LatentGroups, parameters): + """ + Procedure for completing the graph when no more clusters may be found, + by introducing a temporary root latent variable. + """ + LOGGER.info(f"{'-'*15} Check If Introducing temporary root variable ... {'-'*15}") + G.updateActiveSet(if_for_finish=True) + if len(G.activeSet) == 1: + pass + #elif len(G.activeSet) >2: + # G.introduceTempRoot() + # LOGGER.info(f"{'-'*15} Introduced temporary root variable ... {'-'*15}") + else: + + G.updateActiveSet(if_for_finish=False) + remain_covers = list(G.activeSet) + remain_latent_covers = [] + remain_observed_covers = [] + + for i in range(len(remain_covers)): + if not remain_covers[i].is_observed: + if remain_covers[i].atomic: + remain_latent_covers.append(remain_covers[i]) + else: + remain_observed_covers.append(remain_covers[i]) + + if len(remain_observed_covers)!=0: + for i in range(len(remain_observed_covers)): + to_add = set() + for j in range(len(remain_latent_covers)): + fail_to_reject, rk = structuralRankTest(parameters['xvars'], parameters['ranktest_method'], parameters['alpha_dict'], G, \ + set([remain_observed_covers[i]]), set([remain_latent_covers[j]]), 0, []) + if not fail_to_reject: + to_add.add(remain_latent_covers[j]) + + if len(to_add)!=0: + G.addOrUpdateCover(remain_observed_covers[i], to_add) + G.updateActiveSet() + + else: + G.updateActiveSet(if_for_finish=True) + remain_covers = list(G.activeSet) + remain_latent_covers = [] + + for i in range(len(remain_covers)): + if not remain_covers[i].is_observed: + remain_latent_covers.append(remain_covers[i]) + + if len(remain_latent_covers)>=2: + for i in range(len(remain_latent_covers)-1): + G.addOrUpdateCover(remain_latent_covers[i], set([remain_latent_covers[i+1]])) + G.updateActiveSet() + + M.display(G) + #printGraph(G) + #assert len(G.activeSet) == 1, "The graph should have one root variable." + return G \ No newline at end of file diff --git a/causallearn/search/HiddenCausal/RLCD/__init__.py b/causallearn/search/HiddenCausal/RLCD/__init__.py new file mode 100644 index 00000000..e1177d06 --- /dev/null +++ b/causallearn/search/HiddenCausal/RLCD/__init__.py @@ -0,0 +1,4 @@ +from .RLCD_alg import RLCD +from .Chi2RankTest import Chi2RankTest + +__all__ = ["RLCD", "Chi2RankTest"] diff --git a/causallearn/search/HiddenCausal/RLCD/logger.py b/causallearn/search/HiddenCausal/RLCD/logger.py new file mode 100644 index 00000000..de06ae9d --- /dev/null +++ b/causallearn/search/HiddenCausal/RLCD/logger.py @@ -0,0 +1,6 @@ +import logging + +log_format = "%(asctime)s | %(levelname)-5s %(funcName)-30s %(message)s" +log_level = logging.DEBUG +LOGGER = logging.getLogger(__name__) +LOGGER.addHandler(logging.NullHandler()) \ No newline at end of file diff --git a/causallearn/search/HiddenCausal/RLCD/misc.py b/causallearn/search/HiddenCausal/RLCD/misc.py new file mode 100644 index 00000000..16515652 --- /dev/null +++ b/causallearn/search/HiddenCausal/RLCD/misc.py @@ -0,0 +1,425 @@ +from __future__ import annotations +from copy import deepcopy +from math import factorial as fac +from math import sqrt +import numpy as np +from numpy.linalg import matrix_rank +from scipy.stats import norm +from itertools import combinations, chain, combinations_with_replacement +import pdb +import os +import glob +from .GraphDrawer import DotGraph +from .logger import LOGGER +from .Cover import Cover + +def generateSubsetMinimal(vset, k=1): + """ + Given a set of Covers, generate all minimum subsets s.t. cardinality > k. + """ + + def recursiveSearch(d, gap, currSubset=set()): + thread = f"currSubset: {currSubset}, d: {d}, gap is {gap}" + d = deepcopy(d) + currSubset = deepcopy(currSubset) + + # Terminate if empty list + if len(d) == 0: + return set() + + # Pop one Cover with largest cardinality + maxDim = max(d) + v = d[maxDim].pop() + if len(d[maxDim]) == 0: + d.pop(maxDim) + + # Branch to consider all cases + # Continue current search without this element + if len(d) > 0: + yield from recursiveSearch(d, gap, currSubset) + + # Add this group + if not groupInLatentSet(v, currSubset): + currSubset.add(v) + gap -= maxDim + + # Continue search if gap not met + if gap >= 0 and len(d) > 0: + yield from recursiveSearch(d, gap, currSubset) + + # End of search tree + if gap < 0: + yield currSubset + + #if k == 0: + # return set() + + # Create dictionary where key is in descending dimension size + # and v is a list of frozensets of variables + d = {} + + ordered_list = list(vset) + ordered_list.sort(key=lambda x: x.__hash__()) + + for v in ordered_list: + assert isinstance(v, Cover), "Should be Cover." + n = len(v) + d[n] = d.get(n, set()).union([v]) + + # Run recursive search + yield from recursiveSearch(d, k) + + +# Check if new group of latent vars exists in a current +# list of latent vars +def groupInLatentSet(V: Cover, currSubset: set): + for group in currSubset: + if len(V.vars.intersection(group.vars)) > 0: + return True + return False + + +# Centre the mean of data +def meanCentre(df): + n = df.shape[0] + return df - df.sum(axis=0) / n + + +# Return n choose r +def numCombinations(n, r): + return fac(n) // fac(r) // fac(n - r) + + +def getAllMeasures(latentDict, subgroups): + measures = set() + + for subgroup in subgroups: + values = latentDict[subgroup] + childrenP = values["children"] + subgroupsP = values["subgroups"] + + for child in childrenP: + if not child.isLatent(): + measures.update(childrenP) + + if len(subgroupsP) > 0: + measures.update(getAllMeasures(latentDict, subgroupsP)) + + return measures + + +# Given a set of Children, try the exact same set in a dictionary +def findEntry(latentDict, refChildren, subgroupMeasures): + for group in latentDict: + values = latentDict[group] + children = values["children"] + if children == refChildren: + subMeasures = getAllMeasures(latentDict, values["subgroups"]) + if subMeasures == subgroupMeasures: + return True + return False + + +#!! TESTS + +# S: Sample Covariance +# I, J: Disjoint index sets, |I| = |J| +def traceMatrixCompound(S, I, J, k): + X = I + J # Union of I and J + SijInv = np.linalg.inv(S[np.ix_(X, X)]) + Inew = [X.index(i) for i in I] + Jnew = [X.index(j) for j in J] + Sij = SijInv[np.ix_(Inew, Jnew)] + Sji = S[np.ix_(J, I)] + A = Sji @ Sij + + m = A.shape[0] + Sum = 0 + for Vs in combinations(range(m), k): + Vs = list(Vs) + Atemp = A[np.ix_(Vs, Vs)] + Sum += np.linalg.det(Atemp) + Sum = pow(-1, k) * Sum + # print(f"traceMatrixCompound is {Sum}") + return Sum + + +def determinantVariance(S, I, J, n): + assert len(I) == len(J), "I and J must be same length" + m = len(I) + X = I + J + SijDet = np.linalg.det(S[np.ix_(I, J)]) + SijijDet = np.linalg.det(S[np.ix_(X, X)]) + # print(f"SijDet is {SijDet}") + + Sum = 0 + for k in range(m): + Sum += ( + fac(m - k) * fac(n + 2) / fac(n + 2 - k) * traceMatrixCompound(S, I, J, k) + ) + firstTerm = ( + fac(n) + / fac(n - m) + * pow(SijDet, 2) + * (fac(n + 2) / fac(n + 2 - m) - fac(n) / fac(n - m)) + ) + secondTerm = fac(n) / fac(n - m) * SijijDet * Sum + variance = firstTerm + secondTerm + + # Heuristic (better way to handle negative variance?) + if variance < 0: + return 1 + else: + return variance + + +def determinantMean(S, I, J, n): + Scatter = S * n + x = np.linalg.det(Scatter[np.ix_(I, J)]) + return x + + +# Returns p value +def determinantTest(S, I, J, n): + detMean = determinantMean(S, I, J, n) + detVar = determinantVariance(S, I, J, n) + zStat = abs(detMean) / sqrt(detVar) + pValue = (1 - norm.cdf(zStat)) * 2 + return pValue + + +# Return true if fail to reject null +def bonferroniTest(plist, alpha): + m = len(plist) + return not any([p < alpha / m for p in plist]) + + +# Return true if fail to reject null +def bonferroniHolmTest(plist, alpha): + plist = sorted(plist) + m = len(plist) + tests = [p < alpha / (m + 1 - k) for k, p in enumerate(plist)] + # print(sum(tests)) + # print(len(plist)) + # print(plist[0]) + return not any(tests) + + +# Given data df, bootstrap sample and make new covariance +def bootStrapCovariance(data): + n = data.shape[0] + index = np.random.randint(low=0, high=n, size=n) + bootstrap = data.values[index] + cov = 1 / (n - 1) * bootstrap.T @ bootstrap + return cov + + +def scombinations(elements, k): + combs = [set(X) for X in combinations(elements, k)] + if len(combs) == 0: + return [set()] + return combs + + +# Given two lists of sets, get the cartesian product +def cartesian(list1, list2): + result = [] + for l1 in list1: + for l2 in list2: + result.append(l1.union(l2)) + return result + + +# Compare two rankDicts +def cmpDict(d1, d2): + mismatches = {} + same = True + for key in d1: + if d1[key] != d2[key]: + mismatches[key] = (d1[key], d2[key]) + same = False + + if len(mismatches) > 50: + break + return same, mismatches + + +# Equivalent graphs must have equal rank on all subcovariances +# Test each combination for equivalence +def compareGraphs(g1, g2): + if not set(g1.xvars) == set(g2.xvars): + print(g1.xvars, g2.xvars) + pdb.set_trace() + raise ValueError("X variables must be the same") + n = len(g1.xvars) + numbers = list(range(2, n - 1)) + combns = list(combinations_with_replacement(numbers, 2)) + + for i, j in combns: + LOGGER.info(f"Testing i={i} vs j={j}...") + Asets = list(combinations(g1.xvars, i)) + Bsets = list(combinations(g1.xvars, j)) + for A in Asets: + for B in Bsets: + Aset = frozenset(A) + Bset = frozenset(B) + + if len(Aset.intersection(Bset)) > 0: + continue + + A = sorted(A) + B = sorted(B) + cov1 = g1.subcovariance(A, B) + rk1 = matrix_rank(cov1) + cov2 = g2.subcovariance(A, B) + rk2 = matrix_rank(cov2) + + if rk1 != rk2: + return (False, (A, B, rk1, rk2)) + + return (True, None) + + +def display(G): + """ + Display a prettified version of the latentDict. + """ + LOGGER.info(f"{'='*10} Printing Current LatentDict: {'='*10}") + LOGGER.info(f"Active Set: {','.join([str(V) for V in G.activeSet])}") + + for P, v in G.latentDict.items(): + subcovers = v["subcovers"] + Cs = v["children"] + Ctext = ",".join([str(C) for C in Cs]) + Ptext = str(P) + + text = f"{Ptext} : {Ctext}" + if len(subcovers) > 0: + text += " | " + for subcover in subcovers: + text += f"[{str(subcover)}]" + + if v.get("refined", False): + text += " - Refined!" + LOGGER.info(f" {text}") + LOGGER.info("=" * 50) + + +def powerset(iterable): + """ + Return powerset over the elements in iterable, with the first element being + the empty set and the final element being the full set of elements. + """ + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + + +def clearOutputFolder(): + # Clear output folder + files = glob.glob("output/*.pkl") + glob.glob("output/*.png") + for f in files: + os.remove(f) + + +# Insert a new item k:v in the position of key in d +def insertItemToDict(d, oldkey, newkey, newvalue): + d1 = {} + for k, v in d.items(): + if k == oldkey: + d1[newkey] = newvalue + else: + d1[k] = v + return d1 + + +def extractNumber(string): + return int("".join(filter(str.isdigit, string))) + + +def reorderCovers(covers): + """ + Given a set of covers, reorder them such that the lowest number is first. + Reduces some randomness. + """ + covers = list(covers.copy()) + cover_dict = {} + for cover in covers: + digits = [extractNumber(v) for v in cover.vars] + min_digit = min(digits) + cover_dict[min_digit] = cover + covers = [v for k, v in sorted(cover_dict.items())] + return covers + + +def displayI(I: dict): + for k, v in I.items(): + A, B = list(k) + condition = [x for x in v["condition"]] + LOGGER.info(f" {A} indep {B} | {condition}") + + +class Independences(dict): + """ + Simple wrapper over dict to do conditional update of independence relns. + + Specifically, if we already found A indep B with details (e.g. setA, setB), + we do not want to overwrite these details when performing update. + """ + + def update(self, other: Independences): + for k, v in other.items(): + if k in self: + if not isinstance(self[k], dict): + self[k] = v + elif len(self[k].get("setA", set())) > 0: + continue + else: + self[k] = v + else: + self[k] = v + + +class Edges(dict): + """ + Simple wrapper over dict to do conditional update of edges. + + Specifically, if we already found an edge A -> B, we do not want to + overwrite it with A - B. + """ + + def update(self, other: Edges): + for k, v in other.items(): + if k in self: + if not isinstance(self[k], dict): + self[k] = v + elif self[k][2] == 1: + continue + else: + self[k] = v + else: + self[k] = v + + def getDotGraph(self): + """ + Parse an Edges object into a DotGraph. + """ + + def addParentToGraph(G, parent, childrenSet, directed=0): + for P in parent.vars: + for childGroup in childrenSet: + for child in childGroup.vars: + G.addEdge(P, child, type=directed) + + E = deepcopy(self) + G = DotGraph() + + for AB in E: + for V in list(AB): + for v in V.vars: + G.addNode(v) + + while len(E) > 0: + _, (A, B, direction) = E.popitem() + addParentToGraph(G, A, set([B]), direction) + + return G diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index c67b5a31..b43d6d2a 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -17,6 +17,8 @@ Requirements * scikit-learn * statsmodels * pydot +* communities +* joblib (For visualization) diff --git a/docs/source/search_methods_index/Hidden causal representation learning/index.rst b/docs/source/search_methods_index/Hidden causal representation learning/index.rst index 715bf12b..a6923657 100644 --- a/docs/source/search_methods_index/Hidden causal representation learning/index.rst +++ b/docs/source/search_methods_index/Hidden causal representation learning/index.rst @@ -1,7 +1,7 @@ Hidden causal representation learning ============================================== In this section, we would like to introduce methods in hidden causal representation learning, such as -generalized independent noise (GIN [1]_) condition-based method. +generalized independent noise (GIN [1]_) condition-based method and rank-based latent causal discovery (RLCD [2]_). Contents: @@ -10,5 +10,7 @@ Contents: :maxdepth: 2 gin + rlcd -.. [1] Xie, F., Cai, R., Huang, B., Glymour, C., Hao, Z., & Zhang, K. (2020, January). Generalized Independent Noise Condition for Estimating Latent Variable Causal Graphs. In NeurIPS. \ No newline at end of file +.. [1] Xie, F., Cai, R., Huang, B., Glymour, C., Hao, Z., & Zhang, K. (2020, January). Generalized Independent Noise Condition for Estimating Latent Variable Causal Graphs. In NeurIPS. +.. [2] Dong, X., Huang, B., Ng, I., Song, X., Zheng, Y., Jin, S., Legaspi, R., Spirtes, P., & Zhang, K. (2024). A versatile causal discovery framework to allow causally-related hidden variables. In International Conference on Learning Representations, vol. 2024, pp. 43084-43118. \ No newline at end of file diff --git a/docs/source/search_methods_index/Hidden causal representation learning/rlcd.rst b/docs/source/search_methods_index/Hidden causal representation learning/rlcd.rst new file mode 100644 index 00000000..ecd0dabb --- /dev/null +++ b/docs/source/search_methods_index/Hidden causal representation learning/rlcd.rst @@ -0,0 +1,55 @@ +.. _rlcd: + +Rank-based Latent Causal Discovery (RLCD) +============================================================= + +Algorithm Introduction +----------------------------------------------------------- + +RLCD [1]_ learns causal structures with causally-related hidden variables from rank constraints in partially observed linear causal models. + +This implementation includes the structure learning part of RLCD from ``scm-identify``. It provides the main RLCD search routine and the rank-test helper used for sample data. + +Usage +----------------------------------------------------------- +.. code-block:: python + + from causallearn.search.HiddenCausal.RLCD import RLCD + + # default parameters + cg = RLCD(data) + + # or customized parameters + cg = RLCD(data, ranktest_method, stage1_method, alpha_dict, maxk, node_names) + + # visualization using pydot + cg.draw_pydot_graph() + + # or save the graph + from causallearn.utils.GraphUtils import GraphUtils + + pyd = GraphUtils.to_pydot(cg.G) + pyd.write_png('rlcd_result.png') + +Visualization using pydot is recommended. If specific label names are needed, please refer to this `usage example `_ (e.g., 'cg.draw_pydot_graph(labels=["A", "B", "C"])' or 'GraphUtils.to_pydot(cg.G, labels=["A", "B", "C"])'). + +Parameters +----------------------------------------------------------- +**data**: numpy.ndarray, shape (n_samples, n_features). Data, where n_samples is the number of samples +and n_features is the number of features. + +**ranktest_method**: rank test object, optional. The rank test object should provide a ``test(pcols, qcols, r, alpha)`` method. If not provided, ``Chi2RankTest(data)`` is used. + +**stage1_method**: str. Stage-1 method used to partition observed variables. Default: 'ges'. + +**alpha_dict**: dict, optional. Significance levels for rank tests by rank. Default: ``{0: 0.01, 1: 0.01, 2: 0.01, 3: 0.01}``. + +**maxk**: int. Maximum rank-search cardinality. Default: 3. + +**node_names**: list, optional. Names of observed variables in the returned graph. If not provided, variables are named ``X1``, ``X2``, ... Latent variables are named ``L1``, ``L2``, ... + +Returns +----------------------------------------------------------- +**cg**: CausalGraph. Learned graph over observed and latent variables, where ``cg.G.graph[j,i]=1`` and ``cg.G.graph[i,j]=-1`` indicate ``i --> j``; ``cg.G.graph[i,j] = cg.G.graph[j,i] = -1`` indicate ``i --- j``; ``cg.G.graph[i,j] = cg.G.graph[j,i] = 1`` indicates ``i <-> j``. + +.. [1] Dong, X., Huang, B., Ng, I., Song, X., Zheng, Y., Jin, S., Legaspi, R., Spirtes, P., & Zhang, K. (2024). A versatile causal discovery framework to allow causally-related hidden variables. In International Conference on Learning Representations, vol. 2024, pp. 43084-43118. diff --git a/setup.py b/setup.py index 0b502da0..a0c9de50 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,9 @@ 'networkx', 'pydot', 'tqdm', - 'momentchi2' + 'momentchi2', + 'communities', + 'joblib' ], url='https://github.com/py-why/causal-learn', packages=setuptools.find_packages(), diff --git a/tests/TestRLCD.py b/tests/TestRLCD.py new file mode 100644 index 00000000..4835a0d3 --- /dev/null +++ b/tests/TestRLCD.py @@ -0,0 +1,56 @@ +import unittest + +import numpy as np + +from causallearn.graph.GraphClass import CausalGraph +from causallearn.search.HiddenCausal.RLCD import Chi2RankTest, RLCD +from causallearn.utils.GraphUtils import GraphUtils + + +class TestRLCD(unittest.TestCase): + def test_rlcd_recovers_linear_gaussian_hidden_structure(self): + rng = np.random.default_rng(1) + sample_size = 2000 + L1 = rng.normal(size=sample_size) + L2 = 0.8 * L1 + rng.normal(size=sample_size) + X1 = 1.2 * L1 + 0.05 * rng.normal(size=sample_size) + X2 = 1.4 * L1 + 0.05 * rng.normal(size=sample_size) + X3 = 1.6 * L1 + 0.05 * rng.normal(size=sample_size) + X4 = 1.1 * L2 + 0.05 * rng.normal(size=sample_size) + X5 = 1.3 * L2 + 0.05 * rng.normal(size=sample_size) + X6 = 1.5 * L2 + 0.05 * rng.normal(size=sample_size) + data = np.column_stack([X1, X2, X3, X4, X5, X6]) + data = (data - data.mean(axis=0)) / data.std(axis=0) + + cg = RLCD( + data, + ranktest_method=Chi2RankTest(data), + stage1_method="all", + alpha_dict={0: 0.01, 1: 0.01, 2: 0.01, 3: 0.01}, + maxk=2, + ) + + self.assertIsInstance(cg, CausalGraph) + self.assertIsInstance(cg.stage1_cg, CausalGraph) + self.assertEqual(cg.all_vars[:6], ["X1", "X2", "X3", "X4", "X5", "X6"]) + self.assertEqual(len(cg.all_vars), 8) + + graph = cg.G.graph + + def get_latent_parent(children): + for latent_idx in range(6, 8): + if np.all(graph[children, latent_idx] == 1) and np.all(graph[latent_idx, children] == -1): + return latent_idx + return None + + l1_parent = get_latent_parent([0, 1, 2]) + l2_parent = get_latent_parent([3, 4, 5]) + self.assertIsNotNone(l1_parent) + self.assertIsNotNone(l2_parent) + self.assertNotEqual(l1_parent, l2_parent) + self.assertIn((graph[l2_parent, l1_parent], graph[l1_parent, l2_parent]), [(1, -1), (-1, 1)]) + self.assertIsNotNone(GraphUtils.to_pydot(cg.G)) + + +if __name__ == "__main__": + unittest.main()