diff --git a/include/svs/index/vamana/prune.h b/include/svs/index/vamana/prune.h index aeab27ac..bbcb9df9 100644 --- a/include/svs/index/vamana/prune.h +++ b/include/svs/index/vamana/prune.h @@ -130,6 +130,9 @@ void heuristic_prune_neighbors( auto pruned = std::vector(poolsize, PruneState::Available); float current_alpha = 1.0f; + float anchor_dist = 0.0f; + bool anchor_set = false; + bool all_duplicates = true; while (result.size() < max_result_size && !cmp(alpha, current_alpha)) { size_t start = 0; while (result.size() < max_result_size && start < poolsize) { @@ -145,6 +148,16 @@ void heuristic_prune_neighbors( const auto& query = accessor(dataset, id); distance::maybe_fix_argument(distance_function, query); result.push_back(detail::construct_as(lib::Type(), pool[start])); + + if (all_duplicates) { + if (!anchor_set) { + anchor_dist = pool[start].distance(); + anchor_set = true; + } else if (pool[start].distance() != anchor_dist) { + all_duplicates = false; + } + } + for (size_t t = start + 1; t < poolsize; ++t) { if (excluded(pruned[t])) { continue; @@ -171,6 +184,35 @@ void heuristic_prune_neighbors( } current_alpha *= alpha; } + + // Add a diversity edge if a duplicate cluster is detected + if (all_duplicates && anchor_set && !result.empty()) { + auto result_id = [](const I& r) -> size_t { + if constexpr (std::integral) { + return static_cast(r); + } else { + return static_cast(r.id()); + } + }; + for (size_t t = 0; t < poolsize; ++t) { + const auto& candidate = pool[t]; + auto cid = candidate.id(); + if (cid == current_node_id || candidate.distance() == anchor_dist) { + continue; + } + bool in_result = false; + for (const auto& r : result) { + if (result_id(r) == static_cast(cid)) { + in_result = true; + break; + } + } + if (!in_result) { + result.back() = detail::construct_as(lib::Type(), candidate); + break; + } + } + } } template < @@ -203,6 +245,9 @@ void heuristic_prune_neighbors( std::vector pruned(poolsize, type_traits::tombstone_v); float current_alpha = 1.0f; + float anchor_dist = 0.0f; + bool anchor_set = false; + bool all_duplicates = true; while (result.size() < max_result_size && !cmp(alpha, current_alpha)) { size_t start = 0; while (result.size() < max_result_size && start < poolsize) { @@ -218,6 +263,16 @@ void heuristic_prune_neighbors( const auto& query = accessor(dataset, id); distance::maybe_fix_argument(distance_function, query); result.push_back(detail::construct_as(lib::Type(), pool[start])); + + if (all_duplicates) { + if (!anchor_set) { + anchor_dist = pool[start].distance(); + anchor_set = true; + } else if (pool[start].distance() != anchor_dist) { + all_duplicates = false; + } + } + for (size_t t = start + 1; t < poolsize; ++t) { if (cmp(current_alpha, pruned[t])) { continue; @@ -236,6 +291,35 @@ void heuristic_prune_neighbors( } current_alpha *= alpha; } + + // Add a diversity edge if a duplicate cluster is detected + if (all_duplicates && anchor_set && !result.empty()) { + auto result_id = [](const I& r) -> size_t { + if constexpr (std::integral) { + return static_cast(r); + } else { + return static_cast(r.id()); + } + }; + for (size_t t = 0; t < poolsize; ++t) { + const auto& candidate = pool[t]; + auto cid = candidate.id(); + if (cid == current_node_id || candidate.distance() == anchor_dist) { + continue; + } + bool in_result = false; + for (const auto& r : result) { + if (result_id(r) == static_cast(cid)) { + in_result = true; + break; + } + } + if (!in_result) { + result.back() = detail::construct_as(lib::Type(), candidate); + break; + } + } + } } ///