139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
"""
|
|
Verify that the installed filter_short_groups module works correctly.
|
|
"""
|
|
import time
|
|
import numpy as np
|
|
import random
|
|
|
|
# Try to import the module
|
|
try:
|
|
from filter_short_groups import filter_short_groups_c
|
|
print("Successfully imported filter_short_groups_c")
|
|
except ImportError as e:
|
|
print(f"Error importing module: {e}")
|
|
exit(1)
|
|
|
|
# Define the NumPy reference implementation for comparison
|
|
def filter_short_groups_numpy(presence_list, filter_size, device_id, dates_str):
|
|
"""
|
|
NumPy implementation, kept for comparison purposes.
|
|
"""
|
|
st = time.time()
|
|
|
|
if not presence_list or filter_size <= 1:
|
|
print(f"NumPy: Early exit/no processing time: {time.time() - st:.6f}s")
|
|
return presence_list[:] if isinstance(presence_list, list) else list(presence_list)
|
|
|
|
result = np.array(presence_list, dtype=float)
|
|
n = len(result)
|
|
previous_states = set()
|
|
|
|
while True:
|
|
current_state_tuple = tuple(result)
|
|
if current_state_tuple in previous_states:
|
|
print("NumPy: Cycle detected, breaking.")
|
|
break
|
|
previous_states.add(current_state_tuple)
|
|
|
|
signs = np.sign(result)
|
|
change_indices = np.where(np.diff(signs) != 0)[0] + 1
|
|
boundaries = np.concatenate(([0], change_indices, [n]))
|
|
|
|
if len(boundaries) <= 2:
|
|
break
|
|
|
|
run_starts = boundaries[:-1]
|
|
run_ends = boundaries[1:]
|
|
run_lengths = run_ends - run_starts
|
|
run_signs = signs[run_starts]
|
|
|
|
short_runs_to_process = []
|
|
for i in range(len(run_starts)):
|
|
if run_lengths[i] > 0 and run_lengths[i] < filter_size:
|
|
short_runs_to_process.append({
|
|
'start': run_starts[i],
|
|
'end': run_ends[i],
|
|
'sign': run_signs[i],
|
|
'length': run_lengths[i]
|
|
})
|
|
|
|
if not short_runs_to_process:
|
|
break
|
|
|
|
short_runs_to_process.sort(key=lambda r: (r['length'], r['start']))
|
|
run_to_process = short_runs_to_process[0]
|
|
start = run_to_process['start']
|
|
end = run_to_process['end']
|
|
run_sign = run_to_process['sign']
|
|
replacement_value = 1.0 if run_sign == 0 else 0.0
|
|
result[start:end] = replacement_value
|
|
|
|
print(f"filter_short_groups_numpy time: {time.time() - st:.6f}s")
|
|
return result.tolist()
|
|
|
|
def run_test_case(test_data, filter_size, name="Test"):
|
|
"""Run both implementations and compare results and performance."""
|
|
print(f"\n===== {name} =====")
|
|
|
|
device_id = "test_device"
|
|
dates_str = "2025-05-21"
|
|
|
|
# Run NumPy implementation
|
|
start_time = time.time()
|
|
numpy_result = filter_short_groups_numpy(test_data, filter_size, device_id, dates_str)
|
|
numpy_time = time.time() - start_time
|
|
|
|
# Run C implementation
|
|
start_time = time.time()
|
|
c_result = filter_short_groups_c(test_data, filter_size, device_id, dates_str)
|
|
c_time = time.time() - start_time
|
|
|
|
# Compare results
|
|
results_match = numpy_result == c_result
|
|
|
|
# Print results
|
|
print(f"Results match: {results_match}")
|
|
if not results_match:
|
|
print(f"NumPy result: {numpy_result[:20]}...")
|
|
print(f"C result: {c_result[:20]}...")
|
|
|
|
print(f"NumPy time: {numpy_time:.6f}s")
|
|
print(f"C time: {c_time:.6f}s")
|
|
print(f"Speedup: {numpy_time / c_time:.2f}x")
|
|
|
|
return results_match
|
|
|
|
def main():
|
|
"""Run all test cases."""
|
|
# Simple predefined test case
|
|
test_case_1 = [0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0]
|
|
filter_size_1 = 3
|
|
|
|
# Small random test case
|
|
np.random.seed(42)
|
|
test_case_2 = list(np.random.choice([0.0, 1.0], size=100))
|
|
filter_size_2 = 4
|
|
|
|
# Medium random test case
|
|
test_case_3 = list(np.random.choice([0.0, 1.0], size=1000))
|
|
filter_size_3 = 5
|
|
|
|
# Large random test case
|
|
test_case_4 = list(np.random.choice([0.0, 1.0], size=10000))
|
|
filter_size_4 = 10
|
|
|
|
# Run test cases
|
|
all_passed = True
|
|
all_passed &= run_test_case(test_case_1, filter_size_1, "Simple Test Case")
|
|
all_passed &= run_test_case(test_case_2, filter_size_2, "Small Random Test (n=100)")
|
|
all_passed &= run_test_case(test_case_3, filter_size_3, "Medium Random Test (n=1000)")
|
|
all_passed &= run_test_case(test_case_4, filter_size_4, "Large Random Test (n=10000)")
|
|
|
|
if all_passed:
|
|
print("\n✅ All tests passed! The C implementation works correctly.")
|
|
else:
|
|
print("\n❌ Some tests failed! Check the output above for details.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|