Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions causallearn/search/HiddenCausal/RLCD/Chi2RankTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import numpy as np
from pdb import set_trace
from statsmodels.multivariate.cancorr import CanCorr
from math import log, pow
from scipy.stats import chi2
from scipy.linalg import eigh

from .logger import LOGGER


class Chi2RankTest(object):

def __init__(self, data, N_scaling=1):

self.data = data
self.data = self.data - self.data.mean(axis=0)
self.data = self.data/self.data.std(axis=0)

self.N = data.shape[0]
self.N_scaling = N_scaling
self.unnormalized_crosscovs = self.data.T@self.data # data are zero mean
self.cca_cache_dict = {}

def get_cachekey(self, pcols_, qcols_):

pcols = pcols_.copy()
qcols = qcols_.copy()
pcols.sort()
qcols.sort()

key = ''
for i in range(self.data.shape[1]):
if i in pcols and i in qcols:
key = key+str(3)
elif i in pcols and not i in qcols:
key = key+str(2)
elif not i in pcols and i in qcols:
key = key+str(1)
else:
key = key+str(0)

return key

def test(self, pcols, qcols, r, alpha):
'''
Parameters
----------
pcols, qcols : column indices of data
r: null hypo that rank <= r
alpha: significance level

Returns
-------
if_fail_to_reject: 0 means reject and 1 means fail to reject
p : the p-value of the test
'''



cachekey = self.get_cachekey(pcols, qcols)

if cachekey in self.cca_cache_dict:
cancorr = self.cca_cache_dict[cachekey]
else:

X = self.data[:, pcols]
Y = self.data[:, qcols]

unnormalized_crosscovs = [self.unnormalized_crosscovs[pcols,:][:,pcols], self.unnormalized_crosscovs[pcols,:][:,qcols], \
self.unnormalized_crosscovs[qcols,:][:,pcols], self.unnormalized_crosscovs[qcols,:][:,qcols]]

try:
comps = kcca_modified([X, Y], reg=0.,
numCC=None, kernelcca=False, ktype='linear',
gausigma=1.0, degree=2, crosscovs = unnormalized_crosscovs)

cancorr, _, _ = recon([X,Y], comps, kernelcca=False)
cancorr = cancorr[:,0,1]
except:
LOGGER.debug("calculating cancorr error %s %s, using slower implementation instead", pcols, qcols)
X_fallback = np.atleast_2d(X).reshape(X.shape[0], -1)
Y_fallback = np.atleast_2d(Y).reshape(Y.shape[0], -1)
cancorr = CanCorr(X_fallback, Y_fallback, tolerance=1e-8).cancorr

self.cca_cache_dict[cachekey] = cancorr

testStat = 0
p = len(pcols)
q = len(qcols)

l = cancorr[r:]
for li in l:
li = min(li, 1-1e-15)
testStat += log(1)-log(1-li*li)
ratio = 0
for i in range(r):
li = cancorr[i]
ratio += 1/(li*li)-1

ratio += self.N*self.N_scaling - r - 0.5*(p+q+1)
testStat = testStat * ratio

dfreedom = (p-r) * (q-r)
criticalValue = chi2.ppf(1-alpha, dfreedom)
p = 1 - chi2.cdf(testStat, dfreedom)
if_fail_to_reject = testStat<=criticalValue

# due to numerical errors comparing criticalValue with testStat is more accurate

return if_fail_to_reject


def kcca_modified(
data, reg=0.0, numCC=None, kernelcca=False, ktype="linear", gausigma=1.0, degree=2, crosscovs=None
):
"""Set up and solve the kernel CCA eigenproblem"""
if kernelcca:
raise NotImplementedError
#kernel = [
# _make_kernel(d, ktype=ktype, gausigma=gausigma, degree=degree) for d in data
#]
else:
kernel = [d.T for d in data]

nDs = len(kernel)
nFs = [k.shape[0] for k in kernel]
numCC = min([k.shape[0] for k in kernel]) if numCC is None else numCC

# Get the auto- and cross-covariance matrices
if crosscovs is None:
crosscovs = [np.dot(ki, kj.T) for ki in kernel for kj in kernel]

# Allocate left-hand side (LH) and right-hand side (RH):
n = sum(nFs)
LH = np.zeros((n, n))
RH = np.zeros((n, n))

# Fill the left and right sides of the eigenvalue problem
for i in range(nDs):
RH[
sum(nFs[:i]): sum(nFs[: i + 1]), sum(nFs[:i]): sum(nFs[: i + 1])
] = crosscovs[i * (nDs + 1)] + reg * np.eye(nFs[i])

for j in range(nDs):
if i != j:
LH[
sum(nFs[:j]): sum(nFs[: j + 1]), sum(nFs[:i]): sum(nFs[: i + 1])
] = crosscovs[nDs * j + i]

LH = (LH + LH.T) / 2.0
RH = (RH + RH.T) / 2.0

maxCC = LH.shape[0]
try:
r, Vs = eigh(LH, RH, subset_by_index=[maxCC - numCC, maxCC - 1])
except TypeError:
r, Vs = eigh(LH, RH, eigvals=(maxCC - numCC, maxCC - 1))
r[np.isnan(r)] = 0
rindex = np.argsort(r)[::-1]
comp = []
Vs = Vs[:, rindex]
for i in range(nDs):
comp.append(Vs[sum(nFs[:i]): sum(nFs[: i + 1]), :numCC])
return comp

