import unittest

import numpy as np

from openqlab.conversion import db


class TestDB(unittest.TestCase):
    def test_db_to_lin(self):
        self.assertAlmostEqual(db.to_lin(20), 100)
        self.assertAlmostEqual(db.to_lin(100), 1e10)
        self.assertAlmostEqual(db.to_lin(3), 1.99526, places=5)
        self.assertAlmostEqual(db.to_lin(0), 1)
        self.assertAlmostEqual(db.to_lin(-20), 0.01)
        self.assertAlmostEqual(db.to_lin(-100), 1e-10)
        self.assertAlmostEqual(db.to_lin(-3), 0.501187, places=5)
        self.assertAlmostEqual(db.to_lin(-0), 1)

    def test_lin_to_db(self):
        self.assertAlmostEqual(db.from_lin(100), 20)
        self.assertAlmostEqual(db.from_lin(1e10), 100)
        self.assertAlmostEqual(db.from_lin(1.99526), 3, places=4)
        self.assertAlmostEqual(db.from_lin(1), 0)
        self.assertAlmostEqual(db.from_lin(0.01), -20)
        self.assertAlmostEqual(db.from_lin(1e-10), -100)
        self.assertAlmostEqual(db.from_lin(0.501187), -3, places=4)

    def test_mean(self):
        lin_data = np.array([1, 2, 3, 4, 5, 6])
        db_data = db.from_lin(lin_data)
        self.assertAlmostEqual(db.to_lin(db.mean(db_data)), 3.5)

    def test_subtract(self):
        lin_data1 = np.array([5, 3, 1e4, 8, 5, 6])
        lin_data2 = np.array([1, 2, 1e3, 4, 4, 5])
        result = np.array([4, 1, 9e3, 4, 1, 1])

        db_data1 = db.from_lin(lin_data1)
        db_data2 = db.from_lin(lin_data2)

        np.testing.assert_array_almost_equal(
            db.to_lin(db.subtract(db_data1, db_data2)), result, decimal=3
        )

    def test_average(self):
        lin_data = np.array([1, 2, 3, 4, 5, 6, 7])
        db_data = db.from_lin(lin_data)
        self.assertAlmostEqual(db.to_lin(db.average(db_data)), 4)

    def test_dBm2Vrms(self):
        dbm_data = np.array([0, 1, 3, 6, 10, 100, -3, -10])
        dbv_data = np.array(
            [0.224, 0.251, 0.316, 0.446, 0.707, 22360.680, 0.158, 0.071]
        )

        np.testing.assert_array_almost_equal(db.dBm2Vrms(dbm_data), dbv_data, decimal=3)