diff --git a/src/openqlab/analysis/servo_design.py b/src/openqlab/analysis/servo_design.py index 5ec006bc6b5e58be6399e519b5ebe69ae4efbd5a..d63faec81fe87136303f71c8591bc5d70d3467d6 100644 --- a/src/openqlab/analysis/servo_design.py +++ b/src/openqlab/analysis/servo_design.py @@ -38,8 +38,8 @@ class Filter(ABC): Gain """ - def __init__(self, cF, second_parameter=None, enabled=True): - self._cF = cF + def __init__(self, corner_frequency, second_parameter=None, enabled=True): + self._corner_frequency = corner_frequency self._second_parameter = second_parameter self._enabled = enabled self.update() @@ -76,12 +76,12 @@ class Filter(ABC): self.update() @property - def cF(self): - return self._cF + def corner_frequency(self): + return self._corner_frequency - @cF.setter - def cF(self, value): - self._cF = value + @corner_frequency.setter + def corner_frequency(self, value): + self._corner_frequency = value self.update() @property @@ -115,29 +115,29 @@ class Filter(ABC): def gain(self): return self._gain - def discrete_SOS(self, fs): # pylint: disable=invalid-name + def discrete_SOS(self, sampling_frequency): # pylint: disable=invalid-name """ Return a discrete-time second order section of this filter, with sampling - frequency `fs`. + frequency `sampling_frequency`. """ - return signal.zpk2sos(*self.discrete_zpk(fs)) + return signal.zpk2sos(*self.discrete_zpk(sampling_frequency)) - def discrete_zpk(self, fs): + def discrete_zpk(self, sampling_frequency): """ Return the discrete-time transfer function of this filter, evaluated - for a sampling frequency of `fs`. + for a sampling frequency of `sampling_frequency`. """ - z, p, k = self._prewarp(fs) - return signal.bilinear_zpk(z, p, k, fs) + z, p, k = self._prewarp(sampling_frequency) + return signal.bilinear_zpk(z, p, k, sampling_frequency) - def _prewarp(self, fs): + def _prewarp(self, sampling_frequency): """ Prewarp frequencies of poles and zeros, to correct for the nonlinear mapping of frequencies between continuous-time and discrete-time domain. Parameters ---------- - fs: :obj:`float` + sampling_frequency: :obj:`float` the sampling frequency Returns @@ -149,7 +149,13 @@ class Filter(ABC): """ def warp(x): - return 2 * fs * x / abs(x) * np.tan(abs(x / fs) * np.pi) + return ( + 2 + * sampling_frequency + * x + / abs(x) + * np.tan(abs(x / sampling_frequency) * np.pi) + ) # since we're calculating in Hz, we need to scale the gain as well # by 2pi for each pole and 1/2pi for each zero @@ -173,27 +179,27 @@ class Filter(ABC): class Integrator(Filter): """ - Create an integrator with corner frequency 'cF', compensated for unity gain at high frequencies. + Create an integrator with corner frequency 'corner_frequency', compensated for unity gain at high frequencies. Parameters ---------- - cF: :obj:`float` + corner_frequency: :obj:`float` The corner frequency. sF: :obj:`float`, optional - Frequency were the ~1/f slope starts, defaults to 0.001 * `cF`. + Frequency were the ~1/f slope starts, defaults to 0.001 * `corner_frequency`. """ def calculate(self): - z = -self.cF + z = -self.corner_frequency if self.sF is None: - self._second_parameter = self.cF * 0.001 + self._second_parameter = self.corner_frequency * 0.001 p = -self.sF k = 1.0 # Gain = 1 return z, p, k @property def description(self): - return "Int {0}".format(human_readable(self.cF, "Hz")) + return "Int {0}".format(human_readable(self.corner_frequency, "Hz")) @property def sF(self): @@ -206,29 +212,29 @@ class Integrator(Filter): class Differentiator(Filter): """ - Create a differentiator with corner frequency `cF`, + Create a differentiator with corner frequency `corner_frequency`, compensated for unity gain at low frequencies. Parameters ---------- - cF: :obj:`float` + corner_frequency: :obj:`float` The corner frequency. sF: :obj:`float`, optional - Frequency were the ~f slope stops, defaults to 10 * `cF`. + Frequency were the ~f slope stops, defaults to 10 * `corner_frequency`. """ def calculate(self): - z = -self.cF + z = -self.corner_frequency if self.sF is None: - self._second_parameter = self.cF * 10 + self._second_parameter = self.corner_frequency * 10 p = -self.sF - k = self.sF / self.cF + k = self.sF / self.corner_frequency return z, p, k @property def description(self): - return "Diff {0}".format(human_readable(self.cF, "Hz")) + return "Diff {0}".format(human_readable(self.corner_frequency, "Hz")) @property def sF(self): @@ -247,29 +253,33 @@ class Lowpass(Filter): Parameters ---------- - cF: :obj:`float` + corner_frequency: :obj:`float` The corner frequency. """ - def __init__(self, cF, Q=0.707, enabled=True): - super().__init__(cF, Q, enabled) + def __init__(self, corner_frequency, Q=0.707, enabled=True): + super().__init__(corner_frequency, Q, enabled) def calculate(self): z = [] - cF = self.cF + corner_frequency = self.corner_frequency Q = self.Q p = [ - -cF / (2 * Q) + ((cF / (2 * Q)) ** 2 - cF ** 2) ** 0.5, - -cF / (2 * Q) - ((cF / (2 * Q)) ** 2 - cF ** 2) ** 0.5, + -corner_frequency / (2 * Q) + + ((corner_frequency / (2 * Q)) ** 2 - corner_frequency ** 2) ** 0.5, + -corner_frequency / (2 * Q) + - ((corner_frequency / (2 * Q)) ** 2 - corner_frequency ** 2) ** 0.5, ] - k = cF * cF + k = corner_frequency ** 2 return z, p, k @property def description(self): - return "LP2 {0}, Q={1:.4g}".format(human_readable(self.cF, "Hz"), self.Q) + return "LP2 {0}, Q={1:.4g}".format( + human_readable(self.corner_frequency, "Hz"), self.Q + ) @property def Q(self): @@ -282,35 +292,39 @@ class Lowpass(Filter): class Notch(Filter): """ - Create a notch filter at frequency `cF` with a quality + Create a notch filter at frequency `corner_frequency` with a quality factor `Q`, where the -3dB filter bandwidth ``bw`` is - given by ``Q = cF/bw``. + given by ``Q = corner_frequency/bw``. Parameters ---------- - cF: :obj:`float` + corner_frequency: :obj:`float` Frequency to remove from the spectrum Q: :obj:`float` Quality factor of the notch filter. Defaults to 1. """ - def __init__(self, cF, Q=1, enabled=True): - super().__init__(cF, Q, enabled) + def __init__(self, corner_frequency, Q=1, enabled=True): + super().__init__(corner_frequency, Q, enabled) def calculate(self): - cF = self.cF + corner_frequency = self.corner_frequency Q = self.Q - z = [cF * 1j, -cF * 1j] + z = [corner_frequency * 1j, -corner_frequency * 1j] p = [ - -cF / (2 * Q) + ((cF / (2 * Q)) ** 2 - cF ** 2) ** 0.5, - -cF / (2 * Q) - ((cF / (2 * Q)) ** 2 - cF ** 2) ** 0.5, + -corner_frequency / (2 * Q) + + ((corner_frequency / (2 * Q)) ** 2 - corner_frequency ** 2) ** 0.5, + -corner_frequency / (2 * Q) + - ((corner_frequency / (2 * Q)) ** 2 - corner_frequency ** 2) ** 0.5, ] k = 1 return z, p, k @property def description(self): - return "Notch {0}, Q={1:.4g}".format(human_readable(self.cF, "Hz"), self.Q) + return "Notch {0}, Q={1:.4g}".format( + human_readable(self.corner_frequency, "Hz"), self.Q + ) @property def Q(self): @@ -362,49 +376,60 @@ class ServoDesign: return self._filters def clear(self): - self._filters = [] + self._filters = [None] * self.MAX_FILTERS self.gain = 1.0 def _get_first_none_entry(self): - for i in range(len(self.filters)): + for i in range(self.MAX_FILTERS): if self.filters[i] is None: return i return None def _add_filter_on_index(self, filter, index): # pylint: disable=redefined-builtin + if index is None: + raise IndexError("Please provide a valid index.") if index >= self.MAX_FILTERS: raise IndexError("Max {0} filters are allowed.".format(self.MAX_FILTERS)) - # Fill up the list with none filters if necessary - while index > len(self._filters) - 1: - self._filters.append(None) self._filters[index] = filter - def add(self, filter, index=None): # pylint: disable=redefined-builtin + def add( + self, filter, index=None, override=False + ): # pylint: disable=redefined-builtin """ - Add a filter to the servo. Up to {0} filters can be added. + Add a filter to the servo. Up to {0} filters can be added. If the list is full and not index is provided, the filter at the last index may be overriden, depending on whether `override` has been set to true or false. Parameters ---------- filter: :obj:`Filter` the Filter object to be added + index: :obj:`int` + optional filter index. Default `None`. + override: :obj:`bool` + whether to override filter if adding without index. Defaults to `False`. """.format( self.MAX_FILTERS ) - if len(self) >= self.MAX_FILTERS and index is None: - raise Exception( - "Cannot add more than {0} filters to servo.".format(self.MAX_FILTERS) + if index is not None and not 0 <= index < self.MAX_FILTERS: + raise IndexError( + f"index needs to be in valid range from 0 to {self.MAX_FILTERS}" ) if not isinstance(filter, Filter): raise TypeError("filter must be a Filter() object") + # check if there is an empty index if index is None: index = self._get_first_none_entry() - if index is None: - self._filters.append(filter) - else: + # if no empty index was found, `None` was returned, check for override + if index is None and override: + self._add_filter_on_index(filter, self.MAX_FILTERS - 1) + elif index is not None: self._add_filter_on_index(filter, index) + else: + raise IndexError( + "No filter was added, list was full. You might wanna set `override=True` or remove a filter." + ) def get(self, index): """ @@ -424,13 +449,11 @@ class ServoDesign: raise IndexError( "Filter index must be between 0 and {}.".format(self.MAX_FILTERS - 1) ) - if index >= len(self._filters): - return None return self._filters[index] def remove(self, index): """ - Remove a filter from the servo. Effectively sets the slot at the given index to None, if it has been set before. + Remove a filter from the servo. Effectively sets the slot at the given index to None. Parameters ---------- @@ -443,9 +466,16 @@ class ServoDesign: raise IndexError( "Filter index must be between 0 and {}.".format(self.MAX_FILTERS - 1) ) - # only set a filter to None if array is already specified at index - if index < len(self._filters): - self._filters[index] = None + self._filters[index] = None + + def is_empty(self): + """ + Check whether ServoDesign contains any filter. + """ + for f in self._filters: + if f is not None: + return False + return True def __len__(self): """ @@ -469,36 +499,36 @@ class ServoDesign: # Add Filters the old way #################################### - def integrator(self, fc, fstop=None, enabled=True): + def integrator(self, corner_frequency, fstop=None, enabled=True): """ - Add an integrator with corner frequency `fc`, + Add an integrator with corner frequency `corner_frequency`, compensated for unity gain at high frequencies. Parameters ---------- - fc: :obj:`float` + corner_frequency: :obj:`float` The corner frequency. fstop: :obj:`float`, optional Frequency were the ~1/f slope starts, - defaults to 0.001 * `fc`. + defaults to 0.001 * `corner_frequency`. """ - self.add(Integrator(fc, fstop, enabled)) + self.add(Integrator(corner_frequency, fstop, enabled)) - def differentiator(self, fc, fstop=None, enabled=True): + def differentiator(self, corner_frequency, fstop=None, enabled=True): """ - Add a differentiator with corner frequency `fc`, + Add a differentiator with corner frequency `corner_frequency`, compensated for unity gain at low frequencies. Parameters ---------- - fc: :obj:`float` + corner_frequency: :obj:`float` The corner frequency. fstop: :obj:`float`, optional - Frequency were the ~f slope stops, defaults to 1000 * `fc`. + Frequency were the ~f slope stops, defaults to 1000 * `corner_frequency`. """ - self.add(Differentiator(fc, fstop, enabled)) + self.add(Differentiator(corner_frequency, fstop, enabled)) - def lowpass(self, fc, Q=0.707, enabled=True): + def lowpass(self, corner_frequency, Q=0.707, enabled=True): """ Add a 2nd-order lowpass filter with variable quality factor `Q`. @@ -509,17 +539,17 @@ class ServoDesign: parameter: :obj:`type` parameter description """ - self.add(Lowpass(fc, Q, enabled)) + self.add(Lowpass(corner_frequency, Q, enabled)) - def notch(self, fc, Q=1, enabled=True): + def notch(self, corner_frequency, Q=1, enabled=True): """ - Add a notch filter at frequency `fc` with a + Add a notch filter at frequency `corner_frequency` with a quality factor `Q`, where the -3dB filter bandwidth ``bw`` - is given by ``Q = fc/bw``. + is given by ``Q = corner_frequency/bw``. Parameters ---------- - fc: :obj:`float` + corner_frequency: :obj:`float` Frequency to remove from the spectrum Q: :obj:`float` Quality factor of the notch filter @@ -529,7 +559,7 @@ class ServoDesign: :obj:`Servo` the servo object with added notch filter """ - self.add(Notch(fc, Q, enabled)) + self.add(Notch(corner_frequency, Q, enabled)) #################################### # CLASS UTILITY @@ -686,36 +716,39 @@ class ServoDesign: Returns ------- - :obj:`dict` - a dictionary containing sample rate, gain and SOS coefficients for - each filter + :obj:`list` + a list containing a dict for each filter with additional information. """ if fs is not None: warn("fs is deprecated. use sampling_frequency.", DeprecationWarning) sampling_frequency = fs - coeffs = [] - for f in self._filters: - if f is not None: - coeffs.append(f.discrete_SOS(sampling_frequency).flatten()) - filters = {} - for f, d in zip(self._filters, coeffs): - if f is not None: - filters[f.description] = d + filters = [] - data = {"fs": sampling_frequency, "gain": self.gain, "filters": filters} + for i, f in enumerate(self._filters): + if f is not None: + filters.append( + { + "description": f.description, + "sos": f.discrete_SOS(sampling_frequency).flatten(), + "enabled": f.enabled, + "index": i, + } + ) + + data = { + "sampling_frequency": sampling_frequency, + "gain": self.gain, + "filters": filters, + } if filename: with open(filename, "w") as fp: - fp.write("Sampling rate: {0}\n".format(data["fs"])) - fp.write("Gain: {0:.10g}\n".format(data["gain"])) - for desc, c in data["filters"].items(): - fp.write( - ( - "{desc}: {c[0]:.10g} {c[1]:.10g} {c[2]:.10g} " - + "{c[3]:.10g} {c[4]:.10g} {c[5]:.10g}\n" - ).format(desc=desc, c=c) - ) + fp.write(f"Sampling rate: {data['sampling_frequency']}\n") + fp.write(f"Gain: {data['gain']}\n") + fp.write(f"Filters:\n") + for f in data["filters"]: + fp.write(f"{f}\n") return data diff --git a/src/tests/test_analysis/test_servo_design/test_servo_design.py b/src/tests/test_analysis/test_servo_design/test_servo_design.py index b3c0a2b16907e310b463ab30ec44415e0600d778..9d1a7509ddde81050555e24bca3e009a6370411e 100644 --- a/src/tests/test_analysis/test_servo_design/test_servo_design.py +++ b/src/tests/test_analysis/test_servo_design/test_servo_design.py @@ -1,11 +1,9 @@ -import logging as log import os import unittest from pathlib import Path import jsonpickle import matplotlib as mp -import matplotlib.pyplot as plt import numpy as np from pandas import DataFrame @@ -24,41 +22,41 @@ filedir = Path(__file__).parent class TestFilter(unittest.TestCase): def test_integrator(self): - cF = 840 + corner_frequency = 840 sF = 5300 - i = Integrator(cF, sF) - np.testing.assert_allclose(i.zeros, -cF) + i = Integrator(corner_frequency, sF) + np.testing.assert_allclose(i.zeros, -corner_frequency) np.testing.assert_allclose(i.poles, -sF) np.testing.assert_allclose(i.gain, 1) - i.sF = cF * 0.001 - np.testing.assert_allclose(i.zeros, -cF) - np.testing.assert_allclose(i.poles, -cF / 1000) + i.sF = corner_frequency * 0.001 + np.testing.assert_allclose(i.zeros, -corner_frequency) + np.testing.assert_allclose(i.poles, -corner_frequency / 1000) def test_differentiator(self): - cF = 840 + corner_frequency = 840 sF = 5300 - i = Differentiator(cF, sF) - np.testing.assert_allclose(i.zeros, -cF) + i = Differentiator(corner_frequency, sF) + np.testing.assert_allclose(i.zeros, -corner_frequency) np.testing.assert_allclose(i.poles, -sF) np.testing.assert_allclose(i.gain, 6.309524) i.sF = None - np.testing.assert_allclose(i.zeros, -cF) - np.testing.assert_allclose(i.poles, -cF * 10) + np.testing.assert_allclose(i.zeros, -corner_frequency) + np.testing.assert_allclose(i.poles, -corner_frequency * 10) np.testing.assert_allclose(i.gain, 10) def test_lowpass(self): - cF = 840 + corner_frequency = 840 Q = 3 - i1 = Lowpass(cF, Q) - np.testing.assert_allclose(i1.zeros, -cF) + i1 = Lowpass(corner_frequency, Q) + np.testing.assert_allclose(i1.zeros, -corner_frequency) np.testing.assert_allclose( i1.poles, [-140.0 + 828.25116963j, -140.0 - 828.25116963j] ) np.testing.assert_allclose(i1.gain, 705600) - i2 = Lowpass(cF) + i2 = Lowpass(corner_frequency) np.testing.assert_allclose(i2.zeros, []) np.testing.assert_allclose( i2.poles, [-594.05940594 + 593.8799729j, -594.05940594 - 593.8799729j] @@ -66,9 +64,9 @@ class TestFilter(unittest.TestCase): np.testing.assert_allclose(i2.gain, 705600) def test_notch(self): - cF = 840 + corner_frequency = 840 Q = 1 - i1 = Notch(cF, Q) + i1 = Notch(corner_frequency, Q) np.testing.assert_allclose(i1.zeros, [0.0 + 840.0j, -0.0 - 840.0j]) np.testing.assert_allclose( i1.poles, [-420.0 + 727.46133918j, -420.0 - 727.46133918j] @@ -76,7 +74,7 @@ class TestFilter(unittest.TestCase): np.testing.assert_allclose(i1.gain, 1) Q = 32 - i2 = Notch(cF, Q) + i2 = Notch(corner_frequency, Q) np.testing.assert_allclose(i2.zeros, [0.0 + 840.0j, -0.0 - 840.0j]) np.testing.assert_allclose( i2.poles, [-13.125 + 839.89745468j, -13.125 - 839.89745468j] @@ -114,31 +112,47 @@ class TestServoDesign(unittest.TestCase): def test_set_plant_none(self): self.sd.plant = None - def test_is_empty(self): + def test_len_empty(self): self.assertEqual(len(self.sd), 0) self.sd.integrator(200) self.assertEqual(len(self.sd), 1) self.sd.clear() self.assertEqual(len(self.sd), 0) + def test_isempty(self): + self.sd.differentiator(100) + self.assertFalse(self.sd.is_empty()) + self.sd.clear() + self.assertTrue(self.sd.is_empty()) + def test_filter_list(self): self.sd.integrator(340) self.sd.notch(239, 10) self.assertEqual(len(self.sd), 2) - # TODO test entries + self.sd.clear() + self.assertEqual(self.sd.filters, [None] * self.sd.MAX_FILTERS) def test_change_a_filter_inplace(self): self.sd.integrator(247, 20) - self.assertEqual(self.sd.filters[0].cF, 247) - self.sd.filters[0].cF = 103 - self.assertEqual(self.sd.filters[0].cF, 103) + self.assertEqual(self.sd.filters[0].corner_frequency, 247) + self.sd.filters[0].corner_frequency = 103 + self.assertEqual(self.sd.filters[0].corner_frequency, 103) def test_add_too_many_filters(self): for i in range(1, 6): self.sd.lowpass(300 * i) self.assertEqual(len(self.sd), 5) - with self.assertRaises(Exception): + with self.assertRaises(IndexError): self.sd.integrator(304) + self.sd.add(Lowpass(500), override=True) + self.assertEqual(self.sd.get(4).corner_frequency, Lowpass(500).corner_frequency) + + def test_get_index(self): + self.sd.integrator(50) + self.assertEqual(self.sd.get(1), None) + self.assertEqual( + self.sd.get(0).corner_frequency, Integrator(50).corner_frequency + ) def test_remove_filter(self): self.sd.integrator(10) @@ -147,9 +161,9 @@ class TestServoDesign(unittest.TestCase): self.sd.integrator(13) self.sd.integrator(14) self.assertEqual(len(self.sd), 5) - del self.sd.filters[2] + self.sd.remove(2) + self.assertEqual(self.sd.get(2), None) self.assertEqual(len(self.sd), 4) - self.assertEqual(self.sd.filters[2].cF, 13) def test_set_filter_none(self): self.sd.integrator(10) @@ -166,6 +180,7 @@ class TestServoDesign(unittest.TestCase): self.sd.add(Integrator(13), index=3) self.assertIsNone(self.sd.filters[2]) self.assertEqual(len(self.sd), 1) + self.assertFalse(self.sd.is_empty()) def test_add_filter_on_first_none_index(self): self.sd.integrator(10) @@ -173,7 +188,7 @@ class TestServoDesign(unittest.TestCase): self.sd.integrator(12) self.sd.filters[1] = None self.sd.integrator(42) - self.assertEqual(self.sd.filters[1].cF, 42) + self.assertEqual(self.sd.filters[1].corner_frequency, 42) self.assertEqual(len(self.sd), 3) def test_add_on_wrong_index(self): @@ -279,23 +294,23 @@ class TestServoDesign(unittest.TestCase): def test_description(self): self.sd.integrator(438) self.assertEqual(self.sd.filters[0].description, "Int 438 Hz") - self.sd.filters[0].cF = 1200 + self.sd.filters[0].corner_frequency = 1200 self.assertEqual(self.sd.filters[0].description, "Int 1.2 kHz") self.sd.differentiator(438) self.assertEqual(self.sd.filters[1].description, "Diff 438 Hz") - self.sd.filters[1].cF = 1200 + self.sd.filters[1].corner_frequency = 1200 self.assertEqual(self.sd.filters[1].description, "Diff 1.2 kHz") self.sd.lowpass(438) self.assertEqual(self.sd.filters[2].description, "LP2 438 Hz, Q=0.707") - self.sd.filters[2].cF = 1200 + self.sd.filters[2].corner_frequency = 1200 self.sd.filters[2].Q = 12 self.assertEqual(self.sd.filters[2].description, "LP2 1.2 kHz, Q=12") self.sd.notch(438, Q=11) self.assertEqual(self.sd.filters[3].description, "Notch 438 Hz, Q=11") - self.sd.filters[3].cF = 1200 + self.sd.filters[3].corner_frequency = 1200 self.sd.filters[3].Q = 1.3 self.assertEqual(self.sd.filters[3].description, "Notch 1.2 kHz, Q=1.3") @@ -303,25 +318,38 @@ class TestServoDesign(unittest.TestCase): self.sd.integrator(500) self.sd.notch(900, Q=200) discrete_orig = { - "fs": 100000.0, + "sampling_frequency": 100000.0, "gain": 1.0, - "filters": { - "Int 500 Hz": np.array([1.0156933, -0.98427528, 0, 1, -0.99996858, 0]), - "Notch 900 Hz, Q=200": np.array( - [0.99985872, -1.996521, 0.99985872, 1, -1.996521, 0.99971745] - ), - }, + "filters": [ + { + "description": "Int 500 Hz", + "sos": np.array( + [1.0156933, -0.98427528, 0.0, 1.0, -0.99996858, 0.0] + ), + "enabled": True, + "index": 0, + }, + { + "description": "Notch 900 Hz, Q=200", + "sos": np.array( + [0.99985872, -1.996521, 0.99985872, 1.0, -1.996521, 0.99971745] + ), + "enabled": True, + "index": 1, + }, + ], } + discrete = self.sd.discrete_form(sampling_frequency=100e3) - self.assertEqual(discrete["fs"], discrete_orig["fs"]) - self.assertEqual(discrete["gain"], discrete_orig["gain"]) - np.testing.assert_allclose( - discrete["filters"]["Int 500 Hz"], discrete_orig["filters"]["Int 500 Hz"] - ) - np.testing.assert_allclose( - discrete["filters"]["Notch 900 Hz, Q=200"], - discrete_orig["filters"]["Notch 900 Hz, Q=200"], + self.assertEqual( + discrete["sampling_frequency"], discrete_orig["sampling_frequency"] ) + self.assertEqual(discrete["gain"], discrete_orig["gain"]) + for f1, f2 in zip(discrete["filters"], discrete_orig["filters"]): + self.assertEqual(f1["description"], f2["description"]) + self.assertEqual(f1["enabled"], f2["enabled"]) + self.assertEqual(f1["index"], f2["index"]) + np.testing.assert_allclose(f1["sos"], f2["sos"]) def test_correct_latency(self): self.sd.plant = io.read(f"{filedir}/fra_3.csv") @@ -360,25 +388,61 @@ class TestServoDesign(unittest.TestCase): self.sd.integrator(500, enabled=False) self.sd.lowpass(5000, enabled=False) discrete_orig = { - "fs": 100000.0, + "sampling_frequency": 100000.0, "gain": 1.0, - "filters": { - "Int 500 Hz": np.array([1.0156933, -0.98427528, 0, 1, -0.99996858, 0]), - "Notch 900 Hz, Q=200": np.array( - [0.99985872, -1.996521, 0.99985872, 1, -1.996521, 0.99971745] - ), - }, + "filters": [ + { + "description": "Int 500 Hz", + "sos": np.array( + [1.0156933, -0.98427528, 0.0, 1.0, -0.99996858, 0.0] + ), + "enabled": True, + "index": 0, + }, + { + "description": "Notch 900 Hz, Q=200", + "sos": np.array( + [0.99985872, -1.996521, 0.99985872, 1.0, -1.996521, 0.99971745] + ), + "enabled": True, + "index": 1, + }, + { + "description": "Int 500 Hz", + "sos": np.array( + [1.0156933, -0.98427528, 0.0, 1.0, -0.99996858, 0.0] + ), + "enabled": False, + "index": 2, + }, + { + "description": "LP2 5 kHz, Q=0.707", + "sos": np.array( + [ + 0.01975329, + 0.03950658, + 0.01975329, + 1.0, + -1.5609758, + 0.64130708, + ] + ), + "enabled": False, + "index": 3, + }, + ], } + discrete = self.sd.discrete_form(sampling_frequency=100e3) - self.assertEqual(discrete["fs"], discrete_orig["fs"]) - self.assertEqual(discrete["gain"], discrete_orig["gain"]) - np.testing.assert_allclose( - discrete["filters"]["Int 500 Hz"], discrete_orig["filters"]["Int 500 Hz"] - ) - np.testing.assert_allclose( - discrete["filters"]["Notch 900 Hz, Q=200"], - discrete_orig["filters"]["Notch 900 Hz, Q=200"], + self.assertEqual( + discrete["sampling_frequency"], discrete_orig["sampling_frequency"] ) + self.assertEqual(discrete["gain"], discrete_orig["gain"]) + for f1, f2 in zip(discrete["filters"], discrete_orig["filters"]): + self.assertEqual(f1["description"], f2["description"]) + self.assertEqual(f1["enabled"], f2["enabled"]) + self.assertEqual(f1["index"], f2["index"]) + np.testing.assert_allclose(f1["sos"], f2["sos"]) def test_jsonpickle(self): self.sd.notch(900, Q=200)