Skip to content

Commit 732890f

Browse files
committed
add debug interface
1 parent 3f012b7 commit 732890f

7 files changed

Lines changed: 492 additions & 63 deletions

File tree

include/abstract_index.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "index_config.h"
77
#include "index_build_params.h"
88
#include "percentile_stats.h"
9+
#include "debug_utils.h"
910
#include <any>
1011

1112
namespace diskann
@@ -90,6 +91,30 @@ class AbstractIndex
9091
float *distances,
9192
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);
9293

94+
// Debug interface: retrieve the raw embedding at internal location index.
95+
// Caller must pre-allocate vec with at least the index dimension elements.
96+
template <typename data_type>
97+
void get_embedding(uint32_t location, data_type *vec);
98+
99+
// Debug search: runs ANN search and records every traversed node in debug_info.
100+
template <typename data_type, typename IDType>
101+
std::pair<uint32_t, uint32_t> debug_search(
102+
const data_type *query, const size_t K, const uint32_t L,
103+
IDType *indices, float *distances,
104+
DebugTraversalInfo &debug_info,
105+
const uint32_t maxLperSeller = 0,
106+
std::function<float(const std::uint8_t *, size_t)> rerank_fn = nullptr);
107+
108+
// Debug filtered search: same as debug_search with label filtering.
109+
template <typename IDType>
110+
std::pair<uint32_t, uint32_t> debug_search_with_filters(
111+
const DataType &query, const std::vector<std::string> &raw_labels,
112+
const size_t K, const uint32_t L,
113+
IDType *indices, float *distances,
114+
DebugTraversalInfo &debug_info,
115+
const uint32_t maxLperSeller = 0,
116+
std::function<float(const std::uint8_t *, size_t)> rerank_fn = nullptr);
117+
93118
// insert points with labels, labels should be present for filtered index
94119
template <typename data_type, typename tag_type>
95120
int insert_point(const data_type *point, const tag_type tag, const std::vector<std::string> &labels);
@@ -148,5 +173,18 @@ class AbstractIndex
148173
const std::vector<std::string>& filter_labels) = 0;
149174
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0;
150175
virtual void _set_universal_label(const LabelType universal_label) = 0;
176+
virtual void _get_embedding(uint32_t location, DataType &vec) = 0;
177+
virtual std::pair<uint32_t, uint32_t> _debug_search(const DataType &query, const size_t K, const uint32_t L,
178+
std::any &indices, float *distances,
179+
DebugTraversalInfo &debug_info,
180+
const uint32_t maxLperSeller,
181+
std::function<float(const std::uint8_t *, size_t)> rerank_fn) = 0;
182+
virtual std::pair<uint32_t, uint32_t> _debug_search_with_filters(const DataType &query,
183+
const std::vector<std::string> &raw_labels,
184+
const size_t K, const uint32_t L,
185+
std::any &indices, float *distances,
186+
DebugTraversalInfo &debug_info,
187+
const uint32_t maxLperSeller,
188+
std::function<float(const std::uint8_t *, size_t)> rerank_fn) = 0;
151189
};
152190
} // namespace diskann

include/debug_utils.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT license.
3+
4+
#pragma once
5+
6+
#include <cstdint>
7+
#include <vector>
8+
#include <limits>
9+
10+
namespace diskann
11+
{
12+
13+
// Reason why a node visited during ANN graph traversal was or was not included
14+
// in the final result set.
15+
enum class FilterReason : uint8_t
16+
{
17+
InResult = 0, // Node was kept in the final top-K result
18+
DistanceTooLarge, // Node was visited but its distance was too large for top-K
19+
LabelMismatch, // Node was skipped because its label did not match the filter
20+
};
21+
22+
// Collects per-node traversal information during a debug ANN search.
23+
// Populated by iterate_to_fixed_point / cached_beam_search when a non-null
24+
// pointer is passed. Each parallel vector entry corresponds to one node
25+
// encountered during traversal.
26+
struct DebugTraversalInfo
27+
{
28+
std::vector<uint32_t> ids; // Internal location index of each encountered node
29+
std::vector<float> distances; // PQ/exact distance to query; FLT_MAX when label-rejected
30+
std::vector<uint8_t> label_rejected; // 1 if skipped due to label mismatch, 0 if evaluated
31+
32+
void clear()
33+
{
34+
ids.clear();
35+
distances.clear();
36+
label_rejected.clear();
37+
}
38+
39+
void record_label_rejected(uint32_t id)
40+
{
41+
ids.push_back(id);
42+
distances.push_back(std::numeric_limits<float>::max());
43+
label_rejected.push_back(1);
44+
}
45+
46+
void record_visited(uint32_t id, float dist)
47+
{
48+
ids.push_back(id);
49+
distances.push_back(dist);
50+
label_rejected.push_back(0);
51+
}
52+
};
53+
54+
} // namespace diskann

