diff --git a/tests/test_parser.py b/tests/test_parser.py index a37f781..29ee14f 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -499,16 +499,35 @@ def test_multiline_with_multiple_indentation_levels(): def test_multiline_literal_with_empty_lines_and_indentation(): """Test literal scalar with empty lines and indentation.""" comp(""" -key: | +key: > line 1 indented line back to base """) + # Input has blank lines without indentation + comp( + "key: |\n" + " line 1\n" + "\n" # blank line without indentation + " indented line\n" + "\n" # blank line without indentation + " back to base\n" + ) -def test_multiline_folded_code_block(): +def test_multiline_literal_preserves_indented_blank_lines(): + """Test that blank lines in multiline literals preserve their indentation.""" + # Use explicit string to ensure the blank line has indentation (4 spaces) + yaml_input = ( + "config:\n" + " summary: |\n" + " First paragraph here.\n" + "\n" + " Second paragraph here.\n" + ) + comp(yaml_input) """Test folded scalar with code-like indented content.""" comp(""" description: > @@ -558,7 +577,8 @@ def test_sequence_with_empty_line_and_comment(): Note: Comments at column 0 are attached to the next sequence item and output at the sequence indentation level. This is the expected behavior. """ - comp(""" + comp( + """ test: - first_item - second_item @@ -566,7 +586,8 @@ def test_sequence_with_empty_line_and_comment(): - third_item # a comment - fourth_item -""", expected_result=""" +""", + expected_result=""" test: - first_item - second_item @@ -574,7 +595,8 @@ def test_sequence_with_empty_line_and_comment(): - third_item # a comment - fourth_item -""") +""", + ) def test_comment_at_column_zero_in_sequence(): @@ -583,16 +605,126 @@ def test_comment_at_column_zero_in_sequence(): Note: Comments at column 0 are attached to the next sequence item and output at the sequence indentation level. This is the expected behavior. """ - comp(""" + comp( + """ items: - item1 # comment at column 0 - item2 - item3 -""", expected_result=""" +""", + expected_result=""" items: - item1 # comment at column 0 - item2 - item3 -""") +""", + ) + + +# ============================================================================= +# Chomping indicator tests +# ============================================================================= + + +def test_literal_strip(): + """Test |- strips all trailing newlines.""" + yaml = "key: |-\n value\n" + result = parse(yaml) + # Strip chomping removes all trailing newlines + assert result["key"]._value == "value" + assert not result["key"]._value.endswith("\n") + + +def test_literal_keep(): + """Test |+ keeps all trailing newlines.""" + yaml = "key: |+\n value\n\n\n" + result = parse(yaml) + # Keep chomping preserves all trailing newlines + assert result["key"]._value == "value\n\n\n" + + +def test_literal_clip(): + """Test | (clip) adds single trailing newline.""" + yaml = "key: |\n value\n\n\n" + result = parse(yaml) + # Clip chomping (default) adds exactly one trailing newline + assert result["key"]._value == "value\n" + + +def test_folded_strip(): + """Test >- strips all trailing newlines.""" + yaml = "key: >-\n value\n" + result = parse(yaml) + # Strip chomping removes all trailing newlines + assert result["key"]._value == "value" + assert not result["key"]._value.endswith("\n") + + +def test_folded_keep(): + """Test >+ keeps all trailing newlines.""" + yaml = "key: >+\n value\n\n\n" + result = parse(yaml) + # Keep chomping preserves all trailing newlines + assert result["key"]._value == "value\n\n\n" + + +def test_folded_clip(): + """Test > (clip) adds single trailing newline.""" + yaml = "key: >\n value\n\n\n" + result = parse(yaml) + # Clip chomping (default) adds exactly one trailing newline + assert result["key"]._value == "value\n" + + +def test_chomping_roundtrip_strip(): + """Test |- is preserved in round-trip.""" + yaml = """key: |- + no trailing newline +""" + result = parse(yaml) + output = result.to_yaml() + assert "|-" in output + + +def test_chomping_roundtrip_keep(): + """Test |+ is preserved in round-trip.""" + yaml = """key: |+ + keep trailing newlines +""" + result = parse(yaml) + output = result.to_yaml() + assert "|+" in output + + +def test_chomping_roundtrip_folded_strip(): + """Test >- is preserved in round-trip.""" + yaml = """key: >- + no trailing newline +""" + result = parse(yaml) + output = result.to_yaml() + assert ">-" in output + + +def test_chomping_roundtrip_folded_keep(): + """Test >+ is preserved in round-trip.""" + yaml = """key: >+ + keep trailing newlines +""" + result = parse(yaml) + output = result.to_yaml() + assert ">+" in output + + +def test_chomping_multiline_content(): + """Test chomping with multiple lines of content.""" + yaml = """key: |- + line1 + line2 + line3 +""" + result = parse(yaml) + assert result["key"]._value == "line1\nline2\nline3" + assert not result["key"]._value.endswith("\n") diff --git a/tests/test_parser_to_json.py b/tests/test_parser_to_json.py index cd51d1e..b5dcada 100644 --- a/tests/test_parser_to_json.py +++ b/tests/test_parser_to_json.py @@ -144,8 +144,9 @@ def test_multiline_strings(): be joined with spaces """, { - "multiline": "This is a\nmultiline string\nwith multiple lines", - "folded": "This is a folded string that will be joined with spaces", + # Clip mode (default) adds a single trailing newline per YAML spec + "multiline": "This is a\nmultiline string\nwith multiple lines\n", + "folded": "This is a folded string that will be joined with spaces\n", }, ) diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..96fa6cc --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,139 @@ +"""Security tests for yamlium. + +These tests verify that yamlium protects against common YAML security vulnerabilities: +- Alias bomb (billion laughs) attacks +- Excessive nesting depth +- Circular references +""" + +import pytest + +from yamlium import ParsingError, parse + + +def test_alias_bomb_protection(): + """Alias bomb (billion laughs) attack should be detected and raise error. + + This test creates many alias references relative to a small number of actual nodes, + which triggers the alias ratio protection (MAX_ALIAS_RATIO = 10). + """ + # Create a YAML that has a high alias-to-decode ratio + # Structure: 1 anchor, 1 key, then many aliases in a flow sequence + # decode_count ~ 3 (key 'a', scalar '1', key 'b'), alias_count = 50 + # ratio = 50/3 ~ 16.7 > 10 + yaml = "a: &a 1\n" + yaml += "b: [" + ", ".join(["*a"] * 50) + "]\n" + with pytest.raises(ParsingError, match="Excessive aliasing"): + parse(yaml) + + +def test_alias_bomb_moderate_usage_allowed(): + """Normal alias usage should not trigger protection.""" + yaml = """ +base: &base + name: default + value: 42 + +derived1: *base +derived2: *base +derived3: + <<: *base + extra: value +""" + # This should parse without error + result = parse(yaml) + assert result["base"]["name"] == "default" + + +def test_depth_limit_exceeded(): + """Deeply nested YAML should raise error when exceeding limit.""" + # Generate valid deeply nested YAML (more than MAX_DEPTH=200 levels) + # Each level must have proper content + depth = 202 + lines = [] + for i in range(depth): + indent = " " * i + lines.append(f"{indent}level{i}:") + # Add a final value + lines.append(f"{' ' * depth}value: end") + yaml = "\n".join(lines) + with pytest.raises(ParsingError, match="Maximum nesting depth"): + parse(yaml) + + +def test_depth_limit_normal_nesting_allowed(): + """Normal nesting depth should be allowed.""" + yaml = """ +level1: + level2: + level3: + level4: + level5: + value: deep +""" + result = parse(yaml) + assert result["level1"]["level2"]["level3"]["level4"]["level5"]["value"] == "deep" + + +def test_depth_limit_sequence_nesting(): + """Sequence nesting should also be depth limited.""" + # Generate deeply nested sequences with mappings + # Each iteration creates 2 depth levels (sequence + mapping), so 102 iterations = 204 levels + depth = 102 + lines = ["root:"] + for i in range(depth): + indent = " " * (i * 2 + 1) + lines.append(f"{indent}- level{i}:") + # Add a final value + lines.append(f"{' ' * (depth * 2 + 1)}value: end") + yaml = "\n".join(lines) + with pytest.raises(ParsingError, match="Maximum nesting depth"): + parse(yaml) + + +def test_circular_reference_detection(): + """Self-referential anchor should raise error.""" + yaml = """ +a: &a + b: *a +""" + with pytest.raises(ParsingError, match="Circular reference"): + parse(yaml) + + +def test_circular_reference_simple(): + """Simple direct circular reference.""" + yaml = """ +root: &root + child: *root +""" + with pytest.raises(ParsingError, match="Circular reference"): + parse(yaml) + + +def test_non_circular_forward_reference_allowed(): + """Forward references that are not circular should be allowed.""" + yaml = """ +first: &first + value: 1 + +second: &second + ref: *first + value: 2 + +third: + ref1: *first + ref2: *second +""" + result = parse(yaml) + assert result["first"]["value"] == 1 + assert result["third"]["ref1"].child["value"] == 1 + + +def test_alias_without_anchor(): + """Alias without a defined anchor should raise error.""" + yaml = """ +key: *undefined_anchor +""" + with pytest.raises(ParsingError, match="No anchor found"): + parse(yaml) diff --git a/yamlium/lexer.py b/yamlium/lexer.py index a2e69ca..5b8dd1d 100644 --- a/yamlium/lexer.py +++ b/yamlium/lexer.py @@ -49,6 +49,7 @@ class Token: start: int end: int quote_char: str | None = None + chomp: str = "" # "", "-" (strip), or "+" (keep) for multiline scalars @dataclass @@ -259,8 +260,16 @@ def _parse_scalar(self, extra_stop_chars: set = set()) -> list[Token]: def _parse_multiline_scalar(self) -> list[Token]: s = self._snapshot multiline_type = T.MULTILINE_PIPE if self.c == "|" else T.MULTILINE_ARROW + self._nc() # Consume | or > - # TODO: Add functionality for newline preserve/chomp: |- |+ >- >+ + # Parse chomping indicator: - (strip), + (keep), or none (clip) + chomp = "" + if self.c == "-": + chomp = "-" + self._nc() + elif self.c == "+": + chomp = "+" + self._nc() post_multiline_newlines = 0 indent = 0 @@ -303,7 +312,7 @@ def _parse_multiline_scalar(self) -> list[Token]: # Process each line: remove base indentation but preserve additional indentation processed_lines = [] - for line in split[1:]: # Skip the first line (which is just "|" or ">") + for line in split[1:]: # Skip the first line (which is just "|", ">", "|-", etc.) if not line.strip(): # Empty line processed_lines.append("") else: @@ -318,10 +327,30 @@ def _parse_multiline_scalar(self) -> list[Token]: value = "\n".join(processed_lines) - tokens = self._build_token(t=multiline_type, value=value, s=s) + # Apply chomping behavior to trailing newlines + if chomp == "-": + # Strip: remove all trailing newlines + value = value.rstrip("\n") + elif chomp == "+": + # Keep: preserve all trailing newlines + # Add back the trailing newlines that were part of the content + if post_multiline_newlines > 0: + value = value + "\n" * post_multiline_newlines + else: + # Clip (default): single trailing newline + value = value.rstrip("\n") + if value: # Only add newline if there's content + value = value + "\n" + + tokens = self._build_token(t=multiline_type, value=value, s=s, chomp=chomp) + + # For strip and clip modes, we've handled newlines in the value + # For keep mode, we've also included trailing newlines in the value + # So we only add EMPTY_LINE tokens if not in keep mode and there are extra newlines + if chomp != "+": + for _ in range(post_multiline_newlines - 1): + tokens.extend(self._build_token(t=T.EMPTY_LINE, value="")) - for _ in range(post_multiline_newlines - 1): - tokens.extend(self._build_token(t=T.EMPTY_LINE, value="")) if indent != -1 and indent < self.indent_stack[-1]: # If the most recent indent we fetched is less than indent stack # Then add as a dedent. @@ -519,7 +548,12 @@ def _skip_whitespaces(self) -> None: break def _build_token( - self, t: T, value: str, s: Snapshot | None = None, quote_char: str | None = None + self, + t: T, + value: str, + s: Snapshot | None = None, + quote_char: str | None = None, + chomp: str = "", ) -> list[Token]: if not s: s = self._snapshot @@ -532,6 +566,7 @@ def _build_token( start=s.position, end=s.position + len(value), quote_char=quote_char, + chomp=chomp, ) ] diff --git a/yamlium/nodes.py b/yamlium/nodes.py index 4907578..dd50247 100644 --- a/yamlium/nodes.py +++ b/yamlium/nodes.py @@ -555,6 +555,7 @@ def __init__( _is_indented: bool = False, _original_value: str = "", _quote_char: str | None = None, + _chomp: str = "", # "", "-" (strip), or "+" (keep) for multiline scalars ) -> None: super().__init__(_value, _line, _indent) self._type = _type @@ -563,6 +564,7 @@ def __init__( # null, ~, empty space self._original_value = _original_value self._quote_char = _quote_char + self._chomp = _chomp def __str__(self) -> str: return str(self._value) @@ -577,8 +579,15 @@ def to_dict(self) -> str | int | float | bool | None: return self._value if self._type == T.MULTILINE_PIPE: return self._value - # Otherwise we have an arrow multiline, i.e. ignoring newlines. - return self._value.replace("\n", " ") # type: ignore + # Otherwise we have an arrow multiline, i.e. folding newlines to spaces. + # Preserve trailing newline if present (from chomping), only fold internal newlines. + val = self._value + if isinstance(val, str): + trailing_newline = val.endswith("\n") + val = val.rstrip("\n").replace("\n", " ") + if trailing_newline: + val += "\n" + return val def _to_yaml(self, i: int = 0) -> str: if self._type == T.SCALAR: @@ -591,18 +600,19 @@ def _to_yaml(self, i: int = 0) -> str: else: val = str(self._value) else: - val = "|" if self._type == T.MULTILINE_PIPE else ">" + # Multiline scalar: | or > with optional chomping indicator + indicator = "|" if self._type == T.MULTILINE_PIPE else ">" + indicator += self._chomp # Add - or + if present + val = indicator if self._value: i_ = _indent(i) + # For the value, we need to strip trailing newlines for proper formatting + # (chomping is already applied in the lexer and stored in _value) + content = self._value.rstrip("\n") if isinstance(self._value, str) else "" val = ( val + "\n" - + "\n".join( - [ - (i_ + r) if r else "" - for r in self._value.split("\n") # type: ignore - ] - ) + + "\n".join([(i_ + r) if r else "" for r in content.split("\n")]) ) return self._enrich_yaml(val) diff --git a/yamlium/parser.py b/yamlium/parser.py index dbcac10..e16a6ad 100644 --- a/yamlium/parser.py +++ b/yamlium/parser.py @@ -48,6 +48,16 @@ class Parser: root: Sequence stack: deque[Node] + # Security limits + # MAX_DEPTH is set conservatively to catch deeply nested YAML before hitting + # Python's default recursion limit (~1000). Due to recursive calls in parsing, + # each logical nesting level may use multiple call stack frames. + MAX_DEPTH = 200 + # MAX_ALIAS_RATIO: Maximum ratio of alias references to decoded nodes. + # A ratio of 10 allows normal anchor/alias usage while catching potential + # alias bomb attacks (exponential expansion through nested aliasing). + MAX_ALIAS_RATIO = 10 + def __init__(self, input: str) -> None: self.input = input @@ -103,6 +113,10 @@ def _process_node(self, n: NodeType) -> NodeType: # Always add the node to the stack. self.node_stack.append(n) + # Track decode count for alias bomb protection (don't count aliases) + if not isinstance(n, Alias): + self.decode_count += 1 + # Check if this node should be the value of an anchor # if self.anchor_cache: # self.anchors[self.anchor_cache] = n @@ -120,7 +134,7 @@ def _process_node(self, n: NodeType) -> NodeType: # ------------------------------------------------------------------ def _build_scalar(self, in_mapping: bool = False) -> Scalar: t = self._take_token - val = t.value.rstrip() + val = t.value.rstrip() if t.t == T.SCALAR else t.value indented = in_mapping and t.line > self._current_line and t.t == T.SCALAR return self._process_node( Scalar( @@ -130,6 +144,7 @@ def _build_scalar(self, in_mapping: bool = False) -> Scalar: _is_indented=indented, _original_value=val, _quote_char=t.quote_char, + _chomp=t.chomp, ) ) @@ -147,11 +162,30 @@ def _build_key(self) -> Key: def _build_alias(self) -> Alias: t = self._take_token alias_name = t.value + + # Circular reference detection + if alias_name in self.resolving_aliases: + self._raise_parsing_error( + f"Circular reference detected: anchor '{alias_name}' references itself", + pos=t.start, + ) + node_value = self.anchors.get(alias_name) - if not node_value: + if node_value is None: raise self._raise_parsing_error( f"No anchor found for alias `*{alias_name}`", pos=t.start ) + + # Alias bomb protection + self.alias_count += 1 + if self.decode_count > 0: + ratio = self.alias_count / self.decode_count + if ratio > self.MAX_ALIAS_RATIO: + self._raise_parsing_error( + f"Excessive aliasing detected (ratio {ratio:.1f} exceeds limit {self.MAX_ALIAS_RATIO})", + pos=t.start, + ) + return self._process_node( Alias(_line=t.line, child=node_value, _value=alias_name) ) @@ -175,10 +209,15 @@ def _handle_anchor(self) -> Mapping | Scalar | Sequence | Alias: self._raise_parsing_error("Anchors can only be placed with keys.") n.anchor = t.value - # Now find the value beyond the anchor - value = self._parse_value() - self.anchors[t.value] = value - return value + # Track that we're resolving this anchor (for circular reference detection) + self.resolving_aliases.add(t.value) + try: + # Now find the value beyond the anchor + value = self._parse_value() + self.anchors[t.value] = value + return value + finally: + self.resolving_aliases.discard(t.value) def _handle_indent(self) -> None: self._take_token @@ -188,6 +227,13 @@ def _handle_dedent(self) -> None: self._take_token self.current_indent -= 1 + def _check_depth(self) -> None: + """Check if current nesting depth exceeds the maximum allowed.""" + if self.current_depth > self.MAX_DEPTH: + self._raise_parsing_error( + f"Maximum nesting depth exceeded ({self.MAX_DEPTH})" + ) + def _check_special_types(self, t: T | None) -> bool: if t == T.COMMENT: self._handle_comment() @@ -204,6 +250,9 @@ def _check_special_types(self, t: T | None) -> bool: def _parse_value( self, in_mapping: bool = False ) -> Mapping | Scalar | Sequence | Alias: + # Check depth before any recursive parsing + self._check_depth() + t = self._token_type if t == T.KEY: n = self._last_node @@ -274,58 +323,68 @@ def _parse_inline_sequence(self) -> Sequence: self._raise_parsing_error("Inline sequence not closed.") def _parse_mapping(self) -> Mapping: - m = Mapping(_line=self._last_token.line) - m._indent = self.current_indent - start_indent = self.current_indent - m._column = self._peek_token.column # Set mapping column - - while t := self._token_type: - if t == T.KEY: - key = self._build_key() - m[key] = self._parse_value(in_mapping=True) - elif self._check_special_types(t=t): - continue - elif t == T.DEDENT: - self._handle_dedent() - if self.current_indent < start_indent: + self.current_depth += 1 + self._check_depth() + try: + m = Mapping(_line=self._last_token.line) + m._indent = self.current_indent + start_indent = self.current_indent + m._column = self._peek_token.column # Set mapping column + + while t := self._token_type: + if t == T.KEY: + key = self._build_key() + m[key] = self._parse_value(in_mapping=True) + elif self._check_special_types(t=t): + continue + elif t == T.DEDENT: + self._handle_dedent() + if self.current_indent < start_indent: + break + elif t in _MAPPING_STOP_TOKENS: break - elif t in _MAPPING_STOP_TOKENS: - break - else: - self._raise_unexpected_token() - - # Transfer newlines from the last scalar value to the parent mapping - # This preserves whitespace after mappings when the last value is updated - if m: - keys = list(m.keys()) - last_value = m[keys[-1]] - if isinstance(last_value, Scalar) and last_value.newlines > 0: - m.newlines = last_value.newlines - last_value.newlines = 0 - - return m + else: + self._raise_unexpected_token() + + # Transfer newlines from the last scalar value to the parent mapping + # This preserves whitespace after mappings when the last value is updated + if m: + keys = list(m.keys()) + last_value = m[keys[-1]] + if isinstance(last_value, Scalar) and last_value.newlines > 0: + m.newlines = last_value.newlines + last_value.newlines = 0 + + return m + finally: + self.current_depth -= 1 def _parse_sequence(self) -> Sequence: - s = Sequence(_line=self._last_token.line) - start_indent = self.current_indent - while t := self._token_type: - if t == T.DASH: - self._take_token - # Immediate token after dashes will be an indentation - if self._token_type == T.INDENT: - self._handle_indent() - s.append(self._parse_value()) - elif self._check_special_types(t=t): - continue - elif t == T.DEDENT: - self._handle_dedent() - if self.current_indent < start_indent: + self.current_depth += 1 + self._check_depth() + try: + s = Sequence(_line=self._last_token.line) + start_indent = self.current_indent + while t := self._token_type: + if t == T.DASH: + self._take_token + # Immediate token after dashes will be an indentation + if self._token_type == T.INDENT: + self._handle_indent() + s.append(self._parse_value()) + elif self._check_special_types(t=t): + continue + elif t == T.DEDENT: + self._handle_dedent() + if self.current_indent < start_indent: + break + elif t in _SEQUENCE_STOP_TOKENS: break - elif t in _SEQUENCE_STOP_TOKENS: - break - else: - self._raise_unexpected_token() - return s + else: + self._raise_unexpected_token() + return s + finally: + self.current_depth -= 1 def parse(self) -> Document: # Set up class vars @@ -338,6 +397,16 @@ def parse(self) -> Document: self.anchor_cache: str | None = None self.comment_cache: list[str] = [] + # Security: depth limiting + self.current_depth = 0 + + # Security: alias bomb protection + self.alias_count = 0 + self.decode_count = 0 + + # Security: circular reference detection + self.resolving_aliases: set[str] = set() + root = Document() while t := self._token_type: if t == T.KEY: