diff --git a/amaranth_soc/csr/reg.py b/amaranth_soc/csr/reg.py index cc8a120..4e8836b 100644 --- a/amaranth_soc/csr/reg.py +++ b/amaranth_soc/csr/reg.py @@ -114,6 +114,8 @@ class FieldMap: The amount of bits required to store the field map. shape : :class:`StructLayout` Shape of the field map. + reset : dict + The reset value associated with the field map. """ def __init__(self, fields): offset = 0 @@ -145,29 +147,11 @@ def size(self): @property def shape(self): - return data.StructLayout({ - name: field.shape for name, field in self - }) + return data.StructLayout({name: field.shape for name, field in self}) @property def reset(self): - """Get the reset value associated with the field map. - - Returns - ------- - A nested dict of a :class:`str` as keys to an :class:`int` or integral Enum, depending on - the reset value of each :class:`GenericField`. - """ - reset = dict() - for key, field in self: - if isinstance(field, GenericField): - reset[key] = field.reset - elif isinstance(field, FieldMap): - for sub_name, sub_field in field.all_fields(): - reset[key] = field.reset - else: - assert False # :nocov: - return reset + return {key: field.reset for key, field in self} def __getitem__(self, key): """Retrieve a field from the field map. @@ -261,6 +245,10 @@ def __init__(self, field, length): def shape(self): return data.ArrayLayout(self._field.shape, self._length) + @property + def reset(self): + return [self._field.reset for key in range(self._length)] + def __getitem__(self, key): """Retrieve a field from the field array. diff --git a/tests/test_csr_reg.py b/tests/test_csr_reg.py index de2b6a5..6e3db22 100644 --- a/tests/test_csr_reg.py +++ b/tests/test_csr_reg.py @@ -141,7 +141,7 @@ def test_simple(self): self.assertEqual(field_array.size, 16) self.assertEqual(field_array.shape, data.ArrayLayout(unsigned(2), 8)) - self.assertEqual(field_array.reset, dict(enumerate(3 for _ in range(8)))) + self.assertEqual(field_array.reset, [3 for _ in range(8)]) for i in range(8): self.assertEqual(field_array[i], csr.field.R(unsigned(2), reset=3)) @@ -152,8 +152,7 @@ def test_dim_2(self): self.assertEqual(field_array.size, 16) self.assertEqual(field_array.shape, data.ArrayLayout(data.ArrayLayout(unsigned(1), length=4), length=4)) - self.assertEqual(field_array.reset, - dict(enumerate(dict(enumerate(1 for _ in range(4))) for _ in range(4)))) + self.assertEqual(field_array.reset, [[1 for _ in range(4)] for _ in range(4)]) for i in range(4): self.assertEqual(field_array[i], csr.FieldArray(csr.field.R(1, reset=1), length=4)) @@ -172,8 +171,7 @@ def test_nested(self): "b": data.ArrayLayout(unsigned(1), length=4), }), length=4)) self.assertEqual(field_array.reset, - dict(enumerate({"a": 0xa, "b": dict(enumerate(0 for _ in range(4)))} - for _ in range(4)))) + [{"a": 0xa, "b":[0,0,0,0]} for _ in range(4)]) for i in range(4): self.assertEqual(field_array[i], csr.FieldMap({ "a": csr.field.R(unsigned(4), reset=0xa),