btllib
Loading...
Searching...
No Matches
counting_bloom_filter-inl.hpp
1#ifndef BTLLIB_COUNTING_BLOOM_FILTER_INL_HPP
2#define BTLLIB_COUNTING_BLOOM_FILTER_INL_HPP
3
4#include "btllib/bloom_filter.hpp"
5#include "btllib/counting_bloom_filter.hpp"
6#include "btllib/nthash.hpp"
7#include "btllib/status.hpp"
8
9#include "cpptoml.h"
10
11#include <atomic>
12#include <cmath>
13#include <cstdint>
14#include <fstream>
15#include <limits>
16#include <memory>
17#include <string>
18#include <vector>
19
20namespace btllib {
21
22using CountingBloomFilter8 = CountingBloomFilter<uint8_t>;
23using CountingBloomFilter16 = CountingBloomFilter<uint16_t>;
24using CountingBloomFilter32 = CountingBloomFilter<uint32_t>;
25
26using KmerCountingBloomFilter8 = KmerCountingBloomFilter<uint8_t>;
27using KmerCountingBloomFilter16 = KmerCountingBloomFilter<uint16_t>;
28using KmerCountingBloomFilter32 = KmerCountingBloomFilter<uint32_t>;
29
30template<typename T>
32 unsigned hash_num,
33 std::string hash_fn)
34 : bytes(
35 size_t(std::ceil(double(bytes) / sizeof(uint64_t)) * sizeof(uint64_t)))
36 , array_size(get_bytes() / sizeof(array[0]))
37 , hash_num(hash_num)
38 , hash_fn(std::move(hash_fn))
39 , array(new std::atomic<T>[array_size])
40{
41 check_error(bytes == 0, "CountingBloomFilter: memory budget must be >0!");
42 check_error(hash_num == 0,
43 "CountingBloomFilter: number of hash values must be >0!");
45 hash_num > MAX_HASH_VALUES,
46 "CountingBloomFilter: number of hash values cannot be over 1024!");
47 check_warning(sizeof(uint8_t) != sizeof(std::atomic<uint8_t>),
48 "Atomic primitives take extra memory. CountingBloomFilter will "
49 "have less than " +
50 std::to_string(bytes) + " for bit array.");
51 std::memset((void*)array.get(), 0, array_size * sizeof(array[0]));
52}
53
54/*
55 * Assumes min_count is not std::numeric_limits<T>::max()
56 */
57template<typename T>
58inline void
59CountingBloomFilter<T>::set(const uint64_t* hashes, T min_val, T new_val)
60{
61 // Update flag to track if increment is done on at least one counter
62 bool update_done = false;
63 T tmp_min_val;
64 while (true) {
65 for (size_t i = 0; i < hash_num; ++i) {
66 tmp_min_val = min_val;
67 update_done |= array[hashes[i] % array_size].compare_exchange_strong(
68 tmp_min_val, new_val);
69 }
70 if (update_done) {
71 break;
72 }
73 min_val = contains(hashes);
74 if (min_val == std::numeric_limits<T>::max()) {
75 break;
76 }
77 }
78}
79
80template<typename T>
81inline void
82CountingBloomFilter<T>::insert(const uint64_t* hashes, T n)
83{
84 contains_insert(hashes, n);
85}
86
87template<typename T>
88inline void
89CountingBloomFilter<T>::remove(const uint64_t* hashes)
90{
91 T min_val = contains(hashes);
92 set(hashes, min_val, min_val > 1 ? min_val - 1 : 0);
93}
94
95template<typename T>
96void
97CountingBloomFilter<T>::clear(const uint64_t* hashes)
98{
99 T min_val = contains(hashes);
100 set(hashes, min_val, 0);
101}
102
103template<typename T>
104inline T
105CountingBloomFilter<T>::contains(const uint64_t* hashes) const
106{
107 T min = array[hashes[0] % array_size];
108 for (size_t i = 1; i < hash_num; ++i) {
109 const size_t idx = hashes[i] % array_size;
110 if (array[idx] < min) {
111 min = array[idx];
112 }
113 }
114 return min;
115}
116
117template<typename T>
118inline T
119CountingBloomFilter<T>::contains_insert(const uint64_t* hashes, T n)
120{
121 const auto count = contains(hashes);
122 if (count <= std::numeric_limits<T>::max() - n) {
123 set(hashes, count, count + n);
124 }
125 return count;
126}
127
128template<typename T>
129inline T
130CountingBloomFilter<T>::insert_contains(const uint64_t* hashes, T n)
131{
132 const auto count = contains(hashes);
133 if (count <= std::numeric_limits<T>::max() + n) {
134 set(hashes, count, count + n);
135 return count + n;
136 }
137 return std::numeric_limits<T>::max();
138}
139
140template<typename T>
141inline T
143 const T threshold)
144{
145 const auto count = contains(hashes);
146 if (count < threshold) {
147 set(hashes, count, count + 1);
148 return count + 1;
149 }
150 return count;
151}
152
153template<typename T>
154inline T
156 const T threshold)
157{
158 const auto count = contains(hashes);
159 if (count < threshold) {
160 set(hashes, count, count + 1);
161 }
162 return count;
163}
164
165template<typename T>
166inline uint64_t
168{
169 uint64_t pop_cnt = 0;
170// OpenMP make up your mind man. Using default(none) here causes errors on
171// some compilers and not others.
172// NOLINTNEXTLINE(openmp-use-default-none,-warnings-as-errors)
173#pragma omp parallel for reduction(+ : pop_cnt)
174 for (size_t i = 0; i < array_size; ++i) {
175 if (array[i] >= threshold) {
176 ++pop_cnt;
177 }
178 }
179 return pop_cnt;
180}
181
182template<typename T>
183inline double
185{
186 return double(get_pop_cnt(threshold)) / double(array_size);
187}
188
189template<typename T>
190inline double
191CountingBloomFilter<T>::get_fpr(const T threshold) const
192{
193 return std::pow(get_occupancy(threshold), double(hash_num));
194}
195
196template<typename T>
197inline CountingBloomFilter<T>::CountingBloomFilter(const std::string& path)
199 std::make_shared<BloomFilterInitializer>(path,
200 COUNTING_BLOOM_FILTER_SIGNATURE))
201{
202}
203
204template<typename T>
206 const std::shared_ptr<BloomFilterInitializer>& bfi)
207 : bytes(*bfi->table->get_as<decltype(bytes)>("bytes"))
208 , array_size(bytes / sizeof(array[0]))
209 , hash_num(*(bfi->table->get_as<decltype(hash_num)>("hash_num")))
210 , hash_fn(bfi->table->contains("hash_fn")
211 ? *(bfi->table->get_as<decltype(hash_fn)>("hash_fn"))
212 : "")
213 , array(new std::atomic<T>[array_size])
214{
215 check_warning(sizeof(uint8_t) != sizeof(std::atomic<uint8_t>),
216 "Atomic primitives take extra memory. CountingBloomFilter will "
217 "have less than " +
218 std::to_string(bytes) + " for bit array.");
219 const auto loaded_counter_bits =
220 *(bfi->table->get_as<size_t>("counter_bits"));
221 check_error(sizeof(array[0]) * CHAR_BIT != loaded_counter_bits,
222 "CountingBloomFilter" +
223 std::to_string(sizeof(array[0]) * CHAR_BIT) +
224 " tried to load a file of CountingBloomFilter" +
225 std::to_string(loaded_counter_bits));
226 bfi->ifs.read((char*)array.get(),
227 std::streamsize(array_size * sizeof(array[0])));
228}
229
230template<typename T>
231inline void
232CountingBloomFilter<T>::save(const std::string& path)
233{
234 /* Initialize cpptoml root table
235 Note: Tables and fields are unordered
236 Ordering of table is maintained by directing the table
237 to the output stream immediately after completion */
238 auto root = cpptoml::make_table();
239
240 /* Initialize bloom filter section and insert fields
241 and output to ostream */
242 auto header = cpptoml::make_table();
243 header->insert("bytes", get_bytes());
244 header->insert("hash_num", get_hash_num());
245 if (!hash_fn.empty()) {
246 header->insert("hash_fn", hash_fn);
247 }
248 header->insert("counter_bits", size_t(sizeof(array[0]) * CHAR_BIT));
249 std::string header_string = COUNTING_BLOOM_FILTER_SIGNATURE;
250 header_string =
251 header_string.substr(1, header_string.size() - 2); // Remove [ ]
252 root->insert(header_string, header);
253
255 path, *root, (char*)array.get(), array_size * sizeof(array[0]));
256}
257
258template<typename T>
260 unsigned hash_num,
261 unsigned k)
262 : k(k)
263 , counting_bloom_filter(bytes, hash_num, HASH_FN)
264{
265}
266
267template<typename T>
268inline void
269KmerCountingBloomFilter<T>::insert(const char* seq, size_t seq_len)
270{
271 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
272 while (nthash.roll()) {
273 counting_bloom_filter.insert(nthash.hashes());
274 }
275}
276
277template<typename T>
278inline void
279KmerCountingBloomFilter<T>::remove(const char* seq, size_t seq_len)
280{
281 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
282 while (nthash.roll()) {
283 counting_bloom_filter.remove(nthash.hashes());
284 }
285}
286
287template<typename T>
288inline void
289KmerCountingBloomFilter<T>::clear(const char* seq, size_t seq_len)
290{
291 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
292 while (nthash.roll()) {
293 counting_bloom_filter.clear(nthash.hashes());
294 }
295}
296
297template<typename T>
298inline uint64_t
299KmerCountingBloomFilter<T>::contains(const char* seq, size_t seq_len) const
300{
301 uint64_t sum = 0;
302 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
303 while (nthash.roll()) {
304 sum += counting_bloom_filter.contains(nthash.hashes());
305 }
306 return sum;
307}
308
309template<typename T>
310inline T
311KmerCountingBloomFilter<T>::contains_insert(const char* seq, size_t seq_len)
312{
313 uint64_t sum = 0;
314 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
315 while (nthash.roll()) {
316 sum += counting_bloom_filter.contains_insert(nthash.hashes());
317 }
318 return sum;
319}
320
321template<typename T>
322inline T
323KmerCountingBloomFilter<T>::insert_contains(const char* seq, size_t seq_len)
324{
325 uint64_t sum = 0;
326 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
327 while (nthash.roll()) {
328 sum += counting_bloom_filter.insert_contains(nthash.hashes());
329 }
330 return sum;
331}
332
333template<typename T>
334inline T
336 size_t seq_len,
337 const T threshold)
338{
339 uint64_t sum = 0;
340 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
341 while (nthash.roll()) {
342 sum +=
343 counting_bloom_filter.insert_thresh_contains(nthash.hashes(), threshold);
344 }
345 return sum;
346}
347
348template<typename T>
349inline T
351 size_t seq_len,
352 const T threshold)
353{
354 uint64_t sum = 0;
355 NtHash nthash(seq, seq_len, get_hash_num(), get_k());
356 while (nthash.roll()) {
357 sum +=
358 counting_bloom_filter.contains_insert_thresh(nthash.hashes(), threshold);
359 }
360 return sum;
361}
362
363template<typename T>
365 const std::string& path)
367 std::make_shared<BloomFilterInitializer>(
368 path,
369 KMER_COUNTING_BLOOM_FILTER_SIGNATURE))
370{
371}
372
373template<typename T>
375 const std::shared_ptr<BloomFilterInitializer>& bfi)
376 : k(*(bfi->table->get_as<decltype(k)>("k")))
377 , counting_bloom_filter(bfi)
378{
379 check_error(counting_bloom_filter.hash_fn != HASH_FN,
380 "KmerCountingBloomFilter: loaded hash function (" +
381 counting_bloom_filter.hash_fn +
382 ") is different from the one used by default (" + HASH_FN +
383 ").");
384}
385
386template<typename T>
387inline void
388KmerCountingBloomFilter<T>::save(const std::string& path)
389{
390 /* Initialize cpptoml root table
391 Note: Tables and fields are unordered
392 Ordering of table is maintained by directing the table
393 to the output stream immediately after completion */
394 auto root = cpptoml::make_table();
395
396 /* Initialize bloom filter section and insert fields
397 and output to ostream */
398 auto header = cpptoml::make_table();
399 header->insert("bytes", get_bytes());
400 header->insert("hash_num", get_hash_num());
401 header->insert("hash_fn", get_hash_fn());
402 header->insert("counter_bits",
403 size_t(sizeof(counting_bloom_filter.array[0]) * CHAR_BIT));
404 header->insert("k", k);
405 std::string header_string = KMER_COUNTING_BLOOM_FILTER_SIGNATURE;
406 header_string =
407 header_string.substr(1, header_string.size() - 2); // Remove [ ]
408 root->insert(header_string, header);
409
411 *root,
412 (char*)counting_bloom_filter.array.get(),
413 counting_bloom_filter.array_size *
414 sizeof(counting_bloom_filter.array[0]));
415}
416} // namespace btllib
417
418#endif
void save(const std::string &path)
Definition counting_bloom_filter.hpp:43
void remove(const uint64_t *hashes)
Definition counting_bloom_filter-inl.hpp:89
double get_occupancy(T threshold=1) const
Definition counting_bloom_filter-inl.hpp:184
void insert(const uint64_t *hashes, T n=1)
Definition counting_bloom_filter-inl.hpp:82
double get_fpr(T threshold=1) const
Definition counting_bloom_filter-inl.hpp:191
uint64_t get_pop_cnt(T threshold=1) const
Definition counting_bloom_filter-inl.hpp:167
void save(const std::string &path)
Definition counting_bloom_filter-inl.hpp:232
T contains_insert_thresh(const uint64_t *hashes, T threshold)
Definition counting_bloom_filter-inl.hpp:155
T contains(const uint64_t *hashes) const
Definition counting_bloom_filter-inl.hpp:105
CountingBloomFilter()
Definition counting_bloom_filter.hpp:47
void clear(const uint64_t *hashes)
Definition counting_bloom_filter-inl.hpp:97
T insert_contains(const uint64_t *hashes, T n=1)
Definition counting_bloom_filter-inl.hpp:130
T contains_insert(const uint64_t *hashes, T n=1)
Definition counting_bloom_filter-inl.hpp:119
T insert_thresh_contains(const uint64_t *hashes, T threshold)
Definition counting_bloom_filter-inl.hpp:142
Definition counting_bloom_filter.hpp:310
uint64_t contains(const char *seq, size_t seq_len) const
Definition counting_bloom_filter-inl.hpp:299
void clear(const char *seq, size_t seq_len)
Definition counting_bloom_filter-inl.hpp:289
void save(const std::string &path)
Definition counting_bloom_filter-inl.hpp:388
T insert_contains(const char *seq, size_t seq_len)
Definition counting_bloom_filter-inl.hpp:323
T insert_thresh_contains(const char *seq, size_t seq_len, T threshold)
Definition counting_bloom_filter-inl.hpp:335
void remove(const char *seq, size_t seq_len)
Definition counting_bloom_filter-inl.hpp:279
T contains_insert_thresh(const char *seq, size_t seq_len, T threshold)
Definition counting_bloom_filter-inl.hpp:350
void insert(const char *seq, size_t seq_len)
Definition counting_bloom_filter-inl.hpp:269
T contains_insert(const char *seq, size_t seq_len)
Definition counting_bloom_filter-inl.hpp:311
KmerCountingBloomFilter()
Definition counting_bloom_filter.hpp:314
Definition nthash_kmer.hpp:237
bool roll()
Definition nthash_kmer.hpp:315
const uint64_t * hashes() const
Definition nthash_kmer.hpp:443
Definition aahash.hpp:12
void check_error(bool condition, const std::string &msg)
void check_warning(bool condition, const std::string &msg)