@@ -20,7 +20,7 @@ import warnings
2020
2121
2222cdef extern from " EMD.h" :
23- int EMD_wrap(int n1,int n2, double * X, double * Y,double * D, double * G, double * alpha, double * beta, double * cost, uint64_t maxIter) nogil
23+ int EMD_wrap(int n1,int n2, double * X, double * Y,double * D, double * G, double * alpha, double * beta, double * cost, uint64_t maxIter, double * alpha_init, double * beta_init ) nogil
2424 int EMD_wrap_omp(int n1,int n2, double * X, double * Y,double * D, double * G, double * alpha, double * beta, double * cost, uint64_t maxIter, int numThreads) nogil
2525 int EMD_wrap_sparse(int n1, int n2, double * X, double * Y, uint64_t n_edges, uint64_t * edge_sources, uint64_t * edge_targets, double * edge_costs, uint64_t * flow_sources_out, uint64_t * flow_targets_out, double * flow_values_out, uint64_t * n_flows_out, double * alpha, double * beta, double * cost, uint64_t maxIter) nogil
2626 int EMD_wrap_lazy(int n1, int n2, double * X, double * Y, double * coords_a, double * coords_b, int dim, int metric, double * G, double * alpha, double * beta, double * cost, uint64_t maxIter) nogil
@@ -42,7 +42,7 @@ def check_result(result_code):
4242
4343@ cython.boundscheck (False )
4444@ cython.wraparound (False )
45- def emd_c (np.ndarray[double , ndim = 1 , mode = " c" ] a, np.ndarray[double , ndim = 1 , mode = " c" ] b, np.ndarray[double , ndim = 2 , mode = " c" ] M, uint64_t max_iter , int numThreads ):
45+ def emd_c (np.ndarray[double , ndim = 1 , mode = " c" ] a, np.ndarray[double , ndim = 1 , mode = " c" ] b, np.ndarray[double , ndim = 2 , mode = " c" ] M, uint64_t max_iter , int numThreads , alpha_init = None , beta_init = None ):
4646 """
4747 Solves the Earth Movers distance problem and returns the optimal transport matrix
4848
@@ -81,6 +81,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
8181 max_iter : uint64_t
8282 The maximum number of iterations before stopping the optimization
8383 algorithm if it has not converged.
84+ alpha_init : (ns,) numpy.ndarray, float64, optional
85+ Initial dual potentials for sources (warmstart)
86+ beta_init : (nt,) numpy.ndarray, float64, optional
87+ Initial dual potentials for targets (warmstart)
8488
8589 Returns
8690 -------
@@ -101,6 +105,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
101105 cdef np.ndarray[double , ndim= 2 , mode= " c" ] G= np.zeros([0 , 0 ])
102106
103107 cdef np.ndarray[double , ndim= 1 , mode= " c" ] Gv= np.zeros(0 )
108+
109+ # Warmstart potentials
110+ cdef np.ndarray[double , ndim= 1 , mode= " c" ] alpha_init_c
111+ cdef np.ndarray[double , ndim= 1 , mode= " c" ] beta_init_c
112+ cdef double * alpha_init_ptr = NULL
113+ cdef double * beta_init_ptr = NULL
104114
105115 if not len (a):
106116 a= np.ones((n1,))/ n1
@@ -110,11 +120,18 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
110120
111121 # init OT matrix
112122 G= np.zeros([n1, n2])
123+
124+ # Setup warmstart pointers if provided
125+ if alpha_init is not None and beta_init is not None :
126+ alpha_init_c = np.ascontiguousarray(alpha_init, dtype = np.float64)
127+ beta_init_c = np.ascontiguousarray(beta_init, dtype = np.float64)
128+ alpha_init_ptr = < double * > alpha_init_c.data
129+ beta_init_ptr = < double * > beta_init_c.data
113130
114131 # calling the function
115132 with nogil:
116133 if numThreads == 1 :
117- result_code = EMD_wrap(n1, n2, < double * > a.data, < double * > b.data, < double * > M.data, < double * > G.data, < double * > alpha.data, < double * > beta.data, < double * > & cost, max_iter)
134+ result_code = EMD_wrap(n1, n2, < double * > a.data, < double * > b.data, < double * > M.data, < double * > G.data, < double * > alpha.data, < double * > beta.data, < double * > & cost, max_iter, alpha_init_ptr, beta_init_ptr )
118135 else :
119136 result_code = EMD_wrap_omp(n1, n2, < double * > a.data, < double * > b.data, < double * > M.data, < double * > G.data, < double * > alpha.data, < double * > beta.data, < double * > & cost, max_iter, numThreads)
120137 return G, cost, alpha, beta, result_code
0 commit comments