include/index.h

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include "quantized_distance.h"
3030
#include "pq_data_store.h"
31+
#include "debug_utils.h"
3132

3233
#define OVERHEAD_FACTOR 1.1
3334
#define EXPAND_IF_FULL 0
@@ -145,7 +146,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
145146
template <typename IDType>
146147
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search(const T *query, const size_t K, const uint32_t L,
147148
IDType *indices, float *distances = nullptr, const uint32_t maxLperSeller = 0,
148-
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);
149+
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr,
150+
DebugTraversalInfo *debug_info = nullptr);
149151

150152
template <typename IDType>
151153
std::pair<uint32_t, uint32_t> diverse_search(const T* query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IDType* indices,
@@ -157,6 +159,30 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
157159
float *distances, std::vector<T *> &res_vectors, bool use_filters,
158160
const std::vector<std::string>& filter_labels);
159161

162+
// Debug interface: retrieve the raw embedding stored at the given internal location index.
163+
// Caller must allocate vec with at least get_aligned_dim() elements of type T.
164+
DISKANN_DLLEXPORT void get_embedding(uint32_t location, T *vec) const;
165+
166+
// Debug search: runs ANN search and records every traversed node.
167+
// debug_info is populated in traversal order; call FilterReason helpers to classify entries.
168+
template <typename IDType>
169+
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> debug_search(
170+
const T *query, const size_t K, const uint32_t L,
171+
IDType *indices, float *distances,
172+
DebugTraversalInfo &debug_info,
173+
const uint32_t maxLperSeller = 0,
174+
std::function<float(const std::uint8_t *, size_t)> rerank_fn = nullptr);
175+
176+
// Debug filtered search: same as debug_search but applies label filtering.
177+
template <typename IDType>
178+
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> debug_search_with_filters(
179+
const T *query, const std::vector<LabelT> &filter_labels,
180+
const size_t K, const uint32_t L,
181+
IDType *indices, float *distances,
182+
DebugTraversalInfo &debug_info,
183+
const uint32_t maxLperSeller = 0,
184+
std::function<float(const std::uint8_t *, size_t)> rerank_fn = nullptr);
185+
160186
virtual std::pair<uint32_t, uint32_t> _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
161187
std::any& indices, float* distances = nullptr,
162188
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr) override;
@@ -166,7 +192,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
166192
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const std::vector<LabelT> &filter_labels,
167193
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
168194
IndexType *indices, float *distances,
169-
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);
195+
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr,
196+
DebugTraversalInfo *debug_info = nullptr);
170197

171198
// Will fail if tag already in the index or if tag=0.
172199
DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag);
@@ -235,6 +262,22 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
235262
float *distances,
236263
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr) override;
237264

265+
virtual void _get_embedding(uint32_t location, DataType &vec) override;
266+
267+
virtual std::pair<uint32_t, uint32_t> _debug_search(const DataType &query, const size_t K, const uint32_t L,
268+
std::any &indices, float *distances,
269+
DebugTraversalInfo &debug_info,
270+
const uint32_t maxLperSeller,
271+
std::function<float(const std::uint8_t *, size_t)> rerank_fn) override;
272+
273+
virtual std::pair<uint32_t, uint32_t> _debug_search_with_filters(const DataType &query,
274+
const std::vector<std::string> &raw_labels,
275+
const size_t K, const uint32_t L,
276+
std::any &indices, float *distances,
277+
DebugTraversalInfo &debug_info,
278+
const uint32_t maxLperSeller,
279+
std::function<float(const std::uint8_t *, size_t)> rerank_fn) override;
280+
238281
virtual int _insert_point(const DataType &data_point, const TagType tag) override;
239282
virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector<std::string> &labels) override;
240283

