btllib
Loading...
Searching...
No Matches
mi_bloom_filter-inl.hpp
1#ifndef BTLLIB_MI_BLOOM_FILTER_INL_HPP
2#define BTLLIB_MI_BLOOM_FILTER_INL_HPP
3
4#include "btllib/mi_bloom_filter.hpp"
5#include "btllib/nthash.hpp"
6#include "btllib/status.hpp"
7
8#include "cpptoml.h"
9
10#include <climits>
11#include <cstdlib>
12
13#include <sdsl/bit_vector_il.hpp>
14#include <sdsl/rank_support.hpp>
15
16namespace btllib {
17bool
18MIBloomFilterInitializer::check_file_signature(
19 std::ifstream& ifs,
20 const std::string& expected_signature,
21 std::string& file_signature)
22{
23 std::getline(ifs, file_signature);
24 return file_signature == expected_signature;
25}
26
27std::shared_ptr<cpptoml::table>
28MIBloomFilterInitializer::parse_header(const std::string& expected_signature)
29{
31 btllib::check_error(ifs_id_arr.fail(),
32 "MIBloomFilterInitializer: failed to open " + path);
33
34 std::string file_signature;
35 if (!check_file_signature(ifs_id_arr, expected_signature, file_signature)) {
36 log_error(std::string("File signature does not match (possibly version "
37 "mismatch) for file:\n") +
38 path + '\n' + "Expected signature:\t" + expected_signature +
39 '\n' + "File signature: \t" + file_signature);
40 std::exit(EXIT_FAILURE); // NOLINT(concurrency-mt-unsafe)
41 }
42
43 /* Read bloom filter line by line until it sees "[HeaderEnd]"
44 * which is used to mark the end of the header section and
45 * assigns the header to a char array*/
46 std::string toml_buffer(file_signature + '\n');
47 std::string line;
48 bool header_end_found = false;
49 while (bool(std::getline(ifs_id_arr, line))) {
50 toml_buffer.append(line + '\n');
51 if (line == "[HeaderEnd]") {
52 header_end_found = true;
53 break;
54 }
55 }
56 if (!header_end_found) {
57 log_error("Pre-built multi-index Bloom filter does not have the correct "
58 "header end.");
59 std::exit(EXIT_FAILURE); // NOLINT(concurrency-mt-unsafe)
60 }
61 for (unsigned i = 0; i < PLACEHOLDER_NEWLINES_MIBF; i++) {
62 std::getline(ifs_id_arr, line);
63 }
64
65 // Send the char array to a stringstream for the cpptoml parser to parse
66 std::istringstream toml_stream(toml_buffer);
67 cpptoml::parser toml_parser(toml_stream);
68 const auto header_config = toml_parser.parse();
69
70 // Obtain header values from toml parser and assign them to class members
71 const auto header_string =
72 file_signature.substr(1, file_signature.size() - 2); // Remove [ ]
73 return header_config->get_table(header_string);
74}
75
76template<typename T>
77MIBloomFilter<T>::MIBloomFilter(const std::string& path)
78 : MIBloomFilter<T>::MIBloomFilter(
79 std::make_shared<MIBloomFilterInitializer>(path,
80 MI_BLOOM_FILTER_SIGNATURE))
81{
82}
83
84template<typename T>
85inline MIBloomFilter<T>::MIBloomFilter(
86 const std::shared_ptr<MIBloomFilterInitializer>& mibfi)
87 : id_array_size(
88 *(mibfi->table->get_as<decltype(id_array_size)>("id_array_size")))
89 , kmer_size(*(mibfi->table->get_as<decltype(kmer_size)>("kmer_size")))
90 , hash_num(*(mibfi->table->get_as<decltype(hash_num)>("hash_num")))
91 , hash_fn(mibfi->table->contains("hash_fn")
92 ? *(mibfi->table->get_as<decltype(hash_fn)>("hash_fn"))
93 : "")
94
95 , id_array(new std::atomic<T>[id_array_size])
96 , bv_insertion_completed(
97 static_cast<bool>(*(mibfi->table->get_as<int>("bv_insertion_completed"))))
98 , id_insertion_completed(
99 static_cast<bool>(*(mibfi->table->get_as<int>("id_insertion_completed"))))
100{
101 // read id array
102 mibfi->ifs_id_arr.read((char*)id_array.get(),
103 std::streamsize(id_array_size * sizeof(T)));
104 // read bv and bv rank support
105 sdsl::load_from_file(il_bit_vector, mibfi->path + ".sdsl");
106 bv_rank_support = sdsl::rank_support_il<1>(&il_bit_vector);
107
108 // init counts array
109 counts_array = std::unique_ptr<std::atomic<uint16_t>[]>(
110 new std::atomic<uint16_t>[id_array_size]);
111 std::memset(
112 (void*)counts_array.get(), 0, id_array_size * sizeof(counts_array[0]));
113
114 log_info(
115 "MIBloomFilter: Bit vector size: " + std::to_string(il_bit_vector.size()) +
116 "\nPopcount: " + std::to_string(get_pop_cnt()));
117}
118
119template<typename T>
120inline MIBloomFilter<T>::MIBloomFilter(size_t bv_size,
121 unsigned hash_num,
122 std::string hash_fn)
123 : bv_size(bv_size)
124 , hash_num(hash_num)
125 , hash_fn(std::move(hash_fn))
126{
127 bit_vector = sdsl::bit_vector(bv_size);
128}
129
130template<typename T>
131inline MIBloomFilter<T>::MIBloomFilter(sdsl::bit_vector& bit_vector,
132 unsigned hash_num,
133 std::string hash_fn)
134 : bit_vector(bit_vector)
135 , hash_num(hash_num)
136 , hash_fn(std::move(hash_fn))
137{
138 complete_bv_insertion();
139}
140
141template<typename T>
142inline void
143MIBloomFilter<T>::insert_bv(const uint64_t* hashes)
144{
145 assert(!bv_insertion_completed);
146 // check array size = hash_num
147 for (unsigned i = 0; i < hash_num; ++i) {
148 uint64_t pos = hashes[i] % bit_vector.size();
149 uint64_t* data_index = bit_vector.data() + (pos >> 6); // NOLINT
150 uint64_t bit_mask_value = (uint64_t)1 << (pos & 0x3F); // NOLINT
151 (void)(__sync_fetch_and_or(data_index, bit_mask_value) >>
152 (pos & 0x3F) & // NOLINT
153 1); // NOLINT
154 }
155}
156template<typename T>
157inline bool
158MIBloomFilter<T>::bv_contains(const uint64_t* hashes)
159{
160 assert(bv_insertion_completed);
161 for (unsigned i = 0; i < hash_num; i++) {
162 uint64_t pos = hashes[i] % il_bit_vector.size();
163 if (il_bit_vector[pos] == 0) {
164 return false;
165 }
166 }
167 return true;
168}
169template<typename T>
170inline void
171MIBloomFilter<T>::complete_bv_insertion()
172{
173 assert(!id_insertion_completed);
174 bv_insertion_completed = true;
175
176 il_bit_vector = sdsl::bit_vector_il<BLOCKSIZE>(bit_vector);
177 bv_rank_support = sdsl::rank_support_il<1>(&il_bit_vector);
178 id_array_size = get_pop_cnt();
179 id_array =
180 std::unique_ptr<std::atomic<T>[]>(new std::atomic<T>[id_array_size]);
181 std::memset((void*)id_array.get(), 0, id_array_size * sizeof(std::atomic<T>));
182 counts_array = std::unique_ptr<std::atomic<uint16_t>[]>(
183 new std::atomic<uint16_t>[id_array_size]);
184 std::memset(
185 (void*)counts_array.get(), 0, id_array_size * sizeof(counts_array[0]));
186}
187template<typename T>
188inline void
189MIBloomFilter<T>::insert_id(const uint64_t* hashes, const T& id)
190{
191 assert(bv_insertion_completed && !id_insertion_completed);
192
193 uint32_t rand = std::rand(); // NOLINT
194 for (unsigned i = 0; i < hash_num; ++i) {
195 uint64_t rank = get_rank_pos(hashes[i]);
196 uint16_t count = ++counts_array[rank];
197 T random_num = (rand ^ hashes[i]) % count;
198 if (random_num == count - 1) {
199 set_data(rank, id);
200 }
201 }
202}
203template<typename T>
204inline std::vector<T>
205MIBloomFilter<T>::get_id(const uint64_t* hashes)
206{
207 return get_data(get_rank_pos(hashes));
208}
209template<typename T>
210inline void
211MIBloomFilter<T>::insert_saturation(const uint64_t* hashes, const T& id)
212{
213 assert(id_insertion_completed);
214 std::vector<uint64_t> rank_pos = get_rank_pos(hashes);
215 std::vector<T> results = get_data(rank_pos);
216 std::vector<T> replacement_ids(hash_num);
217 bool value_found = false;
218 std::vector<T> seen_set(hash_num);
219
220 for (unsigned i = 0; i < hash_num; i++) {
221 T current_result = results[i] & (btllib::MIBloomFilter<T>::ANTI_MASK &
222 btllib::MIBloomFilter<T>::ANTI_STRAND);
223 // break if ID exists
224 if (current_result == id) {
225 value_found = true;
226 break;
227 }
228 // if haven't seen before add to seen set
229 if (find(seen_set.begin(), seen_set.end(), current_result) ==
230 seen_set.end()) {
231 seen_set.push_back(current_result);
232 }
233 // if have seen before add to replacement IDs
234 else {
235 replacement_ids.push_back(current_result);
236 }
237 }
238 // if value not found try to survive
239 if (!value_found) {
240 uint64_t replacement_pos = id_array_size;
241 T min_count = std::numeric_limits<T>::min();
242 for (unsigned i = 0; i < hash_num; i++) {
243 T current_result = results[i] & btllib::MIBloomFilter<T>::ANTI_MASK;
244 if (find(replacement_ids.begin(),
245 replacement_ids.end(),
246 current_result) != replacement_ids.end()) {
247 if (min_count < counts_array[rank_pos[i]]) {
248 min_count = counts_array[rank_pos[i]];
249 replacement_pos = rank_pos[i];
250 }
251 }
252 }
253 if (replacement_pos != id_array_size) {
254 set_data(replacement_pos, id);
255 ++counts_array[replacement_pos];
256 } else {
257 set_saturated(hashes);
258 }
259 }
260}
261template<typename T>
262inline void
263MIBloomFilter<T>::set_data(const uint64_t& pos, const T& id)
264{
265 T old_value;
266 do {
267 old_value = id_array[pos];
268 } while (!(id_array[pos].compare_exchange_strong(
269 old_value, old_value > MASK ? (id | MASK) : id)));
270}
271template<typename T>
272inline void
273MIBloomFilter<T>::set_saturated(const uint64_t* hashes)
274{
275 for (unsigned i = 0; i < hash_num; ++i) {
276 uint64_t pos = bv_rank_support(hashes[i] % il_bit_vector.size());
277 id_array[pos].fetch_or(MASK);
278 }
279}
280template<typename T>
281inline std::vector<uint64_t>
282MIBloomFilter<T>::get_rank_pos(const uint64_t* hashes) const
283{
284 std::vector<uint64_t> rank_pos(hash_num);
285 for (unsigned i = 0; i < hash_num; ++i) {
286 uint64_t pos = hashes[i] % il_bit_vector.size();
287 rank_pos[i] = bv_rank_support(pos);
288 }
289 return rank_pos;
290}
291template<typename T>
292inline std::vector<T>
293MIBloomFilter<T>::get_data(const std::vector<uint64_t>& rank_pos) const
294{
295 std::vector<T> results(hash_num);
296 for (unsigned i = 0; i < hash_num; ++i) {
297 results[i] = id_array[rank_pos[i]];
298 }
299 return results;
300}
301template<typename T>
302inline void
303MIBloomFilter<T>::save(const std::string& path,
304 const cpptoml::table& table,
305 const char* data,
306 size_t n)
307{
308 std::ofstream ofs(path.c_str(), std::ios::out | std::ios::binary);
309
310 ofs << table << "[HeaderEnd]\n";
311 for (unsigned i = 0; i < PLACEHOLDER_NEWLINES_MIBF; i++) {
312 if (i == 1) {
313 ofs << " <binary data>";
314 }
315 ofs << '\n';
316 }
317
318 ofs.write(data, std::streamsize(n));
319}
320
321template<typename T>
322inline void
323MIBloomFilter<T>::save(const std::string& path)
324{
325 /* Initialize cpptoml root table
326 * Note: Tables and fields are unordered
327 * Ordering of table is maintained by directing the table
328 * to the output stream immediately after completion */
329 auto root = cpptoml::make_table();
330
331 /* Initialize bloom filter section and insert fields
332 * and output to ostream */
333 auto header = cpptoml::make_table();
334 header->insert("id_array_size", id_array_size);
335 header->insert("hash_num", get_hash_num());
336 header->insert("kmer_size", get_k());
337 header->insert("bv_insertion_completed",
338 static_cast<int>(bv_insertion_completed));
339 header->insert("id_insertion_completed",
340 static_cast<int>(id_insertion_completed));
341
342 if (!hash_fn.empty()) {
343 header->insert("hash_fn", get_hash_fn());
344 }
345 std::string header_string = MI_BLOOM_FILTER_SIGNATURE;
346 header_string =
347 header_string.substr(1, header_string.size() - 2); // Remove [ ]
348 root->insert(header_string, header);
349 save(path, *root, (char*)id_array.get(), id_array_size * sizeof(id_array[0]));
350 sdsl::store_to_file(il_bit_vector, path + ".sdsl");
351}
352
353template<typename T>
354inline uint64_t
355MIBloomFilter<T>::get_pop_cnt()
356{
357 assert(bv_insertion_completed);
358 size_t index = il_bit_vector.size() - 1;
359 while (il_bit_vector[index] == 0) {
360 --index;
361 }
362 return bv_rank_support(index) + 1;
363}
364template<typename T>
365inline uint64_t
366MIBloomFilter<T>::get_pop_saturated_cnt()
367{
368 size_t count = 0;
369 for (size_t i = 0; i < id_array_size; ++i) {
370 if (id_array[i] >= MASK) {
371 ++count;
372 }
373 }
374 return count;
375}
376
377template<typename T>
378inline std::vector<size_t>
379MIBloomFilter<T>::get_id_occurence_count(const bool& include_saturated)
380{
381 // Ensure the bloom filter has been initialized
382 assert(bv_insertion_completed);
383
384 // Initialize a temporary vector to store counts
385 std::vector<std::atomic<size_t>> count_vec(MASK - 1);
386
387 // Iterate over the id_array in parallel, incrementing the counts in count_vec
388#pragma omp parallel for default(none) \
389 shared(id_array_size, include_saturated, id_array, count_vec)
390 for (size_t k = 0; k < id_array_size; k++) {
391 if (!include_saturated && id_array[k] > ANTI_MASK) {
392 continue;
393 }
394 count_vec[id_array[k] & ANTI_MASK].fetch_add(1);
395 }
396
397 // Convert the atomic count_vec to a non-atomic result vector,
398 // excluding trailing zeros
399 std::vector<size_t> result;
400 bool has_trailing_zeros = true;
401 for (size_t i = MASK - 2; i != SIZE_MAX; i--) {
402 if (count_vec[i] != 0) {
403 has_trailing_zeros = false;
404 }
405 if (!has_trailing_zeros) {
406 result.insert(result.begin(), count_vec[i].load());
407 }
408 }
409
410 return result;
411}
412
413template<typename T>
414inline size_t
415MIBloomFilter<T>::calc_optimal_size(size_t entries,
416 unsigned hash_num,
417 double occupancy)
418{
419 auto non_64_approx_val =
420 size_t(-double(entries) * double(hash_num) / log(1.0 - occupancy));
421 const int magic = 64;
422 return non_64_approx_val + (magic - non_64_approx_val % magic);
423}
424} // namespace btllib
425
426#endif
Definition aahash.hpp:12
void check_error(bool condition, const std::string &msg)
void log_info(const std::string &msg)
void check_file_accessibility(const std::string &filepath)
void log_error(const std::string &msg)