#include <iostream>
#include <vector>
#include <unordered_map>
#include <string>
#include <algorithm>
#include <cstdint>
#include <functional>
// A helper struct to hash a vector of uint8_t.
struct VectorHash {
std::size_t operator()(const std::vector<uint8_t>& v) const noexcept {
// A simple FNV-like hash for demonstration.
std::size_t hash = 2166136261u;
for (auto &byte : v) {
hash ^= static_cast<std::size_t>(byte);
hash *= 16777619u;
}
return hash;
}
};
// A helper struct to compare two vectors of uint8_t for equality keys in the map.
struct VectorEqual {
bool operator()(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b) const noexcept {
if (a.size() != b.size()) return false;
for (size_t i = 0; i < a.size(); i++){
if (a[i] != b[i]) return false;
}
return true;
}
};
using Count = uint64_t;
using Cache = std::unordered_map<std::vector<uint8_t>, Count, VectorHash, VectorEqual>;
/*
* countState:
* - target: the array of bools (pattern) we compare against
* - distance: the maximum distance allowed (k in the Rust code)
* - state: a modifiable "state" vector representing the current Levenshtein state
* - cache: memoization cache mapping from state -> count
*
* This function checks if 'state' is in cache. If yes, returns the cached value.
* Otherwise, it calculates how many valid states exist by trying symbol=false/true,
* updates 'state' accordingly, and accumulates the number of valid states.
*/
Count countState(const std::vector<bool>& target,
uint8_t distance,
std::vector<uint8_t>& state,
Cache& cache)
{
auto it = cache.find(state);
if (it != cache.end()) {
return it->second;
}
// 'count' starts at 1 if state[target.size()] == distance, else 0
// Because the Rust code checks: state[target.len()] == distance
// This is effectively a boolean -> Count
Count count = (state[target.size()] == distance) ? 1 : 0;
// We try symbol in [false, true]
for (bool symbol : {false, true}) {
// Save oldState so we can revert changes if needed
std::vector<uint8_t> oldState = state;
// Initialize z = min(state[0], distance) + 1
uint8_t z = (std::min(state[0], distance)) + 1;
state[0] = z;
// This loop replicates the Rust nested iteration:
// for (((o, &c), &x), &y) in new_state[1..].iter_mut().zip(target).zip(&state[..])
// ...
// The idea: we move through the arrays, update each new_state[i] using
// comparisons with target[i-1], symbol, and previous states.
for (size_t i = 1; i < state.size(); i++) {
bool c = (i - 1 < target.size()) ? target[i - 1] : false;
uint8_t x = (i - 1 < oldState.size()) ? oldState[i - 1] : 0;
uint8_t y = (i < oldState.size()) ? oldState[i] : 0;
uint8_t cost = static_cast<uint8_t>(symbol != c);
// z = (x + cost).min( y.min(z).min(distance) + 1 )
z = std::min({
static_cast<uint8_t>(x + cost),
static_cast<uint8_t>(y + 1),
static_cast<uint8_t>(z + 1),
static_cast<uint8_t>(distance + 1)
});
state[i] = z;
}
count += countState(target, distance, state, cache);
// Revert the state to oldState for the next loop iteration
state = oldState;
}
// Cache the result
cache[state] = count;
return count;
}
/*
* countFunction:
* - target: array of bools (pattern)
* - distance: maximum allowed distance
*
* Builds the initial state, prepares the cache, then calls countState.
*/
Count countFunction(const std::vector<bool>& target, uint8_t distance) {
// The initial 'state' is (0..=target.size()) mapped to i.min(distance + 1)http://coolaf.com/run/new/cpp#
// We add 1 because Rust used (0..=target.len()), so we have target.size() + 1 elements.
std::vector<uint8_t> state;
state.reserve(target.size() + 1);
for (size_t i = 0; i <= target.size(); i++) {
uint8_t val = static_cast<uint8_t>(std::min<uint8_t>(i, distance + 1));
state.push_back(val);
}
// Prepare a cache. Insert one special entry as in the Rust code:
Cache cache;
{
// In Rust, we had &arena.alloc_extend((0..=target.len()).map(|_| distance + 1))[..]
// as a key with value 0. We'll produce the same key here:
std::vector<uint8_t> initKey(target.size() + 1, distance + 1);
cache[initKey] = 0;
}
// Recursively compute
return countState(target, distance, state, cache);
}
int main(int argc, char* argv[]) {
// <s> is a string of '0' or '1' used to fill a bool array
// <k> is the allowed distance
std::string targetStr = "1001100110";
uint8_t distance = 7;
std::vector<bool> targetVec;
targetVec.reserve(targetStr.size());
for (char c : targetStr) {
// true if '1', false otherwise
targetVec.push_back(c == '1');
}
Count result = countFunction(targetVec, distance);
std::cout << result << std::endl;
return 0;
}