@@ -293,7 +336,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
293336
// The query to use is placed in scratch->aligned_query
294337
std::pair<uint32_t, uint32_t> iterate_to_fixed_point(InMemQueryScratch<T> *scratch, const uint32_t Lindex,
295338
const std::vector<uint32_t> &init_ids, bool use_filter,
296-
const std::vector<LabelT> &filters, bool search_invocation, uint32_t maxLperSeller = 0);
339+
const std::vector<LabelT> &filters, bool search_invocation,
340+
uint32_t maxLperSeller = 0,
341+
DebugTraversalInfo *debug_info = nullptr);
297342

298343
void search_for_point_and_prune(int location, uint32_t Lindex, std::vector<uint32_t> &pruned_list,
299344
InMemQueryScratch<T> *scratch, bool use_filter = false,

include/pq_flash_index.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "tsl/robin_set.h"
1818
#include "label_bitmask.h"
1919
#include "integer_label_vector.h"
20+
#include "debug_utils.h"
2021

2122
#define FULL_PRECISION_REORDER_MULTIPLIER 3
2223

@@ -81,10 +82,11 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
8182
DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
8283
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
8384
const bool use_filter, const std::vector<LabelT> &filter_labels,
84-
const uint32_t io_limit, uint32_t maxLperSeller = 0,
85+
const uint32_t io_limit, uint32_t maxLperSeller = 0,
8586
const bool use_reorder_data = false,
8687
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr,
87-
QueryStats *stats = nullptr);
88+
QueryStats *stats = nullptr,
89+
DebugTraversalInfo *debug_info = nullptr);
8890

8991
DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label);
9092

@@ -117,6 +119,25 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
117119
DISKANN_DLLEXPORT std::vector<std::uint8_t> get_pq_vector(std::uint64_t vid);
118120
DISKANN_DLLEXPORT uint64_t get_num_points();
119121

122+
// Debug interface: retrieve full-precision embedding for a given internal node ID.
123+
// Caller must pre-allocate vec with at least the index dimension elements (get_data_dim()).
124+
DISKANN_DLLEXPORT void get_embedding(uint32_t id, T *vec);
125+
126+
// Debug search: runs ANN search and records every traversed node with a FilterReason.
127+
DISKANN_DLLEXPORT void debug_search(
128+
const T *query, const uint64_t k_search, const uint64_t l_search,
129+
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
130+
DebugTraversalInfo &debug_info,
131+
uint32_t maxLperSeller = 0);
132+
133+
// Debug filtered search: same as debug_search but applies label filtering.
134+
DISKANN_DLLEXPORT void debug_search_with_filters(
135+
const T *query, const uint64_t k_search, const uint64_t l_search,
136+
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
137+
const std::vector<LabelT> &filter_labels,
138+
DebugTraversalInfo &debug_info,
139+
uint32_t maxLperSeller = 0);
140+
120141
protected:
121142
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
122143
DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096);

src/abstract_index.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,83 @@ template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag<tag_uint128, int
378378
template DISKANN_DLLEXPORT void AbstractIndex::set_universal_label<uint16_t>(const uint16_t label);
379379
template DISKANN_DLLEXPORT void AbstractIndex::set_universal_label<uint32_t>(const uint32_t label);
380380

