295template<
typename T,
typename A>
296template<typename TT, typename SerDe, typename std::enable_if<std::is_arithmetic<TT>::value,
int>::type>
298 if (is_empty()) {
return PREAMBLE_LONGS_EMPTY << 3; }
299 size_t num_bytes = (r_ == 0 ? PREAMBLE_LONGS_WARMUP : PREAMBLE_LONGS_FULL) << 3;
300 num_bytes += h_ *
sizeof(double);
301 if (marks_ !=
nullptr) {
302 num_bytes += (h_ / 8) + (h_ % 8 > 0);
304 num_bytes += (h_ + r_) *
sizeof(TT);
309template<
typename T,
typename A>
310template<typename TT, typename SerDe, typename std::enable_if<!std::is_arithmetic<TT>::value,
int>::type>
312 if (is_empty()) {
return PREAMBLE_LONGS_EMPTY << 3; }
313 size_t num_bytes = (r_ == 0 ? PREAMBLE_LONGS_WARMUP : PREAMBLE_LONGS_FULL) << 3;
314 num_bytes += h_ *
sizeof(double);
315 if (marks_ !=
nullptr) {
316 num_bytes += (h_ / 8) + (h_ % 8 > 0);
320 num_bytes += sd.size_of_item(it.first);
324template<
typename T,
typename A>
325template<
typename SerDe>
326std::vector<uint8_t, AllocU8<A>> var_opt_sketch<T, A>::serialize(
unsigned header_size_bytes,
const SerDe& sd)
const {
327 const size_t size = header_size_bytes + get_serialized_size_bytes(sd);
328 std::vector<uint8_t, AllocU8<A>> bytes(size, 0, allocator_);
329 uint8_t* ptr = bytes.data() + header_size_bytes;
330 uint8_t* end_ptr = ptr + size;
332 bool empty = is_empty();
333 uint8_t preLongs = (empty ? PREAMBLE_LONGS_EMPTY
334 : (r_ == 0 ? PREAMBLE_LONGS_WARMUP : PREAMBLE_LONGS_FULL));
335 uint8_t first_byte = (preLongs & 0x3F) | ((
static_cast<uint8_t
>(rf_)) << 6);
336 uint8_t flags = (marks_ !=
nullptr ? GADGET_FLAG_MASK : 0);
339 flags |= EMPTY_FLAG_MASK;
343 uint8_t ser_ver(SER_VER);
344 uint8_t family(FAMILY_ID);
345 ptr += copy_to_mem(first_byte, ptr);
346 ptr += copy_to_mem(ser_ver, ptr);
347 ptr += copy_to_mem(family, ptr);
348 ptr += copy_to_mem(flags, ptr);
349 ptr += copy_to_mem(k_, ptr);
353 ptr += copy_to_mem(n_, ptr);
354 ptr += copy_to_mem(h_, ptr);
355 ptr += copy_to_mem(r_, ptr);
359 ptr += copy_to_mem(total_wt_r_, ptr);
363 ptr += copy_to_mem(weights_, ptr, h_ *
sizeof(
double));
366 if (marks_ !=
nullptr) {
368 for (uint32_t i = 0; i < h_; ++i) {
370 val |= 0x1 << (i & 0x7);
373 if ((i & 0x7) == 0x7) {
374 ptr += copy_to_mem(val, ptr);
380 if ((h_ & 0x7) > 0) {
381 ptr += copy_to_mem(val, ptr);
386 ptr += sd.serialize(ptr, end_ptr - ptr, data_, h_);
387 ptr += sd.serialize(ptr, end_ptr - ptr, &data_[h_ + 1], r_);
390 size_t bytes_written = ptr - bytes.data();
391 if (bytes_written != size) {
392 throw std::logic_error(
"serialized size mismatch: " + std::to_string(bytes_written) +
" != " + std::to_string(size));
398template<
typename T,
typename A>
399template<
typename SerDe>
401 const bool empty = (h_ == 0) && (r_ == 0);
403 const uint8_t preLongs = (empty ? PREAMBLE_LONGS_EMPTY
404 : (r_ == 0 ? PREAMBLE_LONGS_WARMUP : PREAMBLE_LONGS_FULL));
405 const uint8_t first_byte = (preLongs & 0x3F) | ((
static_cast<uint8_t
>(rf_)) << 6);
406 uint8_t flags = (marks_ !=
nullptr ? GADGET_FLAG_MASK : 0);
409 flags |= EMPTY_FLAG_MASK;
413 const uint8_t ser_ver(SER_VER);
414 const uint8_t family(FAMILY_ID);
415 write(os, first_byte);
429 write(os, total_wt_r_);
433 write(os, weights_, h_ *
sizeof(
double));
436 if (marks_ !=
nullptr) {
438 for (uint32_t i = 0; i < h_; ++i) {
440 val |= 0x1 << (i & 0x7);
443 if ((i & 0x7) == 0x7) {
450 if ((h_ & 0x7) > 0) {
456 sd.serialize(os, data_, h_);
457 sd.serialize(os, &data_[h_ + 1], r_);
461template<
typename T,
typename A>
462template<
typename SerDe>
464 ensure_minimum_memory(size, 8);
465 const char* ptr =
static_cast<const char*
>(bytes);
466 const char* base = ptr;
467 const char* end_ptr = ptr + size;
469 ptr += copy_from_mem(ptr, first_byte);
470 uint8_t preamble_longs = first_byte & 0x3f;
471 resize_factor rf =
static_cast<resize_factor
>((first_byte >> 6) & 0x03);
472 uint8_t serial_version;
473 ptr += copy_from_mem(ptr, serial_version);
475 ptr += copy_from_mem(ptr, family_id);
477 ptr += copy_from_mem(ptr, flags);
479 ptr += copy_from_mem(ptr, k);
481 check_preamble_longs(preamble_longs, flags);
482 check_family_and_serialization_version(family_id, serial_version);
483 ensure_minimum_memory(size, preamble_longs << 3);
485 const bool is_empty = flags & EMPTY_FLAG_MASK;
486 const bool is_gadget = flags & GADGET_FLAG_MASK;
495 ptr += copy_from_mem(ptr, n);
496 ptr += copy_from_mem(ptr, h);
497 ptr += copy_from_mem(ptr, r);
499 const uint32_t array_size = validate_and_get_target_size(preamble_longs, k, n, h, r, rf);
502 double total_wt_r = 0.0;
503 if (preamble_longs == PREAMBLE_LONGS_FULL) {
504 ptr += copy_from_mem(ptr, total_wt_r);
505 if (std::isnan(total_wt_r) || r == 0 || total_wt_r <= 0.0) {
506 throw std::invalid_argument(
"Possible corruption: deserializing in full mode but r = 0 or invalid R weight. "
507 "Found r = " + std::to_string(r) +
", R region weight = " + std::to_string(total_wt_r));
514 check_memory_size(ptr - base + (h *
sizeof(
double)), size);
515 std::unique_ptr<double, weights_deleter> weights(AllocDouble(allocator).allocate(array_size),
516 weights_deleter(array_size, allocator));
517 double* wts = weights.get();
518 ptr += copy_from_mem(ptr, wts, h *
sizeof(
double));
519 for (
size_t i = 0; i < h; ++i) {
520 if (!(wts[i] > 0.0)) {
521 throw std::invalid_argument(
"Possible corruption: Non-positive weight when deserializing: " + std::to_string(wts[i]));
524 std::fill(wts + h, wts + array_size, -1.0);
527 uint32_t num_marks_in_h = 0;
528 std::unique_ptr<bool, marks_deleter> marks(
nullptr, marks_deleter(array_size, allocator));
531 marks = std::unique_ptr<bool, marks_deleter>(AllocBool(allocator).allocate(array_size), marks_deleter(array_size, allocator));
532 const size_t size_marks = (h / 8) + (h % 8 > 0 ? 1 : 0);
533 check_memory_size(ptr - base + size_marks, size);
534 for (uint32_t i = 0; i < h; ++i) {
535 if ((i & 0x7) == 0x0) {
536 ptr += copy_from_mem(ptr, val);
538 marks.get()[i] = ((val >> (i & 0x7)) & 0x1) == 1;
539 num_marks_in_h += (marks.get()[i] ? 1 : 0);
544 items_deleter deleter(array_size, allocator);
545 std::unique_ptr<T, items_deleter> items(A(allocator).allocate(array_size), deleter);
547 ptr += sd.deserialize(ptr, end_ptr - ptr, items.get(), h);
548 items.get_deleter().set_h(h);
550 ptr += sd.deserialize(ptr, end_ptr - ptr, &(items.get()[h + 1]), r);
551 items.get_deleter().set_r(r);
553 return var_opt_sketch(k, h, (r > 0 ? 1 : 0), r, n, total_wt_r, rf, array_size, false,
554 std::move(items), std::move(weights), num_marks_in_h, std::move(marks), allocator);
557template<
typename T,
typename A>
558template<
typename SerDe>
559var_opt_sketch<T, A> var_opt_sketch<T, A>::deserialize(std::istream& is,
const SerDe& sd,
const A& allocator) {
560 const auto first_byte = read<uint8_t>(is);
561 uint8_t preamble_longs = first_byte & 0x3f;
562 const resize_factor rf =
static_cast<resize_factor
>((first_byte >> 6) & 0x03);
563 const auto serial_version = read<uint8_t>(is);
564 const auto family_id = read<uint8_t>(is);
565 const auto flags = read<uint8_t>(is);
566 const auto k = read<uint32_t>(is);
568 check_preamble_longs(preamble_longs, flags);
569 check_family_and_serialization_version(family_id, serial_version);
571 const bool is_empty = flags & EMPTY_FLAG_MASK;
572 const bool is_gadget = flags & GADGET_FLAG_MASK;
576 throw std::runtime_error(
"error reading from std::istream");
578 return var_opt_sketch(k, rf, is_gadget, allocator);
582 const auto n = read<uint64_t>(is);
583 const auto h = read<uint32_t>(is);
584 const auto r = read<uint32_t>(is);
586 const uint32_t array_size = validate_and_get_target_size(preamble_longs, k, n, h, r, rf);
589 double total_wt_r = 0.0;
590 if (preamble_longs == PREAMBLE_LONGS_FULL) {
591 total_wt_r = read<double>(is);
592 if (std::isnan(total_wt_r) || r == 0 || total_wt_r <= 0.0) {
593 throw std::invalid_argument(
"Possible corruption: deserializing in full mode but r = 0 or invalid R weight. "
594 "Found r = " + std::to_string(r) +
", R region weight = " + std::to_string(total_wt_r));
599 std::unique_ptr<double, weights_deleter> weights(AllocDouble(allocator).allocate(array_size),
600 weights_deleter(array_size, allocator));
601 double* wts = weights.get();
602 read(is, wts, h *
sizeof(
double));
603 for (
size_t i = 0; i < h; ++i) {
604 if (!(wts[i] > 0.0)) {
605 throw std::invalid_argument(
"Possible corruption: Non-positive weight when deserializing: " + std::to_string(wts[i]));
608 std::fill(wts + h, wts + array_size, -1.0);
611 uint32_t num_marks_in_h = 0;
612 std::unique_ptr<bool, marks_deleter> marks(
nullptr, marks_deleter(array_size, allocator));
614 marks = std::unique_ptr<bool, marks_deleter>(AllocBool(allocator).allocate(array_size), marks_deleter(array_size, allocator));
616 for (uint32_t i = 0; i < h; ++i) {
617 if ((i & 0x7) == 0x0) {
618 val = read<uint8_t>(is);
620 marks.get()[i] = ((val >> (i & 0x7)) & 0x1) == 1;
621 num_marks_in_h += (marks.get()[i] ? 1 : 0);
626 items_deleter deleter(array_size, allocator);
627 std::unique_ptr<T, items_deleter> items(A(allocator).allocate(array_size), deleter);
629 sd.deserialize(is, items.get(), h);
630 items.get_deleter().set_h(h);
632 sd.deserialize(is, &(items.get()[h + 1]), r);
633 items.get_deleter().set_r(r);
636 throw std::runtime_error(
"error reading from std::istream");
638 return var_opt_sketch(k, h, (r > 0 ? 1 : 0), r, n, total_wt_r, rf, array_size, false,
639 std::move(items), std::move(weights), num_marks_in_h, std::move(marks), allocator);
642template<
typename T,
typename A>
644 return (h_ == 0 && r_ == 0);
647template<
typename T,
typename A>
649 const uint32_t prev_alloc = curr_items_alloc_;
650 const uint32_t ceiling_lg_k = to_log_2(ceiling_power_of_2(k_));
651 const uint32_t initial_lg_size = starting_sub_multiple(ceiling_lg_k, rf_, MIN_LG_ARR_ITEMS);
652 curr_items_alloc_ = get_adjusted_size(k_, 1 << initial_lg_size);
653 if (curr_items_alloc_ == k_) {
659 const size_t num_to_destroy = std::min(k_ + 1, prev_alloc);
660 for (
size_t i = 0; i < num_to_destroy; ++i)
664 for (
size_t i = 0; i < h_; ++i)
667 for (
size_t i = h_ + 1; i < h_ + r_ + 1; ++i)
671 if (curr_items_alloc_ < prev_alloc) {
672 const bool is_gadget = (marks_ !=
nullptr);
674 allocator_.deallocate(data_, prev_alloc);
675 AllocDouble(allocator_).deallocate(weights_, prev_alloc);
677 if (marks_ !=
nullptr)
678 AllocBool(allocator_).deallocate(marks_, prev_alloc);
680 allocate_data_arrays(curr_items_alloc_, is_gadget);
689 filled_data_ =
false;
692template<
typename T,
typename A>
697template<
typename T,
typename A>
702template<
typename T,
typename A>
704 const uint32_t num_in_sketch = h_ + r_;
705 return (num_in_sketch < k_ ? num_in_sketch : k_);
708template<
typename T,
typename A>
710 update(item, weight,
false);
713template<
typename T,
typename A>
715 update(std::move(item), weight,
false);
718template<
typename T,
typename A>
722 std::ostringstream os;
723 os <<
"### VarOpt SUMMARY:" << std::endl;
724 os <<
" k : " << k_ << std::endl;
725 os <<
" h : " << h_ << std::endl;
726 os <<
" r : " << r_ << std::endl;
727 os <<
" weight_r : " << total_wt_r_ << std::endl;
728 os <<
" Current size : " << curr_items_alloc_ << std::endl;
729 os <<
" Resize factor: " << (1 << rf_) << std::endl;
730 os <<
"### END SKETCH SUMMARY" << std::endl;
731 return string<A>(os.str().c_str(), allocator_);
734template<
typename T,
typename A>
738 std::ostringstream os;
739 os <<
"### Sketch Items" << std::endl;
741 for (
auto record : *
this) {
742 os << idx <<
": " << record.first <<
"\twt = " << record.second << std::endl;
745 return string<A>(os.str().c_str(), allocator_);
748template<
typename T,
typename A>
752 std::ostringstream os;
753 os <<
"### Sketch Items" << std::endl;
754 const uint32_t array_length = (n_ < k_ ? n_ : k_ + 1);
755 for (uint32_t i = 0, display_idx = 0; i < array_length; ++i) {
756 if (i == h_ && print_gap) {
757 os << display_idx <<
": GAP" << std::endl;
760 os << display_idx <<
": " << data_[i] <<
"\twt = ";
761 if (weights_[i] == -1.0) {
762 os << get_tau() <<
"\t(-1.0)" << std::endl;
764 os << weights_[i] << std::endl;
769 return string<A>(os.str().c_str(), allocator_);
772template<
typename T,
typename A>
774void var_opt_sketch<T, A>::update(O&& item,
double weight,
bool mark) {
775 if (weight <= 0.0 || std::isnan(weight) || std::isinf(weight)) {
776 throw std::invalid_argument(
"Item weights must be positive and finite. Found: "
777 + std::to_string(weight));
784 update_warmup_phase(std::forward<O>(item), weight, mark);
788 if ((h_ != 0) && (peek_min() < get_tau()))
789 throw std::logic_error(
"sketch not in valid estimation mode");
793 const double hypothetical_tau = (weight + total_wt_r_) / ((r_ + 1) - 1);
796 const double condition1 = (h_ == 0) || (weight <= peek_min());
799 const double condition2 = weight < hypothetical_tau;
801 if (condition1 && condition2) {
802 update_light(std::forward<O>(item), weight, mark);
803 }
else if (r_ == 1) {
804 update_heavy_r_eq1(std::forward<O>(item), weight, mark);
806 update_heavy_general(std::forward<O>(item), weight, mark);
811template<
typename T,
typename A>
813void var_opt_sketch<T, A>::update_warmup_phase(O&& item,
double weight,
bool mark) {
815 if (r_ > 0 || m_ != 0 || h_ > k_)
throw std::logic_error(
"invalid sketch state during warmup");
817 if (h_ >= curr_items_alloc_) {
822 new (&data_[h_]) T(std::forward<O>(item));
823 weights_[h_] = weight;
824 if (marks_ !=
nullptr) {
828 num_marks_in_h_ += mark ? 1 : 0;
833 transition_from_warmup();
841template<
typename T,
typename A>
843void var_opt_sketch<T, A>::update_light(O&& item,
double weight,
bool mark) {
844 if (r_ == 0 || (r_ + h_) != k_)
throw std::logic_error(
"invalid sketch state during light warmup");
846 const uint32_t m_slot = h_;
848 if (&data_[m_slot] != &item)
849 data_[m_slot] = std::forward<O>(item);
851 new (&data_[m_slot]) T(std::forward<O>(item));
854 weights_[m_slot] = weight;
855 if (marks_ !=
nullptr) { marks_[m_slot] = mark; }
858 grow_candidate_set(total_wt_r_ + weight, r_ + 1);
869template<
typename T,
typename A>
871void var_opt_sketch<T, A>::update_heavy_general(O&& item,
double weight,
bool mark) {
872 if (r_ < 2 || m_ != 0 || (r_ + h_) != k_)
throw std::logic_error(
"invalid sketch state during heavy general update");
875 push(std::forward<O>(item), weight, mark);
877 grow_candidate_set(total_wt_r_, r_);
883template<
typename T,
typename A>
885void var_opt_sketch<T, A>::update_heavy_r_eq1(O&& item,
double weight,
bool mark) {
886 if (r_ != 1 || m_ != 0 || (r_ + h_) != k_)
throw std::logic_error(
"invalid sketch state during heavy r=1 update");
888 push(std::forward<O>(item), weight, mark);
889 pop_min_to_m_region();
893 const uint32_t m_slot = k_ - 1;
894 grow_candidate_set(weights_[m_slot] + total_wt_r_, 2);
903template<
typename T,
typename A>
904void var_opt_sketch<T, A>::decrease_k_by_1() {
906 throw std::logic_error(
"Cannot decrease k below 1 in union");
909 if ((h_ == 0) && (r_ == 0)) {
912 }
else if ((h_ > 0) && (r_ == 0)) {
916 transition_from_warmup();
918 }
else if ((h_ > 0) && (r_ > 0)) {
924 const uint32_t old_gap_idx = h_;
925 const uint32_t old_final_r_idx = (h_ + 1 + r_) - 1;
926 if (old_final_r_idx != k_)
throw std::logic_error(
"gadget in invalid state");
928 swap_values(old_final_r_idx, old_gap_idx);
935 const uint32_t pulled_idx = h_ - 1;
936 double pulled_weight = weights_[pulled_idx];
937 bool pulled_mark = marks_[pulled_idx];
940 if (pulled_mark) { --num_marks_in_h_; }
941 weights_[pulled_idx] = -1.0;
947 update(std::move(data_[pulled_idx]), pulled_weight, pulled_mark);
948 }
else if ((h_ == 0) && (r_ > 0)) {
950 if (r_ < 2)
throw std::logic_error(
"r_ too small for pure reservoir mode");
952 const uint32_t r_idx_to_delete = 1 + next_int(r_);
953 const uint32_t rightmost_r_idx = (1 + r_) - 1;
954 swap_values(r_idx_to_delete, rightmost_r_idx);
955 weights_[rightmost_r_idx] = -1.0;
962template<
typename T,
typename A>
963void var_opt_sketch<T, A>::allocate_data_arrays(uint32_t tgt_size,
bool use_marks) {
964 filled_data_ =
false;
966 data_ = allocator_.allocate(tgt_size);
967 weights_ = AllocDouble(allocator_).allocate(tgt_size);
970 marks_ = AllocBool(allocator_).allocate(tgt_size);
976template<
typename T,
typename A>
977void var_opt_sketch<T, A>::grow_data_arrays() {
978 const uint32_t prev_size = curr_items_alloc_;
979 curr_items_alloc_ = get_adjusted_size(k_, curr_items_alloc_ << rf_);
980 if (curr_items_alloc_ == k_) {
984 if (prev_size < curr_items_alloc_) {
985 filled_data_ =
false;
987 T* tmp_data = allocator_.allocate(curr_items_alloc_);
988 double* tmp_weights = AllocDouble(allocator_).allocate(curr_items_alloc_);
990 for (uint32_t i = 0; i < prev_size; ++i) {
991 new (&tmp_data[i]) T(std::move(data_[i]));
993 tmp_weights[i] = weights_[i];
996 allocator_.deallocate(data_, prev_size);
997 AllocDouble(allocator_).deallocate(weights_, prev_size);
1000 weights_ = tmp_weights;
1002 if (marks_ !=
nullptr) {
1003 bool* tmp_marks = AllocBool(allocator_).allocate(curr_items_alloc_);
1004 for (uint32_t i = 0; i < prev_size; ++i) {
1005 tmp_marks[i] = marks_[i];
1007 AllocBool(allocator_).deallocate(marks_, prev_size);
1013template<
typename T,
typename A>
1014void var_opt_sketch<T, A>::transition_from_warmup() {
1018 pop_min_to_m_region();
1019 pop_min_to_m_region();
1023 if (h_ != (k_ -1) || m_ != 1 || r_ != 1)
1024 throw std::logic_error(
"invalid state for transitioning from warmup");
1028 total_wt_r_ = weights_[k_];
1029 weights_[k_] = -1.0;
1033 grow_candidate_set(weights_[k_ - 1] + total_wt_r_, 2);
1036template<
typename T,
typename A>
1037void var_opt_sketch<T, A>::convert_to_heap() {
1042 const uint32_t last_slot = h_ - 1;
1043 const int last_non_leaf = ((last_slot + 1) / 2) - 1;
1045 for (
int j = last_non_leaf; j >= 0; --j) {
1046 restore_towards_leaves(j);
1056template<
typename T,
typename A>
1057void var_opt_sketch<T, A>::restore_towards_leaves(uint32_t slot_in) {
1058 const uint32_t last_slot = h_ - 1;
1059 if (h_ == 0 || slot_in > last_slot)
throw std::logic_error(
"invalid heap state");
1061 uint32_t slot = slot_in;
1062 uint32_t child = (2 * slot_in) + 1;
1064 while (child <= last_slot) {
1065 uint32_t child2 = child + 1;
1066 if ((child2 <= last_slot) && (weights_[child2] < weights_[child])) {
1071 if (weights_[slot] <= weights_[child]) {
1077 swap_values(slot, child);
1080 child = (2 * slot) + 1;
1084template<
typename T,
typename A>
1085void var_opt_sketch<T, A>::restore_towards_root(uint32_t slot_in) {
1086 uint32_t slot = slot_in;
1087 uint32_t p = (((slot + 1) / 2) - 1);
1088 while ((slot > 0) && (weights_[slot] < weights_[p])) {
1089 swap_values(slot, p);
1091 p = (((slot + 1) / 2) - 1);
1095template<
typename T,
typename A>
1097void var_opt_sketch<T, A>::push(O&& item,
double wt,
bool mark) {
1099 if (&data_[h_] != &item)
1100 data_[h_] = std::forward<O>(item);
1102 new (&data_[h_]) T(std::forward<O>(item));
1103 filled_data_ =
true;
1106 if (marks_ !=
nullptr) {
1108 num_marks_in_h_ += (mark ? 1 : 0);
1112 restore_towards_root(h_ - 1);
1115template<
typename T,
typename A>
1116void var_opt_sketch<T, A>::pop_min_to_m_region() {
1117 if (h_ == 0 || (h_ + m_ + r_ != k_ + 1))
1118 throw std::logic_error(
"invalid heap state popping min to M region");
1126 uint32_t tgt = h_ - 1;
1127 swap_values(0, tgt);
1131 restore_towards_leaves(0);
1134 if (is_marked(h_)) {
1140template<
typename T,
typename A>
1141void var_opt_sketch<T, A>::swap_values(uint32_t src, uint32_t dst) {
1142 std::swap(data_[src], data_[dst]);
1143 std::swap(weights_[src], weights_[dst]);
1145 if (marks_ !=
nullptr) {
1146 std::swap(marks_[src], marks_[dst]);
1158template<
typename T,
typename A>
1159void var_opt_sketch<T, A>::grow_candidate_set(
double wt_cands, uint32_t num_cands) {
1160 if ((h_ + m_ + r_ != k_ + 1) || (num_cands < 1) || (num_cands != m_ + r_) || (m_ >= 2))
1161 throw std::logic_error(
"invariant violated when growing candidate set");
1164 const double next_wt = peek_min();
1165 const double next_tot_wt = wt_cands + next_wt;
1170 if ((next_wt * num_cands) < next_tot_wt) {
1171 wt_cands = next_tot_wt;
1173 pop_min_to_m_region();
1179 downsample_candidate_set(wt_cands, num_cands);
1182template<
typename T,
typename A>
1183void var_opt_sketch<T, A>::downsample_candidate_set(
double wt_cands, uint32_t num_cands) {
1184 if (num_cands < 2 || h_ + num_cands != k_ + 1)
1185 throw std::logic_error(
"invalid num_cands when downsampling");
1188 const uint32_t delete_slot = choose_delete_slot(wt_cands, num_cands);
1189 const uint32_t leftmost_cand_slot = h_;
1190 if (delete_slot < leftmost_cand_slot || delete_slot > k_)
1191 throw std::logic_error(
"invalid delete slot index when downsampling");
1196 const uint32_t stop_idx = leftmost_cand_slot + m_;
1197 for (uint32_t j = leftmost_cand_slot; j < stop_idx; ++j) {
1202 data_[delete_slot] = std::move(data_[leftmost_cand_slot]);
1206 total_wt_r_ = wt_cands;
1209template<
typename T,
typename A>
1210uint32_t var_opt_sketch<T, A>::choose_delete_slot(
double wt_cands, uint32_t num_cands)
const {
1211 if (r_ == 0)
throw std::logic_error(
"choosing delete slot while in exact mode");
1215 return pick_random_slot_in_r();
1216 }
else if (m_ == 1) {
1219 double wt_m_cand = weights_[h_];
1220 if ((wt_cands * next_double_exclude_zero()) < ((num_cands - 1) * wt_m_cand)) {
1221 return pick_random_slot_in_r();
1227 const uint32_t delete_slot = choose_weighted_delete_slot(wt_cands, num_cands);
1228 const uint32_t first_r_slot = h_ + m_;
1229 if (delete_slot == first_r_slot) {
1230 return pick_random_slot_in_r();
1237template<
typename T,
typename A>
1238uint32_t var_opt_sketch<T, A>::choose_weighted_delete_slot(
double wt_cands, uint32_t num_cands)
const {
1239 if (m_ < 1)
throw std::logic_error(
"must have weighted delete slot");
1241 const uint32_t offset = h_;
1242 const uint32_t final_m = (offset + m_) - 1;
1243 const uint32_t num_to_keep = num_cands - 1;
1245 double left_subtotal = 0.0;
1246 double right_subtotal = -1.0 * wt_cands * next_double_exclude_zero();
1248 for (uint32_t i = offset; i <= final_m; ++i) {
1249 left_subtotal += num_to_keep * weights_[i];
1250 right_subtotal += wt_cands;
1252 if (left_subtotal < right_subtotal) {
1261template<
typename T,
typename A>
1262uint32_t var_opt_sketch<T, A>::pick_random_slot_in_r()
const {
1263 if (r_ == 0)
throw std::logic_error(
"r_ = 0 when picking slot in R region");
1264 const uint32_t offset = h_ + m_;
1268 return offset + next_int(r_);
1272template<
typename T,
typename A>
1273double var_opt_sketch<T, A>::peek_min()
const {
1274 if (h_ == 0)
throw std::logic_error(
"h_ = 0 when checking min in H region");
1278template<
typename T,
typename A>
1279inline bool var_opt_sketch<T, A>::is_marked(uint32_t idx)
const {
1280 return marks_ ==
nullptr ? false : marks_[idx];
1283template<
typename T,
typename A>
1284double var_opt_sketch<T, A>::get_tau()
const {
1285 return r_ == 0 ? std::nan(
"1") : (total_wt_r_ / r_);
1288template<
typename T,
typename A>
1289void var_opt_sketch<T, A>::strip_marks() {
1290 if (marks_ ==
nullptr)
throw std::logic_error(
"request to strip marks from non-gadget");
1291 num_marks_in_h_ = 0;
1292 AllocBool(allocator_).deallocate(marks_, curr_items_alloc_);
1296template<
typename T,
typename A>
1297void var_opt_sketch<T, A>::check_preamble_longs(uint8_t preamble_longs, uint8_t flags) {
1298 const bool is_empty(flags & EMPTY_FLAG_MASK);
1301 if (preamble_longs != PREAMBLE_LONGS_EMPTY) {
1302 throw std::invalid_argument(
"Possible corruption: Preamble longs must be "
1303 + std::to_string(PREAMBLE_LONGS_EMPTY) +
" for an empty sketch. Found: "
1304 + std::to_string(preamble_longs));
1307 if (preamble_longs != PREAMBLE_LONGS_WARMUP
1308 && preamble_longs != PREAMBLE_LONGS_FULL) {
1309 throw std::invalid_argument(
"Possible corruption: Preamble longs must be "
1310 + std::to_string(PREAMBLE_LONGS_WARMUP) +
" or "
1311 + std::to_string(PREAMBLE_LONGS_FULL)
1312 +
" for a non-empty sketch. Found: " + std::to_string(preamble_longs));
1317template<
typename T,
typename A>
1318void var_opt_sketch<T, A>::check_family_and_serialization_version(uint8_t family_id, uint8_t ser_ver) {
1319 if (family_id == FAMILY_ID) {
1320 if (ser_ver != SER_VER) {
1321 throw std::invalid_argument(
"Possible corruption: VarOpt serialization version must be "
1322 + std::to_string(SER_VER) +
". Found: " + std::to_string(ser_ver));
1328 throw std::invalid_argument(
"Possible corruption: VarOpt family id must be "
1329 + std::to_string(FAMILY_ID) +
". Found: " + std::to_string(family_id));
1332template<
typename T,
typename A>
1333uint32_t var_opt_sketch<T, A>::validate_and_get_target_size(uint32_t preamble_longs, uint32_t k, uint64_t n,
1334 uint32_t h, uint32_t r, resize_factor rf) {
1335 if (k == 0 || k > MAX_K) {
1336 throw std::invalid_argument(
"k must be at least 1 and less than 2^31 - 1");
1339 uint32_t array_size;
1342 if (preamble_longs != PREAMBLE_LONGS_WARMUP) {
1343 throw std::invalid_argument(
"Possible corruption: deserializing with n <= k but not in warmup mode. "
1344 "Found n = " + std::to_string(n) +
", k = " + std::to_string(k));
1347 throw std::invalid_argument(
"Possible corruption: deserializing in warmup mode but n != h. "
1348 "Found n = " + std::to_string(n) +
", h = " + std::to_string(h));
1351 throw std::invalid_argument(
"Possible corruption: deserializing in warmup mode but r > 0. "
1352 "Found r = " + std::to_string(r));
1355 const uint32_t ceiling_lg_k = to_log_2(ceiling_power_of_2(k));
1356 const uint32_t min_lg_size = to_log_2(ceiling_power_of_2(h));
1357 const uint32_t initial_lg_size = starting_sub_multiple(ceiling_lg_k, rf, min_lg_size);
1358 array_size = get_adjusted_size(k, 1 << initial_lg_size);
1359 if (array_size == k) {
1363 if (preamble_longs != PREAMBLE_LONGS_FULL) {
1364 throw std::invalid_argument(
"Possible corruption: deserializing with n > k but not in full mode. "
1365 "Found n = " + std::to_string(n) +
", k = " + std::to_string(k));
1368 throw std::invalid_argument(
"Possible corruption: deserializing in full mode but h + r != n. "
1369 "Found h = " + std::to_string(h) +
", r = " + std::to_string(r) +
", n = " + std::to_string(n));
1378template<
typename T,
typename A>
1382 return {0.0, 0.0, 0.0, 0.0};
1385 double total_wt_h = 0.0;
1386 double h_true_wt = 0.0;
1388 for (; idx < h_; ++idx) {
1389 double wt = weights_[idx];
1391 if (predicate(data_[idx])) {
1398 return {h_true_wt, h_true_wt, h_true_wt, h_true_wt};
1402 const uint64_t num_samples = n_ - h_;
1403 double effective_sampling_rate = r_ /
static_cast<double>(num_samples);
1404 if (effective_sampling_rate < 0.0 || effective_sampling_rate > 1.0)
1405 throw std::logic_error(
"invalid sampling rate outside [0.0, 1.0]");
1407 uint32_t r_true_count = 0;
1409 for (; idx < (k_ + 1); ++idx) {
1410 if (predicate(data_[idx])) {
1415 double lb_true_fraction = pseudo_hypergeometric_lb_on_p(r_, r_true_count, effective_sampling_rate);
1416 double estimated_true_fraction = (1.0 * r_true_count) / r_;
1417 double ub_true_fraction = pseudo_hypergeometric_ub_on_p(r_, r_true_count, effective_sampling_rate);
1419 return { h_true_wt + (total_wt_r_ * lb_true_fraction),
1420 h_true_wt + (total_wt_r_ * estimated_true_fraction),
1421 h_true_wt + (total_wt_r_ * ub_true_fraction),
1422 total_wt_h + total_wt_r_
1426template<
typename T,
typename A>
1429 items_deleter(uint32_t num,
const A& allocator) : num(num), h_count(0), r_count(0), allocator(allocator) {}
1430 void set_h(uint32_t h) { h_count = h; }
1431 void set_r(uint32_t r) { r_count = r; }
1432 void operator() (T* ptr) {
1434 for (
size_t i = 0; i < h_count; ++i) {
1439 uint32_t end = h_count + r_count + 1;
1440 for (
size_t i = h_count + 1; i < end; ++i) {
1444 if (ptr !=
nullptr) {
1445 allocator.deallocate(ptr, num);
1455template<
typename T,
typename A>
1456class var_opt_sketch<T, A>::weights_deleter {
1458 weights_deleter(uint32_t num,
const A& allocator) : num(num), allocator(allocator) {}
1459 void operator() (
double* ptr) {
1460 if (ptr !=
nullptr) {
1461 allocator.deallocate(ptr, num);
1466 AllocDouble allocator;
1469template<
typename T,
typename A>
1470class var_opt_sketch<T, A>::marks_deleter {
1472 marks_deleter(uint32_t num,
const A& allocator) : num(num), allocator(allocator) {}
1473 void operator() (
bool* ptr) {
1474 if (ptr !=
nullptr) {
1475 allocator.deallocate(ptr, 1);
1480 AllocBool allocator;
1484template<
typename T,
typename A>
1486 return const_iterator(*
this,
false);
1489template<
typename T,
typename A>
1491 return const_iterator(*
this,
true);