datasketches-cpp
theta_intersection_base_impl.hpp
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include <iostream>
21 #include <sstream>
22 #include <algorithm>
23 #include <stdexcept>
24 
25 #include "conditional_forward.hpp"
26 
27 namespace datasketches {
28 
29 template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
30 theta_intersection_base<EN, EK, P, S, CS, A>::theta_intersection_base(uint64_t seed, const P& policy, const A& allocator):
31 policy_(policy),
32 is_valid_(false),
33 table_(0, 0, resize_factor::X1, 1, theta_constants::MAX_THETA, seed, allocator, false)
34 {}
35 
36 template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
37 template<typename SS>
38 void theta_intersection_base<EN, EK, P, S, CS, A>::update(SS&& sketch) {
39  if (table_.is_empty_) return;
40  if (!sketch.is_empty() && sketch.get_seed_hash() != compute_seed_hash(table_.seed_)) throw std::invalid_argument("seed hash mismatch");
41  table_.is_empty_ |= sketch.is_empty();
42  table_.theta_ = table_.is_empty_ ? theta_constants::MAX_THETA : std::min(table_.theta_, sketch.get_theta64());
43  if (is_valid_ && table_.num_entries_ == 0) return;
44  if (sketch.get_num_retained() == 0) {
45  is_valid_ = true;
46  table_ = hash_table(0, 0, resize_factor::X1, 1, table_.theta_, table_.seed_, table_.allocator_, table_.is_empty_);
47  return;
48  }
49  if (!is_valid_) { // first update, copy or move incoming sketch
50  is_valid_ = true;
51  const uint8_t lg_size = lg_size_from_count(sketch.get_num_retained(), theta_update_sketch_base<EN, EK, A>::REBUILD_THRESHOLD);
52  table_ = hash_table(lg_size, lg_size - 1, resize_factor::X1, 1, table_.theta_, table_.seed_, table_.allocator_, table_.is_empty_);
53  for (auto&& entry: sketch) {
54  auto result = table_.find(EK()(entry));
55  if (result.second) {
56  throw std::invalid_argument("duplicate key, possibly corrupted input sketch");
57  }
58  table_.insert(result.first, conditional_forward<SS>(entry));
59  }
60  if (table_.num_entries_ != sketch.get_num_retained()) throw std::invalid_argument("num entries mismatch, possibly corrupted input sketch");
61  } else { // intersection
62  const uint32_t max_matches = std::min(table_.num_entries_, sketch.get_num_retained());
63  std::vector<EN, A> matched_entries(table_.allocator_);
64  matched_entries.reserve(max_matches);
65  uint32_t match_count = 0;
66  uint32_t count = 0;
67  for (auto&& entry: sketch) {
68  if (EK()(entry) < table_.theta_) {
69  auto result = table_.find(EK()(entry));
70  if (result.second) {
71  if (match_count == max_matches) throw std::invalid_argument("max matches exceeded, possibly corrupted input sketch");
72  policy_(*result.first, conditional_forward<SS>(entry));
73  matched_entries.push_back(std::move(*result.first));
74  ++match_count;
75  }
76  } else if (sketch.is_ordered()) {
77  break; // early stop
78  }
79  ++count;
80  }
81  if (count > sketch.get_num_retained()) {
82  throw std::invalid_argument(" more keys than expected, possibly corrupted input sketch");
83  } else if (!sketch.is_ordered() && count < sketch.get_num_retained()) {
84  throw std::invalid_argument(" fewer keys than expected, possibly corrupted input sketch");
85  }
86  if (match_count == 0) {
87  table_ = hash_table(0, 0, resize_factor::X1, 1, table_.theta_, table_.seed_, table_.allocator_, table_.is_empty_);
88  if (table_.theta_ == theta_constants::MAX_THETA) table_.is_empty_ = true;
89  } else {
90  const uint8_t lg_size = lg_size_from_count(match_count, theta_update_sketch_base<EN, EK, A>::REBUILD_THRESHOLD);
91  table_ = hash_table(lg_size, lg_size - 1, resize_factor::X1, 1, table_.theta_, table_.seed_, table_.allocator_, table_.is_empty_);
92  for (uint32_t i = 0; i < match_count; ++i) {
93  auto result = table_.find(EK()(matched_entries[i]));
94  table_.insert(result.first, std::move(matched_entries[i]));
95  }
96  }
97  }
98 }
99 
100 template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
101 CS theta_intersection_base<EN, EK, P, S, CS, A>::get_result(bool ordered) const {
102  if (!is_valid_) throw std::invalid_argument("calling get_result() before calling update() is undefined");
103  std::vector<EN, A> entries(table_.allocator_);
104  if (table_.num_entries_ > 0) {
105  entries.reserve(table_.num_entries_);
106  std::copy_if(table_.begin(), table_.end(), std::back_inserter(entries), key_not_zero<EN, EK>());
107  if (ordered) std::sort(entries.begin(), entries.end(), comparator());
108  }
109  return CS(table_.is_empty_, ordered, compute_seed_hash(table_.seed_), table_.theta_, std::move(entries));
110 }
111 
112 template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
113 bool theta_intersection_base<EN, EK, P, S, CS, A>::has_result() const {
114  return is_valid_;
115 }
116 
117 template<typename EN, typename EK, typename P, typename S, typename CS, typename A>
118 const P& theta_intersection_base<EN, EK, P, S, CS, A>::get_policy() const {
119  return policy_;
120 }
121 
122 } /* namespace datasketches */
const uint64_t MAX_THETA
max theta - signed max for compatibility with Java
Definition: theta_constants.hpp:36
DataSketches namespace.
Definition: binomial_bounds.hpp:38