381+
// Debug interface template implementations
382+
template <typename data_type>
383+
void AbstractIndex::get_embedding(uint32_t location, data_type *vec)
384+
{
385+
DataType any_vec(vec);
386+
_get_embedding(location, any_vec);
387+
}
388+
389+
template <typename data_type, typename IDType>
390+
std::pair<uint32_t, uint32_t> AbstractIndex::debug_search(
391+
const data_type *query, const size_t K, const uint32_t L,
392+
IDType *indices, float *distances,
393+
DebugTraversalInfo &debug_info,
394+
const uint32_t maxLperSeller,
395+
std::function<float(const std::uint8_t *, size_t)> rerank_fn)
396+
{
397+
auto any_query = std::any(query);
398+
auto any_indices = std::any(indices);
399+
return _debug_search(any_query, K, L, any_indices, distances,
400+
debug_info, maxLperSeller, std::move(rerank_fn));
401+
}
402+
403+
template <typename IDType>
404+
std::pair<uint32_t, uint32_t> AbstractIndex::debug_search_with_filters(
405+
const DataType &query, const std::vector<std::string> &raw_labels,
406+
const size_t K, const uint32_t L,
407+
IDType *indices, float *distances,
408+
DebugTraversalInfo &debug_info,
409+
const uint32_t maxLperSeller,
410+
std::function<float(const std::uint8_t *, size_t)> rerank_fn)
411+
{
412+
auto any_indices = std::any(indices);
413+
return _debug_search_with_filters(query, raw_labels, K, L, any_indices, distances,
414+
debug_info, maxLperSeller, std::move(rerank_fn));
415+
}
416+
417+
// Explicit instantiations for get_embedding
418+
template DISKANN_DLLEXPORT void AbstractIndex::get_embedding<float>(uint32_t location, float *vec);
419+
template DISKANN_DLLEXPORT void AbstractIndex::get_embedding<uint8_t>(uint32_t location, uint8_t *vec);
420+
template DISKANN_DLLEXPORT void AbstractIndex::get_embedding<int8_t>(uint32_t location, int8_t *vec);
421+
422+
// Explicit instantiations for debug_search (float/uint8_t/int8_t × uint32_t/uint64_t indices)
423+
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::debug_search<float, uint32_t>(
424+
const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances,
425+
DebugTraversalInfo &debug_info, const uint32_t maxLperSeller,
426+
std::function<float(const std::uint8_t *, size_t)> rerank_fn);
427+
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::debug_search<uint8_t, uint32_t>(
428+
const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances,
429+
DebugTraversalInfo &debug_info, const uint32_t maxLperSeller,
430+
std::function<float(const std::uint8_t *, size_t)> rerank_fn);
431+
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::debug_search<int8_t, uint32_t>(
432+
const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances,
433+
DebugTraversalInfo &debug_info, const uint32_t maxLperSeller,
434+
std::function<float(const std::uint8_t *, size_t)> rerank_fn);
435+
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::debug_search<float, uint64_t>(
436+
const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances,
437+
DebugTraversalInfo &debug_info, const uint32_t maxLperSeller,
438+
std::function<float(const std::uint8_t *, size_t)> rerank_fn);
439+
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::debug_search<uint8_t, uint64_t>(
440+
const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances,
441+
DebugTraversalInfo &debug_info, const uint32_t maxLperSeller,
442+
std::function<float(const std::uint8_t *, size_t)> rerank_fn);
443+
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::debug_search<int8_t, uint64_t>(
444+
const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances,
445+
DebugTraversalInfo &debug_info, const uint32_t maxLperSeller,
446+
std::function<float(const std::uint8_t *, size_t)> rerank_fn);
447+
448+
// Explicit instantiations for debug_search_with_filters (uint32_t/uint64_t indices)
449+
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::debug_search_with_filters<uint32_t>(
450+
const DataType &query, const std::vector<std::string> &raw_labels,
451+
const size_t K, const uint32_t L, uint32_t *indices, float *distances,
452+
DebugTraversalInfo &debug_info, const uint32_t maxLperSeller,
453+
std::function<float(const std::uint8_t *, size_t)> rerank_fn);
454+
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::debug_search_with_filters<uint64_t>(
455+
const DataType &query, const std::vector<std::string> &raw_labels,
456+
const size_t K, const uint32_t L, uint64_t *indices, float *distances,
457+
DebugTraversalInfo &debug_info, const uint32_t maxLperSeller,
458+
std::function<float(const std::uint8_t *, size_t)> rerank_fn);
459+
381460
} // namespace diskann

0 commit comments

Comments
 (0)