def _listdot(d1, d2):
return [np.dot(x[0].T, x[1]) for x in zip(d1, d2)]

def _listcorr(a):
"""Returns pairwise row correlations for all items in array as a list of matrices"""
corrs = np.zeros((a[0].shape[1], len(a), len(a)))
for i in range(len(a)):
for j in range(len(a)):
if j > i:
corrs[:, i, j] = [
np.nan_to_num(np.corrcoef(ai, aj)[0, 1])
for (ai, aj) in zip(a[i].T, a[j].T)
]
return corrs

def recon(data, comp, corronly=False, kernelcca=False):
# Get canonical variates and CCs
if kernelcca:
ws = _listdot(data, comp)
else:
ws = comp
ccomp = _listdot([d.T for d in data], ws)
corrs = _listcorr(ccomp)
if corronly:
return corrs
else:
return corrs, ws, ccomp

186 changes: 186 additions & 0 deletions causallearn/search/HiddenCausal/RLCD/Cover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from __future__ import annotations
from itertools import combinations


class Cover:
"""
Class to represent a Cover of latent variables. A Cover can either be
atomic or non-atomic.

An atomic Cover is one where the variables cannot be split into
disjoint set of atomic Covers. E.g. if {L1, L2} forms a Cover, it is
not atomic if L1 and L2 individually are atomic Covers.
Note: Only atomic Covers can be children nodes.

A Cover can also be temporary or not. A temporary Cover is one which was
introduced not because a rank deficiency was found but because we had to
introduce a temporary root variable to connect the remaining variables
when no more rank deficient sets may be found.
"""

def __init__(self, varnames, atomic=True, temp=False, is_observed=True, is_leaf=None):
if isinstance(varnames, str):
self.vars = set([varnames])

elif isinstance(varnames, list):
self.vars = set(varnames)

elif isinstance(varnames, set):
self.vars = varnames
else:
raise ValueError(f"{varnames} is neither str, list, set.")

# if len(self.vars) > 0:
# v = next(iter(self.vars))
# self.type = v[:1]

self.atomic = atomic
self.temp = temp
self.is_observed=is_observed
self.is_leaf=is_leaf

def __eq__(self, other):
if not isinstance(other, Cover):
return NotImplemented
return self.vars == other.vars

# The set of variables in any minimalGroup should be unique
def __hash__(self):
s = "".join(sorted(list(self.vars)))
return hash(s)

# Union with another Cover
def union(self, L):
self.vars = self.vars.union(L.vars)

def __len__(self):
return len(self.vars)

@property
def isAtomic(self):
return self.atomic

@property
def isTemp(self):
return self.temp

def takeOne(self):
return next(iter(self.vars))

def __str__(self):
if len(self.vars) == 1:
return next(iter(self.vars))
else:
vars = ",".join(list(self.vars))
return "{" + vars + "}"

def __repr__(self):
return str(self)

def isSubset(self, Bs: set[Cover] | Cover, strict=False):
if isinstance(Bs, set):
Bvars = getVars(Bs)
elif isinstance(Bs, Cover):
Bvars = Bs.vars
else:
raise ValueError("Argument must be set of Covers or Cover.")
if strict:
return self.vars < Bvars
return self.vars <= Bvars

def intersection(self, B):
return self.vars.intersection(B.vars)


##################################
# Methods associated with Covers #
##################################


def setLength(Vs: set[Cover]):
"""
Determine ||Vs||
"""
assert not isinstance(Vs, str), "Cannot be string."
return len(getVars(Vs))


def setDifference(As: set[Cover], Bs: set[Cover]):
diff = As - Bs # first remove any common elements
newset = set()
while len(diff) > 0:
A = diff.pop()
newset.add(A)
for B in Bs:
if len(A.intersection(B)) > 0:
newset.remove(A)
break
return newset


def setOverlap(As: set[Cover], Bs: set[Cover]):
if len(As.intersection(Bs)) > 0:
return True
return len(setIntersection(As, Bs)) > 0


def setIntersection(As: set[Cover], Bs: set[Cover]):
Avars = getVars(As)
Bvars = getVars(Bs)
return Avars.intersection(Bvars)


def getVars(As: set[Cover]):
"""
Get all variables from a set of Covers.
"""
vars = set()
for A in As:
assert isinstance(A, Cover), "Argument should be a set of Covers."
vars.update(A.vars)
return vars

def getOrderedVarsString(As: set[Cover]|Cover):
"""
Get all variables from a set of Covers.
"""

if isinstance(As, Cover):
As = {As}

vars = set()
for A in As:
assert isinstance(A, Cover), "Argument should be a set of Covers."
vars.update(A.vars)

vars = [x for x in vars]
vars.sort()

return " ".join(vars)

def deduplicate(Vs: set[Cover]):
"""
Deduplicate cases where Vs includes {L1, {L1, L3}} into just {{L1, L3}}
"""
newVs = set()
for Vi in Vs:
for Vj in Vs:
isDuplicate = False
if Vi.vars < Vj.vars:
isDuplicate = True
break
if not isDuplicate:
newVs.add(Vi)
return newVs


def pairwiseOverlap(Vs: set[Cover]):
"""
For each pair A, B in Vs, check if there are any variables overlapping.
Return True if so.
"""
for pair in combinations(Vs, 2):
A, B = list(pair)
if len(A.vars.intersection(B.vars)) > 0:
return True
return False
Loading
Loading