295template<
typename T,
typename A>
296template<typename TT, typename SerDe, typename std::enable_if<std::is_arithmetic<TT>::value,
int>::type>
298 if (is_empty()) {
return PREAMBLE_LONGS_EMPTY << 3; }
299 size_t num_bytes = (r_ == 0 ? PREAMBLE_LONGS_WARMUP : PREAMBLE_LONGS_FULL) << 3;
300 num_bytes += h_ *
sizeof(double);
301 if (marks_ !=
nullptr) {
302 num_bytes += (h_ / 8) + (h_ % 8 > 0);
304 num_bytes += (h_ + r_) *
sizeof(TT);
309template<
typename T,
typename A>
310template<typename TT, typename SerDe, typename std::enable_if<!std::is_arithmetic<TT>::value,
int>::type>
312 if (is_empty()) {
return PREAMBLE_LONGS_EMPTY << 3; }
313 size_t num_bytes = (r_ == 0 ? PREAMBLE_LONGS_WARMUP : PREAMBLE_LONGS_FULL) << 3;
314 num_bytes += h_ *
sizeof(double);
315 if (marks_ !=
nullptr) {
316 num_bytes += (h_ / 8) + (h_ % 8 > 0);
320 num_bytes += sd.size_of_item(it.first);
324template<
typename T,
typename A>
325template<
typename SerDe>
326std::vector<uint8_t, AllocU8<A>> var_opt_sketch<T, A>::serialize(
unsigned header_size_bytes,
const SerDe& sd)
const {
327 const size_t size = header_size_bytes + get_serialized_size_bytes(sd);
328 std::vector<uint8_t, AllocU8<A>> bytes(size, 0, allocator_);
329 uint8_t* ptr = bytes.data() + header_size_bytes;
330 uint8_t* end_ptr = ptr + size;
332 bool empty = is_empty();
333 uint8_t preLongs = (empty ? PREAMBLE_LONGS_EMPTY
334 : (r_ == 0 ? PREAMBLE_LONGS_WARMUP : PREAMBLE_LONGS_FULL));
335 uint8_t first_byte = (preLongs & 0x3F) | ((
static_cast<uint8_t
>(rf_)) << 6);
336 uint8_t flags = (marks_ !=
nullptr ? GADGET_FLAG_MASK : 0);
339 flags |= EMPTY_FLAG_MASK;
343 uint8_t ser_ver(SER_VER);
344 uint8_t family(FAMILY_ID);
345 ptr += copy_to_mem(first_byte, ptr);
346 ptr += copy_to_mem(ser_ver, ptr);
347 ptr += copy_to_mem(family, ptr);
348 ptr += copy_to_mem(flags, ptr);
349 ptr += copy_to_mem(k_, ptr);
353 ptr += copy_to_mem(n_, ptr);
354 ptr += copy_to_mem(h_, ptr);
355 ptr += copy_to_mem(r_, ptr);
359 ptr += copy_to_mem(total_wt_r_, ptr);
363 ptr += copy_to_mem(weights_, ptr, h_ *
sizeof(
double));
366 if (marks_ !=
nullptr) {
368 for (uint32_t i = 0; i < h_; ++i) {
370 val |= 0x1 << (i & 0x7);
373 if ((i & 0x7) == 0x7) {
374 ptr += copy_to_mem(val, ptr);
380 if ((h_ & 0x7) > 0) {
381 ptr += copy_to_mem(val, ptr);
386 ptr += sd.serialize(ptr, end_ptr - ptr, data_, h_);
387 ptr += sd.serialize(ptr, end_ptr - ptr, &data_[h_ + 1], r_);
390 size_t bytes_written = ptr - bytes.data();
391 if (bytes_written != size) {
392 throw std::logic_error(
"serialized size mismatch: " + std::to_string(bytes_written) +
" != " + std::to_string(size));
398template<
typename T,
typename A>
399template<
typename SerDe>
401 const bool empty = (h_ == 0) && (r_ == 0);
403 const uint8_t preLongs = (empty ? PREAMBLE_LONGS_EMPTY
404 : (r_ == 0 ? PREAMBLE_LONGS_WARMUP : PREAMBLE_LONGS_FULL));
405 const uint8_t first_byte = (preLongs & 0x3F) | ((
static_cast<uint8_t
>(rf_)) << 6);
406 uint8_t flags = (marks_ !=
nullptr ? GADGET_FLAG_MASK : 0);
409 flags |= EMPTY_FLAG_MASK;
413 const uint8_t ser_ver(SER_VER);
414 const uint8_t family(FAMILY_ID);
415 write(os, first_byte);
429 write(os, total_wt_r_);
433 write(os, weights_, h_ *
sizeof(
double));
436 if (marks_ !=
nullptr) {
438 for (uint32_t i = 0; i < h_; ++i) {
440 val |= 0x1 << (i & 0x7);
443 if ((i & 0x7) == 0x7) {
450 if ((h_ & 0x7) > 0) {
456 sd.serialize(os, data_, h_);
457 sd.serialize(os, &data_[h_ + 1], r_);
461template<
typename T,
typename A>
462template<
typename SerDe>
464 ensure_minimum_memory(size, 8);
465 const char* ptr =
static_cast<const char*
>(bytes);
466 const char* base = ptr;
467 const char* end_ptr = ptr + size;
469 ptr += copy_from_mem(ptr, first_byte);
470 uint8_t preamble_longs = first_byte & 0x3f;
471 resize_factor rf =
static_cast<resize_factor
>((first_byte >> 6) & 0x03);
472 uint8_t serial_version;
473 ptr += copy_from_mem(ptr, serial_version);
475 ptr += copy_from_mem(ptr, family_id);
477 ptr += copy_from_mem(ptr, flags);
479 ptr += copy_from_mem(ptr, k);
481 check_preamble_longs(preamble_longs, flags);
482 check_family_and_serialization_version(family_id, serial_version);
483 ensure_minimum_memory(size, preamble_longs << 3);
485 const bool is_empty = flags & EMPTY_FLAG_MASK;
486 const bool is_gadget = flags & GADGET_FLAG_MASK;
495 ptr += copy_from_mem(ptr, n);
496 ptr += copy_from_mem(ptr, h);
497 ptr += copy_from_mem(ptr, r);
499 const uint32_t array_size = validate_and_get_target_size(preamble_longs, k, n, h, r, rf);
502 double total_wt_r = 0.0;
503 if (preamble_longs == PREAMBLE_LONGS_FULL) {
504 ptr += copy_from_mem(ptr, total_wt_r);
505 if (std::isnan(total_wt_r) || r == 0 || total_wt_r <= 0.0) {
506 throw std::invalid_argument(
"Possible corruption: deserializing in full mode but r = 0 or invalid R weight. "
507 "Found r = " + std::to_string(r) +
", R region weight = " + std::to_string(total_wt_r));
514 check_memory_size(ptr - base + (h *
sizeof(
double)), size);
515 std::unique_ptr<double, weights_deleter> weights(AllocDouble(allocator).allocate(array_size),
516 weights_deleter(array_size, allocator));
517 double* wts = weights.get();
518 ptr += copy_from_mem(ptr, wts, h *
sizeof(
double));
519 for (
size_t i = 0; i < h; ++i) {
520 if (!(wts[i] > 0.0)) {
521 throw std::invalid_argument(
"Possible corruption: Non-positive weight when deserializing: " + std::to_string(wts[i]));
524 std::fill(wts + h, wts + array_size, -1.0);
527 uint32_t num_marks_in_h = 0;
528 std::unique_ptr<bool, marks_deleter> marks(
nullptr, marks_deleter(array_size, allocator));
531 marks = std::unique_ptr<bool, marks_deleter>(AllocBool(allocator).allocate(array_size), marks_deleter(array_size, allocator));
532 const size_t size_marks = (h / 8) + (h % 8 > 0 ? 1 : 0);
533 check_memory_size(ptr - base + size_marks, size);
534 for (uint32_t i = 0; i < h; ++i) {
535 if ((i & 0x7) == 0x0) {
536 ptr += copy_from_mem(ptr, val);
538 marks.get()[i] = ((val >> (i & 0x7)) & 0x1) == 1;
539 num_marks_in_h += (marks.get()[i] ? 1 : 0);
544 items_deleter deleter(array_size, allocator);
545 std::unique_ptr<T, items_deleter> items(A(allocator).allocate(array_size), deleter);
547 ptr += sd.deserialize(ptr, end_ptr - ptr, items.get(), h);
548 items.get_deleter().set_h(h);
550 ptr += sd.deserialize(ptr, end_ptr - ptr, &(items.get()[h + 1]), r);
551 items.get_deleter().set_r(r);
553 return var_opt_sketch(k, h, (r > 0 ? 1 : 0), r, n, total_wt_r, rf, array_size, false,
554 std::move(items), std::move(weights), num_marks_in_h, std::move(marks), allocator);
557template<
typename T,
typename A>
558template<
typename SerDe>
559var_opt_sketch<T, A> var_opt_sketch<T, A>::deserialize(std::istream& is,
const SerDe& sd,
const A& allocator) {
560 const auto first_byte = read<uint8_t>(is);
561 uint8_t preamble_longs = first_byte & 0x3f;
562 const resize_factor rf =
static_cast<resize_factor
>((first_byte >> 6) & 0x03);
563 const auto serial_version = read<uint8_t>(is);
564 const auto family_id = read<uint8_t>(is);
565 const auto flags = read<uint8_t>(is);
566 const auto k = read<uint32_t>(is);
568 check_preamble_longs(preamble_longs, flags);
569 check_family_and_serialization_version(family_id, serial_version);
571 const bool is_empty = flags & EMPTY_FLAG_MASK;
572 const bool is_gadget = flags & GADGET_FLAG_MASK;
576 throw std::runtime_error(
"error reading from std::istream");
578 return var_opt_sketch(k, rf, is_gadget, allocator);
582 const auto n = read<uint64_t>(is);
583 const auto h = read<uint32_t>(is);
584 const auto r = read<uint32_t>(is);
586 const uint32_t array_size = validate_and_get_target_size(preamble_longs, k, n, h, r, rf);
589 double total_wt_r = 0.0;
590 if (preamble_longs == PREAMBLE_LONGS_FULL) {
591 total_wt_r = read<double>(is);
592 if (std::isnan(total_wt_r) || r == 0 || total_wt_r <= 0.0) {
593 throw std::invalid_argument(
"Possible corruption: deserializing in full mode but r = 0 or invalid R weight. "
594 "Found r = " + std::to_string(r) +
", R region weight = " + std::to_string(total_wt_r));
599 std::unique_ptr<double, weights_deleter> weights(AllocDouble(allocator).allocate(array_size),
600 weights_deleter(array_size, allocator));
601 double* wts = weights.get();
602 read(is, wts, h *
sizeof(
double));
603 for (
size_t i = 0; i < h; ++i) {
604 if (!(wts[i] > 0.0)) {
605 throw std::invalid_argument(
"Possible corruption: Non-positive weight when deserializing: " + std::to_string(wts[i]));
608 std::fill(wts + h, wts + array_size, -1.0);
611 uint32_t num_marks_in_h = 0;
612 std::unique_ptr<bool, marks_deleter> marks(
nullptr, marks_deleter(array_size, allocator));
614 marks = std::unique_ptr<bool, marks_deleter>(AllocBool(allocator).allocate(array_size), marks_deleter(array_size, allocator));
616 for (uint32_t i = 0; i < h; ++i) {
617 if ((i & 0x7) == 0x0) {
618 val = read<uint8_t>(is);
620 marks.get()[i] = ((val >> (i & 0x7)) & 0x1) == 1;
621 num_marks_in_h += (marks.get()[i] ? 1 : 0);
626 items_deleter deleter(array_size, allocator);
627 std::unique_ptr<T, items_deleter> items(A(allocator).allocate(array_size), deleter);
629 sd.deserialize(is, items.get(), h);
630 items.get_deleter().set_h(h);
632 sd.deserialize(is, &(items.get()[h + 1]), r);
633 items.get_deleter().set_r(r);
636 throw std::runtime_error(
"error reading from std::istream");
638 return var_opt_sketch(k, h, (r > 0 ? 1 : 0), r, n, total_wt_r, rf, array_size, false,
639 std::move(items), std::move(weights), num_marks_in_h, std::move(marks), allocator);
642template<
typename T,
typename A>
644 return (h_ == 0 && r_ == 0);
647template<
typename T,
typename A>
649 const uint32_t prev_alloc = curr_items_alloc_;
650 const uint32_t ceiling_lg_k = to_log_2(ceiling_power_of_2(k_));
651 const uint32_t initial_lg_size = starting_sub_multiple(ceiling_lg_k, rf_, MIN_LG_ARR_ITEMS);
652 curr_items_alloc_ = get_adjusted_size(k_, 1 << initial_lg_size);
653 if (curr_items_alloc_ == k_) {
659 const size_t num_to_destroy = std::min(k_ + 1, prev_alloc);
660 for (
size_t i = 0; i < num_to_destroy; ++i)
664 for (
size_t i = 0; i < h_; ++i)
667 for (
size_t i = h_ + 1; i < h_ + r_ + 1; ++i)
671 if (curr_items_alloc_ < prev_alloc) {
672 const bool is_gadget = (marks_ !=
nullptr);
674 allocator_.deallocate(data_, prev_alloc);
675 AllocDouble(allocator_).deallocate(weights_, prev_alloc);
677 if (marks_ !=
nullptr)
678 AllocBool(allocator_).deallocate(marks_, prev_alloc);
680 allocate_data_arrays(curr_items_alloc_, is_gadget);
689 filled_data_ =
false;
692template<
typename T,
typename A>
697template<
typename T,
typename A>
702template<
typename T,
typename A>
704 const uint32_t num_in_sketch = h_ + r_;
705 return (num_in_sketch < k_ ? num_in_sketch : k_);
708template<
typename T,
typename A>
710 update(item, weight,
false);
713template<
typename T,
typename A>
715 update(std::move(item), weight,
false);
718template<
typename T,
typename A>
722 std::ostringstream os;
723 os <<
"### VarOpt SUMMARY:" << std::endl;
724 os <<
" k : " << k_ << std::endl;
725 os <<
" h : " << h_ << std::endl;
726 os <<
" r : " << r_ << std::endl;
727 os <<
" weight_r : " << total_wt_r_ << std::endl;
728 os <<
" Current size : " << curr_items_alloc_ << std::endl;
729 os <<
" Resize factor: " << (1 << rf_) << std::endl;
730 os <<
"### END SKETCH SUMMARY" << std::endl;
731 return string<A>(os.str().c_str(), allocator_);
734template<
typename T,
typename A>
738 std::ostringstream os;
739 os <<
"### Sketch Items" << std::endl;
741 for (
auto record : *
this) {
742 os << idx <<
": " << record.first <<
"\twt = " << record.second << std::endl;
745 return string<A>(os.str().c_str(), allocator_);
748template<
typename T,
typename A>
752 std::ostringstream os;
753 os <<
"### Sketch Items" << std::endl;
754 const uint32_t array_length = (n_ < k_ ? n_ : k_ + 1);
755 for (uint32_t i = 0, display_idx = 0; i < array_length; ++i) {
756 if (i == h_ && print_gap) {
757 os << display_idx <<
": GAP" << std::endl;
760 os << display_idx <<
": " << data_[i] <<
"\twt = ";
761 if (weights_[i] == -1.0) {
762 os << get_tau() <<
"\t(-1.0)" << std::endl;
764 os << weights_[i] << std::endl;
769 return string<A>(os.str().c_str(), allocator_);
772template<
typename T,
typename A>
774void var_opt_sketch<T, A>::update(O&& item,
double weight,
bool mark) {
775 if (weight < 0.0 || std::isnan(weight) || std::isinf(weight)) {
776 throw std::invalid_argument(
"Item weights must be nonnegative and finite. Found: "
777 + std::to_string(weight));
778 }
else if (weight == 0.0) {
785 update_warmup_phase(std::forward<O>(item), weight, mark);
789 if ((h_ != 0) && (peek_min() < get_tau()))
790 throw std::logic_error(
"sketch not in valid estimation mode");
794 const double hypothetical_tau = (weight + total_wt_r_) / ((r_ + 1) - 1);
797 const double condition1 = (h_ == 0) || (weight <= peek_min());
800 const double condition2 = weight < hypothetical_tau;
802 if (condition1 && condition2) {
803 update_light(std::forward<O>(item), weight, mark);
804 }
else if (r_ == 1) {
805 update_heavy_r_eq1(std::forward<O>(item), weight, mark);
807 update_heavy_general(std::forward<O>(item), weight, mark);
812template<
typename T,
typename A>
814void var_opt_sketch<T, A>::update_warmup_phase(O&& item,
double weight,
bool mark) {
816 if (r_ > 0 || m_ != 0 || h_ > k_)
throw std::logic_error(
"invalid sketch state during warmup");
818 if (h_ >= curr_items_alloc_) {
823 new (&data_[h_]) T(std::forward<O>(item));
824 weights_[h_] = weight;
825 if (marks_ !=
nullptr) {
829 num_marks_in_h_ += mark ? 1 : 0;
834 transition_from_warmup();
842template<
typename T,
typename A>
844void var_opt_sketch<T, A>::update_light(O&& item,
double weight,
bool mark) {
845 if (r_ == 0 || (r_ + h_) != k_)
throw std::logic_error(
"invalid sketch state during light warmup");
847 const uint32_t m_slot = h_;
849 if (&data_[m_slot] != &item)
850 data_[m_slot] = std::forward<O>(item);
852 new (&data_[m_slot]) T(std::forward<O>(item));
855 weights_[m_slot] = weight;
856 if (marks_ !=
nullptr) { marks_[m_slot] = mark; }
859 grow_candidate_set(total_wt_r_ + weight, r_ + 1);
870template<
typename T,
typename A>
872void var_opt_sketch<T, A>::update_heavy_general(O&& item,
double weight,
bool mark) {
873 if (r_ < 2 || m_ != 0 || (r_ + h_) != k_)
throw std::logic_error(
"invalid sketch state during heavy general update");
876 push(std::forward<O>(item), weight, mark);
878 grow_candidate_set(total_wt_r_, r_);
884template<
typename T,
typename A>
886void var_opt_sketch<T, A>::update_heavy_r_eq1(O&& item,
double weight,
bool mark) {
887 if (r_ != 1 || m_ != 0 || (r_ + h_) != k_)
throw std::logic_error(
"invalid sketch state during heavy r=1 update");
889 push(std::forward<O>(item), weight, mark);
890 pop_min_to_m_region();
894 const uint32_t m_slot = k_ - 1;
895 grow_candidate_set(weights_[m_slot] + total_wt_r_, 2);
904template<
typename T,
typename A>
905void var_opt_sketch<T, A>::decrease_k_by_1() {
907 throw std::logic_error(
"Cannot decrease k below 1 in union");
910 if ((h_ == 0) && (r_ == 0)) {
913 }
else if ((h_ > 0) && (r_ == 0)) {
917 transition_from_warmup();
919 }
else if ((h_ > 0) && (r_ > 0)) {
925 const uint32_t old_gap_idx = h_;
926 const uint32_t old_final_r_idx = (h_ + 1 + r_) - 1;
927 if (old_final_r_idx != k_)
throw std::logic_error(
"gadget in invalid state");
929 swap_values(old_final_r_idx, old_gap_idx);
936 const uint32_t pulled_idx = h_ - 1;
937 double pulled_weight = weights_[pulled_idx];
938 bool pulled_mark = marks_[pulled_idx];
941 if (pulled_mark) { --num_marks_in_h_; }
942 weights_[pulled_idx] = -1.0;
948 update(std::move(data_[pulled_idx]), pulled_weight, pulled_mark);
949 }
else if ((h_ == 0) && (r_ > 0)) {
951 if (r_ < 2)
throw std::logic_error(
"r_ too small for pure reservoir mode");
953 const uint32_t r_idx_to_delete = 1 + next_int(r_);
954 const uint32_t rightmost_r_idx = (1 + r_) - 1;
955 swap_values(r_idx_to_delete, rightmost_r_idx);
956 weights_[rightmost_r_idx] = -1.0;
963template<
typename T,
typename A>
964void var_opt_sketch<T, A>::allocate_data_arrays(uint32_t tgt_size,
bool use_marks) {
965 filled_data_ =
false;
967 data_ = allocator_.allocate(tgt_size);
968 weights_ = AllocDouble(allocator_).allocate(tgt_size);
971 marks_ = AllocBool(allocator_).allocate(tgt_size);
977template<
typename T,
typename A>
978void var_opt_sketch<T, A>::grow_data_arrays() {
979 const uint32_t prev_size = curr_items_alloc_;
980 curr_items_alloc_ = get_adjusted_size(k_, curr_items_alloc_ << rf_);
981 if (curr_items_alloc_ == k_) {
985 if (prev_size < curr_items_alloc_) {
986 filled_data_ =
false;
988 T* tmp_data = allocator_.allocate(curr_items_alloc_);
989 double* tmp_weights = AllocDouble(allocator_).allocate(curr_items_alloc_);
991 for (uint32_t i = 0; i < prev_size; ++i) {
992 new (&tmp_data[i]) T(std::move(data_[i]));
994 tmp_weights[i] = weights_[i];
997 allocator_.deallocate(data_, prev_size);
998 AllocDouble(allocator_).deallocate(weights_, prev_size);
1001 weights_ = tmp_weights;
1003 if (marks_ !=
nullptr) {
1004 bool* tmp_marks = AllocBool(allocator_).allocate(curr_items_alloc_);
1005 for (uint32_t i = 0; i < prev_size; ++i) {
1006 tmp_marks[i] = marks_[i];
1008 AllocBool(allocator_).deallocate(marks_, prev_size);
1014template<
typename T,
typename A>
1015void var_opt_sketch<T, A>::transition_from_warmup() {
1019 pop_min_to_m_region();
1020 pop_min_to_m_region();
1024 if (h_ != (k_ -1) || m_ != 1 || r_ != 1)
1025 throw std::logic_error(
"invalid state for transitioning from warmup");
1029 total_wt_r_ = weights_[k_];
1030 weights_[k_] = -1.0;
1034 grow_candidate_set(weights_[k_ - 1] + total_wt_r_, 2);
1037template<
typename T,
typename A>
1038void var_opt_sketch<T, A>::convert_to_heap() {
1043 const uint32_t last_slot = h_ - 1;
1044 const int last_non_leaf = ((last_slot + 1) / 2) - 1;
1046 for (
int j = last_non_leaf; j >= 0; --j) {
1047 restore_towards_leaves(j);
1057template<
typename T,
typename A>
1058void var_opt_sketch<T, A>::restore_towards_leaves(uint32_t slot_in) {
1059 const uint32_t last_slot = h_ - 1;
1060 if (h_ == 0 || slot_in > last_slot)
throw std::logic_error(
"invalid heap state");
1062 uint32_t slot = slot_in;
1063 uint32_t child = (2 * slot_in) + 1;
1065 while (child <= last_slot) {
1066 uint32_t child2 = child + 1;
1067 if ((child2 <= last_slot) && (weights_[child2] < weights_[child])) {
1072 if (weights_[slot] <= weights_[child]) {
1078 swap_values(slot, child);
1081 child = (2 * slot) + 1;
1085template<
typename T,
typename A>
1086void var_opt_sketch<T, A>::restore_towards_root(uint32_t slot_in) {
1087 uint32_t slot = slot_in;
1088 uint32_t p = (((slot + 1) / 2) - 1);
1089 while ((slot > 0) && (weights_[slot] < weights_[p])) {
1090 swap_values(slot, p);
1092 p = (((slot + 1) / 2) - 1);
1096template<
typename T,
typename A>
1098void var_opt_sketch<T, A>::push(O&& item,
double wt,
bool mark) {
1100 if (&data_[h_] != &item)
1101 data_[h_] = std::forward<O>(item);
1103 new (&data_[h_]) T(std::forward<O>(item));
1104 filled_data_ =
true;
1107 if (marks_ !=
nullptr) {
1109 num_marks_in_h_ += (mark ? 1 : 0);
1113 restore_towards_root(h_ - 1);
1116template<
typename T,
typename A>
1117void var_opt_sketch<T, A>::pop_min_to_m_region() {
1118 if (h_ == 0 || (h_ + m_ + r_ != k_ + 1))
1119 throw std::logic_error(
"invalid heap state popping min to M region");
1127 uint32_t tgt = h_ - 1;
1128 swap_values(0, tgt);
1132 restore_towards_leaves(0);
1135 if (is_marked(h_)) {
1141template<
typename T,
typename A>
1142void var_opt_sketch<T, A>::swap_values(uint32_t src, uint32_t dst) {
1143 std::swap(data_[src], data_[dst]);
1144 std::swap(weights_[src], weights_[dst]);
1146 if (marks_ !=
nullptr) {
1147 std::swap(marks_[src], marks_[dst]);
1159template<
typename T,
typename A>
1160void var_opt_sketch<T, A>::grow_candidate_set(
double wt_cands, uint32_t num_cands) {
1161 if ((h_ + m_ + r_ != k_ + 1) || (num_cands < 1) || (num_cands != m_ + r_) || (m_ >= 2))
1162 throw std::logic_error(
"invariant violated when growing candidate set");
1165 const double next_wt = peek_min();
1166 const double next_tot_wt = wt_cands + next_wt;
1171 if ((next_wt * num_cands) < next_tot_wt) {
1172 wt_cands = next_tot_wt;
1174 pop_min_to_m_region();
1180 downsample_candidate_set(wt_cands, num_cands);
1183template<
typename T,
typename A>
1184void var_opt_sketch<T, A>::downsample_candidate_set(
double wt_cands, uint32_t num_cands) {
1185 if (num_cands < 2 || h_ + num_cands != k_ + 1)
1186 throw std::logic_error(
"invalid num_cands when downsampling");
1189 const uint32_t delete_slot = choose_delete_slot(wt_cands, num_cands);
1190 const uint32_t leftmost_cand_slot = h_;
1191 if (delete_slot < leftmost_cand_slot || delete_slot > k_)
1192 throw std::logic_error(
"invalid delete slot index when downsampling");
1197 const uint32_t stop_idx = leftmost_cand_slot + m_;
1198 for (uint32_t j = leftmost_cand_slot; j < stop_idx; ++j) {
1203 data_[delete_slot] = std::move(data_[leftmost_cand_slot]);
1207 total_wt_r_ = wt_cands;
1210template<
typename T,
typename A>
1211uint32_t var_opt_sketch<T, A>::choose_delete_slot(
double wt_cands, uint32_t num_cands)
const {
1212 if (r_ == 0)
throw std::logic_error(
"choosing delete slot while in exact mode");
1216 return pick_random_slot_in_r();
1217 }
else if (m_ == 1) {
1220 double wt_m_cand = weights_[h_];
1221 if ((wt_cands * next_double_exclude_zero()) < ((num_cands - 1) * wt_m_cand)) {
1222 return pick_random_slot_in_r();
1228 const uint32_t delete_slot = choose_weighted_delete_slot(wt_cands, num_cands);
1229 const uint32_t first_r_slot = h_ + m_;
1230 if (delete_slot == first_r_slot) {
1231 return pick_random_slot_in_r();
1238template<
typename T,
typename A>
1239uint32_t var_opt_sketch<T, A>::choose_weighted_delete_slot(
double wt_cands, uint32_t num_cands)
const {
1240 if (m_ < 1)
throw std::logic_error(
"must have weighted delete slot");
1242 const uint32_t offset = h_;
1243 const uint32_t final_m = (offset + m_) - 1;
1244 const uint32_t num_to_keep = num_cands - 1;
1246 double left_subtotal = 0.0;
1247 double right_subtotal = -1.0 * wt_cands * next_double_exclude_zero();
1249 for (uint32_t i = offset; i <= final_m; ++i) {
1250 left_subtotal += num_to_keep * weights_[i];
1251 right_subtotal += wt_cands;
1253 if (left_subtotal < right_subtotal) {
1262template<
typename T,
typename A>
1263uint32_t var_opt_sketch<T, A>::pick_random_slot_in_r()
const {
1264 if (r_ == 0)
throw std::logic_error(
"r_ = 0 when picking slot in R region");
1265 const uint32_t offset = h_ + m_;
1269 return offset + next_int(r_);
1273template<
typename T,
typename A>
1274double var_opt_sketch<T, A>::peek_min()
const {
1275 if (h_ == 0)
throw std::logic_error(
"h_ = 0 when checking min in H region");
1279template<
typename T,
typename A>
1280inline bool var_opt_sketch<T, A>::is_marked(uint32_t idx)
const {
1281 return marks_ ==
nullptr ? false : marks_[idx];
1284template<
typename T,
typename A>
1285double var_opt_sketch<T, A>::get_tau()
const {
1286 return r_ == 0 ? std::nan(
"1") : (total_wt_r_ / r_);
1289template<
typename T,
typename A>
1290void var_opt_sketch<T, A>::strip_marks() {
1291 if (marks_ ==
nullptr)
throw std::logic_error(
"request to strip marks from non-gadget");
1292 num_marks_in_h_ = 0;
1293 AllocBool(allocator_).deallocate(marks_, curr_items_alloc_);
1297template<
typename T,
typename A>
1298void var_opt_sketch<T, A>::check_preamble_longs(uint8_t preamble_longs, uint8_t flags) {
1299 const bool is_empty(flags & EMPTY_FLAG_MASK);
1302 if (preamble_longs != PREAMBLE_LONGS_EMPTY) {
1303 throw std::invalid_argument(
"Possible corruption: Preamble longs must be "
1304 + std::to_string(PREAMBLE_LONGS_EMPTY) +
" for an empty sketch. Found: "
1305 + std::to_string(preamble_longs));
1308 if (preamble_longs != PREAMBLE_LONGS_WARMUP
1309 && preamble_longs != PREAMBLE_LONGS_FULL) {
1310 throw std::invalid_argument(
"Possible corruption: Preamble longs must be "
1311 + std::to_string(PREAMBLE_LONGS_WARMUP) +
" or "
1312 + std::to_string(PREAMBLE_LONGS_FULL)
1313 +
" for a non-empty sketch. Found: " + std::to_string(preamble_longs));
1318template<
typename T,
typename A>
1319void var_opt_sketch<T, A>::check_family_and_serialization_version(uint8_t family_id, uint8_t ser_ver) {
1320 if (family_id == FAMILY_ID) {
1321 if (ser_ver != SER_VER) {
1322 throw std::invalid_argument(
"Possible corruption: VarOpt serialization version must be "
1323 + std::to_string(SER_VER) +
". Found: " + std::to_string(ser_ver));
1329 throw std::invalid_argument(
"Possible corruption: VarOpt family id must be "
1330 + std::to_string(FAMILY_ID) +
". Found: " + std::to_string(family_id));
1333template<
typename T,
typename A>
1334uint32_t var_opt_sketch<T, A>::validate_and_get_target_size(uint32_t preamble_longs, uint32_t k, uint64_t n,
1335 uint32_t h, uint32_t r, resize_factor rf) {
1336 if (k == 0 || k > MAX_K) {
1337 throw std::invalid_argument(
"k must be at least 1 and less than 2^31 - 1");
1340 uint32_t array_size;
1343 if (preamble_longs != PREAMBLE_LONGS_WARMUP) {
1344 throw std::invalid_argument(
"Possible corruption: deserializing with n <= k but not in warmup mode. "
1345 "Found n = " + std::to_string(n) +
", k = " + std::to_string(k));
1348 throw std::invalid_argument(
"Possible corruption: deserializing in warmup mode but n != h. "
1349 "Found n = " + std::to_string(n) +
", h = " + std::to_string(h));
1352 throw std::invalid_argument(
"Possible corruption: deserializing in warmup mode but r > 0. "
1353 "Found r = " + std::to_string(r));
1356 const uint32_t ceiling_lg_k = to_log_2(ceiling_power_of_2(k));
1357 const uint32_t min_lg_size = to_log_2(ceiling_power_of_2(h));
1358 const uint32_t initial_lg_size = starting_sub_multiple(ceiling_lg_k, rf, min_lg_size);
1359 array_size = get_adjusted_size(k, 1 << initial_lg_size);
1360 if (array_size == k) {
1364 if (preamble_longs != PREAMBLE_LONGS_FULL) {
1365 throw std::invalid_argument(
"Possible corruption: deserializing with n > k but not in full mode. "
1366 "Found n = " + std::to_string(n) +
", k = " + std::to_string(k));
1369 throw std::invalid_argument(
"Possible corruption: deserializing in full mode but h + r != n. "
1370 "Found h = " + std::to_string(h) +
", r = " + std::to_string(r) +
", n = " + std::to_string(n));
1379template<
typename T,
typename A>
1383 return {0.0, 0.0, 0.0, 0.0};
1386 double total_wt_h = 0.0;
1387 double h_true_wt = 0.0;
1389 for (; idx < h_; ++idx) {
1390 double wt = weights_[idx];
1392 if (predicate(data_[idx])) {
1399 return {h_true_wt, h_true_wt, h_true_wt, h_true_wt};
1403 const uint64_t num_samples = n_ - h_;
1404 double effective_sampling_rate = r_ /
static_cast<double>(num_samples);
1405 if (effective_sampling_rate < 0.0 || effective_sampling_rate > 1.0)
1406 throw std::logic_error(
"invalid sampling rate outside [0.0, 1.0]");
1408 uint32_t r_true_count = 0;
1410 for (; idx < (k_ + 1); ++idx) {
1411 if (predicate(data_[idx])) {
1416 double lb_true_fraction = pseudo_hypergeometric_lb_on_p(r_, r_true_count, effective_sampling_rate);
1417 double estimated_true_fraction = (1.0 * r_true_count) / r_;
1418 double ub_true_fraction = pseudo_hypergeometric_ub_on_p(r_, r_true_count, effective_sampling_rate);
1420 return { h_true_wt + (total_wt_r_ * lb_true_fraction),
1421 h_true_wt + (total_wt_r_ * estimated_true_fraction),
1422 h_true_wt + (total_wt_r_ * ub_true_fraction),
1423 total_wt_h + total_wt_r_
1427template<
typename T,
typename A>
1430 items_deleter(uint32_t num,
const A& allocator) : num(num), h_count(0), r_count(0), allocator(allocator) {}
1431 void set_h(uint32_t h) { h_count = h; }
1432 void set_r(uint32_t r) { r_count = r; }
1433 void operator() (T* ptr) {
1435 for (
size_t i = 0; i < h_count; ++i) {
1440 uint32_t end = h_count + r_count + 1;
1441 for (
size_t i = h_count + 1; i < end; ++i) {
1445 if (ptr !=
nullptr) {
1446 allocator.deallocate(ptr, num);
1456template<
typename T,
typename A>
1457class var_opt_sketch<T, A>::weights_deleter {
1459 weights_deleter(uint32_t num,
const A& allocator) : num(num), allocator(allocator) {}
1460 void operator() (
double* ptr) {
1461 if (ptr !=
nullptr) {
1462 allocator.deallocate(ptr, num);
1467 AllocDouble allocator;
1470template<
typename T,
typename A>
1471class var_opt_sketch<T, A>::marks_deleter {
1473 marks_deleter(uint32_t num,
const A& allocator) : num(num), allocator(allocator) {}
1474 void operator() (
bool* ptr) {
1475 if (ptr !=
nullptr) {
1476 allocator.deallocate(ptr, 1);
1481 AllocBool allocator;
1485template<
typename T,
typename A>
1487 return const_iterator(*
this,
false);
1490template<
typename T,
typename A>
1492 return const_iterator(*
this,
true);