Unit testing a service that accepts an Fn closure as a callback - unit-testing

I have the following service that registers callbacks to execute at a certain epoch, identified by an i64. The service has a vector of callbacks (that are bounded by the Send + Fn() -> () traits). Each callback can be executed multiple times (hence Fn instead of FnOnce or FnMut). The Send trait is needed because the callbacks will be registered by other threads, and this service will run in the background.
So far so good, but I'd like to test that the callbacks are executed the way they should be (i.e. the i64 epoch ticking in some direction which may (or may not) cause the callback to be executed). The problem is that I cannot seem to be able to think of a way to achieve this. I'm coming from Golang in which it is quite easy to inject a mock callback and assert whether it was called since such limitations are not imposed by the compiler, however when I employ the same methods in Rust, I end up with an FnMut instead of an Fn.
use std::sync::{Arc, Mutex};
use std::collections::HashMap;
struct Service<T: Send + Fn() -> ()> {
triggers: Arc<Mutex<HashMap<i64, Vec<Box<T>>>>>,
}
impl<T: Send + Fn() -> ()> Service<T> {
pub fn build() -> Self {
Service {
triggers: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn poll(&'static self) {
let hs = Arc::clone(&self.triggers);
tokio::spawn(async move {
loop {
// do some stuff and get `val`
if let Some(v) = hs.lock().unwrap().get(&val) {
for cb in v.iter() {
cb();
}
}
}
});
()
}
pub fn register_callback(&self, val: i64, cb: Box<T>) -> () {
self.triggers
.lock()
.unwrap()
.entry(val)
.or_insert(Vec::new())
.push(cb);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_poll() {
let c = Service::build();
let mut called = false;
let cb = || called = true;
let h: i64 = 10;
c.register_callback(h, Box::new(cb));
assert_eq!(called, false);
}
}
Any ideas on how would this sort of behavior could be tested in Rust? The only thing I can think of is perhaps some channel that would pass a local value to the test and relinquish ownership over it?

The best way would probably be to make your interface as general as possible:
// type bounds on structs are generally unnecessary so I removed it here.
struct Service<T> {
triggers: Arc<Mutex<HashMap<i64, Vec<Box<T>>>>>,
}
impl<T: Send + FnMut() -> ()> Service<T> {
pub fn build() -> Self {
Service {
triggers: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn poll(&'static self, val: i64) {
let hs = Arc::clone(&self.triggers);
tokio::spawn(async move {
loop {
// do some stuff and get `val`
if let Some(v) = hs.lock().unwrap().get_mut(&val) {
for cb in v.iter_mut() {
cb();
}
}
}
});
()
}
pub fn register_callback(&self, val: i64, cb: Box<T>) -> () {
self.triggers
.lock()
.unwrap()
.entry(val)
.or_insert(Vec::new())
.push(cb);
}
}
But if you can't generalize the interface you can just use an AtomicBool like this:
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{Ordering, AtomicBool};
#[test]
fn test_poll() {
let c = Service::build();
let mut called = AtomicBool::new(false);
let cb = || called.store(true, Ordering::Relaxed);
let h: i64 = 10;
c.register_callback(h, Box::new(cb));
assert!(!called.load(Ordering::Relaxed));
}
}

Related

Unit Testing: Verify that a method was called, without testing frameworks like Mockito or MockK

Not using testing frameworks like MockK or Mockito seems to be becoming more and more popular. I decided to try this approach. So far so good, returning fake data is simple. But how do I verify that a function (that does not return data) has been called?
Imagine having a calss like this:
class TestToaster: Toaster {
override fun showSuccessMessage(message: String) {
throw UnsupportedOperationException()
}
override fun showSuccessMessage(message: Int) {
throw UnsupportedOperationException()
}
override fun showErrorMessage(message: String) {
throw UnsupportedOperationException()
}
override fun showErrorMessage(message: Int) {
throw UnsupportedOperationException()
}
}
With MockK I would do
verify { toaster.showSuccessMessage() }
I do not want to reinvent a wheel so decided to ask. Finding anything on Google seems to be very difficult.
Since this is a thing, I assume the point would be to totally remove mocking libraries and everything can be done without them.
The old school way to do it before any appearance of the mocking library is to manually create an implementation that is just for testing . The test implementation will store how an method is called to some internal state such that the testing codes can verify if a method is called with expected parameters by checking the related state.
For example , a very simple Toaster implementation for testing can be :
public class MockToaster implements Toaster {
public String showSuccesMessageStr ;
public Integer showSuccesMessageInt;
public String showErrorMessageStr;
public Integer showErrorMessageInt;
public void showSuccessMessage(String msg){
this.showSuccesMessageStr = msg;
}
public void showSuccessMessage(Integer msg){
this.showSuccesMessageInt = msg;
}
public void showErrorMessage(String msg){
this.showErrorMessageStr = msg;
}
public void showErrorMessage(Integer msg){
this.showErrorMessageInt = msg;
}
}
Then in your test codes , you configure the object that you want to test to use MockToaster. To verify if it does really call showSuccessMessage("foo") , you can then assert if its showSuccesMessageStr equal to foo at the end of the test.
A lot of people seem to be suggesting the very straight forward solution for this, which totally makes sense. I decided to go a bit fancy and achieve this syntax:
verify(toaster = toaster, times = 1).showErrorMessage(any<String>()).
I created simple Matchers:
inline fun <reified T> anyObject(): T {
return T::class.constructors.first().call()
}
inline fun <reified T> anyPrimitive(): T {
return when (T::class) {
Int::class -> Int.MIN_VALUE as T
Long::class -> Long.MIN_VALUE as T
Byte::class -> Byte.MIN_VALUE as T
Short::class -> Short.MIN_VALUE as T
Float::class -> Float.MIN_VALUE as T
Double::class -> Double.MIN_VALUE as T
Char::class -> Char.MIN_VALUE as T
String:: class -> "io.readian.readian.matchers.strings" as T
Boolean::class -> false as T
else -> {
throw IllegalArgumentException("Not a primitive type ${T::class}")
}
}
}
Added a map to store call count for each method to my TestToaster where the key is the name of the function and value is the count:
private var callCount: MutableMap<String, Int> = mutableMapOf()
Whenever a function gets called I increase current call count value for a method. I get current method name through reflection
val key = object {}.javaClass.enclosingMethod?.name + param::class.simpleName
addCall(key)
In oder to achieve the "fancy" syntax, I created inner subcalss for TestToaster and a verify function:
fun verify(toaster: Toaster , times: Int = 1): Toaster {
return TestToaster.InnerToaster(toaster, times)
}
That function sends current toaster instance to the inner subclass to create new instance and returns it. When I call a method of the subclass in my above syntax, the check happens. If the check passes, nothing happens and test is passed, if conditions not met - and exception is thrown.
To make it more general and extendable I created this interface:
interface TestCallVerifiable {
var callCount: MutableMap<String, Int>
val callParams: MutableMap<String, CallParam>
fun addCall(key: String, vararg param: Any) {
val currentCountValue = callCount.getOrDefault(key, 0)
callCount[key] = currentCountValue + 1
callParams[key] = CallParam(param.toMutableList())
}
abstract class InnerTestVerifiable(
private val outer: TestCallVerifiable,
private val times: Int = 1,
) {
protected val params: CallParam = CallParam(mutableListOf())
protected fun check(functionName: String) {
val actualTimes = getActualCallCount(functionName)
if (actualTimes != times) {
throw IllegalStateException(
"$functionName expected to be called $times, but actual was $actualTimes"
)
}
val callParams = outer.callParams.getOrDefault(functionName, CallParam(mutableListOf()))
val result = mutableListOf<Boolean>()
callParams.values.forEachIndexed { index, item ->
val actualParam = params.values[index]
if (item == params.values[index] || (item != actualParam && isAnyParams(actualParam))) {
result.add(true)
}
}
if (params.values.isNotEmpty() && !result.all { it } || result.isEmpty()) {
throw IllegalStateException(
"$functionName expected to be called with ${callParams.values}, but actual was with ${params.values}"
)
}
}
private fun isAnyParams(vararg param: Any): Boolean {
param.forEach {
if (it.isAnyPrimitive()) return true
}
return false
}
private fun getActualCallCount(functionName: String): Int {
return outer.callCount.getOrDefault(functionName, 0)
}
}
data class CallParam(val values: MutableList<Any> = mutableListOf())
}
Here is the complete class:
open class TestToaster : TestCallVerifiable, Toaster {
override var callCount: MutableMap<String, Int> = mutableMapOf()
override val callParams: MutableMap<String, TestCallVerifiable.CallParam> = mutableMapOf()
override fun showSuccessMessage(message: String) {
val key = object {}.javaClass.enclosingMethod?.name + message::class.simpleName
addCall(key, message)
}
override fun showSuccessMessage(message: Int) {
val key = object {}.javaClass.enclosingMethod?.name + message::class.simpleName
addCall(key, message)
}
override fun showErrorMessage(message: String) {
val key = object {}.javaClass.enclosingMethod?.name + message::class.simpleName
addCall(key, message)
}
override fun showErrorMessage(message: Int) {
val key = object {}.javaClass.enclosingMethod?.name + message::class.simpleName
addCall(key, message)
}
private class InnerToaster(
verifiable: TestCallVerifiable,
times: Int,
) : TestCallVerifiable.InnerTestVerifiable(
outer = verifiable,
times = times,
), Toaster {
override fun showSuccessMessage(message: String) {
params.values.add(message)
val functionName = object {}.javaClass.enclosingMethod?.name + message::class.simpleName
check(functionName)
}
override fun showSuccessMessage(message: Int) {
params.values.add(message)
val functionName = object {}.javaClass.enclosingMethod?.name + message::class.simpleName
check(functionName)
}
override fun showErrorMessage(message: String) {
params.values.add(message)
val functionName = object {}.javaClass.enclosingMethod?.name + message::class.simpleName
check(functionName)
}
override fun showErrorMessage(message: Int) {
params.values.add(message)
val functionName = object {}.javaClass.enclosingMethod?.name + message::class.simpleName
check(functionName)
}
}
companion object {
fun verify(toaster: Toaster, times: Int = 1): Toaster {
return InnerToaster(toaster as TestCallVerifiable, times)
}
}
}
I have not tested this extensively and it will evolve with time, but so far it works well for me.
I also wrote an article about this on Medium: https://sermilion.medium.com/unit-testing-verify-that-a-method-was-called-without-testing-frameworks-like-mockito-or-mockk-433ef8e1aff4

Mocking functions in rust

Is there a way to mock regular functions in rust?
Consider the following code:
fn main() {
println!("{}", foo());
}
fn get_user_input() -> u8 {
// Placeholder for some unknown value
42
}
fn foo() -> u8 {
get_user_input()
}
#[cfg(test)]
mod tests {
#[test]
fn test_foo() {
use super::*;
get_user_input = || 12u8;
assert_eq!(foo(), 12u8);
}
}
I would like to unit test foo() without having to rely on the output of get_user_input().
I obviously cannot overwrite get_user_input() like I tried in the example code.
I have only found ways to mock structs, traits and modules but nothing about mocking regular free functions. Am I missing something?
Edit: I have looked primarily at the mockall crate.
You could use cfg:
#[cfg(not(test))]
fn get_user_input() -> u8 {
// Placeholder for some unknown value
42
}
#[cfg(test)]
fn get_user_input() -> u8 {
12
}
playground
Or dependency injection:
pub fn main() {
println!("{}", foo(get_user_input));
}
fn get_user_input() -> u8 {
// Placeholder for some unknown value
42
}
fn foo(get_user_input_: impl Fn() -> u8) -> u8 {
get_user_input_()
}
#[cfg(test)]
mod tests {
#[test]
fn test_foo() {
use super::*;
let get_user_input = || 12u8;
assert_eq!(foo(get_user_input), 12u8);
}
}
playgound

Problem calling C++ Class method from Rust

The following code is generated by bindgen.
extern "C" {
#[doc = "MoraComm Properties"]
#[link_name = "\u{1}_ZN22MoraCommManagerWrapped10propertiesEj"]
pub fn MoraCommManagerWrapped_properties(
this: *mut MoraCommManagerWrapped,
deviceNumber: ::std::os::raw::c_uint,
) -> MoraCommPropertiesWrapped;
}
impl MoraCommManagerWrapped {
#[inline]
pub unsafe fn properties(
&mut self,
deviceNumber: ::std::os::raw::c_uint,
) -> MoraCommPropertiesWrapped {
MoraCommManagerWrapped_properties(self, deviceNumber)
}
#[inline]
pub unsafe fn new() -> Self {
let mut __bindgen_tmp = ::std::mem::MaybeUninit::uninit();
MoraCommManagerWrapped_MoraCommManagerWrapped(__bindgen_tmp.as_mut_ptr());
__bindgen_tmp.assume_init()
}
}
fn main() {
unsafe {
let mut mgr: MoraCommManagerWrapped = MoraCommManagerWrapped::new();
let p = mgr.properties(0); // <- segfault
}
}
When main is run, there is a segfault.
I added logging on the C++ side, that shows that the value of the this pointer is equal to value of the deviceNumber arg.
I call another class that is part of the same bindgen generated bindings and it works.
Any thoughts on what might be going on?

How do you test for a specific Rust error?

I can find ways to detect if Rust gives me an error,
assert!(fs::metadata(path).is_err())
source
How do I test for a specific error?
You can directly compare the returned Err variant if it impl Debug + PartialEq:
#[derive(Debug, PartialEq)]
enum MyError {
TooBig,
TooSmall,
}
pub fn encode(&self, decoded: &'a Bytes) -> Result<&'a Bytes, MyError> {
if decoded.len() > self.length() as usize {
Err(MyError::TooBig)
} else {
Ok(&decoded)
}
}
assert_eq!(fixed.encode(&[1]), Err(MyError::TooBig));
Following solution doesn't require PartialEq trait to be implemented. For instance std::io::Error does not implement this, and more general solution is required.
In these cases, you can borrow a macro assert_matches from matches crate. It works by giving more succinct way to pattern match, the macro is so short you can just type it too:
macro_rules! assert_err {
($expression:expr, $($pattern:tt)+) => {
match $expression {
$($pattern)+ => (),
ref e => panic!("expected `{}` but got `{:?}`", stringify!($($pattern)+), e),
}
}
}
// Example usages:
assert_err!(your_func(), Err(Error::UrlParsingFailed(_)));
assert_err!(your_func(), Err(Error::CanonicalizationFailed(_)));
assert_err!(your_func(), Err(Error::FileOpenFailed(er)) if er.kind() == ErrorKind::NotFound);
Full playground buildable example, with example Error enum:
#[derive(Debug)]
pub enum Error {
UrlCreationFailed,
CanonicalizationFailed(std::io::Error),
FileOpenFailed(std::io::Error),
UrlParsingFailed(url::ParseError),
}
pub fn your_func() -> Result<(), Error> {
Ok(())
}
#[cfg(test)]
mod test {
use std::io::ErrorKind;
use super::{your_func, Error};
macro_rules! assert_err {
($expression:expr, $($pattern:tt)+) => {
match $expression {
$($pattern)+ => (),
ref e => panic!("expected `{}` but got `{:?}`", stringify!($($pattern)+), e),
}
}
}
#[test]
fn test_failures() {
// Few examples are here:
assert_err!(your_func(), Err(Error::UrlParsingFailed(_)));
assert_err!(your_func(), Err(Error::CanonicalizationFailed(_)));
assert_err!(your_func(), Err(Error::FileOpenFailed(er)) if er.kind() == ErrorKind::NotFound);
}
}

Preventing a Fn from being invoked again while it's already running

I am using inputbot to write a program that provides some global macros for my computer. For example, when I press the h key, it should execute the macro typing
Hello World
into the current application. I tried to implement it like this:
extern crate inputbot;
fn main() {
let mut callback = || {
inputbot::KeySequence("Hello World").send();
};
inputbot::KeybdKey::HKey.bind(callback);
inputbot::handle_input_events();
}
However, when I pressed the h key what I actually got was:
hHHHHHHHHHHHHHHHHHHHHHHHHHEHEHhEhEEHHhEhEhEHhEHHEHHEEHhEHlhEHEHHEHLEHLHeeleleelelelllelelleelehlhehlleeheehelheelleeleelhllllllellelolelellelleoleloloelellololol olollollelllolllol lloo ol o oo l lo lolooloooloo loo LOWOLO O L OLW WOWO L WLLOLOW L O O O O o WOWW low o oOow WWW WOW wowooWWWO oOWRWOoor W RoW oOWorororWRRWLR rLROwoRWLWOworo WorrrRWl ow o WRLR OLw o OWLDol rollWWLDWowDLlroWWo r oWDWOL dorRrwrolrdrrorlrLWDRdodRLowdllrllolrdlrddolrdlrldowldorowlrdlrorloLDLWDLoddlrddlrdldldldrrdordldrlrddrodlrrldoldlrlddldlrdlldlrdlddrlddldddlddlddd
The macro was triggering itself each time it sent the h key event. 😬
How can I prevent a Fn from being invoked again while another instance of it is still running? This is the main functionality of a small application, so there's nothing else to really worry about compatibility with.
My naive attempt to fix
this was to add a mut running variable in main, which callback would set to true while it was running, or immediately return if it was already true:
extern crate inputbot;
use std::time::Duration;
use std::thread::sleep;
fn main() {
let mut running = false;
let mut callback = || {
if running { return };
running = true;
inputbot::KeySequence("Hello World").send();
// wait to make sure keyboard events are done.
sleep(Duration::from_millis(125));
running = false;
};
inputbot::KeybdKey::HKey.bind(callback);
inputbot::handle_input_events();
}
However, this doesn't compile:
error[E0525]: expected a closure that implements the `Fn` trait, but this closure only implements `FnMut`
After some reading, my understanding is now that a Fn closure (required by inputbot's .bind() methods) can't own any mutable data, like a captured mut variable.
Maybe it's possible to wrap the variable in some kind of non-mut value? Perhaps some kind-of lock, to make the potential concurrency safe, like this pseudocde?
fn main() {
let mut running = false;
let lockedRunning = example::Lock(&running);
let mut callback = || {
{
let mut running = lockedRunning.acquire();
if running { return };
running = true;
}
inputbot::KeySequence("Hello World").send();
// wait to make sure keyboard events are done.
sleep(Duration::from_millis(125));
{
let mut running = lockedRunning.acquire();
running = false;
}
};
}
What you want here is that the function is mutually exclusive to itself.
Rust allows you to do this with the Mutex struct. It allows you to hold a lock that when acquired stops anyone else from taking it until you release it.
Specifically the functionality you want is the try_lock method which would allow you to check if the lock has already been acquired and would allow you to handle that case.
let lock = mutex.try_lock();
match lock {
Ok(_) => {
// We are the sole owners here
}
Err(TryLockError::WouldBlock) => return,
Err(TryLockError::Poisoned(_)) => {
println!("The mutex is poisoned");
return
}
}
Using an atomic value is a bit simpler than a Mutex as you don't need to worry about failure cases and it can easily be made into a static variable without using lazy-static:
use std::sync::atomic::{AtomicBool, Ordering};
fn main() {
let is_being_called = AtomicBool::new(false);
bind(move || {
if !is_being_called.compare_and_swap(false, true, Ordering::SeqCst) {
print!("I'm doing work");
is_being_called.store(false, Ordering::SeqCst);
}
});
}
I have a hunch that this is also more efficient than using a Mutex as no heap allocations need to be made, but I didn't benchmark it.
If you are in a single-threaded context and your callback is somehow (accidentally?) recursive (which closures cannot be) you can also use a Cell:
use std::cell::Cell;
fn main() {
let is_being_called = Cell::new(false);
bind(move || {
if !is_being_called.get() {
is_being_called.set(true);
print!("doing work");
is_being_called.set(false);
}
})
}
If you happen to have a FnMut closure, you don't even need the Cell and can just use a boolean:
fn main() {
let mut is_being_called = false;
bind(move || {
if !is_being_called {
is_being_called = true;
print!("doing work");
is_being_called = false;
}
})
}