Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 127 additions & 101 deletions go/internal/feast/onlinestore/cassandraonlinestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,13 @@ func (c *CassandraOnlineStore) validateUniqueFeatureNames(featureViewNames []str
return nil
}

func convertTimestampParam(value interface{}) interface{} {
if valInt64, ok := value.(int64); ok {
return time.Unix(valInt64, 0)
}
return value
}

func (c *CassandraOnlineStore) UnbatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {
if err := c.validateUniqueFeatureNames(featureViewNames); err != nil {
return nil, err
Expand Down Expand Up @@ -692,109 +699,137 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ

func (c *CassandraOnlineStore) rangeFilterToCQL(filter *model.SortKeyFilter) (string, []interface{}) {
rangeParams := make([]interface{}, 0)

equality := ""

if filter.Equals != nil {
equality = fmt.Sprintf(`"%s" = ?`, filter.SortKeyName)
rangeParams = append(rangeParams, filter.Equals)
paramVal := convertTimestampParam(filter.Equals)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would convert normal Int64 values to timestamps as well. Why not just do this data conversion in typeconversion.go?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would need to know the expected type of the sortkey to know what to convert it to. We pass that into the request. For example,
"sort_key_filters": [ { "sort_key_name": "feature_5", "range": { "range_end": { "double_val": "60" }, "end_inclusive": true } } ]
However, we don't pass this to the onlinereadrange method. My proposed solution would be to add a ValueType field to the SortKeyFilter struct, so that we can know the expected type of it upon a request. Then I'd add a method in the type conversion package that translates the request type to the data type cassandra expects (based on the Valuetype of the SortKeyFilter).

I want to get your input on this before implementation in case there is a more simple solution I'm not considering.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already validate that the sort key filter values are the correct type. Since both the proto and go are strictly typed there shouldn't be any need to explicitly pass the type information.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But how can the OnlineReadRange method know what the correct type is? I think it needs to know that prior to constructing the cql statement, otherwise it won't be able to know if an int64 should be converted to an int64 or to unix timestamp.

equality = fmt.Sprintf("\"%s\" = ?", filter.SortKeyName)
rangeParams = append(rangeParams, paramVal)
return equality, rangeParams
}

rangeStart := ""
if filter.RangeStart != nil {
paramVal := convertTimestampParam(filter.RangeStart)
if filter.StartInclusive {
rangeStart = fmt.Sprintf(`"%s" >= ?`, filter.SortKeyName)
} else {
rangeStart = fmt.Sprintf(`"%s" > ?`, filter.SortKeyName)
}
rangeParams = append(rangeParams, filter.RangeStart)
rangeParams = append(rangeParams, paramVal)
}

rangeEnd := ""
if filter.RangeEnd != nil {
paramVal := convertTimestampParam(filter.RangeEnd)
if filter.EndInclusive {
rangeEnd = fmt.Sprintf(`"%s" <= ?`, filter.SortKeyName)
} else {
rangeEnd = fmt.Sprintf(`"%s" < ?`, filter.SortKeyName)
}
rangeParams = append(rangeParams, filter.RangeEnd)
rangeParams = append(rangeParams, paramVal)
}

var condition string
if rangeStart != "" && rangeEnd != "" {
return fmt.Sprintf(`%s AND %s`, rangeStart, rangeEnd), rangeParams
condition = fmt.Sprintf("%s AND %s", rangeStart, rangeEnd)
} else if rangeStart != "" {
return rangeStart, rangeParams
condition = rangeStart
} else if rangeEnd != "" {
return rangeEnd, rangeParams
condition = rangeEnd
} else {
return "", rangeParams
condition = ""
}
return condition, rangeParams
}

func (c *CassandraOnlineStore) getRangeQueryCQLStatement(tableName string, featureNames []string, sortKeyFilters []*model.SortKeyFilter, limit int32) (string, []interface{}) {
func (c CassandraOnlineStore) getRangeQueryCQLStatement(tableName string, featureNames []string, sortKeyFilters []*model.SortKeyFilter, limit int32) (string, []interface{}) {
// this prevents fetching unnecessary features
quotedFeatureNames := make([]string, len(featureNames))
for i, featureName := range featureNames {
quotedFeatureNames[i] = fmt.Sprintf(`"%s"`, featureName)
}

rangeFilterString := ""
orderByString := ""
params := make([]interface{}, 0)
rangeFilterClauses := make([]string, 0)
orderByClauses := make([]string, 0)
allParams := make([]interface{}, 0)

if len(sortKeyFilters) > 0 {
rangeFilters := make([]string, 0)
orderBy := make([]string, 0)
for _, filter := range sortKeyFilters {
filterString, filterParams := c.rangeFilterToCQL(filter)
if filterString != "" {
rangeFilters = append(rangeFilters, filterString)
rangeFilterClauses = append(rangeFilterClauses, filterString)
allParams = append(allParams, filterParams...)
}
orderBy = append(orderBy, fmt.Sprintf(`"%s" %s`, filter.SortKeyName, filter.Order.String()))
params = append(params, filterParams...)
orderByClauses = append(orderByClauses, fmt.Sprintf("\"%s\" %s", filter.SortKeyName, filter.Order.String()))
}
if len(rangeFilters) > 0 {
rangeFilterString = fmt.Sprintf(" AND %s", strings.Join(rangeFilters, " AND "))
}
orderByString = fmt.Sprintf(" ORDER BY %s", strings.Join(orderBy, ", "))
}

rangeFilterString := ""
if len(rangeFilterClauses) > 0 {
rangeFilterString = fmt.Sprintf(" AND %s", strings.Join(rangeFilterClauses, " AND "))
}

orderByString := ""
if len(orderByClauses) > 0 {
orderByString = fmt.Sprintf(" ORDER BY %s", strings.Join(orderByClauses, ", "))
}

limitString := ""
if limit > 0 {
limitString = " LIMIT ?"
params = append(params, limit)
allParams = append(allParams, limit)
}

return fmt.Sprintf(
`SELECT "entity_key", "event_ts", %s FROM %s WHERE "entity_key" = ?%s%s%s`,
strings.Join(quotedFeatureNames, ", "),
selectColumns := append([]string{"\"entity_key\"", "\"event_ts\""}, quotedFeatureNames...)
uniqueSelectColumnsMap := make(map[string]struct{})
uniqueSelectColumns := []string{}
for _, col := range selectColumns {
if _, exists := uniqueSelectColumnsMap[col]; !exists {
uniqueSelectColumnsMap[col] = struct{}{}
uniqueSelectColumns = append(uniqueSelectColumns, col)
}
}

cql := fmt.Sprintf(
"SELECT %s FROM %s WHERE \"entity_key\" = ?%s%s%s",
strings.Join(uniqueSelectColumns, ", "), // Use unique columns
tableName,
rangeFilterString,
orderByString,
limitString,
), params
)
return cql, allParams
}

func (c *CassandraOnlineStore) OnlineReadRange(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string, sortKeyFilters []*model.SortKeyFilter, limit int32) ([][]RangeFeatureData, error) {
func (c CassandraOnlineStore) OnlineReadRange(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string, sortKeyFilters []*model.SortKeyFilter, limit int32) ([][]RangeFeatureData, error) {
if err := c.validateUniqueFeatureNames(featureViewNames); err != nil {
return nil, err
}

serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys)
serializedEntityKeys, _, err := c.buildCassandraEntityKeys(entityKeys)
if err != nil {
return nil, fmt.Errorf("error when serializing entity keys for Cassandra: %v", err)
}

results := make([][]RangeFeatureData, len(entityKeys))
for i := range results {
results[i] = make([]RangeFeatureData, len(featureNames))
for j := range results[i] {
results[i][j] = RangeFeatureData{
FeatureView: featureViewNames[0],
FeatureName: featureNames[j],
Values: make([]interface{}, 0),
EventTimestamps: make([]timestamppb.Timestamp, 0),
}
}
}

featureNamesToIdx := make(map[string]int)
for idx, name := range featureNames {
featureNamesToIdx[name] = idx
}

featureViewName := featureViewNames[0]

// Prepare the query
tableName, err := c.getFqTableName(c.clusterConfigs.Keyspace, c.project, featureViewName, c.tableNameFormatVersion)
if err != nil {
return nil, err
Expand All @@ -804,84 +839,59 @@ func (c *CassandraOnlineStore) OnlineReadRange(ctx context.Context, entityKeys [

var waitGroup sync.WaitGroup
waitGroup.Add(len(serializedEntityKeys))

errorsChannel := make(chan error, len(serializedEntityKeys))
for _, serializedEntityKey := range serializedEntityKeys {
go func(serEntityKey any) {
defer waitGroup.Done()

for i, serializedEntityKey := range serializedEntityKeys {
go func(serEntityKey interface{}, entityIndex int) {
defer waitGroup.Done()
queryParams := append([]interface{}{serEntityKey}, rangeParams...)
iter := c.session.Query(cqlStatement, queryParams...).WithContext(ctx).Iter()
rowIdx := serializedEntityKeyToIndex[serializedEntityKey.(string)]

// fill the row with nulls if not found
if iter.NumRows() == 0 {
for _, featName := range featureNames {
results[rowIdx][featureNamesToIdx[featName]] = RangeFeatureData{
FeatureView: featureViewName,
FeatureName: featName,
Values: []interface{}{nil},
Statuses: []serving.FieldStatus{serving.FieldStatus_NOT_FOUND},
}
rowDataList := make([]map[string]interface{}, 0, iter.NumRows())
for {
row := make(map[string]interface{})
if !iter.MapScan(row) {
break
}
rowDataList = append(rowDataList, row)
}

if err := iter.Close(); err != nil {
errorsChannel <- fmt.Errorf("error iterating results for entity %v: %w", serEntityKey, err)
return
}

for i := 0; i < iter.NumRows(); i++ {
readValues := make(map[string]interface{})
iter.MapScan(readValues)
eventTs := readValues["event_ts"].(time.Time)
if len(rowDataList) == 0 {
for j := range featureNames {
results[entityIndex][j].Values = []interface{}{nil}
results[entityIndex][j].EventTimestamps = []timestamppb.Timestamp{{}}
}
return
}

rowFeatures := results[rowIdx]
for _, featName := range featureNames {
if val, ok := readValues[featName]; ok {
var status serving.FieldStatus
if val == nil {
status = serving.FieldStatus_NULL_VALUE
} else {
status = serving.FieldStatus_PRESENT
}
entityResults := results[entityIndex]
for _, readValues := range rowDataList {
var eventTs time.Time
if tsVal, ok := readValues["event_ts"].(time.Time); ok {
eventTs = tsVal
} else {
errorsChannel <- fmt.Errorf("event_ts missing or not time.Time for entity %v, row %v", serEntityKey, readValues)
continue
}
eventTsProtoPtr := timestamppb.New(eventTs)
if eventTsProtoPtr == nil {
errorsChannel <- fmt.Errorf("failed to create timestamp proto for entity %v", serEntityKey)
continue
}
eventTsProtoValue := *eventTsProtoPtr

if featureData := &rowFeatures[featureNamesToIdx[featName]]; featureData != nil {
rowFeatures[featureNamesToIdx[featName]] = RangeFeatureData{
FeatureView: featureViewName,
FeatureName: featName,
Values: append(featureData.Values, val),
Statuses: append(featureData.Statuses, status),
EventTimestamps: append(featureData.EventTimestamps, timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}),
}
} else {
rowFeatures[featureNamesToIdx[featName]] = RangeFeatureData{
FeatureView: featureViewName,
FeatureName: featName,
Values: []interface{}{val},
Statuses: []serving.FieldStatus{status},
EventTimestamps: []timestamppb.Timestamp{{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}},
}
}
} else {
if featureData := &rowFeatures[featureNamesToIdx[featName]]; featureData != nil {
rowFeatures[featureNamesToIdx[featName]] = RangeFeatureData{
FeatureView: featureViewName,
FeatureName: featName,
Values: append(featureData.Values, nil),
Statuses: append(featureData.Statuses, serving.FieldStatus_NOT_FOUND),
EventTimestamps: append(featureData.EventTimestamps, timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}),
}
} else {
rowFeatures[featureNamesToIdx[featName]] = RangeFeatureData{
FeatureView: featureViewName,
FeatureName: featName,
Values: []interface{}{nil},
Statuses: []serving.FieldStatus{serving.FieldStatus_NOT_FOUND},
EventTimestamps: []timestamppb.Timestamp{{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}},
}
}
}
for j, featName := range featureNames {
val, _ := readValues[featName]
entityResults[j].Values = append(entityResults[j].Values, val)
entityResults[j].EventTimestamps = append(entityResults[j].EventTimestamps, eventTsProtoValue)
}
results[rowIdx] = rowFeatures
}
}(serializedEntityKey)
}(serializedEntityKey, i)
}

// wait until all concurrent single-key queries are done
Expand All @@ -890,12 +900,28 @@ func (c *CassandraOnlineStore) OnlineReadRange(ctx context.Context, entityKeys [

var collectedErrors []error
for err := range errorsChannel {
if err != nil {
collectedErrors = append(collectedErrors, err)
}
collectedErrors = append(collectedErrors, err)
}
if len(collectedErrors) > 0 {
return nil, errors.Join(collectedErrors...)
return nil, fmt.Errorf("encountered errors during range read: %v", collectedErrors)
}

for _, entityRow := range results {
for i := range entityRow {
featureData := &entityRow[i]
if len(featureData.Values) == 1 && featureData.Values[0] == nil && len(featureData.EventTimestamps) == 1 && featureData.EventTimestamps[0].Seconds == 0 && featureData.EventTimestamps[0].Nanos == 0 {
featureData.Statuses = []serving.FieldStatus{serving.FieldStatus_NOT_FOUND}
} else {
featureData.Statuses = make([]serving.FieldStatus, len(featureData.Values))
for k, val := range featureData.Values {
if val == nil {
featureData.Statuses[k] = serving.FieldStatus_NULL_VALUE
} else {
featureData.Statuses[k] = serving.FieldStatus_PRESENT
}
}
}
}
}

return results, nil
Expand Down
Loading