Skip to content
Snippets Groups Projects
Select Git revision
  • b20622601a97c3cd9de0933e0c2aedc94cf734c1
  • main default protected
2 results

tests.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    test_servo_design.py 13.51 KiB
    import os
    import unittest
    from OpenQlab.analysis.servo_design import ServoDesign, Filter, Integrator, Differentiator, Notch, Lowpass
    from OpenQlab import io
    import numpy as np
    import matplotlib
    import logging as log
    import jsonpickle
    from pandas import DataFrame
    
    
    class TestFilter(unittest.TestCase):
    
        def test_filter_initialization(self):
            description = 'descriptiön'
            z = [4, 7]
            p = [2, 8]
            k = 1.2
            f = Integrator(1)
            f._zeros = z
            f._poles = p
            f._gain = k
            f._description = description
            self.assertEqual(f.description, description)
            self.assertListEqual(list(f.zeros), z)
            self.assertListEqual(list(f.poles), p)
            self.assertEqual(f.gain, k)
    
        def test_integrator(self):
            cF = 840
            sF = 5300
            i = Integrator(cF, sF)
            np.testing.assert_allclose(i.zeros, -cF)
            np.testing.assert_allclose(i.poles, -sF)
            np.testing.assert_allclose(i.gain, 1)
    
            i.sF = cF * .001
            np.testing.assert_allclose(i.zeros, -cF)
            np.testing.assert_allclose(i.poles, -cF / 1000)
    
        def test_differentiator(self):
            cF = 840
            sF = 5300
            i = Differentiator(cF, sF)
            np.testing.assert_allclose(i.zeros, -cF)
            np.testing.assert_allclose(i.poles, -sF)
            np.testing.assert_allclose(i.gain, 6.309524)
    
            i.sF = None
            np.testing.assert_allclose(i.zeros, -cF)
            np.testing.assert_allclose(i.poles, -cF * 1000)
            np.testing.assert_allclose(i.gain, 1000)
    
        def test_lowpass(self):
            cF = 840
            Q = 3
            i1 = Lowpass(cF, Q)
            np.testing.assert_allclose(i1.zeros, -cF)
            np.testing.assert_allclose(i1.poles, [-140. + 828.25116963j, -140. - 828.25116963j])
            np.testing.assert_allclose(i1.gain, 705600)
    
            i2 = Lowpass(cF)
            np.testing.assert_allclose(i2.zeros, [])
            np.testing.assert_allclose(i2.poles, [-594.05940594 + 593.8799729j, -594.05940594 - 593.8799729j])
            np.testing.assert_allclose(i2.gain, 705600)
    
        def test_notch(self):
            cF = 840
            Q = 1
            i1 = Notch(cF, Q)
            np.testing.assert_allclose(i1.zeros, [0. + 840.j, -0. - 840.j])
            np.testing.assert_allclose(i1.poles, [-420. + 727.46133918j, -420. - 727.46133918j])
            np.testing.assert_allclose(i1.gain, 1)
    
            Q = 32
            i2 = Notch(cF, Q)
            np.testing.assert_allclose(i2.zeros, [0. + 840.j, -0. - 840.j])
            np.testing.assert_allclose(i2.poles, [-13.125 + 839.89745468j, -13.125 - 839.89745468j])
            np.testing.assert_allclose(i2.gain, 1)
    
        def test_enabled_setter(self):
            f = Integrator(4899)
            f.enabled = False
            with self.assertRaises(TypeError):
                f.enabled = 'nein'
            with self.assertRaises(TypeError):
                f.enabled = 1
    
        def test_filter_enabled(self):
            for F in [Integrator, Differentiator, Lowpass, Notch]:
                f = F(500, enabled=False)
                self.assertFalse(f.enabled)
                f.enabled = True
                self.assertTrue(f.enabled)
                f.enabled = False
                self.assertFalse(f.enabled)
                f = F(500, enabled=True)
                self.assertTrue(f.enabled)
    
    
    class TestServoDesign(unittest.TestCase):
        def setUp(self):
            if os.environ.get('DISPLAY'):
                self.display_available = True
            else:
                self.display_available = False
            self.sd = ServoDesign()
    
        def test_set_plant_none(self):
            self.sd.plant = None
    
        def test_is_empty(self):
            self.assertEqual(len(self.sd), 0)
            self.sd.integrator(200)
            self.assertEqual(len(self.sd), 1)
            self.sd.clear()
            self.assertEqual(len(self.sd), 0)
    
        def test_filter_list(self):
            self.sd.integrator(340)
            self.sd.notch(239, 10)
            self.assertEqual(len(self.sd), 2)
            # TODO test entries
    
        def test_change_a_filter_inplace(self):
            self.sd.integrator(247, 20)
            self.assertEqual(self.sd.filters[0].cF, 247)
            self.sd.filters[0].cF = 103
            self.assertEqual(self.sd.filters[0].cF, 103)
    
        def test_add_too_many_filters(self):
            for i in range(1, 6):
                self.sd.lowpass(300 * i)
            self.assertEqual(len(self.sd), 5)
            with self.assertRaises(Exception):
                self.sd.integrator(304)
    
        def test_remove_filter(self):
            self.sd.integrator(10)
            self.sd.integrator(11)
            self.sd.integrator(12)
            self.sd.integrator(13)
            self.sd.integrator(14)
            self.assertEqual(len(self.sd), 5)
            del self.sd.filters[2]
            self.assertEqual(len(self.sd), 4)
            self.assertEqual(self.sd.filters[2].cF, 13)
    
        def test_set_filter_none(self):
            self.sd.integrator(10)
            self.sd.integrator(11)
            self.sd.integrator(12)
            self.sd.integrator(13)
            self.sd.integrator(14)
            self.assertEqual(len(self.sd), 5)
            self.sd.filters[3] = None
            self.assertEqual(len(self.sd), 4)
            self.assertIsNone(self.sd.filters[3])
    
        def test_add_filter_on_index(self):
            self.sd.add(Integrator(13), index=3)
            self.assertIsNone(self.sd.filters[2])
            self.assertEqual(len(self.sd), 1)
    
        def test_add_filter_on_first_none_index(self):
            self.sd.integrator(10)
            self.sd.integrator(11)
            self.sd.integrator(12)
            self.sd.filters[1] = None
            self.sd.integrator(42)
            self.assertEqual(self.sd.filters[1].cF, 42)
            self.assertEqual(len(self.sd), 3)
    
        def test_add_on_wrong_index(self):
            with self.assertRaises(IndexError):
                self.sd.add(Differentiator(34), index=5)
    
        def test_change_gain(self):
            self.assertEqual(self.sd.gain, 1)
            self.sd.gain = 5
            self.assertEqual(self.sd.gain, 5)
            self.sd.gain = 9.3
            self.assertEqual(self.sd.gain, 9.3)
    
        def test_gain_db(self):
            self.sd.gain = 10
            self.sd.log_gain(6)
            self.assertAlmostEqual(self.sd.gain, 20, delta=.05)
            self.sd.log_gain(-6)
            self.assertAlmostEqual(self.sd.gain, 10)
    
        def test_zpk(self):
            zpk_expected = (np.array([]), np.array([]), 1.0)
            zpk = self.sd.zpk()
            for i in range(3):
                np.testing.assert_allclose(zpk[i], zpk_expected[i])
    
            self.sd.integrator(500)
            self.sd.differentiator(1000)
            self.sd.notch(900, Q=200)
            zpk_expected = (
                np.array([-500., -1000., 0. + 900.j, -0. - 900.j]),
                np.array([-.5, -1e6, -2.25e+00 + 899.9971875j, -2.25e+00 - 899.9971875j]),
                1000.0
            )
            zpk = self.sd.zpk()
            for i in range(3):
                np.testing.assert_allclose(zpk[i], zpk_expected[i])
    
        def test_set_plant(self):
            with self.assertRaises(TypeError):
                self.sd.plant = 8
    
            columns = ['Gain (dB)', 'Phase (deg)']
            fra = io.read('servo_design/fra_3.csv')
            self.sd.plant = fra
            self.assertListEqual(self.sd.plant.columns.tolist(), columns)
    
            fra2 = fra.copy()
            del fra2['Gain (dB)']
            del fra2['Phase (deg)']
            with self.assertRaises(Exception):
                self.sd.plant = fra2
    
        def test_simple_plot(self):
            if self.display_available:
                plt = self.sd.plot()
                self.assertIsInstance(plt, matplotlib.figure.Figure)
                ax = plt.axes[0]
                self.assertLessEqual(ax.get_xlim()[1], 1e6)
                self.assertGreaterEqual(ax.get_xlim()[1], 1e5)
                self.sd.integrator(500)
                self.sd.integrator(500)
                self.sd.integrator(500)
                self.sd.integrator(500)
                self.sd.integrator(500)
                plt = self.sd.plot()
                self.assertIsInstance(plt, matplotlib.figure.Figure)
    
        def test_plot_with_frequencies(self):
            if self.display_available:
                plt = self.sd.plot(freq=np.logspace(1, 3, num=100))
                ax = plt.axes[0]
                self.assertLessEqual(ax.get_xlim()[1], 1e4)
                self.assertGreaterEqual(ax.get_xlim()[1], 1e3)
    
        def test_plot_with_plant(self):
            if self.display_available:
                self.sd.differentiator(390)
                self.sd.plant = io.read('servo_design/fra_3.csv')
                plt = self.sd.plot()
                self.assertIsInstance(plt, matplotlib.figure.Figure)
                ax = plt.axes[0]
                self.assertLessEqual(ax.get_xlim()[1], 3e4)
                self.assertGreaterEqual(ax.get_xlim()[1], 5e3)
    
        def test_get_dataframe(self):
            self.sd.notch(1e3, Q=3)
            df = self.sd.plot(plot=False)
            self.assertIsInstance(df, DataFrame)
            columns = ['Servo A', 'Servo P']
            self.assertListEqual(df.columns.tolist(), columns)
    
            self.sd.plant = io.read('servo_design/fra_3.csv')
            df = self.sd.plot(plot=False)
            self.assertIsInstance(df, DataFrame)
            columns = ['Servo A', 'Servo P', 'Servo+TF A', 'Servo+TF P']
            self.assertListEqual(df.columns.tolist(), columns)
    
        def test_description(self):
            self.sd.integrator(438)
            self.assertEqual(self.sd.filters[0].description, 'Int 438Hz')
            self.sd.filters[0].cF = 1200
            self.assertEqual(self.sd.filters[0].description, 'Int 1.2kHz')
    
            self.sd.differentiator(438)
            self.assertEqual(self.sd.filters[1].description, 'Diff 438Hz')
            self.sd.filters[1].cF = 1200
            self.assertEqual(self.sd.filters[1].description, 'Diff 1.2kHz')
    
            self.sd.lowpass(438)
            self.assertEqual(self.sd.filters[2].description, 'LP2 438Hz, Q=0.707')
            self.sd.filters[2].cF = 1200
            self.sd.filters[2].Q = 12
            self.assertEqual(self.sd.filters[2].description, 'LP2 1.2kHz, Q=12')
    
            self.sd.notch(438, Q=11)
            self.assertEqual(self.sd.filters[3].description, 'Notch 438Hz, Q=11')
            self.sd.filters[3].cF = 1200
            self.sd.filters[3].Q = 1.3
            self.assertEqual(self.sd.filters[3].description, 'Notch 1.2kHz, Q=1.3')
    
        def test_discrete_form(self):
            self.sd.integrator(500)
            self.sd.notch(900, Q=200)
            discrete_orig = {
                'fs': 100000.0,
                'gain': 1.0,
                'filters': {
                    'Int 500Hz': np.array([1.0156933, -0.98427528, 0, 1, -0.99996858, 0]),
                    'Notch 900Hz, Q=200': np.array([0.99985872, -1.996521, 0.99985872, 1, -1.996521, 0.99971745])
                }
            }
            discrete = self.sd.discrete_form(fs=100e3)
            self.assertEqual(discrete['fs'], discrete_orig['fs'])
            self.assertEqual(discrete['gain'], discrete_orig['gain'])
            np.testing.assert_allclose(discrete['filters']['Int 500Hz'], discrete_orig['filters']['Int 500Hz'])
            np.testing.assert_allclose(discrete['filters']['Notch 900Hz, Q=200'],
                                       discrete_orig['filters']['Notch 900Hz, Q=200'])
    
        def test_correct_latency(self):
            self.sd.plant = io.read('servo_design/fra_3.csv')
            df = self.sd.plot(plot=False, correct_latency=False)
            df_corrected = self.sd.plot(plot=False, correct_latency=True)
            columns = ['Servo A', 'Servo P', 'Servo+TF A', 'Servo+TF P']
            self.assertListEqual(list(df.columns), columns)
            self.assertEqual(df.iloc[-1]['Servo+TF P'] + 360 * df.index[-1] / 200000, df_corrected.iloc[-1]['Servo+TF P'])
    
        def test_with_disabled_filter(self):
            zpk_expected = (np.array([]), np.array([]), 1.0)
            zpk = self.sd.zpk()
            for i in range(3):
                np.testing.assert_allclose(zpk[i], zpk_expected[i])
    
            self.sd.integrator(500)
            self.sd.differentiator(1000)
            self.sd.notch(900, Q=200)
            self.sd.notch(1400, Q=200, enabled=False)
            self.sd.differentiator(3000, enabled=False)
            zpk_expected = (
                np.array([-500., -1000., 0. + 900.j, -0. - 900.j]),
                np.array([-.5, -1e6, -2.25e+00 + 899.9971875j, -2.25e+00 - 899.9971875j]),
                1000.0
            )
            zpk = self.sd.zpk()
            for i in range(3):
                np.testing.assert_allclose(zpk[i], zpk_expected[i])
    
        def test_discrete_form_with_disabled_filter(self):
            self.sd.integrator(500)
            self.sd.notch(900, Q=200)
            self.sd.integrator(500, enabled=False)
            self.sd.lowpass(5000, enabled=False)
            discrete_orig = {
                'fs': 100000.0,
                'gain': 1.0,
                'filters': {
                    'Int 500Hz': np.array([1.0156933, -0.98427528, 0, 1, -0.99996858, 0]),
                    'Notch 900Hz, Q=200': np.array([0.99985872, -1.996521, 0.99985872, 1, -1.996521, 0.99971745])
                }
            }
            discrete = self.sd.discrete_form(fs=100e3)
            self.assertEqual(discrete['fs'], discrete_orig['fs'])
            self.assertEqual(discrete['gain'], discrete_orig['gain'])
            np.testing.assert_allclose(discrete['filters']['Int 500Hz'], discrete_orig['filters']['Int 500Hz'])
            np.testing.assert_allclose(discrete['filters']['Notch 900Hz, Q=200'],
                                       discrete_orig['filters']['Notch 900Hz, Q=200'])
    
        def test_jsonpickle(self):
            self.sd.notch(900, Q=200)
            self.sd.integrator(500)
            self.sd.lowpass(5000)
    
            # Encode and decode without plant
            sdjson = jsonpickle.encode(self.sd)
            sd = jsonpickle.decode(sdjson)
            self.assertEqual(self.sd.__str__(), sd.__str__())
    
            # Encode and decode with plant
            fra = io.read('servo_design/fra_3.csv')
            self.sd.plant = fra
            sdjson = jsonpickle.encode(self.sd)
            sd = jsonpickle.decode(sdjson)
            self.assertEqual(self.sd.__str__(), sd.__str__())
            self.assertIsInstance(sd.plant, io.DataContainer)
            self.assertEqual(len(self.sd.plant), len(sd.plant))
    
    
    if __name__ == '__main__':
        unittest.main()