20#ifndef _VAR_OPT_UNION_IMPL_HPP_
21#define _VAR_OPT_UNION_IMPL_HPP_
23#include "var_opt_union.hpp"
31template<
typename T,
typename A>
32var_opt_union<T, A>::var_opt_union(uint32_t max_k,
const A& allocator) :
34 outer_tau_numer_(0.0),
37 allocator_(allocator),
41template<
typename T,
typename A>
42var_opt_union<T, A>::var_opt_union(
const var_opt_union& other) :
44 outer_tau_numer_(other.outer_tau_numer_),
45 outer_tau_denom_(other.outer_tau_denom_),
47 allocator_(other.allocator_),
48 gadget_(other.gadget_)
51template<
typename T,
typename A>
52var_opt_union<T, A>::var_opt_union(var_opt_union&& other) noexcept :
54 outer_tau_numer_(other.outer_tau_numer_),
55 outer_tau_denom_(other.outer_tau_denom_),
57 allocator_(other.allocator_),
58 gadget_(std::move(other.gadget_))
61template<
typename T,
typename A>
62var_opt_union<T, A>::var_opt_union(uint64_t n,
double outer_tau_numer, uint64_t outer_tau_denom,
63 uint32_t max_k, var_opt_sketch<T, A>&& gadget,
const A& allocator) :
65 outer_tau_numer_(outer_tau_numer),
66 outer_tau_denom_(outer_tau_denom),
68 allocator_(allocator),
72template<
typename T,
typename A>
73var_opt_union<T, A>::~var_opt_union() {}
75template<
typename T,
typename A>
76var_opt_union<T, A>& var_opt_union<T, A>::operator=(
const var_opt_union& other) {
77 var_opt_union union_copy(other);
78 std::swap(n_, union_copy.n_);
79 std::swap(outer_tau_numer_, union_copy.outer_tau_numer_);
80 std::swap(outer_tau_denom_, union_copy.outer_tau_denom_);
81 std::swap(max_k_, union_copy.max_k_);
82 std::swap(allocator_, other.allocator_);
83 std::swap(gadget_, union_copy.gadget_);
87template<
typename T,
typename A>
88var_opt_union<T, A>& var_opt_union<T, A>::operator=(var_opt_union&& other) {
89 std::swap(n_, other.n_);
90 std::swap(outer_tau_numer_, other.outer_tau_numer_);
91 std::swap(outer_tau_denom_, other.outer_tau_denom_);
92 std::swap(max_k_, other.max_k_);
93 std::swap(allocator_, other.allocator_);
94 std::swap(gadget_, other.gadget_);
137template<
typename T,
typename A>
138template<
typename SerDe>
139var_opt_union<T, A> var_opt_union<T, A>::deserialize(std::istream& is,
const SerDe& sd,
const A& allocator) {
140 const auto preamble_longs = read<uint8_t>(is);
141 const auto serial_version = read<uint8_t>(is);
142 const auto family_id = read<uint8_t>(is);
143 const auto flags = read<uint8_t>(is);
144 const auto max_k = read<uint32_t>(is);
146 check_preamble_longs(preamble_longs, flags);
147 check_family_and_serialization_version(family_id, serial_version);
149 if (max_k == 0 || max_k > var_opt_constants::MAX_K) {
150 throw std::invalid_argument(
"k must be at least 1 and less than 2^31 - 1");
153 bool is_empty = flags & EMPTY_FLAG_MASK;
157 throw std::runtime_error(
"error reading from std::istream");
159 return var_opt_union(max_k);
162 const auto items_seen = read<uint64_t>(is);
163 const auto outer_tau_numer = read<double>(is);
164 const auto outer_tau_denom = read<uint64_t>(is);
166 var_opt_sketch<T, A> gadget = var_opt_sketch<T, A>::deserialize(is, sd, allocator);
169 throw std::runtime_error(
"error reading from std::istream");
171 return var_opt_union(items_seen, outer_tau_numer, outer_tau_denom, max_k, std::move(gadget), allocator);
174template<
typename T,
typename A>
175template<
typename SerDe>
176var_opt_union<T, A> var_opt_union<T, A>::deserialize(
const void* bytes,
size_t size,
const SerDe& sd,
const A& allocator) {
177 ensure_minimum_memory(size, 8);
178 const char* ptr =
static_cast<const char*
>(bytes);
179 uint8_t preamble_longs;
180 ptr += copy_from_mem(ptr, preamble_longs);
181 uint8_t serial_version;
182 ptr += copy_from_mem(ptr, serial_version);
184 ptr += copy_from_mem(ptr, family_id);
186 ptr += copy_from_mem(ptr, flags);
188 ptr += copy_from_mem(ptr, max_k);
190 check_preamble_longs(preamble_longs, flags);
191 check_family_and_serialization_version(family_id, serial_version);
193 if (max_k == 0 || max_k > var_opt_constants::MAX_K) {
194 throw std::invalid_argument(
"k must be at least 1 and less than 2^31 - 1");
197 bool is_empty = flags & EMPTY_FLAG_MASK;
200 return var_opt_union(max_k);
204 ptr += copy_from_mem(ptr, items_seen);
205 double outer_tau_numer;
206 ptr += copy_from_mem(ptr, outer_tau_numer);
207 uint64_t outer_tau_denom;
208 ptr += copy_from_mem(ptr, outer_tau_denom);
210 const size_t gadget_size = size - (PREAMBLE_LONGS_NON_EMPTY << 3);
211 var_opt_sketch<T, A> gadget = var_opt_sketch<T, A>::deserialize(ptr, gadget_size, sd, allocator);
213 return var_opt_union(items_seen, outer_tau_numer, outer_tau_denom, max_k, std::move(gadget), allocator);
216template<
typename T,
typename A>
217template<
typename SerDe>
220 return PREAMBLE_LONGS_EMPTY << 3;
226template<
typename T,
typename A>
227template<
typename SerDe>
229 bool empty = (n_ == 0);
231 const uint8_t serialization_version(SER_VER);
232 const uint8_t family_id(FAMILY_ID);
234 uint8_t preamble_longs;
237 preamble_longs = PREAMBLE_LONGS_EMPTY;
238 flags = EMPTY_FLAG_MASK;
240 preamble_longs = PREAMBLE_LONGS_NON_EMPTY;
244 write(os, preamble_longs);
245 write(os, serialization_version);
246 write(os, family_id);
252 write(os, outer_tau_numer_);
253 write(os, outer_tau_denom_);
254 gadget_.serialize(os, sd);
258template<
typename T,
typename A>
259template<
typename SerDe>
261 const size_t size = header_size_bytes + get_serialized_size_bytes(sd);
262 std::vector<uint8_t, AllocU8<A>> bytes(size, 0, gadget_.allocator_);
263 uint8_t* ptr = bytes.data() + header_size_bytes;
265 const bool empty = n_ == 0;
267 const uint8_t serialization_version(SER_VER);
268 const uint8_t family_id(FAMILY_ID);
270 uint8_t preamble_longs;
274 preamble_longs = PREAMBLE_LONGS_EMPTY;
275 flags = EMPTY_FLAG_MASK;
277 preamble_longs = PREAMBLE_LONGS_NON_EMPTY;
282 ptr += copy_to_mem(preamble_longs, ptr);
283 ptr += copy_to_mem(serialization_version, ptr);
284 ptr += copy_to_mem(family_id, ptr);
285 ptr += copy_to_mem(flags, ptr);
286 ptr += copy_to_mem(max_k_, ptr);
289 ptr += copy_to_mem(n_, ptr);
290 ptr += copy_to_mem(outer_tau_numer_, ptr);
291 ptr += copy_to_mem(outer_tau_denom_, ptr);
293 auto gadget_bytes = gadget_.serialize(0, sd);
294 ptr += copy_to_mem(gadget_bytes.data(), ptr, gadget_bytes.size() *
sizeof(uint8_t));
300template<
typename T,
typename A>
303 outer_tau_numer_ = 0.0;
304 outer_tau_denom_ = 0;
308template<
typename T,
typename A>
312 std::ostringstream os;
313 os <<
"### VarOpt Union SUMMARY:" << std::endl;
314 os <<
" n : " << n_ << std::endl;
315 os <<
" Max k : " << max_k_ << std::endl;
316 os <<
" Gadget Summary:" << std::endl;
318 os <<
"### END VarOpt Union SUMMARY" << std::endl;
319 return string<A>(os.str().c_str(), gadget_.allocator_);
322template<
typename T,
typename A>
328template<
typename T,
typename A>
330 merge_items(std::move(sk));
334template<
typename T,
typename A>
336 if (outer_tau_denom_ == 0) {
339 return outer_tau_numer_ / outer_tau_denom_;
343template<
typename T,
typename A>
344void var_opt_union<T, A>::merge_items(
const var_opt_sketch<T, A>& sketch) {
345 if (sketch.n_ == 0) {
352 typename var_opt_sketch<T, A>::const_iterator h_itr(sketch,
false,
false);
353 typename var_opt_sketch<T, A>::const_iterator h_end(sketch,
true,
false);
354 while (h_itr != h_end) {
355 std::pair<const T&, const double> sample = *h_itr;
356 gadget_.update(sample.first, sample.second,
false);
361 typename var_opt_sketch<T, A>::iterator r_itr(sketch,
false,
true);
362 typename var_opt_sketch<T, A>::iterator r_end(sketch,
true,
true);
363 while (r_itr != r_end) {
364 std::pair<const T&, const double> sample = *r_itr;
365 gadget_.update(sample.first, sample.second,
true);
370template<
typename T,
typename A>
371void var_opt_union<T, A>::merge_items(var_opt_sketch<T, A>&& sketch) {
372 if (sketch.n_ == 0) {
379 typename var_opt_sketch<T, A>::iterator h_itr(sketch,
false,
false);
380 typename var_opt_sketch<T, A>::iterator h_end(sketch,
true,
false);
381 while (h_itr != h_end) {
382 std::pair<T&, double> sample = *h_itr;
383 gadget_.update(std::move(sample.first), sample.second,
false);
388 typename var_opt_sketch<T, A>::iterator r_itr(sketch,
false,
true);
389 typename var_opt_sketch<T, A>::iterator r_end(sketch,
true,
true);
390 while (r_itr != r_end) {
391 std::pair<T&, double> sample = *r_itr;
392 gadget_.update(std::move(sample.first), sample.second,
true);
397template<
typename T,
typename A>
398void var_opt_union<T, A>::resolve_tau(
const var_opt_sketch<T, A>& sketch) {
400 const double sketch_tau = sketch.get_tau();
401 const double outer_tau = get_outer_tau();
403 if (outer_tau_denom_ == 0) {
405 outer_tau_numer_ = sketch.total_wt_r_;
406 outer_tau_denom_ = sketch.r_;
407 }
else if (sketch_tau > outer_tau) {
409 outer_tau_numer_ = sketch.total_wt_r_;
410 outer_tau_denom_ = sketch.r_;
411 }
else if (sketch_tau == outer_tau) {
416 outer_tau_numer_ += sketch.total_wt_r_;
417 outer_tau_denom_ += sketch.r_;
424template<
typename T,
typename A>
428 if (gadget_.num_marks_in_h_ == 0) {
429 return simple_gadget_coercer();
439 const bool is_pseudo_exact = detect_and_handle_subcase_of_pseudo_exact(gcopy);
440 if (!is_pseudo_exact) {
442 migrate_marked_items_by_decreasing_k(gcopy);
455template<
typename T,
typename A>
457 if (gadget_.num_marks_in_h_ != 0)
throw std::logic_error(
"simple gadget coercer only applies if no marks");
462template<
typename T,
typename A>
463bool var_opt_union<T, A>::there_exist_unmarked_h_items_lighter_than_target(
double threshold)
const {
464 for (uint32_t i = 0; i < gadget_.h_; ++i) {
465 if ((gadget_.weights_[i] < threshold) && !gadget_.marks_[i]) {
472template<
typename T,
typename A>
473bool var_opt_union<T, A>::detect_and_handle_subcase_of_pseudo_exact(var_opt_sketch<T, A>& sk)
const {
475 const bool condition1 = gadget_.r_ == 0;
478 const bool condition2 = gadget_.num_marks_in_h_ > 0;
483 const bool condition3 = gadget_.num_marks_in_h_ == outer_tau_denom_;
485 if (!(condition1 && condition2 && condition3)) {
490 const bool anti_condition4 = there_exist_unmarked_h_items_lighter_than_target(gadget_.get_tau());
491 if (anti_condition4) {
495 mark_moving_gadget_coercer(sk);
509template<
typename T,
typename A>
510void var_opt_union<T, A>::mark_moving_gadget_coercer(var_opt_sketch<T, A>& sk)
const {
511 const uint32_t result_k = gadget_.h_ + gadget_.r_;
513 uint32_t result_h = 0;
514 uint32_t result_r = 0;
515 size_t next_r_pos = result_k;
517 double* wts = AllocDouble(allocator_).allocate(result_k + 1);
518 T* data = A(allocator_).allocate(result_k + 1);
525 const size_t final_idx = gadget_.get_num_samples();
526 for (
size_t idx = gadget_.h_ + 1; idx <= final_idx; ++idx) {
527 new (&data[next_r_pos]) T(gadget_.data_[idx]);
528 wts[next_r_pos] = gadget_.weights_[idx];
533 double transferred_weight = 0;
536 for (
size_t idx = 0; idx < gadget_.h_; ++idx) {
537 if (gadget_.marks_[idx]) {
538 new (&data[next_r_pos]) T(gadget_.data_[idx]);
539 wts[next_r_pos] = -1.0;
540 transferred_weight += gadget_.weights_[idx];
544 new (&data[result_h]) T(gadget_.data_[idx]);
545 wts[result_h] = gadget_.weights_[idx];
550 if (result_h + result_r != result_k)
throw std::logic_error(
"H + R counts must equal k");
551 if (std::abs(transferred_weight - outer_tau_numer_) > 1e-10) {
552 throw std::logic_error(
"unexpected mismatch in transferred weight");
555 const double result_r_weight = gadget_.total_wt_r_ + transferred_weight;
556 const uint64_t result_n = n_;
559 wts[result_h] = -1.0;
562 AllocBool(allocator_).deallocate(sk.marks_, sk.curr_items_alloc_);
563 AllocDouble(allocator_).deallocate(sk.weights_, sk.curr_items_alloc_);
564 for (
size_t i = 0; i < result_k; ++i) { sk.data_[i].~T(); }
565 A(allocator_).deallocate(sk.data_, sk.curr_items_alloc_);
570 sk.num_marks_in_h_ = 0;
571 sk.curr_items_alloc_ = result_k + 1;
576 sk.total_wt_r_ = result_r_weight;
580template<
typename T,
typename A>
581void var_opt_union<T, A>::migrate_marked_items_by_decreasing_k(var_opt_sketch<T, A>& gcopy)
const {
582 const uint32_t r_count = gcopy.r_;
583 const uint32_t h_count = gcopy.h_;
584 const uint32_t k = gcopy.k_;
587 if (gcopy.num_marks_in_h_ == 0)
throw std::logic_error(
"unexpectedly found no marked items to migrate");
589 if ((r_count != 0) && ((h_count + r_count) != k))
throw std::logic_error(
"invalid gadget state");
592 if ((r_count == 0) && (h_count < k)) {
599 gcopy.decrease_k_by_1();
602 if (gcopy.get_tau() == 0.0)
throw std::logic_error(
"gadget must be in sampling mode");
605 while (gcopy.num_marks_in_h_ > 0) {
607 gcopy.decrease_k_by_1();
613template<
typename T,
typename A>
614void var_opt_union<T, A>::check_preamble_longs(uint8_t preamble_longs, uint8_t flags) {
615 bool is_empty(flags & EMPTY_FLAG_MASK);
618 if (preamble_longs != PREAMBLE_LONGS_EMPTY) {
619 throw std::invalid_argument(
"Possible corruption: Preamble longs must be "
620 + std::to_string(PREAMBLE_LONGS_EMPTY) +
" for an empty sketch. Found: "
621 + std::to_string(preamble_longs));
624 if (preamble_longs != PREAMBLE_LONGS_NON_EMPTY) {
625 throw std::invalid_argument(
"Possible corruption: Preamble longs must be "
626 + std::to_string(PREAMBLE_LONGS_NON_EMPTY)
627 +
" for a non-empty sketch. Found: " + std::to_string(preamble_longs));
632template<
typename T,
typename A>
633void var_opt_union<T, A>::check_family_and_serialization_version(uint8_t family_id, uint8_t ser_ver) {
634 if (family_id == FAMILY_ID) {
635 if (ser_ver != SER_VER) {
636 throw std::invalid_argument(
"Possible corruption: VarOpt Union serialization version must be "
637 + std::to_string(SER_VER) +
". Found: " + std::to_string(ser_ver));
643 throw std::invalid_argument(
"Possible corruption: VarOpt Union family id must be "
644 + std::to_string(FAMILY_ID) +
". Found: " + std::to_string(family_id));
This sketch samples data from a stream of items.
Definition var_opt_sketch.hpp:67
Provides a unioning operation over var_opt_sketch objects.
Definition var_opt_union.hpp:52
size_t get_serialized_size_bytes(const SerDe &sd=SerDe()) const
Computes size needed to serialize the current state of the union.
Definition var_opt_union_impl.hpp:218
void reset()
Resets the union to its default, empty state.
Definition var_opt_union_impl.hpp:301
string< A > to_string() const
Prints a summary of the union as a string.
Definition var_opt_union_impl.hpp:309
const resize_factor DEFAULT_RESIZE_FACTOR
default resize factor
Definition theta_constants.hpp:33
DataSketches namespace.
Definition binomial_bounds.hpp:38