Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions include/svs/index/vamana/prune.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ void heuristic_prune_neighbors(

auto pruned = std::vector<PruneState>(poolsize, PruneState::Available);
float current_alpha = 1.0f;
float anchor_dist = 0.0f;
bool anchor_set = false;
bool all_duplicates = true;
while (result.size() < max_result_size && !cmp(alpha, current_alpha)) {
size_t start = 0;
while (result.size() < max_result_size && start < poolsize) {
Expand All @@ -145,6 +148,16 @@ void heuristic_prune_neighbors(
const auto& query = accessor(dataset, id);
distance::maybe_fix_argument(distance_function, query);
result.push_back(detail::construct_as(lib::Type<I>(), pool[start]));

if (all_duplicates) {
if (!anchor_set) {
anchor_dist = pool[start].distance();
anchor_set = true;
} else if (pool[start].distance() != anchor_dist) {
all_duplicates = false;
}
}

for (size_t t = start + 1; t < poolsize; ++t) {
if (excluded(pruned[t])) {
continue;
Expand All @@ -171,6 +184,35 @@ void heuristic_prune_neighbors(
}
current_alpha *= alpha;
}

// Add a diversity edge if a duplicate cluster is detected
if (all_duplicates && anchor_set && !result.empty()) {
auto result_id = [](const I& r) -> size_t {
if constexpr (std::integral<I>) {
return static_cast<size_t>(r);
} else {
return static_cast<size_t>(r.id());
}
};
for (size_t t = 0; t < poolsize; ++t) {
const auto& candidate = pool[t];
auto cid = candidate.id();
if (cid == current_node_id || candidate.distance() == anchor_dist) {
continue;
}
bool in_result = false;
for (const auto& r : result) {
if (result_id(r) == static_cast<size_t>(cid)) {
in_result = true;
break;
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m confused why cid could already be in result here. We just checked above that this candidate's distance differs from anchor_dist, so I wouldn’t expect it to be present (as all the results have same distance). Should this be an assert instead (i.e., this candidate must not already be in result)? Am I missing a scenario?

if (!in_result) {
result.back() = detail::construct_as(lib::Type<I>(), candidate);
break;
}
}
}
}

template <
Expand Down Expand Up @@ -203,6 +245,9 @@ void heuristic_prune_neighbors(
std::vector<float> pruned(poolsize, type_traits::tombstone_v<float, decltype(cmp)>);

float current_alpha = 1.0f;
float anchor_dist = 0.0f;
bool anchor_set = false;
bool all_duplicates = true;
while (result.size() < max_result_size && !cmp(alpha, current_alpha)) {
size_t start = 0;
while (result.size() < max_result_size && start < poolsize) {
Expand All @@ -218,6 +263,16 @@ void heuristic_prune_neighbors(
const auto& query = accessor(dataset, id);
distance::maybe_fix_argument(distance_function, query);
result.push_back(detail::construct_as(lib::Type<I>(), pool[start]));

if (all_duplicates) {
if (!anchor_set) {
anchor_dist = pool[start].distance();
anchor_set = true;
} else if (pool[start].distance() != anchor_dist) {
all_duplicates = false;
}
}

for (size_t t = start + 1; t < poolsize; ++t) {
if (cmp(current_alpha, pruned[t])) {
continue;
Expand All @@ -236,6 +291,35 @@ void heuristic_prune_neighbors(
}
current_alpha *= alpha;
}

// Add a diversity edge if a duplicate cluster is detected
if (all_duplicates && anchor_set && !result.empty()) {
auto result_id = [](const I& r) -> size_t {
if constexpr (std::integral<I>) {
return static_cast<size_t>(r);
} else {
return static_cast<size_t>(r.id());
}
};
for (size_t t = 0; t < poolsize; ++t) {
const auto& candidate = pool[t];
auto cid = candidate.id();
if (cid == current_node_id || candidate.distance() == anchor_dist) {
continue;
}
bool in_result = false;
for (const auto& r : result) {
if (result_id(r) == static_cast<size_t>(cid)) {
in_result = true;
break;
}
}
if (!in_result) {
result.back() = detail::construct_as(lib::Type<I>(), candidate);
break;
}
}
}
}

///
Expand Down
Loading