datasketches-cpp
Loading...
Searching...
No Matches
ebpps_sketch_impl.hpp
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#ifndef _EBPPS_SKETCH_IMPL_HPP_
21#define _EBPPS_SKETCH_IMPL_HPP_
22
23#include <memory>
24#include <sstream>
25#include <cmath>
26#include <random>
27#include <algorithm>
28#include <stdexcept>
29#include <utility>
30
31#include "ebpps_sketch.hpp"
32
33namespace datasketches {
34
35template<typename T, typename A>
36ebpps_sketch<T, A>::ebpps_sketch(uint32_t k, const A& allocator) :
37 allocator_(allocator),
38 k_(k),
39 n_(0),
40 cumulative_wt_(0.0),
41 wt_max_(0.0),
42 rho_(1.0),
43 sample_(check_k(k), allocator),
44 tmp_(1, allocator)
45 {}
46
47template<typename T, typename A>
48ebpps_sketch<T,A>::ebpps_sketch(uint32_t k, uint64_t n, double cumulative_wt,
49 double wt_max, double rho,
50 ebpps_sample<T,A>&& sample, const A& allocator) :
51 allocator_(allocator),
52 k_(k),
53 n_(n),
54 cumulative_wt_(cumulative_wt),
55 wt_max_(wt_max),
56 rho_(rho),
57 sample_(sample),
58 tmp_(1, allocator)
59 {}
60
61template<typename T, typename A>
62uint32_t ebpps_sketch<T, A>::get_k() const {
63 return k_;
64}
65
66template<typename T, typename A>
67uint64_t ebpps_sketch<T, A>::get_n() const {
68 return n_;
69}
70
71template<typename T, typename A>
73 return sample_.get_c();
74}
75
76template<typename T, typename A>
78 return cumulative_wt_;
79}
80
81template<typename T, typename A>
83 return n_ == 0;
84}
85
86template<typename T, typename A>
88 n_ = 0;
89 cumulative_wt_ = 0.0;
90 wt_max_ = 0.0;
91 rho_ = 1.0;
92 sample_.reset();
93}
94
95template<typename T, typename A>
97 // Using a temporary stream for implementation here does not comply with AllocatorAwareContainer requirements.
98 // The stream does not support passing an allocator instance, and alternatives are complicated.
99 std::ostringstream os;
100 os << "### EBPPS Sketch SUMMARY:" << std::endl;
101 os << " k : " << k_ << std::endl;
102 os << " n : " << n_ << std::endl;
103 os << " cum. weight : " << cumulative_wt_ << std::endl;
104 os << " wt_mac : " << wt_max_ << std::endl;
105 os << " rho : " << rho_ << std::endl;
106 os << " C : " << sample_.get_c() << std::endl;
107 os << "### END SKETCH SUMMARY" << std::endl;
108 return string<A>(os.str().c_str(), allocator_);
109}
110
111template<typename T, typename A>
113 // Using a temporary stream for implementation here does not comply with AllocatorAwareContainer requirements.
114 // The stream does not support passing an allocator instance, and alternatives are complicated.
115 std::ostringstream os;
116 os << "### Sketch Items" << std::endl;
117 os << sample_.to_string(); // assumes std::endl at end
118 return string<A>(os.str().c_str(), allocator_);
119}
120
121template<typename T, typename A>
123 return allocator_;
124}
125
126template<typename T, typename A>
127void ebpps_sketch<T, A>::update(const T& item, double weight) {
128 return internal_update(item, weight);
129}
130
131template<typename T, typename A>
132void ebpps_sketch<T, A>::update(T&& item, double weight) {
133 return internal_update(std::move(item), weight);
134}
135
136template<typename T, typename A>
137template<typename FwdItem>
138void ebpps_sketch<T, A>::internal_update(FwdItem&& item, double weight) {
139 if (weight < 0.0 || std::isnan(weight) || std::isinf(weight)) {
140 throw std::invalid_argument("Item weights must be nonnegative and finite. Found: "
141 + std::to_string(weight));
142 } else if (weight == 0.0) {
143 return;
144 }
145
146 const double new_cum_wt = cumulative_wt_ + weight;
147 const double new_wt_max = std::max(wt_max_, weight);
148 const double new_rho = std::min(1.0 / new_wt_max, k_ / new_cum_wt);
149
150 if (cumulative_wt_ > 0.0)
151 sample_.downsample(new_rho / rho_);
152
153 tmp_.replace_content(conditional_forward<FwdItem>(item), new_rho * weight);
154 sample_.merge(tmp_);
155
156 cumulative_wt_ = new_cum_wt;
157 wt_max_ = new_wt_max;
158 rho_ = new_rho;
159 ++n_;
160}
161
162template<typename T, typename A>
163auto ebpps_sketch<T,A>::get_result() const -> result_type {
164 return sample_.get_sample();
165}
166
167/* Merging
168 * There is a trivial merge algorithm that involves downsampling each sketch A and B
169 * as A.cum_wt / (A.cum_wt + B.cum_wt) and B.cum_wt / (A.cum_wt + B.cum_wt),
170 * respectively. That merge does preserve first-order probabilities, specifically
171 * the probability proportional to size property, and like all other known merge
172 * algorithms distorts second-order probabilities (co-occurrences). There are
173 * pathological cases, most obvious with k=2 and A.cum_wt == B.cum_wt where that
174 * approach will always take exactly 1 item from A and 1 from B, meaning the
175 * co-occurrence rate for two items from either sketch is guaranteed to be 0.0.
176 *
177 * With EBPPS, once an item is accepted into the sketch we no longer need to
178 * track the item's weight: All accepted items are treated equally. As a result, we
179 * can take inspiration from the reservoir sampling merge in the datasketches-java
180 * library. We need to merge the smaller sketch into the larger one, swapping as
181 * needed to ensure that, at which point we simply call update() with the items
182 * in the smaller sketch as long as we adjust the weight appropriately.
183 * Merging smaller into larger is essential to ensure that no item has a
184 * contribution to C > 1.0.
185 */
186
187template<typename T, typename A>
189 if (sk.get_cumulative_weight() == 0.0) return;
190 else if (sk.get_cumulative_weight() > get_cumulative_weight()) {
191 // need to swap this with sk to merge smaller into larger
192 std::swap(*this, sk);
193 }
194
195 internal_merge(sk);
196}
197
198template<typename T, typename A>
200 if (sk.get_cumulative_weight() == 0.0) return;
201 else if (sk.get_cumulative_weight() > get_cumulative_weight()) {
202 // need to swap this with sk to merge, so make a copy, swap,
203 // and use that to merge
204 ebpps_sketch sk_copy(sk);
205 swap(*this, sk_copy);
206 internal_merge(sk_copy);
207 } else {
208 internal_merge(sk);
209 }
210}
211
212template<typename T, typename A>
213template<typename O>
215 // assumes that sk.cumulative_wt_ <= cumulative_wt_,
216 // which must be checked before calling this
217
218 const ebpps_sample<T,A>& other_sample = sk.sample_;
219
220 const double final_cum_wt = cumulative_wt_ + sk.cumulative_wt_;
221 const double new_wt_max = std::max(wt_max_, sk.wt_max_);
222 k_ = std::min(k_, sk.k_);
223 const uint64_t new_n = n_ + sk.n_;
224
225 // Insert sk's items with the cumulative weight
226 // split between the input items. We repeat the same process
227 // for full items and the partial item, scaling the input
228 // weight appropriately.
229 // We handle all C input items, meaning we always process
230 // the partial item using a scaled down weight.
231 // Handling the partial item by probabilistically including
232 // it as a full item would be correct on average but would
233 // introduce bias for any specific merge operation.
234 const double avg_wt = sk.get_cumulative_weight() / sk.get_c();
235 auto items = other_sample.get_full_items();
236 for (size_t i = 0; i < items.size(); ++i) {
237 // new_wt_max is pre-computed
238 const double new_cum_wt = cumulative_wt_ + avg_wt;
239 const double new_rho = std::min(1.0 / new_wt_max, k_ / new_cum_wt);
240
241 if (cumulative_wt_ > 0.0)
242 sample_.downsample(new_rho / rho_);
243
244 tmp_.replace_content(conditional_forward<O>(items[i]), new_rho * avg_wt);
245 sample_.merge(tmp_);
246
247 cumulative_wt_ = new_cum_wt;
248 rho_ = new_rho;
249 }
250
251 // insert partial item with weight scaled by the fractional part of C
252 if (other_sample.has_partial_item()) {
253 double unused;
254 const double other_c_frac = std::modf(other_sample.get_c(), &unused);
255
256 const double new_cum_wt = cumulative_wt_ + (other_c_frac * avg_wt);
257 const double new_rho = std::min(1.0 / new_wt_max, k_ / new_cum_wt);
258
259 if (cumulative_wt_ > 0.0)
260 sample_.downsample(new_rho / rho_);
261
262 tmp_.replace_content(conditional_forward<O>(other_sample.get_partial_item()), new_rho * other_c_frac * avg_wt);
263 sample_.merge(tmp_);
264
265 cumulative_wt_ = new_cum_wt;
266 rho_ = new_rho;
267 }
268
269 // avoid numeric issues by setting cumulative weight to the
270 // pre-computed value
271 cumulative_wt_ = final_cum_wt;
272 n_ = new_n;
273}
274
275/*
276 * An empty sketch requires 8 bytes.
277 *
278 * <pre>
279 * Long || Start Byte Adr:
280 * Adr:
281 * || 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
282 * 0 || Preamble_Longs | SerVer | FamID | Flags |---------Max Res. Size (K)---------|
283 * </pre>
284 *
285 * A non-empty sketch requires 40 bytes of preamble. C looks like part of
286 * the preamble but is serialized as part of the internal sample state.
287 *
288 * The count of items seen is not used but preserved as the value seems like a useful
289 * count to track.
290 *
291 * <pre>
292 * Long || Start Byte Adr:
293 * Adr:
294 * || 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
295 * 0 || Preamble_Longs | SerVer | FamID | Flags |---------Max Res. Size (K)---------|
296 *
297 * || 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
298 * 1 ||---------------------------Items Seen Count (N)--------------------------------|
299 *
300 * || 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 |
301 * 2 ||----------------------------Cumulative Weight----------------------------------|
302 *
303 * || 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |
304 * 3 ||-----------------------------Max Item Weight-----------------------------------|
305 *
306 * || 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 |
307 * 4 ||----------------------------------Rho------------------------------------------|
308 *
309 * || 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 |
310 * 5 ||-----------------------------------C-------------------------------------------|
311 *
312 * || 40+ |
313 * 6+ || {Items Array} |
314 * || {Optional Item (if needed)} |
315 * </pre>
316 */
317
318template<typename T, typename A>
319template<typename SerDe>
321 if (is_empty()) { return PREAMBLE_LONGS_EMPTY << 3; }
322 return (PREAMBLE_LONGS_FULL << 3) + sample_.get_serialized_size_bytes(sd);
323}
324
325template<typename T, typename A>
326template<typename SerDe>
327auto ebpps_sketch<T,A>::serialize(unsigned header_size_bytes, const SerDe& sd) const -> vector_bytes {
328 const uint8_t prelongs = (is_empty() ? PREAMBLE_LONGS_EMPTY : PREAMBLE_LONGS_FULL);
329
330 const size_t size = header_size_bytes + (prelongs << 3) + sample_.get_serialized_size_bytes(sd);
331 vector_bytes bytes(size, 0, allocator_);
332 uint8_t* ptr = bytes.data() + header_size_bytes;
333 const uint8_t* end_ptr = ptr + size;
334
335 uint8_t flags = 0;
336 if (is_empty()) {
337 flags |= EMPTY_FLAG_MASK;
338 } else {
339 flags |= sample_.has_partial_item() ? HAS_PARTIAL_ITEM_MASK : 0;
340 }
341
342 // first prelong
343 const uint8_t ser_ver = SER_VER;
344 const uint8_t family = FAMILY_ID;
345 ptr += copy_to_mem(prelongs, ptr);
346 ptr += copy_to_mem(ser_ver, ptr);
347 ptr += copy_to_mem(family, ptr);
348 ptr += copy_to_mem(flags, ptr);
349 ptr += copy_to_mem(k_, ptr);
350
351 if (!is_empty()) {
352 // remaining preamble
353 ptr += copy_to_mem(n_, ptr);
354 ptr += copy_to_mem(cumulative_wt_, ptr);
355 ptr += copy_to_mem(wt_max_, ptr);
356 ptr += copy_to_mem(rho_, ptr);
357 ptr += sample_.serialize(ptr, end_ptr, sd);
358 }
359
360 return bytes;
361}
362
363template<typename T, typename A>
364template<typename SerDe>
365void ebpps_sketch<T,A>::serialize(std::ostream& os, const SerDe& sd) const {
366 const uint8_t prelongs = (is_empty() ? PREAMBLE_LONGS_EMPTY : PREAMBLE_LONGS_FULL);
367
368 uint8_t flags = 0;
369 if (is_empty()) {
370 flags |= EMPTY_FLAG_MASK;
371 } else {
372 flags |= sample_.has_partial_item() ? HAS_PARTIAL_ITEM_MASK : 0;
373 }
374
375 // first prelong
376 const uint8_t ser_ver = SER_VER;
377 const uint8_t family = FAMILY_ID;
378 write(os, prelongs);
379 write(os, ser_ver);
380 write(os, family);
381 write(os, flags);
382 write(os, k_);
383
384 if (!is_empty()) {
385 // remaining preamble
386 write(os, n_);
387 write(os, cumulative_wt_);
388 write(os, wt_max_);
389 write(os, rho_);
390 sample_.serialize(os, sd);
391 }
392
393 if (!os.good()) throw std::runtime_error("error writing to std::ostream");
394}
395
396template<typename T, typename A>
397template<typename SerDe>
398ebpps_sketch<T,A> ebpps_sketch<T,A>::deserialize(const void* bytes, size_t size, const SerDe& sd, const A& allocator) {
399 ensure_minimum_memory(size, 8);
400 const uint8_t* ptr = static_cast<const uint8_t*>(bytes);
401 const uint8_t* end_ptr = ptr + size;
402 uint8_t prelongs;
403 ptr += copy_from_mem(ptr, prelongs);
404 uint8_t serial_version;
405 ptr += copy_from_mem(ptr, serial_version);
406 uint8_t family_id;
407 ptr += copy_from_mem(ptr, family_id);
408 uint8_t flags;
409 ptr += copy_from_mem(ptr, flags);
410 uint32_t k;
411 ptr += copy_from_mem(ptr, k);
412
413 check_k(k);
414 check_preamble_longs(prelongs, flags);
415 check_family_and_serialization_version(family_id, serial_version);
416 ensure_minimum_memory(size, prelongs << 3);
417
418 const bool empty = flags & EMPTY_FLAG_MASK;
419 if (empty)
420 return ebpps_sketch(k, allocator);
421
422 uint64_t n;
423 ptr += copy_from_mem(ptr, n);
424 double cumulative_wt;
425 ptr += copy_from_mem(ptr, cumulative_wt);
426 double wt_max;
427 ptr += copy_from_mem(ptr, wt_max);
428 double rho;
429 ptr += copy_from_mem(ptr, rho);
430
431 auto pair = ebpps_sample<T, A>::deserialize(ptr, end_ptr - ptr, sd, allocator);
432 ebpps_sample<T, A> sample = pair.first;
433 ptr += pair.second;
434
435 if (sample.has_partial_item() != bool(flags & HAS_PARTIAL_ITEM_MASK))
436 throw std::runtime_error("sketch fails internal consistency check");
437
438 return ebpps_sketch(k, n, cumulative_wt, wt_max, rho, std::move(sample), allocator);
439}
440
441template<typename T, typename A>
442template<typename SerDe>
443ebpps_sketch<T,A> ebpps_sketch<T,A>::deserialize(std::istream& is, const SerDe& sd, const A& allocator) {
444 const uint8_t prelongs = read<uint8_t>(is);
445 const uint8_t ser_ver = read<uint8_t>(is);
446 const uint8_t family = read<uint8_t>(is);
447 const uint8_t flags = read<uint8_t>(is);
448 const uint32_t k = read<uint32_t>(is);
449
450 check_k(k);
451 check_family_and_serialization_version(family, ser_ver);
452 check_preamble_longs(prelongs, flags);
453
454 const bool empty = (flags & EMPTY_FLAG_MASK);
455
456 if (empty)
457 return ebpps_sketch(k, allocator);
458
459 const uint64_t n = read<uint64_t>(is);
460 const double cumulative_wt = read<double>(is);
461 const double wt_max = read<double>(is);
462 const double rho = read<double>(is);
463
464 auto sample = ebpps_sample<T,A>::deserialize(is, sd, allocator);
465
466 if (sample.has_partial_item() != bool(flags & HAS_PARTIAL_ITEM_MASK))
467 throw std::runtime_error("sketch fails internal consistency check");
468
469 return ebpps_sketch(k, n, cumulative_wt, wt_max, rho, std::move(sample), allocator);
470}
471
472template <typename T, typename A>
473inline uint32_t ebpps_sketch<T, A>::check_k(uint32_t k)
474{
475 if (k == 0 || k > MAX_K)
476 throw std::invalid_argument("k must be strictly positive and less than " + std::to_string(MAX_K));
477 return k;
478}
479
480template<typename T, typename A>
481void ebpps_sketch<T, A>::check_family_and_serialization_version(uint8_t family_id, uint8_t ser_ver) {
482 if (family_id == FAMILY_ID) {
483 if (ser_ver != SER_VER) {
484 throw std::invalid_argument("Possible corruption: EBPPS serialization version must be "
485 + std::to_string(SER_VER) + ". Found: " + std::to_string(ser_ver));
486 }
487 return;
488 }
489
490 throw std::invalid_argument("Possible corruption: EBPPS Sketch family id must be "
491 + std::to_string(FAMILY_ID) + ". Found: " + std::to_string(family_id));
492}
493
494template <typename T, typename A>
495void ebpps_sketch<T, A>::check_preamble_longs(uint8_t preamble_longs, uint8_t flags)
496{
497 const bool is_empty(flags & EMPTY_FLAG_MASK);
498
499 if (is_empty) {
500 if (preamble_longs != PREAMBLE_LONGS_EMPTY) {
501 throw std::invalid_argument("Possible corruption: Preamble longs must be "
502 + std::to_string(PREAMBLE_LONGS_EMPTY) + " for an empty sketch. Found: "
503 + std::to_string(preamble_longs));
504 }
505 if (flags & HAS_PARTIAL_ITEM_MASK) {
506 throw std::invalid_argument("Possible corruption: Empty sketch must not "
507 "contain indications of the presence of any item");
508 }
509 } else {
510 if (preamble_longs != PREAMBLE_LONGS_FULL) {
511 throw std::invalid_argument("Possible corruption: Preamble longs must be "
512 + std::to_string(PREAMBLE_LONGS_FULL)
513 + " for a non-empty sketch. Found: " + std::to_string(preamble_longs));
514 }
515 }
516}
517
518template<typename T, typename A>
519typename ebpps_sample<T, A>::const_iterator ebpps_sketch<T, A>::begin() const {
520 return sample_.begin();
521}
522
523template<typename T, typename A>
524typename ebpps_sample<T, A>::const_iterator ebpps_sketch<T, A>::end() const {
525 return sample_.end();
526}
527
528} // namespace datasketches
529
530#endif // _EBPPS_SKETCH_IMPL_HPP_
An implementation of an Exact and Bounded Sampling Proportional to Size sketch.
Definition ebpps_sketch.hpp:59
string< A > items_to_string() const
Prints the raw sketch items to a string.
Definition ebpps_sketch_impl.hpp:112
void update(const T &item, double weight=1.0)
Updates this sketch with the given data item with the given weight.
Definition ebpps_sketch_impl.hpp:127
ebpps_sketch(uint32_t k, const A &allocator=A())
Constructor.
Definition ebpps_sketch_impl.hpp:36
vector_bytes serialize(unsigned header_size_bytes=0, const SerDe &sd=SerDe()) const
This method serializes the sketch as a vector of bytes.
ebpps_sample< T, A >::const_iterator end() const
Iterator pointing to the past-the-end item in the sketch.
Definition ebpps_sketch_impl.hpp:524
bool is_empty() const
Returns true if the sketch is empty.
Definition ebpps_sketch_impl.hpp:82
ebpps_sample< T, A >::const_iterator begin() const
Iterator pointing to the first item in the sketch.
Definition ebpps_sketch_impl.hpp:519
size_t get_serialized_size_bytes(const SerDe &sd=SerDe()) const
Computes size needed to serialize the current state of the sketch.
Definition ebpps_sketch_impl.hpp:320
double get_cumulative_weight() const
Returns the cumulative weight of items processed by the sketch.
Definition ebpps_sketch_impl.hpp:77
A get_allocator() const
Returns an instance of the allocator for this sketch.
Definition ebpps_sketch_impl.hpp:122
void merge(const ebpps_sketch< T, A > &sketch)
Merges the provided sketch into the current one.
Definition ebpps_sketch_impl.hpp:199
static ebpps_sketch deserialize(const void *bytes, size_t size, const SerDe &sd=SerDe(), const A &allocator=A())
This method deserializes a sketch from a given array of bytes.
void reset()
Resets the sketch to its default, empty state.
Definition ebpps_sketch_impl.hpp:87
double get_c() const
Returns the expected number of samples returned upon a call to get_result() or the creation of an ite...
Definition ebpps_sketch_impl.hpp:72
string< A > to_string() const
Prints a summary of the sketch.
Definition ebpps_sketch_impl.hpp:96
result_type get_result() const
Returns a copy of the current sample, as a std::vector.
Definition ebpps_sketch_impl.hpp:163
uint32_t get_k() const
Returns the configured maximum sample size.
Definition ebpps_sketch_impl.hpp:62
uint64_t get_n() const
Returns the number of items processed by the sketch, regardless of item weight.
Definition ebpps_sketch_impl.hpp:67
DataSketches namespace.
Definition binomial_bounds.hpp:38