diff --git a/go/cli/mcap/cmd/filter.go b/go/cli/mcap/cmd/filter.go index b88e3806f..588322ddb 100644 --- a/go/cli/mcap/cmd/filter.go +++ b/go/cli/mcap/cmd/filter.go @@ -17,33 +17,35 @@ import ( ) type filterFlags struct { - output string - includeTopics []string - excludeTopics []string - startSec uint64 - endSec uint64 - startNano uint64 - endNano uint64 - start string - end string - includeMetadata bool - includeAttachments bool - outputCompression string - chunkSize int64 - unchunked bool + output string + includeTopics []string + excludeTopics []string + includeLastPerChannelTopics []string + startSec uint64 + endSec uint64 + startNano uint64 + endNano uint64 + start string + end string + includeMetadata bool + includeAttachments bool + outputCompression string + chunkSize int64 + unchunked bool } type filterOpts struct { - output string - includeTopics []regexp.Regexp - excludeTopics []regexp.Regexp - start uint64 - end uint64 - includeMetadata bool - includeAttachments bool - compressionFormat mcap.CompressionFormat - chunkSize int64 - unchunked bool + output string + includeTopics []regexp.Regexp + excludeTopics []regexp.Regexp + includeLastPerChannelTopics []regexp.Regexp + start uint64 + end uint64 + includeMetadata bool + includeAttachments bool + compressionFormat mcap.CompressionFormat + chunkSize int64 + unchunked bool } // parseDateOrNanos parses a string containing either an RFC3339-formatted date with timezone @@ -115,15 +117,22 @@ func buildFilterOptions(flags *filterFlags) (*filterOpts, error) { includeTopics, err := compileMatchers(flags.includeTopics) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid included topic regex: %w", err) } opts.includeTopics = includeTopics excludeTopics, err := compileMatchers(flags.excludeTopics) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid excluded topic regex: %w", err) } opts.excludeTopics = excludeTopics + + includeLastPerChannelTopics, err := compileMatchers(flags.includeLastPerChannelTopics) + if err != nil { + return nil, fmt.Errorf("invalid last-per-channel topic regex: %w", err) + } + opts.includeLastPerChannelTopics = includeLastPerChannelTopics + opts.chunkSize = flags.chunkSize opts.unchunked = flags.unchunked return opts, nil @@ -192,7 +201,7 @@ func compileMatchers(regexStrings []string) ([]regexp.Regexp, error) { } regex, err := regexp.Compile(regexString) if err != nil { - return nil, err + return nil, fmt.Errorf("%s is not a valid regex: %w", regexString, err) } matchers[i] = *regex } @@ -266,6 +275,8 @@ func filter( buf := make([]byte, 1024) schemas := make(map[uint16]markableSchema) channels := make(map[uint16]markableChannel) + mostRecentMessageBeforeRangeStart := make(map[uint16]*mcap.Message) + messagesBeforeRangeStartWritten := false for { token, data, err := lexer.Next(buf) @@ -298,6 +309,12 @@ func filter( if err != nil { return err } + for i := range opts.includeLastPerChannelTopics { + matcher := opts.includeLastPerChannelTopics[i] + if matcher.MatchString(channel.Topic) { + mostRecentMessageBeforeRangeStart[channel.ID] = nil + } + } // if any topics match an includeTopic, add it. for i := range opts.includeTopics { matcher := opts.includeTopics[i] @@ -327,12 +344,58 @@ func filter( if err != nil { return err } + mostRecent, ok := mostRecentMessageBeforeRangeStart[message.ChannelID] if message.LogTime < opts.start { + if ok { + if mostRecent == nil || mostRecent.LogTime <= message.LogTime { + mostRecentMessageBeforeRangeStart[message.ChannelID] = message + // Copy the data buffer explicitly, to avoid keeping a reference to the greater + // `buf` array that underlies `message.Data`. + mostRecentMessageBeforeRangeStart[message.ChannelID].Data = append([]byte{}, message.Data...) + } + } continue } if message.LogTime >= opts.end { continue } + if !messagesBeforeRangeStartWritten { + messagesBeforeRangeStartWritten = true + // We have reached the start of the record, so add any stored messages here + for _, mostRecent := range mostRecentMessageBeforeRangeStart { + if mostRecent == nil { + continue + } + // We might still need to write the channel + channel, ok := channels[mostRecent.ChannelID] + if !ok { + continue + } + if !channel.written { + if channel.SchemaID != 0 { + schema, ok := schemas[channel.SchemaID] + if !ok { + return fmt.Errorf("encountered channel with topic %s with unknown schema ID %d", + channel.Topic, channel.SchemaID) + } + if !schema.written { + if err := mcapWriter.WriteSchema(schema.Schema); err != nil { + return err + } + schemas[channel.SchemaID] = markableSchema{schema.Schema, true} + } + } + if err := mcapWriter.WriteChannel(channel.Channel); err != nil { + return err + } + channels[mostRecent.ChannelID] = markableChannel{channel.Channel, true} + } + if err := mcapWriter.WriteMessage(mostRecent); err != nil { + return err + } + numMessages++ + } + } channel, ok := channels[message.ChannelID] if !ok { continue @@ -406,6 +469,12 @@ usage: []string{}, "messages with topic names matching this regex will be excluded, can be supplied multiple times", ) + includeLastPerChannelTopics := filterCmd.PersistentFlags().StringArray( + "include-last-per-channel-topic-regex", + []string{}, + "For included topics matching this regex, the most recent message prior to the start time"+ + " will still be included.", + ) start := filterCmd.PersistentFlags().StringP( "start", "S", @@ -463,19 +532,20 @@ usage: ) filterCmd.Run = func(_ *cobra.Command, args []string) { filterOptions, err := buildFilterOptions(&filterFlags{ - output: *output, - includeTopics: *includeTopics, - excludeTopics: *excludeTopics, - start: *start, - startSec: *startSec, - startNano: *startNano, - end: *end, - endSec: *endSec, - endNano: *endNano, - chunkSize: *chunkSize, - includeMetadata: *includeMetadata, - includeAttachments: *includeAttachments, - outputCompression: *outputCompression, + output: *output, + includeTopics: *includeTopics, + excludeTopics: *excludeTopics, + includeLastPerChannelTopics: *includeLastPerChannelTopics, + start: *start, + startSec: *startSec, + startNano: *startNano, + end: *end, + endSec: *endSec, + endNano: *endNano, + chunkSize: *chunkSize, + includeMetadata: *includeMetadata, + includeAttachments: *includeAttachments, + outputCompression: *outputCompression, }) if err != nil { die("configuration error: %s", err) diff --git a/go/cli/mcap/cmd/filter_test.go b/go/cli/mcap/cmd/filter_test.go index 40db7ecd0..5428b7278 100644 --- a/go/cli/mcap/cmd/filter_test.go +++ b/go/cli/mcap/cmd/filter_test.go @@ -311,3 +311,95 @@ func TestBuildFilterOptions(t *testing.T) { }) } } + +func TestLastPerChannelBehavior(t *testing.T) { + cases := []struct { + name string + flags *filterFlags + expectedMessageCount map[uint16]int + }{ + {name: "noop", + flags: &filterFlags{ + startNano: 50, + }, + expectedMessageCount: map[uint16]int{ + 1: 50, + 2: 50, + 3: 50, + }, + }, + {name: "last per channel on all topics", + flags: &filterFlags{ + startNano: 50, + includeLastPerChannelTopics: []string{".*"}, + }, + expectedMessageCount: map[uint16]int{ + 1: 51, + 2: 51, + 3: 51, + }, + }, + {name: "last per channel on camera topics only", + flags: &filterFlags{ + startNano: 50, + includeLastPerChannelTopics: []string{"camera_.*"}, + }, + expectedMessageCount: map[uint16]int{ + 1: 51, + 2: 51, + 3: 50, + }, + }, + {name: "does not override include topics", + flags: &filterFlags{ + startNano: 50, + includeLastPerChannelTopics: []string{"camera_.*"}, + includeTopics: []string{"camera_a"}, + }, + expectedMessageCount: map[uint16]int{ + 1: 51, + 2: 0, + 3: 0, + }, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + opts, err := buildFilterOptions(c.flags) + require.NoError(t, err) + writeBuf := bytes.Buffer{} + readBuf := bytes.Buffer{} + + writeFilterTestInput(t, &readBuf) + require.NoError(t, filter(&readBuf, &writeBuf, opts)) + lexer, err := mcap.NewLexer(&writeBuf, &mcap.LexerOptions{}) + require.NoError(t, err) + defer lexer.Close() + messageCounter := map[uint16]int{ + 1: 0, + 2: 0, + 3: 0, + } + for { + token, record, err := lexer.Next(nil) + if err != nil { + require.ErrorIs(t, err, io.EOF) + break + } + if token == mcap.TokenMessage { + message, err := mcap.ParseMessage(record) + require.NoError(t, err) + messageCounter[message.ChannelID]++ + } + } + for channelID, count := range messageCounter { + require.Equal( + t, + c.expectedMessageCount[channelID], + count, + "message count incorrect on channel %d", channelID, + ) + } + }) + } +}