From 0d815a09a7688caa3679e0a433028552de7e1aee Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Mon, 9 Jan 2023 07:20:44 -0500 Subject: [PATCH] add unit test for custom zbl() function --- unittest/utils/test_lepton.cpp | 41 ++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/unittest/utils/test_lepton.cpp b/unittest/utils/test_lepton.cpp index 91532b385a..415752c70d 100644 --- a/unittest/utils/test_lepton.cpp +++ b/unittest/utils/test_lepton.cpp @@ -85,6 +85,47 @@ TEST_F(LeptonUtilsTest, substitute) } } +// zbl() custom function + +TEST(LeptonCustomFunction, zbl) +{ + Lepton::ZBLFunction zbl(1.0, 1.0, 1.0); + std::map functions = {std::make_pair("zbl", &zbl)}; + std::map variables = {std::make_pair("zi", 6), std::make_pair("zj", 6), + std::make_pair("r", 2.0)}; + + auto parsed = Lepton::Parser::parse("zbl(zi, zj, r)", functions); + auto zbldzi = parsed.differentiate("zi"); + auto zbldzj = parsed.differentiate("zj"); + auto zbldr = parsed.differentiate("r"); + auto zbld2r = zbldr.differentiate("r"); + + double value = parsed.evaluate(variables); + ASSERT_DOUBLE_EQ(value, 0.065721538245489763); + value = zbldr.evaluate(variables); + ASSERT_DOUBLE_EQ(value, -0.15481915325334394); + variables["r"] = 1.0; + value = parsed.evaluate(variables); + ASSERT_DOUBLE_EQ(value, 1.0701488641432269); + value = zbldr.evaluate(variables); + ASSERT_DOUBLE_EQ(value, -3.6376386525054412); + variables["zi"] = 13.0; + value = parsed.evaluate(variables); + ASSERT_DOUBLE_EQ(value, 1.8430432789454971); + value = zbldr.evaluate(variables); + ASSERT_DOUBLE_EQ(value, -6.5373118484557642); + variables["zj"] = 13.0; + value = parsed.evaluate(variables); + ASSERT_DOUBLE_EQ(value, 3.1965196467438446); + value = zbldr.evaluate(variables); + ASSERT_DOUBLE_EQ(value, -11.804490148948526); + + // check for unsupported derivatives + ASSERT_ANY_THROW(value = zbldzi.evaluate(variables)); + ASSERT_ANY_THROW(value = zbldzj.evaluate(variables)); + ASSERT_ANY_THROW(value = zbld2r.evaluate(variables)); +} + /** * This is a custom function equal to f(x,y) = 2*x*y. */